Skip to content

Commit db2785b

Browse files
authored
bug fixes (#69)
1 parent ed4409e commit db2785b

File tree

6 files changed

+36
-22
lines changed

6 files changed

+36
-22
lines changed

eval_protocol/benchmarks/registry.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def test_aime_pointwise(row: EvaluationRow) -> EvaluationRow:
2929
import os
3030
from typing import Any, Callable, Dict, List, Optional
3131

32-
3332
# Global registry: name -> callable runner
3433
_BENCHMARK_REGISTRY: Dict[str, Callable[..., Any]] = {}
3534

@@ -61,9 +60,7 @@ def export_benchmark(name: str) -> Callable[[Callable[..., Any]], Callable[...,
6160
def _decorator(test_wrapper: Callable[..., Any]) -> Callable[..., Any]:
6261
# Pull through metadata attached by evaluation_test
6362
ep_config: Dict[str, Any] = getattr(test_wrapper, "__ep_config", {})
64-
original_test_func: Optional[Callable[..., Any]] = getattr(
65-
test_wrapper, "__ep_original_test_func", None
66-
)
63+
original_test_func: Optional[Callable[..., Any]] = getattr(test_wrapper, "__ep_original_test_func", None)
6764

6865
def _runner(
6966
*,
@@ -87,13 +84,15 @@ def _runner(
8784
# Fireworks OpenAI-compatible endpoint expects extra_body.reasoning_effort, not nested reasoning dict
8885
merged.setdefault("extra_body", {})["reasoning_effort"] = str(reasoning_effort)
8986
if input_params_override:
87+
9088
def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
9189
for k, v in over.items():
9290
if isinstance(v, dict) and isinstance(base.get(k), dict):
9391
_deep_update(base[k], v)
9492
else:
9593
base[k] = v
9694
return base
95+
9796
merged = _deep_update(merged, dict(input_params_override))
9897
if merged:
9998
os.environ["EP_INPUT_PARAMS_JSON"] = json.dumps(merged)
@@ -108,15 +107,14 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
108107
models: List[str] = ep_config.get("model") or []
109108
model_to_use = model or (models[0] if models else None)
110109
if not model_to_use:
111-
raise ValueError(
112-
f"No model provided and none captured from evaluation_test for benchmark '{name}'"
113-
)
110+
raise ValueError(f"No model provided and none captured from evaluation_test for benchmark '{name}'")
114111

115112
input_messages = ep_config.get("input_messages")
116113
input_dataset = ep_config.get("input_dataset")
117114
dataset_adapter = ep_config.get("dataset_adapter")
118115
rollout_input_params_list = ep_config.get("rollout_input_params")
119116
rollout_processor = ep_config.get("rollout_processor")
117+
rollout_processor_kwargs = ep_config.get("rollout_processor_kwargs")
120118
aggregation_method = ep_config.get("aggregation_method")
121119
threshold = ep_config.get("threshold_of_success")
122120
default_num_runs = ep_config.get("num_runs")
@@ -149,6 +147,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
149147
dataset_adapter=dataset_adapter,
150148
rollout_input_params=rollout_params,
151149
rollout_processor=rollout_processor,
150+
rollout_processor_kwargs=rollout_processor_kwargs,
152151
aggregation_method=aggregation_method,
153152
threshold_of_success=threshold,
154153
num_runs=(num_runs if num_runs is not None else default_num_runs),
@@ -170,5 +169,3 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
170169
return test_wrapper
171170

172171
return _decorator
173-
174-

eval_protocol/benchmarks/suites/tau_bench_retail.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
6969
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
7070
rollout_input_params=[{"temperature": 0.8, "extra_body": {"reasoning_effort": "medium"}}],
7171
rollout_processor=default_mcp_gym_rollout_processor,
72+
rollout_processor_kwargs={"domain": "retail"},
7273
num_runs=8,
7374
mode="pointwise",
7475
max_concurrent_rollouts=50,

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ class MCPServerManager:
2020
_active_servers = []
2121
_cleanup_registered = False
2222

23-
def __init__(self, server_script: str, port: int = 8000, domain: str = "airline"):
23+
def __init__(self, server_script: str, port: int = 8000, **kwargs):
2424
self.server_script = server_script
2525
self.port = port
26-
self.domain = domain
26+
self.domain = str(kwargs.get("domain", "airline"))
2727
self.process: Optional[subprocess.Popen] = None
2828
self.base_dir = Path(".").resolve()
2929
self._log_file = None
@@ -58,7 +58,7 @@ def start(self) -> None:
5858
env["PORT"] = str(self.port)
5959

6060
# Start server process (no domain argument needed for tau2_mcp server)
61-
cmd = ["python", self.server_script, "--port", str(self.port)]
61+
cmd = ["python", self.server_script, "--port", str(self.port), "--domain", self.domain]
6262

6363
# Setup log file with cleanup
6464
log_file_path = os.path.join(self.base_dir, f"server_output_{self.domain}_{self.port}.log")
@@ -213,7 +213,7 @@ async def default_mcp_gym_rollout_processor(
213213
"""
214214
if config.server_script_path is None:
215215
raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor")
216-
server = MCPServerManager(config.server_script_path, port=9700)
216+
server = MCPServerManager(config.server_script_path, port=9700, **(config.kwargs or {}))
217217

218218
try:
219219
server.start()

eval_protocol/pytest/evaluation_test.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RolloutInputParam,
3131
RolloutProcessor,
3232
RolloutProcessorConfig,
33+
RolloutProcessorInputParam,
3334
TestFunction,
3435
)
3536
from eval_protocol.pytest.utils import (
@@ -53,6 +54,7 @@ def evaluation_test( # noqa: C901
5354
rollout_input_params: Optional[List[RolloutInputParam]] = None,
5455
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
5556
evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None,
57+
rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None,
5658
aggregation_method: AggregationMethod = "mean",
5759
passed_threshold: Optional[Union[EvaluationThreshold, float]] = None,
5860
num_runs: int = 1,
@@ -114,6 +116,7 @@ def evaluation_test( # noqa: C901
114116
rollout_input_params: Generation parameters for the rollout.
115117
rollout_processor: Function used to perform the rollout.
116118
evaluation_test_kwargs: Kwargs for the evaluation function.
119+
rollout_processor_kwargs: Kwargs for the rollout processor.
117120
aggregation_method: How to aggregate scores across rows.
118121
passed_threshold: Threshold configuration for test success.
119122
Success rate must be above success, and if set, standard deviation must be below standard_deviation.
@@ -399,6 +402,7 @@ def _log_eval_error(
399402
server_script_path=server_script_path,
400403
steps=steps,
401404
logger=active_logger,
405+
kwargs=rollout_processor_kwargs,
402406
)
403407

404408
for i in range(num_runs):
@@ -765,6 +769,7 @@ def dual_mode_wrapper(*args, **kwargs):
765769
"rollout_input_params": rollout_input_params,
766770
"rollout_processor": rollout_processor,
767771
"evaluation_test_kwargs": evaluation_test_kwargs,
772+
"rollout_processor_kwargs": rollout_processor_kwargs,
768773
"aggregation_method": aggregation_method,
769774
"passed_threshold": passed_threshold,
770775
"num_runs": num_runs,
@@ -832,6 +837,7 @@ def run_evaluation_test_direct(
832837
dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter,
833838
rollout_input_params: Optional[RolloutInputParam] = None,
834839
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
840+
rollout_processor_kwargs: Optional[RolloutProcessorInputParam] = None,
835841
aggregation_method: AggregationMethod = "mean",
836842
threshold_of_success: Optional[float] = None,
837843
num_runs: int = 1,
@@ -941,6 +947,7 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
941947
max_concurrent_rollouts=max_concurrent_rollouts,
942948
server_script_path=server_script_path,
943949
steps=steps,
950+
kwargs=rollout_processor_kwargs,
944951
)
945952

946953
all_results: List[EvaluationRow] = []
@@ -1022,8 +1029,8 @@ def _deep_update_dict(base: dict, override: dict) -> dict:
10221029
if summary_path:
10231030
import json as _json
10241031
import pathlib as _pathlib
1025-
import time as _time
10261032
import re as _re
1033+
import time as _time
10271034

10281035
def _sanitize_filename(text: str) -> str:
10291036
safe = _re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip())
@@ -1039,7 +1046,11 @@ def _extract_effort_tag(params: dict) -> str | None:
10391046
return str(eb["reasoning"]["effort"]).lower()
10401047
if "reasoning_effort" in eb:
10411048
return str(eb["reasoning_effort"]).lower()
1042-
if "reasoning" in params and isinstance(params["reasoning"], dict) and "effort" in params["reasoning"]:
1049+
if (
1050+
"reasoning" in params
1051+
and isinstance(params["reasoning"], dict)
1052+
and "effort" in params["reasoning"]
1053+
):
10431054
return str(params["reasoning"]["effort"]).lower()
10441055
except Exception:
10451056
return None
@@ -1069,17 +1080,17 @@ def _extract_effort_tag(params: dict) -> str | None:
10691080
pass
10701081

10711082
if threshold_of_success is not None and not passed:
1072-
assert agg_score >= threshold_of_success, (
1073-
f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
1074-
)
1083+
assert (
1084+
agg_score >= threshold_of_success
1085+
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
10751086

10761087
return {"summary": summary_obj, "results": all_results}
10771088
except Exception:
10781089
# Mark errors on rows
10791090
if eval_metadata is not None:
10801091
eval_metadata.status = "error"
10811092
eval_metadata.passed = False
1082-
for r in (data or []):
1093+
for r in data or []:
10831094
if r.eval_metadata is not None:
10841095
r.eval_metadata.status = "error"
10851096
r.eval_metadata.passed = False

eval_protocol/pytest/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Parameter types
33
"""
44

5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
66
from typing import Any, Callable, Dict, List, Literal, Optional
77

88
from eval_protocol.dataset_logger import default_logger
@@ -15,6 +15,7 @@
1515
RolloutInputParam = Dict[str, Any]
1616
InputMessagesParam = List[Message]
1717
EvaluationInputParam = Dict[str, Any]
18+
RolloutProcessorInputParam = Dict[str, Any]
1819

1920
Dataset = List[EvaluationRow]
2021

@@ -49,6 +50,7 @@ class RolloutProcessorConfig:
4950
max_concurrent_rollouts: int = 8 # maximum number of concurrent rollouts
5051
steps: int = 30 # max number of rollout steps
5152
logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs
53+
kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor
5254

5355

5456
RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]]

examples/tau2_mcp/retail_environment/retail_environment.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@ class RetailEnvironment:
3030

3131
def __init__(self, config: Optional[Dict[str, Any]] = None):
3232
self.config = config or {}
33-
self.db = RetailDB.load(RETAIL_DB_PATH)
34-
self.retail_tools = RetailTools(self.db)
33+
self.db = None
34+
self.airline_tools = None
3535

3636
def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
3737
"""Reset the environment to initial state"""
38+
self.db = RetailDB.load(RETAIL_DB_PATH)
39+
self.retail_tools = RetailTools(self.db)
40+
3841
return {}, {}
3942

4043
def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:

0 commit comments

Comments
 (0)