Skip to content
143 changes: 62 additions & 81 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import threading
import time
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union

import anyio
from openai.types import CompletionUsage
Expand Down Expand Up @@ -43,7 +43,7 @@ async def execute_rollouts(
openai_format_log_file: Optional[str] = None,
max_concurrent_rollouts: int = 8,
evaluation_rows: Optional[List[EvaluationRow]] = None,
) -> List[EvaluationRow]:
) -> AsyncIterator[EvaluationRow]:
"""
Execute general rollouts using tool calling interface with automatic record/playback.

Expand All @@ -66,7 +66,7 @@ async def execute_rollouts(
- Set and file exists: Playback mode (uses recorded data)

Returns:
List of EvaluationRow objects with unified evaluation data format
AsyncIterator of EvaluationRow objects with unified evaluation data format
"""
start_time = time.time()

Expand All @@ -92,96 +92,77 @@ async def execute_rollouts(

logger.info(f"🧵 Starting {envs.n} rollouts with max {max_concurrent_rollouts} concurrent threads...")

results = {}
if evaluation_rows is None:
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in range(envs.n)]

shared_tool_schema = envs.tool_schemas

semaphore = asyncio.Semaphore(max_concurrent_rollouts)

async def _execute_with_semaphore(idx):
async with semaphore:
result = await self._execute_rollout(
trajectory = await self._execute_rollout(
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
)

return result

tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
# exceptions will be try catched inside single _execute_rollout
trajectories = await asyncio.gather(*tasks)

# Calculate durations
total_duration = time.time() - start_time
for trajectory in trajectories:
trajectory.duration = total_duration

shared_tool_schema = envs.tool_schemas

# Enhanced reporting with control plane info
successful = sum(1 for traj in trajectories if traj.total_reward > 0)
terminated_by_control_plane = sum(
1
for traj in trajectories
if traj.control_plane_summary.get("termination_reason") == "control_plane_signal"
)
# Convert trajectory to EvaluationRow immediately
evaluation_row = evaluation_rows[idx]

# Handle multimodal content by extracting text from complex content structures
messages = []
for msg in trajectory.conversation_history:
# Create a copy to avoid modifying the original
msg_dict = dict(msg)

# Handle multimodal content (list of content blocks) by extracting text
if isinstance(msg_dict.get("content"), list):
text_content = None
for content_block in msg_dict["content"]:
if isinstance(content_block, dict) and content_block.get("type") == "text":
text_content = content_block.get("text")
break
msg_dict["content"] = text_content or ""

messages.append(Message.model_validate(msg_dict))

evaluation_row.messages = messages
evaluation_row.tools = shared_tool_schema
evaluation_row.usage = CompletionUsage(**trajectory.usage)
evaluation_row.input_metadata.completion_params = CompletionParams(
model=policy.model_id,
temperature=getattr(policy, "temperature", None),
max_tokens=getattr(policy, "max_tokens", None),
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
)

logger.info(f"📊 Rollout complete: {successful}/{len(trajectories)} reached goal")
logger.info(f"🎛️ Control plane terminations: {terminated_by_control_plane}/{len(trajectories)}")
logger.info(f"⏱️ Total duration: {total_duration:.2f}s")
logger.info(f"🧵 Used {max_concurrent_rollouts} concurrent threads")
if trajectory.terminated:
if trajectory.termination_reason == TerminationReason.ERROR:
evaluation_row.rollout_status.status = "error"
evaluation_row.rollout_status.error_message = trajectory.control_plane_summary.get(
"error_message", None
)
else:
evaluation_row.rollout_status.status = "finished"
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
else:
evaluation_row.rollout_status.status = "running"

# Print log file locations if created
if openai_format_log_file:
logger.info(f"💬 OpenAI format log: {openai_format_log_file}")
if recording_mode:
logger.info(f"📝 Recorded trajectory: {playback_file}")
# Add note about control plane separation
logger.info(f"🎛️ Trajectories include control plane separation")
return evaluation_row

# Convert trajectories to unified EvaluationRow format. If no evaluation_rows are provided, create empty ones for backwards compatibility.
if evaluation_rows is None:
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in trajectories]

for idx, trajectory in enumerate(trajectories):
# Handle multimodal content by extracting text from complex content structures
messages = []
for msg in trajectory.conversation_history:
# Create a copy to avoid modifying the original
msg_dict = dict(msg)

# Handle multimodal content (list of content blocks) by extracting text
if isinstance(msg_dict.get("content"), list):
text_content = None
for content_block in msg_dict["content"]:
if isinstance(content_block, dict) and content_block.get("type") == "text":
text_content = content_block.get("text")
break
msg_dict["content"] = text_content or ""

messages.append(Message.model_validate(msg_dict))

evaluation_rows[idx].messages = messages
# evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
# evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
evaluation_rows[idx].tools = shared_tool_schema
evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage)
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
model=policy.model_id,
temperature=getattr(policy, "temperature", None),
max_tokens=getattr(policy, "max_tokens", None),
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
)
if trajectory.terminated:
if trajectory.termination_reason == TerminationReason.ERROR:
evaluation_rows[idx].rollout_status.status = "error"
evaluation_rows[idx].rollout_status.termination_reason = trajectory.control_plane_summary.get(
"error_message", None
)
else:
evaluation_rows[idx].rollout_status.status = "finished"
evaluation_rows[idx].rollout_status.termination_reason = trajectory.termination_reason
else:
evaluation_rows[idx].rollout_status.status = "running"
# Create all tasks
tasks = [asyncio.create_task(_execute_with_semaphore(i)) for i in range(envs.n)]

return evaluation_rows
# Yield results as they complete (note that they're not necessarily in original order)
try:
for task in asyncio.as_completed(tasks):
try:
yield await task
except Exception:
logger.exception("Error processing rollout")
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)

async def _execute_rollout(
self,
Expand Down
14 changes: 7 additions & 7 deletions eval_protocol/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,20 @@
"""

import asyncio
import hashlib
import json

# For legacy compatibility - import the facade functions
import logging
import random
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union

# Import all functionality from the new modular components
from .mcp.execution.manager import ExecutionManager
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LiteLLMPolicy, LLMBasePolicy, OpenAIPolicy
from .mcp.session.manager import GeneralMCPVectorEnv
from .models import EvaluationRow
from .types import DatasetRow, MCPSession, MCPToolCall
import asyncio
import hashlib
import json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -247,7 +246,7 @@ async def rollout(
steps: int = 512,
openai_format_log_file: Optional[str] = None,
max_concurrent_rollouts: int = 8,
) -> List[EvaluationRow]:
) -> AsyncIterator[EvaluationRow]:
"""
Execute general rollouts using tool calling interface with automatic record/playback.

Expand Down Expand Up @@ -307,9 +306,10 @@ async def rollout(
# Use the new ExecutionManager for execution
execution_manager = ExecutionManager()

return await execution_manager.execute_rollouts(
async for evaluation_row in execution_manager.execute_rollouts(
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
)
):
yield evaluation_row


async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
Expand Down
4 changes: 3 additions & 1 deletion eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from .default_agent_rollout_processor import default_agent_rollout_processor
from .default_dataset_adapter import default_dataset_adapter
from .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
from .default_no_op_rollout_process import default_no_op_rollout_processor
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
from .evaluation_test import evaluation_test
from .types import RolloutProcessor, RolloutProcessorConfig
from .default_dataset_adapter import default_dataset_adapter

__all__ = [
"default_agent_rollout_processor",
"default_mcp_gym_rollout_processor",
"default_no_op_rollout_processor",
"default_single_turn_rollout_processor",
"default_dataset_adapter",
Expand Down
52 changes: 42 additions & 10 deletions eval_protocol/pytest/default_agent_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import json
import logging
import os
from typing import Any, List, Optional, Union
from typing import Any, AsyncIterator, List, Optional, Union

from mcp.types import CallToolResult, TextContent
from openai import NOT_GIVEN, NotGiven
Expand All @@ -14,6 +15,8 @@
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig

logger = logging.getLogger(__name__)


class Agent:
"""
Expand Down Expand Up @@ -114,13 +117,42 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex

async def default_agent_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[EvaluationRow]:
dataset: Dataset = []
for row in rows:
) -> AsyncIterator[EvaluationRow]:
"""Process agent rollouts with bounded concurrency and yield as they complete."""

max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
semaphore = asyncio.Semaphore(max_concurrent)

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
await agent.setup()
await agent.call_agent()
dataset.append(agent.evaluation_row)
if agent.mcp_client:
await agent.mcp_client.cleanup()
return dataset
try:
await agent.setup()
await agent.call_agent()
return agent.evaluation_row
finally:
if agent.mcp_client:
await agent.mcp_client.cleanup()

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
try:
return await process_row(r)
except Exception as e:
logger.exception(f"Error processing row {r.input_metadata.row_id}: {e}")
return r

# Create all tasks
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]

# Yield results as they complete (note that they're not necessarily in original order)
try:
for task in asyncio.as_completed(tasks):
try:
yield await task
except Exception:
logger.exception("Error processing row")
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
18 changes: 7 additions & 11 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
import time
from pathlib import Path
from typing import List, Optional
from typing import AsyncIterator, List, Optional

import eval_protocol as ep
from eval_protocol.models import EvaluationRow, Message
Expand Down Expand Up @@ -194,22 +194,19 @@ def __exit__(self, exc_type, exc_val, exc_tb):

async def default_mcp_gym_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[EvaluationRow]:
) -> AsyncIterator[EvaluationRow]:
"""
Rollout processor for tau bench environments.


This processor starts an MCP server, creates tau bench environments, and runs rollouts
using the eval_protocol framework, following the pattern from test_tau2_e2e.py.

using the eval_protocol framework, yielding results as they complete.

Args:
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
config: RolloutProcessorConfig with model and other parameters


Returns:
List of EvaluationRow objects with completed conversations
AsyncIterator of EvaluationRow objects with completed conversations
"""
if config.server_script_path is None:
raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor")
Expand All @@ -233,15 +230,14 @@ async def default_mcp_gym_rollout_processor(
)

# Run rollout with environments and policy
evaluation_rows = await ep.rollout(
async for evaluation_row in ep.rollout(
envs,
policy=policy,
evaluation_rows=rows,
steps=config.steps,
max_concurrent_rollouts=config.max_concurrent_rollouts,
)

return evaluation_rows
):
yield evaluation_row

finally:
# Always clean up the server
Expand Down
9 changes: 6 additions & 3 deletions eval_protocol/pytest/default_no_op_rollout_process.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import List
from typing import AsyncIterator, List

from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.types import RolloutProcessorConfig


def default_no_op_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]:
async def default_no_op_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> AsyncIterator[EvaluationRow]:
"""
Simply passes input dataset through to the test function. This can be useful
if you want to run the rollout yourself.
"""
return rows
for row in rows:
yield row
Loading
Loading