diff --git a/eval_protocol/benchmarks/registry.py b/eval_protocol/benchmarks/registry.py index 1e3b3e7b..98065b82 100644 --- a/eval_protocol/benchmarks/registry.py +++ b/eval_protocol/benchmarks/registry.py @@ -29,7 +29,6 @@ def test_aime_pointwise(row: EvaluationRow) -> EvaluationRow: import os from typing import Any, Callable, Dict, List, Optional - # Global registry: name -> callable runner _BENCHMARK_REGISTRY: Dict[str, Callable[..., Any]] = {} @@ -61,9 +60,7 @@ def export_benchmark(name: str) -> Callable[[Callable[..., Any]], Callable[..., def _decorator(test_wrapper: Callable[..., Any]) -> Callable[..., Any]: # Pull through metadata attached by evaluation_test ep_config: Dict[str, Any] = getattr(test_wrapper, "__ep_config", {}) - original_test_func: Optional[Callable[..., Any]] = getattr( - test_wrapper, "__ep_original_test_func", None - ) + original_test_func: Optional[Callable[..., Any]] = getattr(test_wrapper, "__ep_original_test_func", None) def _runner( *, @@ -87,6 +84,7 @@ def _runner( # Fireworks OpenAI-compatible endpoint expects extra_body.reasoning_effort, not nested reasoning dict merged.setdefault("extra_body", {})["reasoning_effort"] = str(reasoning_effort) if input_params_override: + def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]: for k, v in over.items(): if isinstance(v, dict) and isinstance(base.get(k), dict): @@ -94,6 +92,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]: else: base[k] = v return base + merged = _deep_update(merged, dict(input_params_override)) if merged: os.environ["EP_INPUT_PARAMS_JSON"] = json.dumps(merged) @@ -108,15 +107,14 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]: models: List[str] = ep_config.get("model") or [] model_to_use = model or (models[0] if models else None) if not model_to_use: - raise ValueError( - f"No model provided and none captured from evaluation_test for benchmark '{name}'" - ) + raise ValueError(f"No model provided and none captured from evaluation_test for benchmark '{name}'") input_messages = ep_config.get("input_messages") input_dataset = ep_config.get("input_dataset") dataset_adapter = ep_config.get("dataset_adapter") rollout_input_params_list = ep_config.get("rollout_input_params") rollout_processor = ep_config.get("rollout_processor") + rollout_processor_kwargs = ep_config.get("rollout_processor_kwargs") aggregation_method = ep_config.get("aggregation_method") threshold = ep_config.get("threshold_of_success") default_num_runs = ep_config.get("num_runs") @@ -149,6 +147,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]: dataset_adapter=dataset_adapter, rollout_input_params=rollout_params, rollout_processor=rollout_processor, + rollout_processor_kwargs=rollout_processor_kwargs, aggregation_method=aggregation_method, threshold_of_success=threshold, num_runs=(num_runs if num_runs is not None else default_num_runs), @@ -170,5 +169,3 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]: return test_wrapper return _decorator - - diff --git a/eval_protocol/benchmarks/suites/tau_bench_retail.py b/eval_protocol/benchmarks/suites/tau_bench_retail.py index 51beab0b..9e1104d4 100644 --- a/eval_protocol/benchmarks/suites/tau_bench_retail.py +++ b/eval_protocol/benchmarks/suites/tau_bench_retail.py @@ -69,6 +69,7 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], rollout_input_params=[{"temperature": 0.8, "extra_body": {"reasoning_effort": "medium"}}], rollout_processor=default_mcp_gym_rollout_processor, + rollout_processor_kwargs={"domain": "retail"}, num_runs=8, mode="pointwise", max_concurrent_rollouts=50, diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 0adbbea0..5037cbad 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -20,10 +20,10 @@ class MCPServerManager: _active_servers = [] _cleanup_registered = False - def __init__(self, server_script: str, port: int = 8000, domain: str = "airline"): + def __init__(self, server_script: str, port: int = 8000, **kwargs): self.server_script = server_script self.port = port - self.domain = domain + self.domain = str(kwargs.get("domain", "airline")) self.process: Optional[subprocess.Popen] = None self.base_dir = Path(".").resolve() self._log_file = None @@ -58,7 +58,7 @@ def start(self) -> None: env["PORT"] = str(self.port) # Start server process (no domain argument needed for tau2_mcp server) - cmd = ["python", self.server_script, "--port", str(self.port)] + cmd = ["python", self.server_script, "--port", str(self.port), "--domain", self.domain] # Setup log file with cleanup log_file_path = os.path.join(self.base_dir, f"server_output_{self.domain}_{self.port}.log") @@ -213,7 +213,7 @@ async def default_mcp_gym_rollout_processor( """ if config.server_script_path is None: raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor") - server = MCPServerManager(config.server_script_path, port=9700) + server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {})) try: server.start() diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 7557ae3d..f1d9af50 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -30,6 +30,7 @@ RolloutInputParam, RolloutProcessor, RolloutProcessorConfig, + RolloutProcessorInputParam, TestFunction, ) from eval_protocol.pytest.utils import ( @@ -53,6 +54,7 @@ def evaluation_test( # noqa: C901 rollout_input_params: Optional[List[RolloutInputParam]] = None, rollout_processor: RolloutProcessor = default_no_op_rollout_processor, evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None, + rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None, aggregation_method: AggregationMethod = "mean", passed_threshold: Optional[Union[EvaluationThreshold, float]] = None, num_runs: int = 1, @@ -114,6 +116,7 @@ def evaluation_test( # noqa: C901 rollout_input_params: Generation parameters for the rollout. rollout_processor: Function used to perform the rollout. evaluation_test_kwargs: Kwargs for the evaluation function. + rollout_processor_kwargs: Kwargs for the rollout processor. aggregation_method: How to aggregate scores across rows. passed_threshold: Threshold configuration for test success. Success rate must be above success, and if set, standard deviation must be below standard_deviation. @@ -399,6 +402,7 @@ def _log_eval_error( server_script_path=server_script_path, steps=steps, logger=active_logger, + kwargs=rollout_processor_kwargs, ) for i in range(num_runs): @@ -765,6 +769,7 @@ def dual_mode_wrapper(*args, **kwargs): "rollout_input_params": rollout_input_params, "rollout_processor": rollout_processor, "evaluation_test_kwargs": evaluation_test_kwargs, + "rollout_processor_kwargs": rollout_processor_kwargs, "aggregation_method": aggregation_method, "passed_threshold": passed_threshold, "num_runs": num_runs, @@ -832,6 +837,7 @@ def run_evaluation_test_direct( dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter, rollout_input_params: Optional[RolloutInputParam] = None, rollout_processor: RolloutProcessor = default_no_op_rollout_processor, + rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None, aggregation_method: AggregationMethod = "mean", threshold_of_success: Optional[float] = None, num_runs: int = 1, @@ -941,6 +947,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict: max_concurrent_rollouts=max_concurrent_rollouts, server_script_path=server_script_path, steps=steps, + kwargs=rollout_processor_kwargs, ) all_results: List[EvaluationRow] = [] @@ -1022,8 +1029,8 @@ def _deep_update_dict(base: dict, override: dict) -> dict: if summary_path: import json as _json import pathlib as _pathlib - import time as _time import re as _re + import time as _time def _sanitize_filename(text: str) -> str: safe = _re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip()) @@ -1039,7 +1046,11 @@ def _extract_effort_tag(params: dict) -> str | None: return str(eb["reasoning"]["effort"]).lower() if "reasoning_effort" in eb: return str(eb["reasoning_effort"]).lower() - if "reasoning" in params and isinstance(params["reasoning"], dict) and "effort" in params["reasoning"]: + if ( + "reasoning" in params + and isinstance(params["reasoning"], dict) + and "effort" in params["reasoning"] + ): return str(params["reasoning"]["effort"]).lower() except Exception: return None @@ -1069,9 +1080,9 @@ def _extract_effort_tag(params: dict) -> str | None: pass if threshold_of_success is not None and not passed: - assert agg_score >= threshold_of_success, ( - f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}" - ) + assert ( + agg_score >= threshold_of_success + ), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}" return {"summary": summary_obj, "results": all_results} except Exception: @@ -1079,7 +1090,7 @@ def _extract_effort_tag(params: dict) -> str | None: if eval_metadata is not None: eval_metadata.status = "error" eval_metadata.passed = False - for r in (data or []): + for r in data or []: if r.eval_metadata is not None: r.eval_metadata.status = "error" r.eval_metadata.passed = False diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 42fb3d56..c6de681e 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -2,7 +2,7 @@ Parameter types """ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Literal, Optional from eval_protocol.dataset_logger import default_logger @@ -15,6 +15,7 @@ RolloutInputParam = Dict[str, Any] InputMessagesParam = List[Message] EvaluationInputParam = Dict[str, Any] +RolloutProcessorInputParam = Dict[str, Any] Dataset = List[EvaluationRow] @@ -49,6 +50,7 @@ class RolloutProcessorConfig: max_concurrent_rollouts: int = 8 # maximum number of concurrent rollouts steps: int = 30 # max number of rollout steps logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs + kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]] diff --git a/examples/tau2_mcp/retail_environment/retail_environment.py b/examples/tau2_mcp/retail_environment/retail_environment.py index 122fc92e..425ef785 100644 --- a/examples/tau2_mcp/retail_environment/retail_environment.py +++ b/examples/tau2_mcp/retail_environment/retail_environment.py @@ -30,11 +30,14 @@ class RetailEnvironment: def __init__(self, config: Optional[Dict[str, Any]] = None): self.config = config or {} - self.db = RetailDB.load(RETAIL_DB_PATH) - self.retail_tools = RetailTools(self.db) + self.db = None + self.airline_tools = None def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Reset the environment to initial state""" + self.db = RetailDB.load(RETAIL_DB_PATH) + self.retail_tools = RetailTools(self.db) + return {}, {} def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]: