Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def execute_rollouts(
steps: int = 512,
openai_format_log_file: Optional[str] = None,
max_concurrent_rollouts: int = 8,
evaluation_rows: Optional[List[EvaluationRow]] = None,
) -> List[EvaluationRow]:
"""
Execute general rollouts using tool calling interface with automatic record/playback.
Expand Down Expand Up @@ -135,9 +136,11 @@ async def _execute_with_semaphore(idx):
# Add note about control plane separation
logger.info(f"🎛️ Trajectories include control plane separation")

# Convert trajectories to unified EvaluationRow format
evaluation_rows = []
for trajectory in trajectories:
# Convert trajectories to unified EvaluationRow format. If no evaluation_rows are provided, create empty ones for backwards compatibility.
if evaluation_rows is None:
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in trajectories]

for idx, trajectory in enumerate(trajectories):
# Handle multimodal content by extracting text from complex content structures
messages = []
for msg in trajectory.conversation_history:
Expand All @@ -155,26 +158,15 @@ async def _execute_with_semaphore(idx):

messages.append(Message.model_validate(msg_dict))

input_metadata = InputMetadata(
row_id=trajectory.session.dataset_row.id if trajectory.session.dataset_row else None,
dataset_info=asdict(trajectory.session.dataset_row) if trajectory.session.dataset_row else {},
completion_params=CompletionParams(
model=policy.model_id,
temperature=getattr(policy, "temperature", None),
max_tokens=getattr(policy, "max_tokens", None),
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
),
session_data={
"timestamp": time.time(),
},
)
evaluation_row = EvaluationRow(
messages=messages,
tools=shared_tool_schema,
input_metadata=input_metadata,
usage=trajectory.usage,
evaluation_rows[idx].messages = messages
evaluation_rows[idx].tools = shared_tool_schema
evaluation_rows[idx].usage = trajectory.usage
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
model=policy.model_id,
temperature=getattr(policy, "temperature", None),
max_tokens=getattr(policy, "max_tokens", None),
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
)
evaluation_rows.append(evaluation_row)

return evaluation_rows

Expand Down
11 changes: 11 additions & 0 deletions eval_protocol/mcp/execution/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(
self.num_retries = num_retries
self.retry_strategy = retry_strategy

# Store additional API parameters from kwargs
self.additional_params = kwargs

# Only initialize LiteLLM in live mode (not in playback mode)
if not self._is_playback:
self._setup_litellm_caching(use_caching, cache_type, redis_url)
Expand Down Expand Up @@ -166,6 +169,14 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
"base_url": self.base_url,
}

# Add additional parameters from kwargs (like reasoning_effort)
if self.additional_params:
request_params.update(self.additional_params)

# Tell LiteLLM to allow reasoning_effort if it's present
if "reasoning_effort" in self.additional_params:
request_params["allowed_openai_params"] = ["reasoning_effort"]

# Add tools if provided
if tools:
request_params["tools"] = tools
Expand Down
7 changes: 4 additions & 3 deletions eval_protocol/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,19 @@
- Resources provide static/configuration data, tools provide dynamic actions
"""

import asyncio

# For legacy compatibility - import the facade functions
import logging
import random
from typing import Any, Callable, Dict, List, Optional, Union

# Import all functionality from the new modular components
from .mcp.execution.manager import ExecutionManager
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LLMBasePolicy, OpenAIPolicy, LiteLLMPolicy
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LiteLLMPolicy, LLMBasePolicy, OpenAIPolicy
from .mcp.session.manager import GeneralMCPVectorEnv
from .models import EvaluationRow
from .types import DatasetRow, MCPSession, MCPToolCall
import asyncio

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -288,7 +289,7 @@ async def rollout(
execution_manager = ExecutionManager()

return await execution_manager.execute_rollouts(
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
)


Expand Down
17 changes: 12 additions & 5 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import asyncio
import atexit
import os
import signal
import socket
import subprocess
import time
import socket
from pathlib import Path
from typing import List, Optional

import eval_protocol as ep
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import RolloutProcessorConfig

import atexit
import signal


class MCPServerManager:
"""Manages MCP server lifecycle for testing."""
Expand Down Expand Up @@ -188,13 +187,16 @@ async def default_mcp_gym_rollout_processor(
"""
Rollout processor for tau bench environments.


This processor starts an MCP server, creates tau bench environments, and runs rollouts
using the eval_protocol framework, following the pattern from test_tau2_e2e.py.


Args:
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
config: RolloutProcessorConfig with model and other parameters


Returns:
List of EvaluationRow objects with completed conversations
"""
Expand All @@ -207,6 +209,7 @@ async def default_mcp_gym_rollout_processor(
model_id=config.model,
temperature=config.input_params.get("temperature", 0.0),
max_tokens=config.input_params.get("max_tokens", 4096),
reasoning_effort=config.input_params.get("reasoning_effort", None),
)

# Create MCP environments directly from evaluation_rows
Expand All @@ -218,7 +221,11 @@ async def default_mcp_gym_rollout_processor(

# Run rollout with environments and policy
evaluation_rows = await ep.rollout(
envs, policy=policy, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts
envs,
policy=policy,
evaluation_rows=rows,
steps=config.steps,
max_concurrent_rollouts=config.max_concurrent_rollouts,
)

return evaluation_rows
Expand Down
43 changes: 20 additions & 23 deletions tests/pytest/test_tau_bench_airline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from pathlib import Path
from typing import Any, Dict, List

from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams
from eval_protocol.models import CompletionParams, EvaluateResult, EvaluationRow, InputMetadata, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor

from vendor.tau2.data_model.message import (
AssistantMessage,
SystemMessage,
Expand All @@ -28,20 +27,21 @@
from vendor.tau2.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator
from vendor.tau2.registry import registry


def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
"""
Convert entries from airline dataset to EvaluationRow objects.
"""
rows = []
test_dir = Path(__file__).parent.parent.parent / "examples" / "tau2_mcp" / "tests"

# Load system prompt from file so we can change it in one place
domain = data[0]["environment_context"]["domain"]
prompt_file = test_dir / f"system_prompts/{domain}_agent_system_prompt.md"

with open(prompt_file, "r") as f:
system_prompt = f.read().strip()

for row in data:
eval_row = EvaluationRow(
messages=[Message(role="system", content=system_prompt)],
Expand All @@ -52,47 +52,46 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval
"user_simulation": row["user_simulation"],
"evaluation_criteria": row["evaluation_criteria"],
"user_prompt_template": row["user_prompt_template"],
}
},
),
)

rows.append(eval_row)

return rows


@evaluation_test(
input_dataset=["tests/pytest/data/airline_dataset.jsonl"],
dataset_adapter=tau_bench_airline_to_evaluation_row,
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
rollout_input_params=[{"temperature": 0.8, "max_tokens": 4096, "reasoning_effort": "high"}],
rollout_processor=default_mcp_gym_rollout_processor,
threshold_of_success=0.4,
num_runs=1,
mode="pointwise",
max_concurrent_rollouts=32,
max_concurrent_rollouts=16,
server_script_path="examples/tau2_mcp/server.py",
)
def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
"""
Test tau bench airline evaluation using the pytest framework.

This test now uses the tau_bench_airline_reward function which automatically
extracts evaluation criteria from dataset entries. No wrapper needed!

Args:
input_dataset: List of EvaluationRow objects from tau bench airline dataset
input_params: Model parameters (temperature, max_tokens, etc.)
model: Model identifier

row: EvaluationRow object from tau bench airline dataset after rollout

Returns:
List of evaluated EvaluationRow objects with scores and feedback
EvaluationRow with tau2 evaluation results
"""
messages = row.messages

# Get evaluation criteria and user_simulation from input_metadata.dataset_info
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
evaluation_criteria = dataset_info.get("evaluation_criteria", {})

nl_assertions = evaluation_criteria.get("nl_assertions", [])
communicate_info = evaluation_criteria.get("communicate_info", [])
actions = evaluation_criteria.get("actions", [])
Expand Down Expand Up @@ -131,10 +130,8 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
communicate_info=communicate_info,
actions=actions,
reward_basis=[
RewardType.NL_ASSERTION,
RewardType.DB,
RewardType.COMMUNICATE,
RewardType.ACTION,
],
)

Expand Down Expand Up @@ -230,4 +227,4 @@ def test_tau_bench_airline_evaluation(row: EvaluationRow) -> EvaluationRow:
reason=reason,
metrics={},
)
return row
return row
1 change: 1 addition & 0 deletions tests/test_rollout_control_plane_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ async def test_rollout_creates_envs_from_url(self):
5,
None,
8,
None,
)

assert result == ["ok"]
Expand Down
6 changes: 2 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading