Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions eval_protocol/benchmarks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def test_aime_pointwise(row: EvaluationRow) -> EvaluationRow:
import os
from typing import Any, Callable, Dict, List, Optional


# Global registry: name -> callable runner
_BENCHMARK_REGISTRY: Dict[str, Callable[..., Any]] = {}

Expand Down Expand Up @@ -61,9 +60,7 @@ def export_benchmark(name: str) -> Callable[[Callable[..., Any]], Callable[...,
def _decorator(test_wrapper: Callable[..., Any]) -> Callable[..., Any]:
# Pull through metadata attached by evaluation_test
ep_config: Dict[str, Any] = getattr(test_wrapper, "__ep_config", {})
original_test_func: Optional[Callable[..., Any]] = getattr(
test_wrapper, "__ep_original_test_func", None
)
original_test_func: Optional[Callable[..., Any]] = getattr(test_wrapper, "__ep_original_test_func", None)

def _runner(
*,
Expand All @@ -87,13 +84,15 @@ def _runner(
# Fireworks OpenAI-compatible endpoint expects extra_body.reasoning_effort, not nested reasoning dict
merged.setdefault("extra_body", {})["reasoning_effort"] = str(reasoning_effort)
if input_params_override:

def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
for k, v in over.items():
if isinstance(v, dict) and isinstance(base.get(k), dict):
_deep_update(base[k], v)
else:
base[k] = v
return base

merged = _deep_update(merged, dict(input_params_override))
if merged:
os.environ["EP_INPUT_PARAMS_JSON"] = json.dumps(merged)
Expand All @@ -108,15 +107,14 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
models: List[str] = ep_config.get("model") or []
model_to_use = model or (models[0] if models else None)
if not model_to_use:
raise ValueError(
f"No model provided and none captured from evaluation_test for benchmark '{name}'"
)
raise ValueError(f"No model provided and none captured from evaluation_test for benchmark '{name}'")

input_messages = ep_config.get("input_messages")
input_dataset = ep_config.get("input_dataset")
dataset_adapter = ep_config.get("dataset_adapter")
rollout_input_params_list = ep_config.get("rollout_input_params")
rollout_processor = ep_config.get("rollout_processor")
rollout_processor_kwargs = ep_config.get("rollout_processor_kwargs")
aggregation_method = ep_config.get("aggregation_method")
threshold = ep_config.get("threshold_of_success")
default_num_runs = ep_config.get("num_runs")
Expand Down Expand Up @@ -149,6 +147,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
dataset_adapter=dataset_adapter,
rollout_input_params=rollout_params,
rollout_processor=rollout_processor,
rollout_processor_kwargs=rollout_processor_kwargs,
aggregation_method=aggregation_method,
threshold_of_success=threshold,
num_runs=(num_runs if num_runs is not None else default_num_runs),
Expand All @@ -170,5 +169,3 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
return test_wrapper

return _decorator


1 change: 1 addition & 0 deletions eval_protocol/benchmarks/suites/tau_bench_retail.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
rollout_input_params=[{"temperature": 0.8, "extra_body": {"reasoning_effort": "medium"}}],
rollout_processor=default_mcp_gym_rollout_processor,
rollout_processor_kwargs={"domain": "retail"},
num_runs=8,
mode="pointwise",
max_concurrent_rollouts=50,
Expand Down
8 changes: 4 additions & 4 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ class MCPServerManager:
_active_servers = []
_cleanup_registered = False

def __init__(self, server_script: str, port: int = 8000, domain: str = "airline"):
def __init__(self, server_script: str, port: int = 8000, **kwargs):
self.server_script = server_script
self.port = port
self.domain = domain
self.domain = str(kwargs.get("domain", "airline"))
self.process: Optional[subprocess.Popen] = None
self.base_dir = Path(".").resolve()
self._log_file = None
Expand Down Expand Up @@ -58,7 +58,7 @@ def start(self) -> None:
env["PORT"] = str(self.port)

# Start server process (no domain argument needed for tau2_mcp server)
cmd = ["python", self.server_script, "--port", str(self.port)]
cmd = ["python", self.server_script, "--port", str(self.port), "--domain", self.domain]

# Setup log file with cleanup
log_file_path = os.path.join(self.base_dir, f"server_output_{self.domain}_{self.port}.log")
Expand Down Expand Up @@ -213,7 +213,7 @@ async def default_mcp_gym_rollout_processor(
"""
if config.server_script_path is None:
raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor")
server = MCPServerManager(config.server_script_path, port=9700)
server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {}))

try:
server.start()
Expand Down
23 changes: 17 additions & 6 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
RolloutInputParam,
RolloutProcessor,
RolloutProcessorConfig,
RolloutProcessorInputParam,
TestFunction,
)
from eval_protocol.pytest.utils import (
Expand All @@ -53,6 +54,7 @@ def evaluation_test( # noqa: C901
rollout_input_params: Optional[List[RolloutInputParam]] = None,
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None,
rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None,
aggregation_method: AggregationMethod = "mean",
passed_threshold: Optional[Union[EvaluationThreshold, float]] = None,
num_runs: int = 1,
Expand Down Expand Up @@ -114,6 +116,7 @@ def evaluation_test( # noqa: C901
rollout_input_params: Generation parameters for the rollout.
rollout_processor: Function used to perform the rollout.
evaluation_test_kwargs: Kwargs for the evaluation function.
rollout_processor_kwargs: Kwargs for the rollout processor.
aggregation_method: How to aggregate scores across rows.
passed_threshold: Threshold configuration for test success.
Success rate must be above success, and if set, standard deviation must be below standard_deviation.
Expand Down Expand Up @@ -399,6 +402,7 @@ def _log_eval_error(
server_script_path=server_script_path,
steps=steps,
logger=active_logger,
kwargs=rollout_processor_kwargs,
)

for i in range(num_runs):
Expand Down Expand Up @@ -765,6 +769,7 @@ def dual_mode_wrapper(*args, **kwargs):
"rollout_input_params": rollout_input_params,
"rollout_processor": rollout_processor,
"evaluation_test_kwargs": evaluation_test_kwargs,
"rollout_processor_kwargs": rollout_processor_kwargs,
"aggregation_method": aggregation_method,
"passed_threshold": passed_threshold,
"num_runs": num_runs,
Expand Down Expand Up @@ -832,6 +837,7 @@ def run_evaluation_test_direct(
dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter,
rollout_input_params: Optional[RolloutInputParam] = None,
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None,
aggregation_method: AggregationMethod = "mean",
threshold_of_success: Optional[float] = None,
num_runs: int = 1,
Expand Down Expand Up @@ -941,6 +947,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
max_concurrent_rollouts=max_concurrent_rollouts,
server_script_path=server_script_path,
steps=steps,
kwargs=rollout_processor_kwargs,
)

all_results: List[EvaluationRow] = []
Expand Down Expand Up @@ -1022,8 +1029,8 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
if summary_path:
import json as _json
import pathlib as _pathlib
import time as _time
import re as _re
import time as _time

def _sanitize_filename(text: str) -> str:
safe = _re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip())
Expand All @@ -1039,7 +1046,11 @@ def _extract_effort_tag(params: dict) -> str | None:
return str(eb["reasoning"]["effort"]).lower()
if "reasoning_effort" in eb:
return str(eb["reasoning_effort"]).lower()
if "reasoning" in params and isinstance(params["reasoning"], dict) and "effort" in params["reasoning"]:
if (
"reasoning" in params
and isinstance(params["reasoning"], dict)
and "effort" in params["reasoning"]
):
return str(params["reasoning"]["effort"]).lower()
except Exception:
return None
Expand Down Expand Up @@ -1069,17 +1080,17 @@ def _extract_effort_tag(params: dict) -> str | None:
pass

if threshold_of_success is not None and not passed:
assert agg_score >= threshold_of_success, (
f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
)
assert (
agg_score >= threshold_of_success
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"

return {"summary": summary_obj, "results": all_results}
except Exception:
# Mark errors on rows
if eval_metadata is not None:
eval_metadata.status = "error"
eval_metadata.passed = False
for r in (data or []):
for r in data or []:
if r.eval_metadata is not None:
r.eval_metadata.status = "error"
r.eval_metadata.passed = False
Expand Down
4 changes: 3 additions & 1 deletion eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Parameter types
"""

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Literal, Optional

from eval_protocol.dataset_logger import default_logger
Expand All @@ -15,6 +15,7 @@
RolloutInputParam = Dict[str, Any]
InputMessagesParam = List[Message]
EvaluationInputParam = Dict[str, Any]
RolloutProcessorInputParam = Dict[str, Any]

Dataset = List[EvaluationRow]

Expand Down Expand Up @@ -49,6 +50,7 @@ class RolloutProcessorConfig:
max_concurrent_rollouts: int = 8 # maximum number of concurrent rollouts
steps: int = 30 # max number of rollout steps
logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs
kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor


RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]]
7 changes: 5 additions & 2 deletions examples/tau2_mcp/retail_environment/retail_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ class RetailEnvironment:

def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
self.db = RetailDB.load(RETAIL_DB_PATH)
self.retail_tools = RetailTools(self.db)
self.db = None
self.airline_tools = None

def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Reset the environment to initial state"""
self.db = RetailDB.load(RETAIL_DB_PATH)
self.retail_tools = RetailTools(self.db)

return {}, {}

def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
Expand Down
Loading