Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion eval_protocol/mcp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def initialize_session(self, session: MCPSession) -> None:

# Update the session ID to match what the server generated
session.session_id = server_session_id
logger.debug(f"Updated session ID to match server: {server_session_id}")
logger.info(f"Updated session ID to match server: {server_session_id}")

# PRE-WARM: Discover and cache tools immediately after session initialization
# This prevents concurrent list_tools() calls later
Expand Down Expand Up @@ -133,6 +133,24 @@ async def _prewarm_tools_cache(self, session: MCPSession) -> None:
self._tools_cache[cache_key] = tool_schemas
logger.debug(f"✅ PRE-WARMED {len(tool_schemas)} tools for{cache_key}")

async def reset_session(self, session: MCPSession) -> None:
"""
Clean session data in remote mcp server for the given session
"""
import httpx

base_url = session.base_url.rstrip("/").removesuffix("/mcp")
url = f"{base_url}/control/reset_session"

headers = {"mcp-session-id": session.session_id}
body = {"seed": session.seed}

timeout = httpx.Timeout(3.0)
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(url, headers=headers, json=body)
resp.raise_for_status()
logger.debug(f"Session {session.session_id}: reset_session -> {resp.json()}")

async def discover_tools(self, session: MCPSession) -> List[Dict]:
"""
Discover available tools from an MCP session.
Expand Down
41 changes: 3 additions & 38 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from ...models import CompletionParams, EvaluationRow, InputMetadata, Message
from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory
from ..client.connection import MCPConnectionManager

if TYPE_CHECKING:
from ..session.manager import GeneralMCPVectorEnv
Expand All @@ -33,43 +32,9 @@

class ExecutionManager:
"""
Unified manager that handles both MCP session lifecycle and rollout execution.

Combines the functionality of SessionManager and RolloutManager for better
organization and reduced complexity.
Manage rollout for MCP environments.
"""

def __init__(self):
"""Initialize the execution manager."""
self.connection_manager = MCPConnectionManager()

async def initialize_sessions(self, sessions: List[MCPSession]) -> None:
"""
Initialize multiple MCP sessions in parallel.

Args:
sessions: List of MCPSessions to initialize
"""
tasks = [self.connection_manager.initialize_session(session) for session in sessions]
await asyncio.gather(*tasks)

async def close_sessions(self, sessions: List[MCPSession]) -> None:
"""
Close multiple MCP sessions in parallel.

Args:
sessions: List of MCPSessions to close
"""
tasks = [asyncio.create_task(self.connection_manager.close_session(session)) for session in sessions]

if tasks:
try:
# Wait for all close operations to complete
await asyncio.gather(*tasks, return_exceptions=True)
except asyncio.CancelledError:
# Handle cancellation gracefully (especially important for Python 3.12)
logger.debug("Close operation was cancelled, but sessions are marked as closed")

async def execute_rollouts(
self,
envs: "GeneralMCPVectorEnv",
Expand Down Expand Up @@ -178,7 +143,7 @@ async def _execute_with_semaphore(idx):
for msg in trajectory.conversation_history:
# Create a copy to avoid modifying the original
msg_dict = dict(msg)

# Handle multimodal content (list of content blocks) by extracting text
if isinstance(msg_dict.get("content"), list):
text_content = None
Expand All @@ -187,7 +152,7 @@ async def _execute_with_semaphore(idx):
text_content = content_block.get("text")
break
msg_dict["content"] = text_content or ""

messages.append(Message.model_validate(msg_dict))

input_metadata = InputMetadata(
Expand Down
27 changes: 25 additions & 2 deletions eval_protocol/mcp/mcpgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional
# Register tools and control plane endpoints
self._register_tools()
self._discover_and_register_control_plane_endpoints()
self._register_session_reset_endpoint()

def _get_session_id(self, ctx: Context) -> str:
"""
Expand Down Expand Up @@ -227,6 +228,28 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:

return self.sessions[session_id]

def _register_session_reset_endpoint(self):

@self.mcp.custom_route("/control/reset_session", methods=["POST"])
async def reset_session_endpoint(request: Request) -> JSONResponse:
session_id = request.headers.get("mcp-session-id")
body = await request.json()
seed = body.get("seed", None)
print(f"🔍 _register_session_reset_endpoint: Resetting session, session_id: {session_id}, seed: {seed}")
if not session_id:
return JSONResponse({"error": "Missing mcp-session-id header"}, status_code=400)
with self.session_lock:
if session_id in self.sessions:
env, obs, _ = self._new_env(seed=seed)
self.sessions[session_id] = {
"env": env,
"obs": obs,
"session_data": {},
"session_id": session_id,
}
print(f"🔍 _register_session_reset_endpoint: Finished reset session, session_id: {session_id}")
return JSONResponse({"message": "Session reset successfully"})

def _discover_and_register_control_plane_endpoints(self):
"""
Discover and register control plane endpoints on the subclass instance.
Expand Down Expand Up @@ -323,7 +346,7 @@ def _update_control_plane(self, reward: float, terminated: bool, truncated: bool

# Log control plane update (for debugging)
print(
f"🎛️ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}"
f"🎛️ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}, total_reward={self.control_plane_state['total_reward']}"
)

def _get_or_create_session_control_plane(self, session_id: str) -> Dict[str, Any]:
Expand Down Expand Up @@ -365,7 +388,7 @@ def _update_session_control_plane(

# Log control plane update
print(
f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}"
f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}, total_reward={control_plane['total_reward']}"
)

def get_control_plane_state(self, session_id: str) -> Optional[Dict[str, Any]]:
Expand Down
16 changes: 7 additions & 9 deletions eval_protocol/mcp/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from ...types import DatasetRow, MCPSession, MCPToolCall
from ..execution.manager import ExecutionManager
from ..client.connection import MCPConnectionManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
self.user_prompt_formatter = user_prompt_formatter or self._default_formatter
self.n = len(sessions)
self.tool_schemas = [] # Discovered from MCP servers
self.execution_manager = ExecutionManager()
self.connection_manager = MCPConnectionManager()
self.usage_stats = {} # llm usage stats for monitoring

if len(sessions) != len(dataset_rows):
Expand All @@ -58,17 +58,14 @@ async def reset(self, session: MCPSession) -> Tuple[Any, List[Dict]]:

This is thread-safe and can be called from worker threads.
"""
# Establish a persistent session for each environment.
await self.execution_manager.connection_manager.initialize_session(session)

# Get available tools from MCP server
tool_schemas = await self.execution_manager.connection_manager.discover_tools(session)
tool_schemas = await self.connection_manager.discover_tools(session)

if not self.tool_schemas:
self.tool_schemas = tool_schemas

# PROPER MCP PATTERN: Get initial state from resources during session establishment
initial_observation = await self.execution_manager.connection_manager.get_initial_state(session)
initial_observation = await self.connection_manager.get_initial_state(session)

# Update session state
session.terminated = False
Expand Down Expand Up @@ -119,7 +116,7 @@ async def step(self, env_index: int, tool_call: MCPToolCall) -> Tuple[Any, float
)

# Execute the tool call via MCP protocol
observation, reward, done, info = await self.execution_manager.connection_manager.call_tool(
observation, reward, done, info = await self.connection_manager.call_tool(
session, tool_call.tool_name, tool_call.arguments
)

Expand Down Expand Up @@ -223,5 +220,6 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st
async def close(self):
"""Closes all MCP sessions."""
print(f"🧹 Closing {self.n} MCP sessions...")
await self.execution_manager.close_sessions(self.sessions)
tasks = [self.connection_manager.close_session(session) for session in self.sessions]
await asyncio.gather(*tasks)
print(f"✅ All MCP sessions closed.")
34 changes: 25 additions & 9 deletions eval_protocol/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b")

# Create environments with evaluation_rows configuration
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)

# Execute tool-calling rollouts
evaluation_rows = await ep.rollout(envs, policy=policy, steps=512)
Expand Down Expand Up @@ -51,18 +51,28 @@
from .mcp.session.manager import GeneralMCPVectorEnv
from .models import EvaluationRow
from .types import DatasetRow, MCPSession, MCPToolCall
import asyncio

logger = logging.getLogger(__name__)


def make(
async def reset_mcp_sessions(envs: GeneralMCPVectorEnv):
"""
Reset mcp server sessions
"""
tasks = [envs.connection_manager.reset_session(session) for session in envs.sessions]
await asyncio.gather(*tasks)


async def make(
env_spec: str,
evaluation_rows: Optional[List[EvaluationRow]] = None,
dataset: Optional[List[Dict]] = None,
n: Optional[int] = None,
seeds: Optional[List[int]] = None,
model_id: str = "unknown",
user_prompt_formatter: Optional[Callable] = None,
reset_sessions: bool = False,
) -> GeneralMCPVectorEnv:
"""
Create general MCP environments driven by evaluation_rows configuration.
Expand All @@ -75,19 +85,20 @@ def make(
seeds: List of seeds (for backward compatibility)
model_id: Model identifier
user_prompt_formatter: Optional callback for formatting user prompts
reset_sessions: Whether to reset sessions before returning the environment

Returns:
General MCP environment that works with any MCP server

Example:
# EvaluationRow approach (preferred)
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)

# Dataset approach (backward compatibility)
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
envs = await ep.make("http://localhost:8000/mcp", dataset=dataset)

# Legacy approach (backward compatibility)
envs = ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
envs = await ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
"""
# Parse environment specification - make sure URL format is correct
base_url = env_spec
Expand Down Expand Up @@ -160,8 +171,6 @@ def make(
)
sessions.append(session)

return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)

else:
# Legacy approach for backward compatibility
if n is None:
Expand Down Expand Up @@ -198,7 +207,14 @@ def make(
)
sessions.append(session)

return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
mcp_envs = GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
tasks = [mcp_envs.connection_manager.initialize_session(session) for session in sessions]
await asyncio.gather(*tasks)

if reset_sessions:
await reset_mcp_sessions(mcp_envs)

return mcp_envs


async def rollout(
Expand Down Expand Up @@ -266,7 +282,7 @@ async def rollout(
raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL")

auto_model_id = model_id or getattr(policy, "model_id", "unknown")
envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)
envs = await make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)

# Use the new ExecutionManager for execution
execution_manager = ExecutionManager()
Expand Down
36 changes: 17 additions & 19 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,49 +182,47 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False # Don't suppress exceptions



async def default_mcp_gym_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]:
async def default_mcp_gym_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[EvaluationRow]:
"""
Rollout processor for tau bench environments.

This processor starts an MCP server, creates tau bench environments, and runs rollouts
using the eval_protocol framework, following the pattern from test_tau2_e2e.py.

Args:
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
config: RolloutProcessorConfig with model and other parameters

Returns:
List of EvaluationRow objects with completed conversations
"""
server = MCPServerManager(config.server_script_path, port=9700)

try:
server.start()

policy = ep.LiteLLMPolicy(
model_id=config.model,
temperature=config.input_params.get('temperature', 0.0),
max_tokens=config.input_params.get('max_tokens', 4096),
temperature=config.input_params.get("temperature", 0.0),
max_tokens=config.input_params.get("max_tokens", 4096),
)

# Create MCP environments directly from evaluation_rows
envs = ep.make(
'http://localhost:9700/mcp/',
envs = await ep.make(
"http://localhost:9700/mcp/",
evaluation_rows=rows,
model_id=policy.model_id,
)

# Run rollout with environments and policy
evaluation_rows = await ep.rollout(
envs,
policy=policy,
steps=config.steps,
max_concurrent_rollouts=config.max_concurrent_rollouts
envs, policy=policy, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts
)

return evaluation_rows

finally:
# Always clean up the server
server.stop()
6 changes: 4 additions & 2 deletions eval_protocol/types/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
from mcp.client.session import ClientSession
from contextlib import AsyncExitStack


class TerminationReason(str, Enum):
Expand Down Expand Up @@ -50,8 +52,8 @@ class MCPSession:
last_observation: Any = None

# Persistent MCP connection components
_exit_stack: Optional[Any] = None
_mcp_session: Optional[Any] = None
_exit_stack: Optional[AsyncExitStack] = None
_mcp_session: Optional[ClientSession] = None


@dataclass
Expand Down
Loading
Loading