1- import inspect
2- import os
31import copy
2+ import inspect
43import math
4+ import os
55import statistics
6- from typing import Any , Callable , Dict , List , Optional
6+ from typing import Any , Callable , Dict , List , Literal , Optional
77
88import pytest
99
1010from eval_protocol .dataset_logger import default_logger
11+ from eval_protocol .human_id import generate_id
1112from eval_protocol .models import CompletionParams , EvalMetadata , EvaluationRow , InputMetadata , Message
1213from eval_protocol .pytest .default_dataset_adapter import default_dataset_adapter
1314from eval_protocol .pytest .default_no_op_rollout_process import default_no_op_rollout_processor
2829 aggregate ,
2930 create_dynamically_parameterized_wrapper ,
3031 execute_function ,
32+ log_eval_status_and_rows ,
3133)
34+ from eval_protocol .stats .confidence_intervals import compute_fixed_set_mu_ci
3235
3336from ..common_utils import load_jsonl
34- from eval_protocol .stats .confidence_intervals import compute_fixed_set_mu_ci
3537
3638
37- def evaluation_test (
39+ def evaluation_test ( # noqa: C901
3840 * ,
3941 model : List [ModelParam ],
4042 input_messages : Optional [List [InputMessagesParam ]] = None ,
@@ -59,6 +61,37 @@ def evaluation_test(
5961]:
6062 """Decorator to create pytest-based evaluation tests.
6163
64+ Here are some key concepts to understand the terminology in EP:
65+
66+ - "invocation" is a single execution of a test function. An invocation can
67+ generate 1 or more cohorts. Grouping by invocation might be useful to
68+ aggregate eval scores across multiple invocations when you want to aggregate
69+ scores across multiple datasets.
70+ - "cohort" is a group of runs with for a combination of parameters. A single
71+ cohort will have multiple runs if num_runs > 1.
72+ 1. If your evaluation_test has combinations of parameters, it will generate
73+ multiple cohorts per combination of parameters.
74+ 2. A new execution of a test function will generate a new cohort.
75+ - "run" is a group of rollouts. For multiple num_runs > 1, there will be
76+ multiple "run_id"s.
77+ - "rollout" is the execution/process that produces a "trajectory". You
78+ "execute" multiple rollouts to generate a dataset of trajectories.
79+ - "trajectory" is the result produced by a rollout — a list of OpenAI Chat
80+ Completion messages (e.g. the "messages" field in EvaluationRow).
81+ - "row" both the input and output of an evaluation. For example, in
82+ tau-bench, a row is a task within the dataset that can be identified as
83+ "airline_task_0" or "airline_task_1" etc. The "row_id" can be populated from
84+ the dataset itself to identify a particular task you want to evaluate. If
85+ not provided, EP will generate a "row_id" for each row whenever you call the
86+ evaluation test.
87+ - "dataset" is a collection of rows (e.g. List[EvauluationRow])
88+ - "eval" is a rubric implemented in the body of an @evaluation_test
89+ decorated test. It simply produces a score from 0 to 1 and attached it
90+ to the row as the "evaluation_result" field.
91+
92+ "invocation", "cohort", "run", "rollout", and "row" each have a unique ID
93+ which can be used to easily group and identify your dataset by.
94+
6295 Args:
6396 model: Model identifiers to query.
6497 input_messages: Messages to send to the model. This is useful if you
@@ -75,7 +108,7 @@ def evaluation_test(
75108 aggregation_method: How to aggregate scores across rows.
76109 threshold_of_success: If set, fail the test if the aggregated score is
77110 below this threshold.
78- num_runs: Number of times to repeat the evaluation .
111+ num_runs: Number of times to repeat the rollout and evaluations .
79112 max_dataset_rows: Limit dataset to the first N rows.
80113 mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
81114 max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
@@ -119,15 +152,15 @@ def decorator(
119152
120153 def execute_with_params (
121154 test_func : TestFunction ,
122- row : EvaluationRow | None = None ,
123- input_dataset : List [EvaluationRow ] | None = None ,
155+ processed_row : EvaluationRow | None = None ,
156+ processed_dataset : List [EvaluationRow ] | None = None ,
124157 evaluation_test_kwargs : Optional [EvaluationInputParam ] = None ,
125158 ):
126159 kwargs = {}
127- if input_dataset is not None :
128- kwargs ["rows" ] = input_dataset
129- if row is not None :
130- kwargs ["row" ] = row
160+ if processed_dataset is not None :
161+ kwargs ["rows" ] = processed_dataset
162+ if processed_row is not None :
163+ kwargs ["row" ] = processed_row
131164 if evaluation_test_kwargs is not None :
132165 if "row" in evaluation_test_kwargs :
133166 raise ValueError ("'row' is a reserved parameter for the evaluation function" )
@@ -176,7 +209,7 @@ def generate_combinations():
176209 datasets = [[input_dataset ]] # type: ignore
177210 else :
178211 datasets = [None ]
179- params : List [Optional [RolloutInputParam ]] = rollout_input_params if rollout_input_params is not None else [None ] # type: ignore
212+ rips : List [Optional [RolloutInputParam ]] = rollout_input_params if rollout_input_params is not None else [None ] # type: ignore
180213 # Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over
181214 # each row. Instead, pass the entire sliced list through in a single test run
182215 # so summaries aggregate all rows together (AIME-style behavior).
@@ -195,15 +228,15 @@ def generate_combinations():
195228 # Generate all combinations
196229 for m in model :
197230 for ds in datasets :
198- for ip in params :
231+ for rip in rips :
199232 for im in messages :
200233 for etk in kwargs :
201234 # if no dataset and no messages, raise an error
202235 if ds is None and im is None :
203236 raise ValueError (
204237 "No dataset or messages provided. Please provide at least one of input_dataset or input_messages."
205238 )
206- combinations .append ((m , ds , ip , im , etk ))
239+ combinations .append ((m , ds , rip , im , etk ))
207240
208241 return combinations
209242
@@ -216,12 +249,12 @@ def generate_combinations():
216249 # Create parameter tuples for pytest.mark.parametrize
217250 param_tuples = []
218251 for combo in combinations :
219- model_name , dataset , params , messages , etk = combo
252+ model_name , dataset , rip , messages , etk = combo
220253 param_tuple = [model_name ]
221254 if input_dataset is not None :
222255 param_tuple .append (dataset )
223256 if rollout_input_params is not None :
224- param_tuple .append (params )
257+ param_tuple .append (rip )
225258 if input_messages is not None :
226259 param_tuple .append (messages )
227260 if evaluation_test_kwargs is not None :
@@ -242,11 +275,20 @@ def generate_combinations():
242275 # Create wrapper function with exact signature that pytest expects
243276 def create_wrapper_with_signature () -> Callable :
244277 # Create the function body that will be used
278+ invocation_id = generate_id ()
279+
245280 def wrapper_body (** kwargs ):
246281 model_name = kwargs ["model" ]
247282 eval_metadata = None
248283 all_results : List [EvaluationRow ] = []
249284
285+ cohort_id = generate_id ()
286+
287+ def _log_eval_error (
288+ status : Literal ["finished" , "error" ], rows : Optional [List [EvaluationRow ]] | None , passed : bool
289+ ) -> None :
290+ log_eval_status_and_rows (eval_metadata , rows , status , passed , default_logger )
291+
250292 try :
251293 # Handle dataset loading
252294 data : List [EvaluationRow ] = []
@@ -283,6 +325,7 @@ def wrapper_body(**kwargs):
283325 # into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
284326 try :
285327 import json as _json
328+
286329 _env_override = os .getenv ("EP_INPUT_PARAMS_JSON" )
287330 if _env_override :
288331 override_obj = _json .loads (_env_override )
@@ -320,6 +363,8 @@ def wrapper_body(**kwargs):
320363 row .input_metadata .session_data ["mode" ] = mode
321364 # Initialize eval_metadata for each row
322365 row .eval_metadata = eval_metadata
366+ row .cohort_id = cohort_id
367+ row .invocation_id = invocation_id
323368
324369 # has to be done in the pytest main process since it's
325370 # used to determine whether this eval has stopped
@@ -339,14 +384,25 @@ def wrapper_body(**kwargs):
339384 for _ in range (num_runs ):
340385 # Regenerate outputs each run by deep-copying the pristine dataset
341386 # so model responses are not reused across runs.
342- fresh_rows = [copy .deepcopy (r ) for r in data ]
343- input_dataset = execute_function (rollout_processor , rows = fresh_rows , config = config )
387+ run_id = generate_id ()
388+ fresh_dataset = [copy .deepcopy (r ) for r in data ]
389+
390+ # apply new run_id to fresh_dataset
391+ for row in fresh_dataset :
392+ row .run_id = run_id
393+
394+ # generate new rollout_id for each row
395+ for row in fresh_dataset :
396+ row .rollout_id = generate_id ()
397+
398+ processed_dataset = execute_function (rollout_processor , rows = fresh_dataset , config = config )
399+
344400 if mode == "pointwise" :
345401 # Pointwise mode: apply the evaluator function to each row
346- for row in input_dataset :
402+ for row in processed_dataset :
347403 result = execute_with_params (
348404 test_func ,
349- row = row ,
405+ processed_row = row ,
350406 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
351407 )
352408 if result is None or not isinstance (result , EvaluationRow ):
@@ -358,7 +414,7 @@ def wrapper_body(**kwargs):
358414 # Batch mode: call the test function with the full dataset
359415 results = execute_with_params (
360416 test_func ,
361- input_dataset = input_dataset ,
417+ processed_dataset = processed_dataset ,
362418 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
363419 )
364420 if results is None :
@@ -430,6 +486,7 @@ def wrapper_body(**kwargs):
430486 # Aggregate per-metric mean and 95% CI when available
431487 metrics_summary : Dict [str , Dict [str , float ]] = {}
432488 from collections import defaultdict
489+
433490 metric_scores : Dict [str , list ] = defaultdict (list )
434491 for r in all_results :
435492 if r .evaluation_result and r .evaluation_result .metrics :
@@ -470,7 +527,10 @@ def wrapper_body(**kwargs):
470527 )
471528 # As per project convention, avoid printing per-metric CI lines to reduce noise
472529 if summary_path :
473- import json , pathlib , time , re
530+ import json
531+ import pathlib
532+ import re
533+ import time
474534
475535 def _sanitize_filename (text : str ) -> str :
476536 safe = re .sub (r"[^A-Za-z0-9._-]+" , "-" , text .strip ())
@@ -487,7 +547,11 @@ def _extract_effort_tag(params: dict) -> str | None:
487547 return str (eb ["reasoning" ]["effort" ]).lower ()
488548 if "reasoning_effort" in eb :
489549 return str (eb ["reasoning_effort" ]).lower ()
490- if "reasoning" in params and isinstance (params ["reasoning" ], dict ) and "effort" in params ["reasoning" ]:
550+ if (
551+ "reasoning" in params
552+ and isinstance (params ["reasoning" ], dict )
553+ and "effort" in params ["reasoning" ]
554+ ):
491555 return str (params ["reasoning" ]["effort" ]).lower ()
492556 except Exception :
493557 return None
@@ -529,25 +593,11 @@ def _extract_effort_tag(params: dict) -> str | None:
529593 agg_score >= threshold_of_success
530594 ), f"Aggregated score { agg_score :.3f} below threshold { threshold_of_success } "
531595
596+ except AssertionError :
597+ _log_eval_error ("finished" , data if "data" in locals () else None , passed = False )
598+ raise
532599 except Exception :
533- # Update eval metadata status to error and log it
534- if eval_metadata is not None :
535- eval_metadata .status = "error"
536- eval_metadata .passed = False
537-
538- # Create a minimal result row to log the error if we don't have any results yet
539- if not data :
540- error_row = EvaluationRow (messages = [], eval_metadata = eval_metadata , evaluation_result = None )
541- default_logger .log (error_row )
542- else :
543- # Update existing results with error status
544- for r in data :
545- if r .eval_metadata is not None :
546- r .eval_metadata .status = "error"
547- r .eval_metadata .passed = False
548- default_logger .log (r )
549-
550- # Re-raise the exception to maintain pytest behavior
600+ _log_eval_error ("error" , data if "data" in locals () else None , passed = False )
551601 raise
552602
553603 return create_dynamically_parameterized_wrapper (test_func , wrapper_body , test_param_names )
0 commit comments