diff --git a/eval_protocol/mcp/client/connection.py b/eval_protocol/mcp/client/connection.py index 350b9b2a..64f80352 100644 --- a/eval_protocol/mcp/client/connection.py +++ b/eval_protocol/mcp/client/connection.py @@ -101,7 +101,7 @@ async def initialize_session(self, session: MCPSession) -> None: # Update the session ID to match what the server generated session.session_id = server_session_id - logger.debug(f"Updated session ID to match server: {server_session_id}") + logger.info(f"Updated session ID to match server: {server_session_id}") # PRE-WARM: Discover and cache tools immediately after session initialization # This prevents concurrent list_tools() calls later @@ -133,6 +133,24 @@ async def _prewarm_tools_cache(self, session: MCPSession) -> None: self._tools_cache[cache_key] = tool_schemas logger.debug(f"โœ… PRE-WARMED {len(tool_schemas)} tools for{cache_key}") + async def reset_session(self, session: MCPSession) -> None: + """ + Clean session data in remote mcp server for the given session + """ + import httpx + + base_url = session.base_url.rstrip("/").removesuffix("/mcp") + url = f"{base_url}/control/reset_session" + + headers = {"mcp-session-id": session.session_id} + body = {"seed": session.seed} + + timeout = httpx.Timeout(3.0) + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post(url, headers=headers, json=body) + resp.raise_for_status() + logger.debug(f"Session {session.session_id}: reset_session -> {resp.json()}") + async def discover_tools(self, session: MCPSession) -> List[Dict]: """ Discover available tools from an MCP session. diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 8f52d323..21b27eb6 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -22,7 +22,6 @@ from ...models import CompletionParams, EvaluationRow, InputMetadata, Message from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory -from ..client.connection import MCPConnectionManager if TYPE_CHECKING: from ..session.manager import GeneralMCPVectorEnv @@ -33,43 +32,9 @@ class ExecutionManager: """ - Unified manager that handles both MCP session lifecycle and rollout execution. - - Combines the functionality of SessionManager and RolloutManager for better - organization and reduced complexity. + Manage rollout for MCP environments. """ - def __init__(self): - """Initialize the execution manager.""" - self.connection_manager = MCPConnectionManager() - - async def initialize_sessions(self, sessions: List[MCPSession]) -> None: - """ - Initialize multiple MCP sessions in parallel. - - Args: - sessions: List of MCPSessions to initialize - """ - tasks = [self.connection_manager.initialize_session(session) for session in sessions] - await asyncio.gather(*tasks) - - async def close_sessions(self, sessions: List[MCPSession]) -> None: - """ - Close multiple MCP sessions in parallel. - - Args: - sessions: List of MCPSessions to close - """ - tasks = [asyncio.create_task(self.connection_manager.close_session(session)) for session in sessions] - - if tasks: - try: - # Wait for all close operations to complete - await asyncio.gather(*tasks, return_exceptions=True) - except asyncio.CancelledError: - # Handle cancellation gracefully (especially important for Python 3.12) - logger.debug("Close operation was cancelled, but sessions are marked as closed") - async def execute_rollouts( self, envs: "GeneralMCPVectorEnv", @@ -178,7 +143,7 @@ async def _execute_with_semaphore(idx): for msg in trajectory.conversation_history: # Create a copy to avoid modifying the original msg_dict = dict(msg) - + # Handle multimodal content (list of content blocks) by extracting text if isinstance(msg_dict.get("content"), list): text_content = None @@ -187,7 +152,7 @@ async def _execute_with_semaphore(idx): text_content = content_block.get("text") break msg_dict["content"] = text_content or "" - + messages.append(Message.model_validate(msg_dict)) input_metadata = InputMetadata( diff --git a/eval_protocol/mcp/mcpgym.py b/eval_protocol/mcp/mcpgym.py index edc1a244..2bfa7d7b 100644 --- a/eval_protocol/mcp/mcpgym.py +++ b/eval_protocol/mcp/mcpgym.py @@ -116,6 +116,7 @@ def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional # Register tools and control plane endpoints self._register_tools() self._discover_and_register_control_plane_endpoints() + self._register_session_reset_endpoint() def _get_session_id(self, ctx: Context) -> str: """ @@ -227,6 +228,28 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]: return self.sessions[session_id] + def _register_session_reset_endpoint(self): + + @self.mcp.custom_route("/control/reset_session", methods=["POST"]) + async def reset_session_endpoint(request: Request) -> JSONResponse: + session_id = request.headers.get("mcp-session-id") + body = await request.json() + seed = body.get("seed", None) + print(f"๐Ÿ” _register_session_reset_endpoint: Resetting session, session_id: {session_id}, seed: {seed}") + if not session_id: + return JSONResponse({"error": "Missing mcp-session-id header"}, status_code=400) + with self.session_lock: + if session_id in self.sessions: + env, obs, _ = self._new_env(seed=seed) + self.sessions[session_id] = { + "env": env, + "obs": obs, + "session_data": {}, + "session_id": session_id, + } + print(f"๐Ÿ” _register_session_reset_endpoint: Finished reset session, session_id: {session_id}") + return JSONResponse({"message": "Session reset successfully"}) + def _discover_and_register_control_plane_endpoints(self): """ Discover and register control plane endpoints on the subclass instance. @@ -323,7 +346,7 @@ def _update_control_plane(self, reward: float, terminated: bool, truncated: bool # Log control plane update (for debugging) print( - f"๐ŸŽ›๏ธ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}" + f"๐ŸŽ›๏ธ Control plane updated: reward={reward}, terminated={terminated}, step={self.control_plane_state['step_count']}, total_reward={self.control_plane_state['total_reward']}" ) def _get_or_create_session_control_plane(self, session_id: str) -> Dict[str, Any]: @@ -365,7 +388,7 @@ def _update_session_control_plane( # Log control plane update print( - f"๐ŸŽ›๏ธ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}" + f"๐ŸŽ›๏ธ Session {session_id[:16]}... control plane: reward={reward}, terminated={terminated}, step={control_plane['step_count']}, total_reward={control_plane['total_reward']}" ) def get_control_plane_state(self, session_id: str) -> Optional[Dict[str, Any]]: diff --git a/eval_protocol/mcp/session/manager.py b/eval_protocol/mcp/session/manager.py index 0a55d387..be413b46 100644 --- a/eval_protocol/mcp/session/manager.py +++ b/eval_protocol/mcp/session/manager.py @@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from ...types import DatasetRow, MCPSession, MCPToolCall -from ..execution.manager import ExecutionManager +from ..client.connection import MCPConnectionManager logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ def __init__( self.user_prompt_formatter = user_prompt_formatter or self._default_formatter self.n = len(sessions) self.tool_schemas = [] # Discovered from MCP servers - self.execution_manager = ExecutionManager() + self.connection_manager = MCPConnectionManager() self.usage_stats = {} # llm usage stats for monitoring if len(sessions) != len(dataset_rows): @@ -58,17 +58,14 @@ async def reset(self, session: MCPSession) -> Tuple[Any, List[Dict]]: This is thread-safe and can be called from worker threads. """ - # Establish a persistent session for each environment. - await self.execution_manager.connection_manager.initialize_session(session) - # Get available tools from MCP server - tool_schemas = await self.execution_manager.connection_manager.discover_tools(session) + tool_schemas = await self.connection_manager.discover_tools(session) if not self.tool_schemas: self.tool_schemas = tool_schemas # PROPER MCP PATTERN: Get initial state from resources during session establishment - initial_observation = await self.execution_manager.connection_manager.get_initial_state(session) + initial_observation = await self.connection_manager.get_initial_state(session) # Update session state session.terminated = False @@ -119,7 +116,7 @@ async def step(self, env_index: int, tool_call: MCPToolCall) -> Tuple[Any, float ) # Execute the tool call via MCP protocol - observation, reward, done, info = await self.execution_manager.connection_manager.call_tool( + observation, reward, done, info = await self.connection_manager.call_tool( session, tool_call.tool_name, tool_call.arguments ) @@ -223,5 +220,6 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st async def close(self): """Closes all MCP sessions.""" print(f"๐Ÿงน Closing {self.n} MCP sessions...") - await self.execution_manager.close_sessions(self.sessions) + tasks = [self.connection_manager.close_session(session) for session in self.sessions] + await asyncio.gather(*tasks) print(f"โœ… All MCP sessions closed.") diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index f52b1793..2a03e931 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -17,7 +17,7 @@ policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b") # Create environments with evaluation_rows configuration - envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows) + envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows) # Execute tool-calling rollouts evaluation_rows = await ep.rollout(envs, policy=policy, steps=512) @@ -51,11 +51,20 @@ from .mcp.session.manager import GeneralMCPVectorEnv from .models import EvaluationRow from .types import DatasetRow, MCPSession, MCPToolCall +import asyncio logger = logging.getLogger(__name__) -def make( +async def reset_mcp_sessions(envs: GeneralMCPVectorEnv): + """ + Reset mcp server sessions + """ + tasks = [envs.connection_manager.reset_session(session) for session in envs.sessions] + await asyncio.gather(*tasks) + + +async def make( env_spec: str, evaluation_rows: Optional[List[EvaluationRow]] = None, dataset: Optional[List[Dict]] = None, @@ -63,6 +72,7 @@ def make( seeds: Optional[List[int]] = None, model_id: str = "unknown", user_prompt_formatter: Optional[Callable] = None, + reset_sessions: bool = False, ) -> GeneralMCPVectorEnv: """ Create general MCP environments driven by evaluation_rows configuration. @@ -75,19 +85,20 @@ def make( seeds: List of seeds (for backward compatibility) model_id: Model identifier user_prompt_formatter: Optional callback for formatting user prompts + reset_sessions: Whether to reset sessions before returning the environment Returns: General MCP environment that works with any MCP server Example: # EvaluationRow approach (preferred) - envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows) + envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows) # Dataset approach (backward compatibility) - envs = ep.make("http://localhost:8000/mcp", dataset=dataset) + envs = await ep.make("http://localhost:8000/mcp", dataset=dataset) # Legacy approach (backward compatibility) - envs = ep.make("http://localhost:8000/mcp", n=10, seeds=seeds) + envs = await ep.make("http://localhost:8000/mcp", n=10, seeds=seeds) """ # Parse environment specification - make sure URL format is correct base_url = env_spec @@ -160,8 +171,6 @@ def make( ) sessions.append(session) - return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter) - else: # Legacy approach for backward compatibility if n is None: @@ -198,7 +207,14 @@ def make( ) sessions.append(session) - return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter) + mcp_envs = GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter) + tasks = [mcp_envs.connection_manager.initialize_session(session) for session in sessions] + await asyncio.gather(*tasks) + + if reset_sessions: + await reset_mcp_sessions(mcp_envs) + + return mcp_envs async def rollout( @@ -266,7 +282,7 @@ async def rollout( raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL") auto_model_id = model_id or getattr(policy, "model_id", "unknown") - envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id) + envs = await make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id) # Use the new ExecutionManager for execution execution_manager = ExecutionManager() diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index bc68d346..a803dd43 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -182,49 +182,47 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False # Don't suppress exceptions - -async def default_mcp_gym_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]: +async def default_mcp_gym_rollout_processor( + rows: List[EvaluationRow], config: RolloutProcessorConfig +) -> List[EvaluationRow]: """ Rollout processor for tau bench environments. - + This processor starts an MCP server, creates tau bench environments, and runs rollouts using the eval_protocol framework, following the pattern from test_tau2_e2e.py. - + Args: rows: List of EvaluationRow objects containing messages and dataset info in input_metadata config: RolloutProcessorConfig with model and other parameters - + Returns: List of EvaluationRow objects with completed conversations """ server = MCPServerManager(config.server_script_path, port=9700) - + try: server.start() - + policy = ep.LiteLLMPolicy( model_id=config.model, - temperature=config.input_params.get('temperature', 0.0), - max_tokens=config.input_params.get('max_tokens', 4096), + temperature=config.input_params.get("temperature", 0.0), + max_tokens=config.input_params.get("max_tokens", 4096), ) - + # Create MCP environments directly from evaluation_rows - envs = ep.make( - 'http://localhost:9700/mcp/', + envs = await ep.make( + "http://localhost:9700/mcp/", evaluation_rows=rows, model_id=policy.model_id, ) - + # Run rollout with environments and policy evaluation_rows = await ep.rollout( - envs, - policy=policy, - steps=config.steps, - max_concurrent_rollouts=config.max_concurrent_rollouts + envs, policy=policy, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts ) - + return evaluation_rows - + finally: # Always clean up the server server.stop() diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index 953f6aa6..7c0184f0 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional +from mcp.client.session import ClientSession +from contextlib import AsyncExitStack class TerminationReason(str, Enum): @@ -50,8 +52,8 @@ class MCPSession: last_observation: Any = None # Persistent MCP connection components - _exit_stack: Optional[Any] = None - _mcp_session: Optional[Any] = None + _exit_stack: Optional[AsyncExitStack] = None + _mcp_session: Optional[ClientSession] = None @dataclass diff --git a/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py b/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py index 69552c53..b77b7daa 100644 --- a/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py +++ b/examples/blackjack_mcp/tests/test_record_and_replay_e2e.py @@ -215,7 +215,7 @@ async def test_production_server_record_and_replay(production_server, blackjack_ assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=blackjack_dataset, model_id=playback_policy.model_id, @@ -250,7 +250,7 @@ async def test_production_server_record_and_replay(production_server, blackjack_ assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create environments - envs = ep.make( + envs = await ep.make( "http://localhost:9500/mcp/", dataset=blackjack_dataset, model_id=policy.model_id, @@ -310,7 +310,7 @@ async def test_production_server_record_and_replay(production_server, blackjack_ assert playback_policy.is_playback_mode(), "Should be in playback mode" # Create new environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=blackjack_dataset, model_id=playback_policy.model_id, @@ -462,7 +462,7 @@ async def test_blackjack_step_by_step(conda_isolation_recording_file): ] # Create environment pointing to conda-isolated server - envs = ep.make( + envs = await ep.make( f"http://localhost:{port}/mcp/", dataset=test_dataset, model_id=policy.model_id, @@ -570,7 +570,7 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording policy = create_blackjack_static_policy(action_sequence=["HIT", "HIT", "STICK"]) # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset, model_id=policy.model_id, @@ -992,7 +992,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=multi_env_dataset, model_id=playback_policy.model_id, @@ -1033,7 +1033,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset, model_id=policy.model_id, @@ -1149,7 +1149,7 @@ async def test_control_plane_state_querying(multi_env_dataset): policy = create_blackjack_static_policy(action_sequence=["HIT", "STAND"]) # Create environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset[:2], # Use only 2 environments for faster testing model_id=policy.model_id, diff --git a/examples/cliff_walking_mcp/tests/test_cliff_walking_e2e.py b/examples/cliff_walking_mcp/tests/test_cliff_walking_e2e.py index fc327f62..277a3457 100644 --- a/examples/cliff_walking_mcp/tests/test_cliff_walking_e2e.py +++ b/examples/cliff_walking_mcp/tests/test_cliff_walking_e2e.py @@ -224,7 +224,7 @@ async def test_production_server_record_and_replay( assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=cliff_walking_dataset, model_id=playback_policy.model_id, @@ -259,7 +259,7 @@ async def test_production_server_record_and_replay( assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create environments - envs = ep.make( + envs = await ep.make( "http://localhost:9500/mcp/", dataset=cliff_walking_dataset, model_id=policy.model_id, @@ -318,7 +318,7 @@ async def test_production_server_record_and_replay( assert playback_policy.is_playback_mode(), "Should be in playback mode" # Create new environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=cliff_walking_dataset, model_id=playback_policy.model_id, @@ -471,7 +471,7 @@ async def test_cliff_walking_step_by_step(conda_isolation_recording_file): ] # Create environment pointing to conda-isolated server - envs = ep.make( + envs = await ep.make( f"http://localhost:{port}/mcp/", dataset=test_dataset, model_id=policy.model_id, @@ -589,7 +589,7 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording ) # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset, model_id=policy.model_id, @@ -1018,7 +1018,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=multi_env_dataset, model_id=playback_policy.model_id, @@ -1059,7 +1059,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset, model_id=policy.model_id, @@ -1178,7 +1178,7 @@ async def test_control_plane_state_querying(multi_env_dataset): policy = create_cliff_walking_static_policy(action_sequence=["UP", "UP"]) # Create environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset[:2], # Use only 2 environments for faster testing model_id=policy.model_id, diff --git a/examples/frozen_lake_mcp/test_basic_functionality.py b/examples/frozen_lake_mcp/test_basic_functionality.py index 296d611f..a4a310ad 100644 --- a/examples/frozen_lake_mcp/test_basic_functionality.py +++ b/examples/frozen_lake_mcp/test_basic_functionality.py @@ -46,7 +46,7 @@ async def test_basic_server_functionality(): policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b", temperature=0.2) # Create environment pointing to local server - envs = ep.make("http://localhost:8000/mcp/", dataset=test_dataset, model_id=policy.model_id) + envs = await ep.make("http://localhost:8000/mcp/", dataset=test_dataset, model_id=policy.model_id) print("โœ… Successfully connected to MCP server") # Test 2: Try to make tool calls (we'll simulate this for now) diff --git a/examples/frozen_lake_mcp/test_multi_session.py b/examples/frozen_lake_mcp/test_multi_session.py index 08140af8..529f543f 100644 --- a/examples/frozen_lake_mcp/test_multi_session.py +++ b/examples/frozen_lake_mcp/test_multi_session.py @@ -60,7 +60,7 @@ async def test_multi_session(): try: # Create environments (assumes server is running on localhost:8000) - envs = ep.make( + envs = await ep.make( "http://localhost:8000/mcp/", dataset=test_dataset, model_id=policy.model_id, diff --git a/examples/frozen_lake_mcp/test_seed_logging.py b/examples/frozen_lake_mcp/test_seed_logging.py index edb1b272..248c004b 100644 --- a/examples/frozen_lake_mcp/test_seed_logging.py +++ b/examples/frozen_lake_mcp/test_seed_logging.py @@ -30,7 +30,7 @@ async def test_seed_logging(): try: # Create environment pointing to our server print("๐Ÿ”Œ Connecting to server...") - envs = ep.make("http://localhost:9600/mcp/", dataset=dataset, model_id="test") + envs = await ep.make("http://localhost:9600/mcp/", dataset=dataset, model_id="test") print(f"โœ… Created envs: {len(envs.sessions)} sessions") # Reset environments to trigger session creation diff --git a/examples/frozen_lake_mcp/tests/test_frozen_lake_e2e.py b/examples/frozen_lake_mcp/tests/test_frozen_lake_e2e.py index e2c4c78e..3ce71f3e 100644 --- a/examples/frozen_lake_mcp/tests/test_frozen_lake_e2e.py +++ b/examples/frozen_lake_mcp/tests/test_frozen_lake_e2e.py @@ -232,7 +232,7 @@ async def test_production_server_record_and_replay(production_server, frozen_lak assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=frozen_lake_dataset, model_id=playback_policy.model_id, @@ -268,7 +268,7 @@ async def test_production_server_record_and_replay(production_server, frozen_lak assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create environments - envs = ep.make( + envs = await ep.make( "http://localhost:9500/mcp/", dataset=frozen_lake_dataset, model_id=policy.model_id, @@ -335,7 +335,7 @@ async def test_production_server_record_and_replay(production_server, frozen_lak assert playback_policy.is_playback_mode(), "Should be in playback mode" # Create new environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=frozen_lake_dataset, model_id=playback_policy.model_id, @@ -488,7 +488,7 @@ async def test_frozen_lake_step_by_step(conda_isolation_recording_file): ] # Create environment pointing to conda-isolated server - envs = ep.make( + envs = await ep.make( f"http://localhost:{port}/mcp/", dataset=test_dataset, model_id=policy.model_id, @@ -593,7 +593,7 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording policy = create_frozen_lake_static_policy(action_sequence=["RIGHT", "RIGHT", "RIGHT", "DOWN", "DOWN", "DOWN"]) # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset, model_id=policy.model_id, @@ -1071,7 +1071,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=multi_env_dataset, model_id=playback_policy.model_id, @@ -1113,7 +1113,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset, model_id=policy.model_id, @@ -1232,7 +1232,7 @@ async def test_control_plane_state_querying(multi_env_dataset): policy = create_frozen_lake_static_policy(action_sequence=["RIGHT", "DOWN"]) # Create environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset[:2], # Use only 2 environments for faster testing model_id=policy.model_id, @@ -1283,7 +1283,7 @@ async def _run_playback_only(recording_file: str, dataset: List[Dict], server_ur assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( server_url, dataset=dataset, model_id=playback_policy.model_id, diff --git a/examples/lunar_lander_mcp/test_lunar_lander_conda.py b/examples/lunar_lander_mcp/test_lunar_lander_conda.py index 98d3c491..9e88f92b 100644 --- a/examples/lunar_lander_mcp/test_lunar_lander_conda.py +++ b/examples/lunar_lander_mcp/test_lunar_lander_conda.py @@ -119,7 +119,7 @@ async def test_lunar_lander_with_conda_isolation(): ] # Configure for MCP environment - envs = ep.make("http://localhost:9004/mcp", dataset=dataset) + envs = await ep.make("http://localhost:9004/mcp", dataset=dataset) # Simple policy that takes random actions class RandomLunarLanderPolicy: diff --git a/examples/lunar_lander_mcp/tests/test_lunar_lander_e2e.py b/examples/lunar_lander_mcp/tests/test_lunar_lander_e2e.py index 723b68bb..7f187cac 100644 --- a/examples/lunar_lander_mcp/tests/test_lunar_lander_e2e.py +++ b/examples/lunar_lander_mcp/tests/test_lunar_lander_e2e.py @@ -235,7 +235,7 @@ async def test_production_server_record_and_replay(production_server, lunar_land assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=lunar_lander_dataset, model_id=playback_policy.model_id, @@ -271,7 +271,7 @@ async def test_production_server_record_and_replay(production_server, lunar_land assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create environments - envs = ep.make( + envs = await ep.make( "http://localhost:9500/mcp/", dataset=lunar_lander_dataset, model_id=policy.model_id, @@ -332,7 +332,7 @@ async def test_production_server_record_and_replay(production_server, lunar_land assert playback_policy.is_playback_mode(), "Should be in playback mode" # Create new environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=lunar_lander_dataset, model_id=playback_policy.model_id, @@ -487,7 +487,7 @@ async def test_lunar_lander_step_by_step(conda_isolation_recording_file): ] # Create environment pointing to conda-isolated server - envs = ep.make( + envs = await ep.make( f"http://localhost:{port}/mcp/", dataset=test_dataset, model_id=policy.model_id, @@ -626,7 +626,7 @@ async def test_multi_environment_sessions(multi_env_dataset, multi_env_recording policy = create_lunar_lander_static_policy() # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset, model_id=policy.model_id, @@ -1076,7 +1076,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=multi_env_dataset, model_id=playback_policy.model_id, @@ -1118,7 +1118,7 @@ async def test_fireworks_multi_environment_sessions(multi_env_dataset, fireworks assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset, model_id=policy.model_id, @@ -1228,7 +1228,7 @@ async def test_control_plane_state_querying(multi_env_dataset): policy = create_lunar_lander_static_policy(action_sequence=["FIRE_MAIN", "FIRE_LEFT"]) # Create environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_dataset[:2], # Use only 2 environments for faster testing model_id=policy.model_id, diff --git a/examples/tau2_mcp/tests/test_tau2_e2e.py b/examples/tau2_mcp/tests/test_tau2_e2e.py index e33dd788..f31584dd 100644 --- a/examples/tau2_mcp/tests/test_tau2_e2e.py +++ b/examples/tau2_mcp/tests/test_tau2_e2e.py @@ -105,9 +105,9 @@ def start(self) -> None: # Set environment for server env = os.environ.copy() env["PORT"] = str(self.port) - if 'PYTHONPATH' not in env: - env['PYTHONPATH'] = '' - env['PYTHONPATH'] += os.pathsep + str(self.base_dir) + if "PYTHONPATH" not in env: + env["PYTHONPATH"] = "" + env["PYTHONPATH"] += os.pathsep + str(self.base_dir) # Start server process (no domain argument needed for tau2_mcp server) cmd = ["python", self.server_script, "--port", str(self.port)] @@ -886,7 +886,7 @@ async def test_fireworks_multi_airline_environment_sessions( assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=multi_env_airline_dataset, model_id=playback_policy.model_id, @@ -928,7 +928,7 @@ async def test_fireworks_multi_airline_environment_sessions( assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_airline_dataset, model_id=policy.model_id, @@ -1029,7 +1029,7 @@ async def test_entire_airline_dataset(multi_env_airline_full_dataset, fireworks_ assert playback_policy.is_playback_mode(), "Should be in playback mode in CI" # Create environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=multi_env_airline_full_dataset, model_id=playback_policy.model_id, @@ -1076,7 +1076,7 @@ async def test_entire_airline_dataset(multi_env_airline_full_dataset, fireworks_ assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_airline_full_dataset, model_id=policy.model_id, @@ -1425,7 +1425,7 @@ async def test_fireworks_multi_mock_environment_sessions( server = _create_test_server(8021, domain="mock") # Use unique port for mock try: - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_mock_dataset, model_id=playback_policy.model_id, @@ -1469,7 +1469,7 @@ async def test_fireworks_multi_mock_environment_sessions( assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_mock_dataset, model_id=policy.model_id, @@ -1559,7 +1559,7 @@ async def test_fireworks_multi_retail_environment_sessions( server = _create_test_server(8022, domain="retail") # Use unique port for retail try: - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_retail_dataset, model_id=playback_policy.model_id, @@ -1603,7 +1603,7 @@ async def test_fireworks_multi_retail_environment_sessions( assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create multiple environments - envs = ep.make( + envs = await ep.make( f"http://localhost:{server.port}/mcp/", dataset=multi_env_retail_dataset, model_id=policy.model_id, diff --git a/examples/taxi_mcp_complete/local_testing/test_north_star.py b/examples/taxi_mcp_complete/local_testing/test_north_star.py index 3721b4e7..b5d84006 100644 --- a/examples/taxi_mcp_complete/local_testing/test_north_star.py +++ b/examples/taxi_mcp_complete/local_testing/test_north_star.py @@ -57,7 +57,7 @@ async def test_north_star_interface(): print(f"โœ… Policy created in {'playback' if policy.is_playback_mode() else 'live'} mode") # Create environments - envs = ep.make("http://localhost:8000/mcp/", dataset=dataset, model_id=policy.model_id) + envs = await ep.make("http://localhost:8000/mcp/", dataset=dataset, model_id=policy.model_id) print("โœ… MCP environments created successfully") # Run rollout - same API for both modes! diff --git a/examples/taxi_mcp_complete/tests/test_taxi_e2e.py b/examples/taxi_mcp_complete/tests/test_taxi_e2e.py index 5fd24fcc..337f9fd1 100644 --- a/examples/taxi_mcp_complete/tests/test_taxi_e2e.py +++ b/examples/taxi_mcp_complete/tests/test_taxi_e2e.py @@ -165,7 +165,7 @@ async def test_production_server_record_and_replay(production_server, taxi_datas assert not policy.is_playback_mode(), "Should be in recording mode initially" # Create environments - envs = ep.make("http://localhost:9500/mcp/", dataset=taxi_dataset, model_id=policy.model_id) + envs = await ep.make("http://localhost:9500/mcp/", dataset=taxi_dataset, model_id=policy.model_id) # Record evaluation rows (Taxi typically needs more steps) start_time = time.time() @@ -196,7 +196,7 @@ async def test_production_server_record_and_replay(production_server, taxi_datas assert playback_policy.is_playback_mode(), "Should be in playback mode" # Create new environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9500/mcp/", dataset=taxi_dataset, model_id=playback_policy.model_id, @@ -242,7 +242,7 @@ async def test_simulation_server_record_and_replay(simulation_server, taxi_datas ) # Create environments pointing to simulation server - envs = ep.make("http://localhost:9501/mcp/", dataset=taxi_dataset, model_id=policy.model_id) + envs = await ep.make("http://localhost:9501/mcp/", dataset=taxi_dataset, model_id=policy.model_id) # Record evaluation rows start_time = time.time() @@ -266,7 +266,7 @@ async def test_simulation_server_record_and_replay(simulation_server, taxi_datas ) # Create new environments for playback - playback_envs = ep.make( + playback_envs = await ep.make( "http://localhost:9501/mcp/", dataset=taxi_dataset, model_id=playback_policy.model_id, diff --git a/local_evals/model_comparison_eval.ipynb b/local_evals/model_comparison_eval.ipynb index e19b554d..e36dbe6c 100644 --- a/local_evals/model_comparison_eval.ipynb +++ b/local_evals/model_comparison_eval.ipynb @@ -401,7 +401,7 @@ " with MCPServerManager(\"../examples/tau2_mcp/server.py\", port=8000, domain=\"airline\") as server:\n", " policy = model_info[\"policy\"]\n", " \n", - " envs = rk.make(\n", + " envs = await rk.make(\n", " \"http://localhost:8000/mcp/\",\n", " dataset=dataset, \n", " model_id=policy.model_id,\n", diff --git a/tests/test_parallel_rollouts.py b/tests/test_parallel_rollouts.py index ef5c83a6..ae12a9c0 100644 --- a/tests/test_parallel_rollouts.py +++ b/tests/test_parallel_rollouts.py @@ -138,7 +138,7 @@ async def _test_seed_handling_and_type_compatibility_impl(): ) # 3. Test that environments are created with proper seed isolation - envs = ep.make("http://127.0.0.1:8001/mcp/", dataset=dataset) + envs = await ep.make("http://127.0.0.1:8001/mcp/", dataset=dataset) # Verify we have the right number of environments assert len(envs.sessions) == len(test_seeds), f"Expected {len(test_seeds)} sessions, got {len(envs.sessions)}" @@ -273,7 +273,7 @@ async def _run_simplified_compatibility_test(): ) # This should work even without a server (just creates session objects) - envs = ep.make("http://127.0.0.1:8001/mcp/", dataset=dataset) + envs = await ep.make("http://127.0.0.1:8001/mcp/", dataset=dataset) assert len(envs.sessions) == len(test_seeds) print("โœ… Environment creation works") diff --git a/tests/test_rollout_control_plane_integration.py b/tests/test_rollout_control_plane_integration.py index 2f35bc2b..667b74cb 100644 --- a/tests/test_rollout_control_plane_integration.py +++ b/tests/test_rollout_control_plane_integration.py @@ -489,7 +489,7 @@ async def test_rollout_creates_envs_from_url(self): policy = MockPolicy(["right"]) with ( - patch("eval_protocol.mcp_env.make") as mock_make, + patch("eval_protocol.mcp_env.make", new_callable=AsyncMock) as mock_make, patch("eval_protocol.mcp_env.ExecutionManager") as MockManager, ): mock_env = MagicMock() @@ -512,7 +512,15 @@ async def test_rollout_creates_envs_from_url(self): dataset=dataset, model_id="test_model", ) - manager_instance.execute_rollouts.assert_called_once_with(mock_env, policy, 5, None, 8) + + manager_instance.execute_rollouts.assert_called_once_with( + mock_make.return_value, + policy, + 5, + None, + 8, + ) + assert result == ["ok"] def test_control_plane_trajectory_serialization(self): diff --git a/tests/test_url_handling.py b/tests/test_url_handling.py index 21d661f0..fbd71b28 100644 --- a/tests/test_url_handling.py +++ b/tests/test_url_handling.py @@ -1,5 +1,4 @@ -import asyncio - +from unittest.mock import AsyncMock, patch import httpx import pytest from werkzeug.wrappers import Response @@ -7,31 +6,46 @@ import eval_protocol as ep -# Sync tests for the ep.make() function -def test_mcp_env_make_appends_trailing_slash(): +# Sync tests for the await ep.make() function +@pytest.mark.asyncio +async def test_mcp_env_make_appends_trailing_slash(): """ - Verify that ep.make() appends a trailing slash to the MCP server URL if it's missing. + Verify that await ep.make() appends a trailing slash to the MCP server URL if it's missing. This prevents 307 redirects that can break HTTP clients. """ base_url = "http://localhost:8000/mcp" corrected_url = "http://localhost:8000/mcp/" - # Use n and seeds to avoid needing a full dataset - envs = ep.make(base_url, n=1, seeds=[42]) + with patch( + "eval_protocol.mcp.client.connection.MCPConnectionManager.initialize_session", + new_callable=AsyncMock, + ) as mock_init: + mock_init.return_value = None + + envs = await ep.make(base_url, n=1, seeds=[42]) + + mock_init.assert_awaited_once() assert len(envs.sessions) == 1 - # The session's base_url should have the trailing slash assert envs.sessions[0].base_url == corrected_url -def test_mcp_env_make_keeps_existing_trailing_slash(): +@pytest.mark.asyncio +async def test_mcp_env_make_keeps_existing_trailing_slash(): """ - Verify that ep.make() does not add an extra slash if one is already present. + Verify that await ep.make() does not add an extra slash if one is already present. """ base_url = "http://localhost:8000/mcp/" - # Use n and seeds to avoid needing a full dataset - envs = ep.make(base_url, n=1, seeds=[42]) + with patch( + "eval_protocol.mcp.client.connection.MCPConnectionManager.initialize_session", + new_callable=AsyncMock, + ) as mock_init: + mock_init.return_value = None + + envs = await ep.make(base_url, n=1, seeds=[42]) + + mock_init.assert_awaited_once() assert len(envs.sessions) == 1 # The session's base_url should remain unchanged