Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,17 @@ def __init__(self, db_path: Optional[str] = None, store: Optional[SqliteEvaluati
self._store = SqliteEvaluationRowStore(self.db_path)

def log(self, row: "EvaluationRow") -> None:
row_id = row.input_metadata.row_id
data = row.model_dump(exclude_none=True, mode="json")
self._store.upsert_row(row_id=row_id, data=data)
self._store.upsert_row(data=data)
try:
event_bus.emit(LOG_EVENT_TYPE, EvaluationRow(**data))
except Exception as e:
# Avoid breaking storage due to event emission issues
logger.error(f"Failed to emit row_upserted event: {e}")
pass

def read(self, row_id: Optional[str] = None) -> List["EvaluationRow"]:
def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]:
from eval_protocol.models import EvaluationRow

results = self._store.read_rows(row_id=row_id)
results = self._store.read_rows(rollout_id=rollout_id)
return [EvaluationRow(**data) for data in results]
25 changes: 14 additions & 11 deletions eval_protocol/dataset_logger/sqlite_evaluation_row_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SqliteEvaluationRowStore:
"""
Lightweight reusable SQLite store for evaluation rows.

Stores arbitrary row data as JSON keyed by a unique string `row_id`.
Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
"""

def __init__(self, db_path: str):
Expand All @@ -24,7 +24,7 @@ class Meta:
database = self._db

class EvaluationRow(BaseModel): # type: ignore
row_id = CharField(unique=True)
rollout_id = CharField(unique=True)
data = JSONField()

self._EvaluationRow = EvaluationRow
Expand All @@ -36,22 +36,25 @@ class EvaluationRow(BaseModel): # type: ignore
def db_path(self) -> str:
return self._db_path

def upsert_row(self, row_id: str, data: dict) -> None:
if self._EvaluationRow.select().where(self._EvaluationRow.row_id == row_id).exists():
self._EvaluationRow.update(data=data).where(self._EvaluationRow.row_id == row_id).execute()
def upsert_row(self, data: dict) -> None:
rollout_id = data["rollout_id"]
if "rollout_id" not in data:
raise ValueError("rollout_id is required to upsert a row")
if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists():
self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute()
else:
self._EvaluationRow.create(row_id=row_id, data=data)
self._EvaluationRow.create(rollout_id=rollout_id, data=data)

def read_rows(self, row_id: Optional[str] = None) -> List[dict]:
if row_id is None:
def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
if rollout_id is None:
query = self._EvaluationRow.select().dicts()
else:
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.row_id == row_id)
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id)
results = list(query)
return [result["data"] for result in results]

def delete_row(self, row_id: str) -> int:
return self._EvaluationRow.delete().where(self._EvaluationRow.row_id == row_id).execute()
def delete_row(self, rollout_id: str) -> int:
return self._EvaluationRow.delete().where(self._EvaluationRow.rollout_id == rollout_id).execute()

def delete_all_rows(self) -> int:
return self._EvaluationRow.delete().execute()
9 changes: 5 additions & 4 deletions eval_protocol/pytest/default_agent_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from openai.types.chat import ChatCompletionContentPartTextParam, ChatCompletionMessage, ChatCompletionToolParam
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam

from eval_protocol.dataset_logger import default_logger
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
from eval_protocol.models import EvaluationRow, Message
Expand All @@ -20,12 +20,13 @@ class Agent:
A really simple agent that calls the model until no more tool calls are needed.
"""

def __init__(self, model: str, row: EvaluationRow, config_path: str):
def __init__(self, model: str, row: EvaluationRow, config_path: str, logger: DatasetLogger):
self.model = model
self.evaluation_row: EvaluationRow = row
self._policy = LiteLLMPolicy(model_id=model)
self.mcp_client = MCPMultiClient(config_path=config_path) if config_path else None
self.tools: Union[List[ChatCompletionToolParam], NotGiven] = NOT_GIVEN
self.logger: DatasetLogger = logger

async def setup(self):
if self.mcp_client:
Expand All @@ -42,7 +43,7 @@ def messages(self) -> list[Message]:

def append_message_and_log(self, message: Message):
self.messages.append(message)
default_logger.log(self.evaluation_row)
self.logger.log(self.evaluation_row)

async def call_agent(self) -> str:
"""
Expand Down Expand Up @@ -116,7 +117,7 @@ async def default_agent_rollout_processor(
) -> List[EvaluationRow]:
dataset: Dataset = []
for row in rows:
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path)
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
await agent.setup()
await agent.call_agent()
dataset.append(agent.evaluation_row)
Expand Down
9 changes: 4 additions & 5 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
from typing import List

import logging
import os
from typing import List

from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message, ChatCompletionMessageToolCall
from eval_protocol.models import ChatCompletionMessageToolCall, EvaluationRow, Message
from eval_protocol.pytest.types import RolloutProcessorConfig


Expand Down Expand Up @@ -49,6 +47,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:

# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
import importlib

_litellm = importlib.import_module("litellm")
acompletion = getattr(_litellm, "acompletion")
response = await acompletion(**request_params)
Expand Down Expand Up @@ -79,7 +78,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
]

row.messages = messages
default_logger.log(row)
config.logger.log(row)
return row

# Process rows with bounded concurrency if configured
Expand Down
32 changes: 29 additions & 3 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from eval_protocol.dataset_logger import default_logger
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.human_id import generate_id
from eval_protocol.models import CompletionParams, EvalMetadata, EvaluationRow, InputMetadata, Message
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
Expand Down Expand Up @@ -55,6 +56,7 @@ def evaluation_test( # noqa: C901
steps: int = 30,
mode: EvaluationTestMode = "batch",
combine_datasets: bool = True,
logger: Optional[DatasetLogger] = None,
) -> Callable[
[TestFunction],
TestFunction,
Expand Down Expand Up @@ -117,8 +119,11 @@ def evaluation_test( # noqa: C901
mode: Evaluation mode. "batch" (default) expects test function to handle
full dataset. "pointwise" applies test function to each row. If your evaluation requires
the full rollout of all rows to compute the score, use
logger: DatasetLogger to use for logging. If not provided, a default logger will be used.
"""

active_logger: DatasetLogger = logger if logger else default_logger

def decorator(
test_func: TestFunction,
):
Expand Down Expand Up @@ -287,7 +292,7 @@ def wrapper_body(**kwargs):
def _log_eval_error(
status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool
) -> None:
log_eval_status_and_rows(eval_metadata, rows, status, passed, default_logger)
log_eval_status_and_rows(eval_metadata, rows, status, passed, active_logger)

try:
# Handle dataset loading
Expand Down Expand Up @@ -369,7 +374,6 @@ def _log_eval_error(
# has to be done in the pytest main process since it's
# used to determine whether this eval has stopped
row.pid = os.getpid()
default_logger.log(row)

# Prepare rollout processor config once; we will generate fresh outputs per run
config = RolloutProcessorConfig(
Expand All @@ -379,6 +383,7 @@ def _log_eval_error(
max_concurrent_rollouts=max_concurrent_rollouts,
server_script_path=server_script_path,
steps=steps,
logger=active_logger,
)

for _ in range(num_runs):
Expand All @@ -395,6 +400,10 @@ def _log_eval_error(
for row in fresh_dataset:
row.rollout_id = generate_id()

# log the fresh_dataset
for row in fresh_dataset:
active_logger.log(row)

processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config)

if mode == "pointwise":
Expand Down Expand Up @@ -463,7 +472,7 @@ def _log_eval_error(
if r.eval_metadata is not None:
r.eval_metadata.status = "finished"
r.eval_metadata.passed = passed
default_logger.log(r)
active_logger.log(r)

# Optional: print and/or persist a summary artifact for CI
try:
Expand Down Expand Up @@ -587,6 +596,23 @@ def _extract_effort_tag(params: dict) -> str | None:
# Do not fail evaluation if summary writing fails
pass

# # Write all rows from active_logger.read() to a JSONL file in the same directory as the summary
# try:
# if active_logger is not None:
# rows = active_logger.read()
# # Write to a .jsonl file alongside the summary file
# jsonl_path = "logs.jsonl"
# import json

# with open(jsonl_path, "w", encoding="utf-8") as f_jsonl:
# for row in rows:
# json.dump(row.model_dump(exclude_none=True, mode="json"), f_jsonl)
# f_jsonl.write("\n")
# except Exception as e:
# # Do not fail evaluation if log writing fails
# print(e)
# pass

# Check threshold after logging
if threshold_of_success is not None and not passed:
assert (
Expand Down
10 changes: 8 additions & 2 deletions eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, Optional

from eval_protocol.dataset_logger import default_logger
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger

from ..models import EvaluationRow, Message

ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
Expand Down Expand Up @@ -39,10 +42,13 @@
class RolloutProcessorConfig:
model: ModelParam
input_params: RolloutInputParam # optional input parameters for inference
mcp_config_path: str
server_script_path: Optional[str] = None # TODO: change from server_script_path to mcp_config_path for agent rollout processor
mcp_config_path: str
server_script_path: Optional[str] = (
None # TODO: change from server_script_path to mcp_config_path for agent rollout processor
)
max_concurrent_rollouts: int = 8 # maximum number of concurrent rollouts
steps: int = 30 # max number of rollout steps
logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs


RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]]
8 changes: 4 additions & 4 deletions tests/dataset_logger/test_sqlite_dataset_logger_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def test_update_log_and_read():
messages = [Message(role="user", content="Hello")]
input_metadata = InputMetadata(row_id="1")
row = EvaluationRow(input_metadata=input_metadata, messages=messages)
store.upsert_row(row_id="1", data=row.model_dump(exclude_none=True, mode="json"))
store.upsert_row(data=row.model_dump(exclude_none=True, mode="json"))

row.messages.append(Message(role="assistant", content="Hello"))

logger = SqliteDatasetLoggerAdapter()
logger = SqliteDatasetLoggerAdapter(store=store)
logger.log(row)
saved = logger.read(row_id="1")[0]
saved = logger.read(row.rollout_id)[0]
assert row.messages == saved.messages
assert row.input_metadata == saved.input_metadata

Expand All @@ -42,7 +42,7 @@ def test_create_log_and_read():
row = EvaluationRow(input_metadata=input_metadata, messages=messages)

logger.log(row)
saved = logger.read(row_id="1")[0]
saved = logger.read(rollout_id=row.rollout_id)[0]
assert row.messages == saved.messages
assert row.input_metadata == saved.input_metadata

Expand Down
31 changes: 31 additions & 0 deletions tests/pytest/test_pytest_flaky_sometimes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import random
from typing import List

import pytest

from eval_protocol.models import EvaluateResult, EvaluationRow, Message
from eval_protocol.pytest import default_no_op_rollout_processor, evaluation_test


# skip in CI since it will intentionally fail. This is useful for local generation of logs
@pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping flaky test in CI")
@evaluation_test(
input_messages=[[Message(role="user", content="Return HEADS or TAILS at random.")]],
model=["dummy/local-model"],
rollout_processor=default_no_op_rollout_processor,
mode="pointwise",
num_runs=5,
)
def test_flaky_passes_sometimes(row: EvaluationRow) -> EvaluationRow:
"""
A deliberately flaky evaluation that only passes occasionally.

With num_runs=5 and a success probability of ~0.3 per run, the aggregated mean
will clear the threshold (0.8) only rarely. Uses the no-op rollout to avoid any
actual model calls.
"""
# Stochastic score: 1.0 with 30% probability, else 0.0
score = 1.0 if random.random() < 0.3 else 0.0
row.evaluation_result = EvaluateResult(score=score, reason=f"stochastic={score}")
return row
Loading
Loading