Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions eval_protocol/mcp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
import hashlib
import json
import logging
import time
from contextlib import AsyncExitStack
from typing import Any, Dict, List, Optional, Tuple

import httpx
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import Implementation

from ...types import MCPSession
from mcp.types import Implementation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,15 +111,13 @@ async def reset_session(self, session: MCPSession) -> None:
"""
Clean session data in remote mcp server for the given session
"""
import httpx

base_url = session.base_url.rstrip("/").removesuffix("/mcp")
url = f"{base_url}/control/reset_session"

headers = {"mcp-session-id": session.session_id}
body = {"seed": session.seed}

timeout = httpx.Timeout(3.0)
timeout = httpx.Timeout(15.0)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mayinghan we should come up with a better solution to this timeout. for complex environments like tau, can definitely take a long time. e.g. it takes ~12 seconds to reset all the environments for airline (loads a large json, and we're doing it on a thread pool, so not truely concurrent)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this sleep for? I thought we fixed it with health check?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it’s not a sleep but a timeout, so if the env reset takes more than 15s, it’ll time out. this reset is called on cleanup when the rollouts end.
can we just remove the timeout amount since it’s possible for env reset to take more than 15s or is that dangerous?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can also consider delete that session completely from the mcp server? but then the server will never be able to persistent any state after one single run

Copy link
Copy Markdown
Contributor Author

@xzrderek xzrderek Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but then the server will never be able to persistent any state after one single run

i don't quite get what this means. i believe you added reset_session recently, and it's triggered at the end of the rollout. so aren't we already not persisting state after a run?

regardless, i'm gonna merge in first and we can talk more later. i'm just calling out that the 15s timeout is likely not a viable long term solution, but it's fine for now.

async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(url, headers=headers, json=body)
resp.raise_for_status()
Expand Down Expand Up @@ -202,8 +202,6 @@ async def get_initial_state(self, session: MCPSession) -> Any:
initial_observation = None

try:
import httpx

# Extract base URL and session ID from the MCP session
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
session_id = session.session_id
Expand Down Expand Up @@ -459,9 +457,6 @@ async def call_tool(self, session: MCPSession, tool_name: str, arguments: Dict)
control_plane_info = {}

try:
# Query control plane endpoints following the new architecture
import httpx

# Extract base URL and session ID from the MCP session
base_url = session.base_url.rstrip("/").removesuffix("/mcp")
# Use the session ID from the established MCP session
Expand Down
7 changes: 3 additions & 4 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from openai.types import CompletionUsage
Expand Down Expand Up @@ -248,7 +247,7 @@ async def _execute_rollout(

# Get initial messages in tau2-bench format for user simulator
user_simulator_state = user_simulator.get_init_state()
user_message, user_simulator_state = user_simulator.generate_next_message(
user_message, user_simulator_state = await user_simulator.generate_next_message(
AssistantMessage(role="assistant", content="Hi! How can I help you today?"),
user_simulator_state,
)
Expand Down Expand Up @@ -280,7 +279,7 @@ async def _execute_rollout(
# Last message was agent, simulated user response
if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage):
# Generate user response using the simulator
user_message, user_simulator_state = user_simulator.generate_next_message(
user_message, user_simulator_state = await user_simulator.generate_next_message(
user_simulator_messages[-1], user_simulator_state
)
user_content = user_message.content if user_message.content else ""
Expand Down
162 changes: 61 additions & 101 deletions eval_protocol/mcp/mcpgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
"""

import asyncio
import dataclasses
import hashlib
import inspect
import json
import logging
import os
import threading
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from datetime import date, datetime
from enum import Enum
from typing import Any, Callable, Dict, Optional, Tuple

import uvicorn
from mcp.server.fastmcp import Context, FastMCP
from pydantic import BaseModel
from starlette.requests import Request
from starlette.responses import JSONResponse
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
Expand Down Expand Up @@ -75,14 +80,23 @@ class McpGym(ABC):
- Environment Implementation: Single-process MCP server per environment
"""

def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional[int] = None):
def __init__(
self,
server_name: str,
adapter: EnvironmentAdapter,
seed: Optional[int] = None,
max_workers: Optional[int] = None,
):
"""
Initialize the MCP-Gym environment.

Args:
server_name: Name for the MCP server
adapter: Environment adapter instance
seed: Optional seed for reproducible environments
max_workers: Optional maximum number of worker threads for ThreadPoolExecutor.
If None, uses ThreadPoolExecutor default (min(32, (os.cpu_count() or 1) + 4))

"""
self.adapter = adapter

Expand Down Expand Up @@ -110,6 +124,8 @@ def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional
"total_reward": 0.0,
}

self.pool = ThreadPoolExecutor(max_workers=max_workers)

# Reset with seed if provided
self.env, self.obs, _info = self._new_env(seed=seed)

Expand Down Expand Up @@ -189,49 +205,7 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
"""
session_id = self._get_session_id(ctx)
print(f"🔍 _get_or_create_session: session_id: {session_id}")

with self.session_lock:
if session_id not in self.sessions:
print(f"🔍 _get_or_create_session: Creating new session for {session_id}")
# Extract seed from context using proper FastMCP pattern
seed = None
config = self._get_default_config()
print(f"🔍 _get_or_create_session: default_config: {config}")

if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
client_params = ctx.session.client_params
if hasattr(client_params, "clientInfo"):
client_info = client_params.clientInfo
if client_info and hasattr(client_info, "_extra"):
extra_data = client_info._extra
print(f"🔍 _get_or_create_session: extra_data in session creation: {extra_data}")
if extra_data and isinstance(extra_data, dict):
# Extract seed from client info
seed = extra_data.get("seed")
print(f"🌱 Extracted seed from client_info: {seed} (type: {type(seed)})")
# Update config with any additional options
if "config" in extra_data:
config.update(extra_data["config"])
print(f"🔍 _get_or_create_session: updated config: {config}")

print(f"🔍 _get_or_create_session: About to create environment with seed: {seed}")

env, obs, info = self._new_env(seed=seed)
print(f"🔍 _get_or_create_session: environment created with obs: {obs}, info: {info}")

# Initialize session state
self.sessions[session_id] = {
"env": env,
"obs": obs,
"session_data": {}, # Subclasses can store additional data here
"session_id": session_id,
}

print(f"🎮 Created new session {session_id[:16]}... with seed {seed}, initial obs: {obs}")
else:
print(f"🔍 _get_or_create_session: Returning existing session {session_id}")

return self.sessions[session_id]
return self.sessions[session_id]

def _register_session_reset_endpoint(self):

Expand All @@ -243,16 +217,17 @@ async def reset_session_endpoint(request: Request) -> JSONResponse:
print(f"🔍 _register_session_reset_endpoint: Resetting session, session_id: {session_id}, seed: {seed}")
if not session_id:
return JSONResponse({"error": "Missing mcp-session-id header"}, status_code=400)
with self.session_lock:
if session_id in self.sessions:
env, obs, _ = self._new_env(seed=seed)
if session_id in self.sessions:
loop = asyncio.get_running_loop()
env, obs, info = await loop.run_in_executor(self.pool, self._new_env, seed)
with self.session_lock:
self.sessions[session_id] = {
"env": env,
"obs": obs,
"session_data": {},
"session_id": session_id,
}
print(f"🔍 _register_session_reset_endpoint: Finished reset session, session_id: {session_id}")
print(f"🔍 _register_session_reset_endpoint: Finished reset session, session_id: {session_id}")
return JSONResponse({"message": "Session reset successfully"})

def _discover_and_register_control_plane_endpoints(self):
Expand Down Expand Up @@ -286,29 +261,27 @@ async def endpoint_handler(request: Request) -> JSONResponse:
)

# Get or create session data
session_data = self.sessions.get(session_id)
if not session_data:
if func.__name__ != "get_initial_state_endpoint":
return JSONResponse(
{"error": f"Session {session_id} not found"},
status_code=404,
)

loop = asyncio.get_running_loop()
env, obs, info = await loop.run_in_executor(self.pool, self._new_env, None)

# Initialize session state with extracted seed from session ID
session_data = {
"env": env,
"obs": obs,
"session_data": {}, # Subclasses can store additional data here
"session_id": session_id,
}
with self.session_lock:
session_data = self.sessions.get(session_id)
if not session_data:
# For initial state endpoint, we need to create the session
# based on the session ID and available information
if func.__name__ == "get_initial_state_endpoint":
env, obs, info = self._new_env(seed=None)
# Initialize session state with extracted seed from session ID
session_data = {
"env": env,
"obs": obs,
"session_data": {}, # Subclasses can store additional data here
"session_id": session_id,
}
# Store the session
self.sessions[session_id] = session_data
else:
return JSONResponse(
{"error": f"Session {session_id} not found"},
status_code=404,
)

# Call the endpoint function with session data
self.sessions[session_id] = session_data

if inspect.iscoroutinefunction(func):
result = await func(session_data=session_data)
else:
Expand Down Expand Up @@ -356,22 +329,21 @@ def _update_control_plane(self, reward: float, terminated: bool, truncated: bool

def _get_or_create_session_control_plane(self, session_id: str) -> Dict[str, Any]:
"""Get or create control plane state for a specific session."""
with self.session_lock:
if session_id not in self.sessions:
return {}

session_data = self.sessions[session_id]
if "control_plane" not in session_data["session_data"]:
session_data["session_data"]["control_plane"] = {
"reward": 0.0,
"terminated": False,
"truncated": False,
"info": {},
"step_count": 0,
"total_reward": 0.0,
}
if session_id not in self.sessions:
raise Exception(f"Session {session_id} not found")

session_data = self.sessions[session_id]
if "control_plane" not in session_data["session_data"]:
session_data["session_data"]["control_plane"] = {
"reward": 0.0,
"terminated": False,
"truncated": False,
"info": {},
"step_count": 0,
"total_reward": 0.0,
}

return session_data["session_data"]["control_plane"]
return session_data["session_data"]["control_plane"]

def _update_session_control_plane(
self,
Expand All @@ -396,13 +368,6 @@ def _update_session_control_plane(
f"🎛️ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}, total_reward={control_plane['total_reward']}"
)

def get_control_plane_state(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get control plane state for a specific session (for rollout system)."""
with self.session_lock:
if session_id in self.sessions:
return self._get_or_create_session_control_plane(session_id).copy()
return None

def _execute_environment_step(self, action_int: int) -> Dict[str, Any]:
"""
Execute environment step and update control plane (single session).
Expand Down Expand Up @@ -510,11 +475,11 @@ def get_info_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
return control_plane.get("info", {})

@control_plane_endpoint("/control/initial_state")
def get_initial_state_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
async def get_initial_state_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
"""Get initial state for this session."""
session_id = session_data.get("session_id", "unknown")
env = session_data.get("env")
obs = session_data.get("obs")

if env and obs is not None:
try:
formatted_obs = self.format_observation(obs, env)
Expand Down Expand Up @@ -604,8 +569,8 @@ async def run_with_high_concurrency():
proxy_headers=True,
forwarded_allow_ips="*",
# HIGH CONCURRENCY SETTINGS
limit_concurrency=200, # Increase for HTTP endpoints + MCP
limit_max_requests=100000, # Higher request limit
limit_concurrency=None, # Increase for HTTP endpoints + MCP
limit_max_requests=None, # Higher request limit
timeout_keep_alive=120, # Longer keep-alive for control plane
timeout_notify=180,
h11_max_incomplete_event_size=4 * 1024 * 1024, # Handle larger events
Expand All @@ -624,11 +589,6 @@ def _to_json_serializable(self, obj: Any) -> Any:
Handles Pydantic models, dataclasses, lists, dicts, and primitive types.
This is a utility method that can be used by format_observation implementations.
"""
import dataclasses
from datetime import date, datetime
from enum import Enum

from pydantic import BaseModel

# Handle None and primitive types
if obj is None or isinstance(obj, (str, int, float, bool)):
Expand Down
11 changes: 11 additions & 0 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def start(self) -> None:
if self.process:
return

try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
result = s.connect_ex(("localhost", self.port))
if result == 0:
raise RuntimeError(
f"Port {self.port} is already in use! Please use a different port or kill the process using it."
)
except socket.error:
pass

# Set environment for server
env = os.environ.copy()
env["PORT"] = str(self.port)
Expand Down
4 changes: 2 additions & 2 deletions examples/blackjack_mcp/blackjack_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class BlackjackMcp(McpGym):
- Multi-session support with session-based control plane state
"""

def __init__(self, seed: Optional[int] = None):
def __init__(self, seed: Optional[int] = None, **kwargs):
"""Initialize Blackjack MCP-Gym environment."""
adapter = BlackjackAdapter()
super().__init__("Blackjack-v1", adapter, seed)
super().__init__("Blackjack-v1", adapter, seed, **kwargs)

# Multi-session support is now handled by the base class

Expand Down
4 changes: 2 additions & 2 deletions examples/cliff_walking_mcp/cliff_walking_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ class CliffWalkingMcp(McpGym):
- Multi-session support with session-based control plane state
"""

def __init__(self, seed: Optional[int] = None):
def __init__(self, seed: Optional[int] = None, **kwargs):
"""Initialize Cliff Walking MCP-Gym environment."""
adapter = CliffWalkingAdapter()
super().__init__("CliffWalking-v1", adapter, seed)
super().__init__("CliffWalking-v1", adapter, seed, **kwargs)

# Multi-session support is now handled by the base class

Expand Down
Loading
Loading