Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ class EvaluationRow(BaseModel):
supporting both row-wise batch evaluation and trajectory-based RL evaluation.
"""

# Core conversation data
messages: List[Message] = Field(description="List of messages in the conversation/trajectory.")
# Core OpenAI ChatCompletion compatible conversation data
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")

# Tool and function call information
tools: Optional[List[Dict[str, Any]]] = Field(
Expand All @@ -264,6 +264,26 @@ class EvaluationRow(BaseModel):
description="The status of the rollout.",
)

invocation_id: Optional[str] = Field(
default_factory=generate_id,
description="The ID of the invocation that this row belongs to.",
)

cohort_id: Optional[str] = Field(
default_factory=generate_id,
description="The ID of the cohort that this row belongs to.",
)

rollout_id: Optional[str] = Field(
default_factory=generate_id,
description="The ID of the rollout that this row belongs to.",
)

run_id: Optional[str] = Field(
None,
description=("The ID of the run that this row belongs to."),
)

# Ground truth reference (moved from EvaluateResult to top level)
ground_truth: Optional[str] = Field(
default=None, description="Optional ground truth reference for this evaluation."
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ async def default_mcp_gym_rollout_processor(
Returns:
List of EvaluationRow objects with completed conversations
"""
if config.server_script_path is None:
raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor")
server = MCPServerManager(config.server_script_path, port=9700)

try:
Expand Down
134 changes: 92 additions & 42 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import inspect
import os
import copy
import inspect
import math
import os
import statistics
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Literal, Optional

import pytest

from eval_protocol.dataset_logger import default_logger
from eval_protocol.human_id import generate_id
from eval_protocol.models import CompletionParams, EvalMetadata, EvaluationRow, InputMetadata, Message
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
Expand All @@ -28,13 +29,14 @@
aggregate,
create_dynamically_parameterized_wrapper,
execute_function,
log_eval_status_and_rows,
)
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci

from ..common_utils import load_jsonl
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci


def evaluation_test(
def evaluation_test( # noqa: C901
*,
model: List[ModelParam],
input_messages: Optional[List[InputMessagesParam]] = None,
Expand All @@ -59,6 +61,37 @@ def evaluation_test(
]:
"""Decorator to create pytest-based evaluation tests.

Here are some key concepts to understand the terminology in EP:

- "invocation" is a single execution of a test function. An invocation can
generate 1 or more cohorts. Grouping by invocation might be useful to
aggregate eval scores across multiple invocations when you want to aggregate
scores across multiple datasets.
- "cohort" is a group of runs with for a combination of parameters. A single
cohort will have multiple runs if num_runs > 1.
1. If your evaluation_test has combinations of parameters, it will generate
multiple cohorts per combination of parameters.
2. A new execution of a test function will generate a new cohort.
- "run" is a group of rollouts. For multiple num_runs > 1, there will be
multiple "run_id"s.
- "rollout" is the execution/process that produces a "trajectory". You
"execute" multiple rollouts to generate a dataset of trajectories.
- "trajectory" is the result produced by a rollout — a list of OpenAI Chat
Completion messages (e.g. the "messages" field in EvaluationRow).
- "row" both the input and output of an evaluation. For example, in
tau-bench, a row is a task within the dataset that can be identified as
"airline_task_0" or "airline_task_1" etc. The "row_id" can be populated from
the dataset itself to identify a particular task you want to evaluate. If
not provided, EP will generate a "row_id" for each row whenever you call the
evaluation test.
- "dataset" is a collection of rows (e.g. List[EvauluationRow])
- "eval" is a rubric implemented in the body of an @evaluation_test
decorated test. It simply produces a score from 0 to 1 and attached it
to the row as the "evaluation_result" field.

"invocation", "cohort", "run", "rollout", and "row" each have a unique ID
which can be used to easily group and identify your dataset by.

Args:
model: Model identifiers to query.
input_messages: Messages to send to the model. This is useful if you
Expand All @@ -75,7 +108,7 @@ def evaluation_test(
aggregation_method: How to aggregate scores across rows.
threshold_of_success: If set, fail the test if the aggregated score is
below this threshold.
num_runs: Number of times to repeat the evaluation.
num_runs: Number of times to repeat the rollout and evaluations.
max_dataset_rows: Limit dataset to the first N rows.
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
Expand Down Expand Up @@ -119,15 +152,15 @@ def decorator(

def execute_with_params(
test_func: TestFunction,
row: EvaluationRow | None = None,
input_dataset: List[EvaluationRow] | None = None,
processed_row: EvaluationRow | None = None,
processed_dataset: List[EvaluationRow] | None = None,
evaluation_test_kwargs: Optional[EvaluationInputParam] = None,
):
kwargs = {}
if input_dataset is not None:
kwargs["rows"] = input_dataset
if row is not None:
kwargs["row"] = row
if processed_dataset is not None:
kwargs["rows"] = processed_dataset
if processed_row is not None:
kwargs["row"] = processed_row
if evaluation_test_kwargs is not None:
if "row" in evaluation_test_kwargs:
raise ValueError("'row' is a reserved parameter for the evaluation function")
Expand Down Expand Up @@ -176,7 +209,7 @@ def generate_combinations():
datasets = [[input_dataset]] # type: ignore
else:
datasets = [None]
params: List[Optional[RolloutInputParam]] = rollout_input_params if rollout_input_params is not None else [None] # type: ignore
rips: List[Optional[RolloutInputParam]] = rollout_input_params if rollout_input_params is not None else [None] # type: ignore
# Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over
# each row. Instead, pass the entire sliced list through in a single test run
# so summaries aggregate all rows together (AIME-style behavior).
Expand All @@ -195,15 +228,15 @@ def generate_combinations():
# Generate all combinations
for m in model:
for ds in datasets:
for ip in params:
for rip in rips:
for im in messages:
for etk in kwargs:
# if no dataset and no messages, raise an error
if ds is None and im is None:
raise ValueError(
"No dataset or messages provided. Please provide at least one of input_dataset or input_messages."
)
combinations.append((m, ds, ip, im, etk))
combinations.append((m, ds, rip, im, etk))

return combinations

Expand All @@ -216,12 +249,12 @@ def generate_combinations():
# Create parameter tuples for pytest.mark.parametrize
param_tuples = []
for combo in combinations:
model_name, dataset, params, messages, etk = combo
model_name, dataset, rip, messages, etk = combo
param_tuple = [model_name]
if input_dataset is not None:
param_tuple.append(dataset)
if rollout_input_params is not None:
param_tuple.append(params)
param_tuple.append(rip)
if input_messages is not None:
param_tuple.append(messages)
if evaluation_test_kwargs is not None:
Expand All @@ -242,11 +275,20 @@ def generate_combinations():
# Create wrapper function with exact signature that pytest expects
def create_wrapper_with_signature() -> Callable:
# Create the function body that will be used
invocation_id = generate_id()

def wrapper_body(**kwargs):
model_name = kwargs["model"]
eval_metadata = None
all_results: List[EvaluationRow] = []

cohort_id = generate_id()

def _log_eval_error(
status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool
) -> None:
log_eval_status_and_rows(eval_metadata, rows, status, passed, default_logger)

try:
# Handle dataset loading
data: List[EvaluationRow] = []
Expand Down Expand Up @@ -283,6 +325,7 @@ def wrapper_body(**kwargs):
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
try:
import json as _json

_env_override = os.getenv("EP_INPUT_PARAMS_JSON")
if _env_override:
override_obj = _json.loads(_env_override)
Expand Down Expand Up @@ -320,6 +363,8 @@ def wrapper_body(**kwargs):
row.input_metadata.session_data["mode"] = mode
# Initialize eval_metadata for each row
row.eval_metadata = eval_metadata
row.cohort_id = cohort_id
row.invocation_id = invocation_id

# has to be done in the pytest main process since it's
# used to determine whether this eval has stopped
Expand All @@ -339,14 +384,25 @@ def wrapper_body(**kwargs):
for _ in range(num_runs):
# Regenerate outputs each run by deep-copying the pristine dataset
# so model responses are not reused across runs.
fresh_rows = [copy.deepcopy(r) for r in data]
input_dataset = execute_function(rollout_processor, rows=fresh_rows, config=config)
run_id = generate_id()
fresh_dataset = [copy.deepcopy(r) for r in data]

# apply new run_id to fresh_dataset
for row in fresh_dataset:
row.run_id = run_id

# generate new rollout_id for each row
for row in fresh_dataset:
row.rollout_id = generate_id()

processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config)

if mode == "pointwise":
# Pointwise mode: apply the evaluator function to each row
for row in input_dataset:
for row in processed_dataset:
result = execute_with_params(
test_func,
row=row,
processed_row=row,
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
)
if result is None or not isinstance(result, EvaluationRow):
Expand All @@ -358,7 +414,7 @@ def wrapper_body(**kwargs):
# Batch mode: call the test function with the full dataset
results = execute_with_params(
test_func,
input_dataset=input_dataset,
processed_dataset=processed_dataset,
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
)
if results is None:
Expand Down Expand Up @@ -430,6 +486,7 @@ def wrapper_body(**kwargs):
# Aggregate per-metric mean and 95% CI when available
metrics_summary: Dict[str, Dict[str, float]] = {}
from collections import defaultdict

metric_scores: Dict[str, list] = defaultdict(list)
for r in all_results:
if r.evaluation_result and r.evaluation_result.metrics:
Expand Down Expand Up @@ -470,7 +527,10 @@ def wrapper_body(**kwargs):
)
# As per project convention, avoid printing per-metric CI lines to reduce noise
if summary_path:
import json, pathlib, time, re
import json
import pathlib
import re
import time

def _sanitize_filename(text: str) -> str:
safe = re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip())
Expand All @@ -487,7 +547,11 @@ def _extract_effort_tag(params: dict) -> str | None:
return str(eb["reasoning"]["effort"]).lower()
if "reasoning_effort" in eb:
return str(eb["reasoning_effort"]).lower()
if "reasoning" in params and isinstance(params["reasoning"], dict) and "effort" in params["reasoning"]:
if (
"reasoning" in params
and isinstance(params["reasoning"], dict)
and "effort" in params["reasoning"]
):
return str(params["reasoning"]["effort"]).lower()
except Exception:
return None
Expand Down Expand Up @@ -529,25 +593,11 @@ def _extract_effort_tag(params: dict) -> str | None:
agg_score >= threshold_of_success
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"

except AssertionError:
_log_eval_error("finished", data if "data" in locals() else None, passed=False)
raise
except Exception:
# Update eval metadata status to error and log it
if eval_metadata is not None:
eval_metadata.status = "error"
eval_metadata.passed = False

# Create a minimal result row to log the error if we don't have any results yet
if not data:
error_row = EvaluationRow(messages=[], eval_metadata=eval_metadata, evaluation_result=None)
default_logger.log(error_row)
else:
# Update existing results with error status
for r in data:
if r.eval_metadata is not None:
r.eval_metadata.status = "error"
r.eval_metadata.passed = False
default_logger.log(r)

# Re-raise the exception to maintain pytest behavior
_log_eval_error("error", data if "data" in locals() else None, passed=False)
raise

return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
Expand Down
34 changes: 33 additions & 1 deletion eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import inspect
from typing import Any, Callable, List, Literal
from typing import Any, Callable, List, Literal, Optional

from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.models import EvalMetadata, EvaluationRow


def execute_function(func: Callable, **kwargs) -> Any:
Expand Down Expand Up @@ -92,3 +95,32 @@ def wrapper(**kwargs):
wrapper.__signature__ = inspect.Signature(parameters)

return wrapper


def log_eval_status_and_rows(
eval_metadata: Optional[EvalMetadata],
rows: Optional[List[EvaluationRow]] | None,
status: Literal["finished", "error"],
passed: bool,
logger: DatasetLogger,
) -> None:
"""Update eval status and emit rows to the given logger.

If no rows are provided, emits a minimal placeholder row so downstream
consumers still observe a terminal status.
"""
if eval_metadata is None:
return

eval_metadata.status = status
eval_metadata.passed = passed

rows_to_log: List[EvaluationRow] = rows or []
if not rows_to_log:
error_row = EvaluationRow(messages=[], eval_metadata=eval_metadata, evaluation_result=None)
logger.log(error_row)
else:
for r in rows_to_log:
if r.eval_metadata is not None:
r.eval_metadata.status = status
logger.log(r)
Loading
Loading