3030 RolloutInputParam ,
3131 RolloutProcessor ,
3232 RolloutProcessorConfig ,
33+ RolloutProcessorInputParam ,
3334 TestFunction ,
3435)
3536from eval_protocol .pytest .utils import (
@@ -53,6 +54,7 @@ def evaluation_test( # noqa: C901
5354 rollout_input_params : Optional [List [RolloutInputParam ]] = None ,
5455 rollout_processor : RolloutProcessor = default_no_op_rollout_processor ,
5556 evaluation_test_kwargs : Optional [List [EvaluationInputParam ]] = None ,
57+ rollout_processor_kwargs : Optional [RolloutProcessorInputParam ] = None ,
5658 aggregation_method : AggregationMethod = "mean" ,
5759 passed_threshold : Optional [Union [EvaluationThreshold , float ]] = None ,
5860 num_runs : int = 1 ,
@@ -114,6 +116,7 @@ def evaluation_test( # noqa: C901
114116 rollout_input_params: Generation parameters for the rollout.
115117 rollout_processor: Function used to perform the rollout.
116118 evaluation_test_kwargs: Kwargs for the evaluation function.
119+ rollout_processor_kwargs: Kwargs for the rollout processor.
117120 aggregation_method: How to aggregate scores across rows.
118121 passed_threshold: Threshold configuration for test success.
119122 Success rate must be above success, and if set, standard deviation must be below standard_deviation.
@@ -399,6 +402,7 @@ def _log_eval_error(
399402 server_script_path = server_script_path ,
400403 steps = steps ,
401404 logger = active_logger ,
405+ kwargs = rollout_processor_kwargs ,
402406 )
403407
404408 for i in range (num_runs ):
@@ -765,6 +769,7 @@ def dual_mode_wrapper(*args, **kwargs):
765769 "rollout_input_params" : rollout_input_params ,
766770 "rollout_processor" : rollout_processor ,
767771 "evaluation_test_kwargs" : evaluation_test_kwargs ,
772+ "rollout_processor_kwargs" : rollout_processor_kwargs ,
768773 "aggregation_method" : aggregation_method ,
769774 "passed_threshold" : passed_threshold ,
770775 "num_runs" : num_runs ,
@@ -832,6 +837,7 @@ def run_evaluation_test_direct(
832837 dataset_adapter : Callable [[List [Dict [str , Any ]]], Dataset ] = default_dataset_adapter ,
833838 rollout_input_params : Optional [RolloutInputParam ] = None ,
834839 rollout_processor : RolloutProcessor = default_no_op_rollout_processor ,
840+ rollout_processor_kwargs : Optional [RolloutProcessorInputParam ] = None ,
835841 aggregation_method : AggregationMethod = "mean" ,
836842 threshold_of_success : Optional [float ] = None ,
837843 num_runs : int = 1 ,
@@ -941,6 +947,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
941947 max_concurrent_rollouts = max_concurrent_rollouts ,
942948 server_script_path = server_script_path ,
943949 steps = steps ,
950+ kwargs = rollout_processor_kwargs ,
944951 )
945952
946953 all_results : List [EvaluationRow ] = []
@@ -1022,8 +1029,8 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
10221029 if summary_path :
10231030 import json as _json
10241031 import pathlib as _pathlib
1025- import time as _time
10261032 import re as _re
1033+ import time as _time
10271034
10281035 def _sanitize_filename (text : str ) -> str :
10291036 safe = _re .sub (r"[^A-Za-z0-9._-]+" , "-" , text .strip ())
@@ -1039,7 +1046,11 @@ def _extract_effort_tag(params: dict) -> str | None:
10391046 return str (eb ["reasoning" ]["effort" ]).lower ()
10401047 if "reasoning_effort" in eb :
10411048 return str (eb ["reasoning_effort" ]).lower ()
1042- if "reasoning" in params and isinstance (params ["reasoning" ], dict ) and "effort" in params ["reasoning" ]:
1049+ if (
1050+ "reasoning" in params
1051+ and isinstance (params ["reasoning" ], dict )
1052+ and "effort" in params ["reasoning" ]
1053+ ):
10431054 return str (params ["reasoning" ]["effort" ]).lower ()
10441055 except Exception :
10451056 return None
@@ -1069,17 +1080,17 @@ def _extract_effort_tag(params: dict) -> str | None:
10691080 pass
10701081
10711082 if threshold_of_success is not None and not passed :
1072- assert agg_score >= threshold_of_success , (
1073- f"Aggregated score { agg_score :.3f } below threshold { threshold_of_success } "
1074- )
1083+ assert (
1084+ agg_score >= threshold_of_success
1085+ ), f"Aggregated score { agg_score :.3f } below threshold { threshold_of_success } "
10751086
10761087 return {"summary" : summary_obj , "results" : all_results }
10771088 except Exception :
10781089 # Mark errors on rows
10791090 if eval_metadata is not None :
10801091 eval_metadata .status = "error"
10811092 eval_metadata .passed = False
1082- for r in ( data or []) :
1093+ for r in data or []:
10831094 if r .eval_metadata is not None :
10841095 r .eval_metadata .status = "error"
10851096 r .eval_metadata .passed = False
0 commit comments