33import math
44import os
55import statistics
6- from typing import Any , Callable , Dict , List , Literal , Optional
6+ from typing import Any , Callable , Dict , List , Literal , Optional , Union
77
88import pytest
99
1010from eval_protocol .dataset_logger import default_logger
1111from eval_protocol .dataset_logger .dataset_logger import DatasetLogger
1212from eval_protocol .human_id import generate_id
13- from eval_protocol .models import CompletionParams , EvalMetadata , EvaluationRow , InputMetadata , Message
13+ from eval_protocol .models import (
14+ CompletionParams ,
15+ EvalMetadata ,
16+ EvaluationRow ,
17+ EvaluationThreshold ,
18+ InputMetadata ,
19+ Message ,
20+ )
1421from eval_protocol .pytest .default_dataset_adapter import default_dataset_adapter
1522from eval_protocol .pytest .default_no_op_rollout_process import default_no_op_rollout_processor
1623from eval_protocol .pytest .types import (
@@ -47,7 +54,7 @@ def evaluation_test( # noqa: C901
4754 rollout_processor : RolloutProcessor = default_no_op_rollout_processor ,
4855 evaluation_test_kwargs : Optional [List [EvaluationInputParam ]] = None ,
4956 aggregation_method : AggregationMethod = "mean" ,
50- threshold_of_success : Optional [float ] = None ,
57+ passed_threshold : Optional [Union [ EvaluationThreshold , float ] ] = None ,
5158 num_runs : int = 1 ,
5259 max_dataset_rows : Optional [int ] = None ,
5360 mcp_config_path : Optional [str ] = None ,
@@ -108,8 +115,8 @@ def evaluation_test( # noqa: C901
108115 rollout_processor: Function used to perform the rollout.
109116 evaluation_test_kwargs: Kwargs for the evaluation function.
110117 aggregation_method: How to aggregate scores across rows.
111- threshold_of_success: If set, fail the test if the aggregated score is
112- below this threshold .
118+ passed_threshold: Threshold configuration for test success.
119+ Success rate must be above success, and if set, standard deviation must be below standard_deviation .
113120 num_runs: Number of times to repeat the rollout and evaluations.
114121 max_dataset_rows: Limit dataset to the first N rows.
115122 mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
@@ -127,6 +134,14 @@ def evaluation_test( # noqa: C901
127134 def decorator (
128135 test_func : TestFunction ,
129136 ):
137+ if passed_threshold is not None :
138+ if isinstance (passed_threshold , float ):
139+ threshold = EvaluationThreshold (success = passed_threshold )
140+ else :
141+ threshold = EvaluationThreshold (** passed_threshold )
142+ else :
143+ threshold = None
144+
130145 sig = inspect .signature (test_func )
131146
132147 # For pointwise/rowwise mode, we expect a different signature
@@ -285,7 +300,7 @@ def create_wrapper_with_signature() -> Callable:
285300 def wrapper_body (** kwargs ):
286301 model_name = kwargs ["model" ]
287302 eval_metadata = None
288- all_results : List [EvaluationRow ] = []
303+ all_results : List [List [ EvaluationRow ]] = [[] for _ in range ( num_runs ) ]
289304
290305 cohort_id = generate_id ()
291306
@@ -346,7 +361,7 @@ def _log_eval_error(
346361 status = "running" ,
347362 num_runs = num_runs ,
348363 aggregation_method = aggregation_method ,
349- threshold_of_success = threshold_of_success ,
364+ passed_threshold = threshold ,
350365 passed = None ,
351366 )
352367
@@ -386,11 +401,11 @@ def _log_eval_error(
386401 logger = active_logger ,
387402 )
388403
389- for _ in range (num_runs ):
404+ for i in range (num_runs ):
390405 # Regenerate outputs each run by deep-copying the pristine dataset
391406 # so model responses are not reused across runs.
392407 run_id = generate_id ()
393- fresh_dataset = [copy . deepcopy ( r ) for r in data ]
408+ fresh_dataset = [r . model_copy ( deep = True ) for r in data ]
394409
395410 # apply new run_id to fresh_dataset
396411 for row in fresh_dataset :
@@ -418,7 +433,7 @@ def _log_eval_error(
418433 raise ValueError (
419434 f"Test function { test_func .__name__ } did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
420435 )
421- all_results .append (result )
436+ all_results [ i ] .append (result )
422437 else :
423438 # Batch mode: call the test function with the full dataset
424439 results = execute_with_params (
@@ -442,17 +457,21 @@ def _log_eval_error(
442457 raise ValueError (
443458 f"Test function { test_func .__name__ } returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
444459 )
445- all_results . extend ( results )
460+ all_results [ i ] = results
446461
447- scores = [r .evaluation_result .score for r in all_results if r .evaluation_result ]
462+ scores = [
463+ sum ([r .evaluation_result .score for r in result if r .evaluation_result ]) / len (result )
464+ for result in all_results
465+ ]
448466 agg_score = aggregate (scores , aggregation_method )
467+ score_std = statistics .stdev (scores ) if len (scores ) > 1 else 0.0
449468
450469 # Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats)
451470 ci_low : float | None = None
452471 ci_high : float | None = None
453472 if aggregation_method == "mean" :
454473 try :
455- result_ci = compute_fixed_set_mu_ci (all_results )
474+ result_ci = compute_fixed_set_mu_ci ([ item for sublist in all_results for item in sublist ] )
456475 mu_ci_low , mu_ci_high = result_ci [1 ], result_ci [2 ]
457476 if mu_ci_low is not None and mu_ci_high is not None :
458477 ci_low = float (mu_ci_low )
@@ -464,23 +483,32 @@ def _log_eval_error(
464483
465484 # Determine if the evaluation passed based on threshold
466485 passed = None
467- if threshold_of_success is not None :
468- passed = agg_score >= threshold_of_success
486+
487+ if threshold is not None :
488+ success_passed , std_passed = True , True
489+
490+ success_passed = agg_score >= threshold .success
491+
492+ if threshold .standard_deviation is not None :
493+ std_passed = score_std <= threshold .standard_deviation
494+
495+ passed = success_passed and std_passed
469496
470497 # Update eval metadata status and passed field for all results
471- for r in all_results :
472- if r .eval_metadata is not None :
473- r .eval_metadata .status = "finished"
474- r .eval_metadata .passed = passed
475- active_logger .log (r )
498+ for result in all_results :
499+ for r in result :
500+ if r .eval_metadata is not None :
501+ r .eval_metadata .status = "finished"
502+ r .eval_metadata .passed = passed
503+ default_logger .log (r )
476504
477505 # Optional: print and/or persist a summary artifact for CI
478506 try :
479507 should_print = os .getenv ("EP_PRINT_SUMMARY" ) == "1"
480508 summary_path = os .getenv ("EP_SUMMARY_JSON" )
481509 suite_name = test_func .__name__
482510 model_used = model_name
483- total_rows = len (all_results )
511+ total_rows = len ([ item for sublist in all_results for item in sublist ] )
484512 summary_obj = {
485513 "suite" : suite_name ,
486514 "model" : model_used ,
@@ -497,7 +525,7 @@ def _log_eval_error(
497525 from collections import defaultdict
498526
499527 metric_scores : Dict [str , list ] = defaultdict (list )
500- for r in all_results :
528+ for r in [ item for sublist in all_results for item in sublist ] :
501529 if r .evaluation_result and r .evaluation_result .metrics :
502530 for m_name , m_res in r .evaluation_result .metrics .items ():
503531 if m_res is not None and getattr (m_res , "score" , None ) is not None :
@@ -614,10 +642,14 @@ def _extract_effort_tag(params: dict) -> str | None:
614642 # pass
615643
616644 # Check threshold after logging
617- if threshold_of_success is not None and not passed :
645+ if threshold is not None and not passed :
618646 assert (
619- agg_score >= threshold_of_success
620- ), f"Aggregated score { agg_score :.3f} below threshold { threshold_of_success } "
647+ agg_score >= threshold .success
648+ ), f"Aggregated score { agg_score :.3f} below threshold { threshold .success } "
649+ if threshold .standard_deviation is not None :
650+ assert (
651+ score_std <= threshold .standard_deviation
652+ ), f"Standard deviation { score_std :.3f} above threshold { threshold .standard_deviation } "
621653
622654 except AssertionError :
623655 _log_eval_error ("finished" , data if "data" in locals () else None , passed = False )
0 commit comments