diff --git a/eval_protocol/mcp/client/connection.py b/eval_protocol/mcp/client/connection.py index 64f80352..97943a23 100644 --- a/eval_protocol/mcp/client/connection.py +++ b/eval_protocol/mcp/client/connection.py @@ -16,6 +16,7 @@ from mcp.client.streamable_http import streamablehttp_client from ...types import MCPSession +from mcp.types import Implementation logger = logging.getLogger(__name__) @@ -50,19 +51,16 @@ async def initialize_session(self, session: MCPSession) -> None: exit_stack = AsyncExitStack() - client_info = None - if session.seed is not None or (session.dataset_row and session.dataset_row.environment_context): - from mcp.types import Implementation - - client_info = Implementation(name="reward-kit", version="1.0.0", _extra={}) - if session.seed is not None: - client_info._extra["seed"] = session.seed - if session.dataset_row and session.dataset_row.environment_context: - client_info._extra["config"] = session.dataset_row.environment_context - if session.dataset_row and session.dataset_row.id: - client_info._extra["dataset_row_id"] = session.dataset_row.id - if session.model_id: - client_info._extra["model_id"] = session.model_id + client_info = Implementation(name="reward-kit", version="1.0.0", _extra={}) + client_info._extra["session_id"] = session.session_id + if session.seed is not None: + client_info._extra["seed"] = session.seed + if session.dataset_row and session.dataset_row.environment_context: + client_info._extra["config"] = session.dataset_row.environment_context + if session.dataset_row and session.dataset_row.id: + client_info._extra["dataset_row_id"] = session.dataset_row.id + if session.model_id: + client_info._extra["model_id"] = session.model_id read_stream, write_stream, _ = await exit_stack.enter_async_context( streamablehttp_client(session.base_url, terminate_on_close=True) @@ -77,32 +75,6 @@ async def initialize_session(self, session: MCPSession) -> None: session._mcp_session = mcp_session session._exit_stack = exit_stack - # Update session ID to match server's calculation (for control plane sync) - if client_info and hasattr(client_info, "_extra"): - extra_data = client_info._extra - if extra_data and isinstance(extra_data, dict): - - seed_value = extra_data.get("seed") - config_value = extra_data.get("config", {}) - dataset_row_id_value = extra_data.get("dataset_row_id") - model_id_value = extra_data.get("model_id") - - stable_data = { - "seed": seed_value, - "config": config_value, - "dataset_row_id": dataset_row_id_value, - "model_id": model_id_value, - "name": client_info.name, - "version": client_info.version, - } - - stable_str = json.dumps(stable_data, sort_keys=True) - server_session_id = hashlib.md5(stable_str.encode()).hexdigest() - - # Update the session ID to match what the server generated - session.session_id = server_session_id - logger.info(f"Updated session ID to match server: {server_session_id}") - # PRE-WARM: Discover and cache tools immediately after session initialization # This prevents concurrent list_tools() calls later await self._prewarm_tools_cache(session) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index ab461f8a..a5f7d706 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -103,7 +103,7 @@ async def _execute_with_semaphore(idx): ) tasks = [_execute_with_semaphore(i) for i in range(envs.n)] - trajectories = await asyncio.gather(*tasks) + trajectories = await asyncio.gather(*tasks, return_exceptions=True) # Calculate durations total_duration = time.time() - start_time @@ -141,6 +141,14 @@ async def _execute_with_semaphore(idx): evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in trajectories] for idx, trajectory in enumerate(trajectories): + evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id + evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx]) + if isinstance(trajectory, Exception): + logger.error(f"๐Ÿšจ Error in rollout id {envs.dataset_rows[idx].id}: {trajectory}") + evaluation_rows[idx].input_metadata.session_data["error"] = True + evaluation_rows[idx].input_metadata.session_data["error_message"] = str(trajectory) + continue + # Handle multimodal content by extracting text from complex content structures messages = [] for msg in trajectory.conversation_history: diff --git a/eval_protocol/mcp/mcpgym.py b/eval_protocol/mcp/mcpgym.py index 2bfa7d7b..f1df1328 100644 --- a/eval_protocol/mcp/mcpgym.py +++ b/eval_protocol/mcp/mcpgym.py @@ -146,7 +146,12 @@ def _get_session_id(self, ctx: Context) -> str: print(f"๐Ÿ” _get_session_id: extra_data type: {type(extra_data)}") if extra_data and isinstance(extra_data, dict): - # Create a stable session ID based on seed and other config + # use the client generated session id + if "session_id" in extra_data: + print(f"๐Ÿ” _get_session_id: using client generated session_id: {extra_data['session_id']}") + return extra_data["session_id"] + + # fallback to create a stable session ID based on seed and other config seed_value = extra_data.get("seed") config_value = extra_data.get("config", {}) dataset_row_id_value = extra_data.get("dataset_row_id") diff --git a/eval_protocol/mcp/session/manager.py b/eval_protocol/mcp/session/manager.py index be413b46..5bd36e5a 100644 --- a/eval_protocol/mcp/session/manager.py +++ b/eval_protocol/mcp/session/manager.py @@ -219,6 +219,9 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st async def close(self): """Closes all MCP sessions.""" + print(f"๐Ÿงน Resetting {self.n} MCP sessions in MCP server...") + cleanup_tasks = [self.connection_manager.reset_session(session) for session in self.sessions] + await asyncio.gather(*cleanup_tasks) print(f"๐Ÿงน Closing {self.n} MCP sessions...") tasks = [self.connection_manager.close_session(session) for session in self.sessions] await asyncio.gather(*tasks) diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index 1d330994..05ea8414 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -53,10 +53,34 @@ from .mcp.session.manager import GeneralMCPVectorEnv from .models import EvaluationRow from .types import DatasetRow, MCPSession, MCPToolCall +import asyncio +import hashlib +import json logger = logging.getLogger(__name__) +def gen_session_id(dataset_row: DatasetRow, model_id: str) -> str: + """ + Generate a session ID for a dataset row + """ + seed_value = dataset_row.seed + config_value = dataset_row.environment_context + dataset_row_id_value = dataset_row.id + model_id_value = model_id + + stable_data = { + "seed": seed_value, + "config": config_value, + "dataset_row_id": dataset_row_id_value, + "model_id": model_id_value, + } + + stable_str = json.dumps(stable_data, sort_keys=True) + + return hashlib.md5(stable_str.encode()).hexdigest() + + async def reset_mcp_sessions(envs: GeneralMCPVectorEnv): """ Reset mcp server sessions @@ -162,9 +186,10 @@ async def make( dataset_rows.append(dataset_row) + session_id = gen_session_id(dataset_row, model_id) # Create MCP session session = MCPSession( - session_id=dataset_row.id, + session_id=session_id, base_url=base_url, seed=dataset_row.seed, model_id=model_id, @@ -198,9 +223,11 @@ async def make( ) dataset_rows.append(dataset_row) + session_id = gen_session_id(dataset_row, model_id) + # Create MCP session session = MCPSession( - session_id=f"session_{i}", + session_id=session_id, base_url=base_url, seed=seeds[i], model_id=model_id, diff --git a/tests/pytest/test_frozen_lake.py b/tests/pytest/test_frozen_lake.py index 69f0c400..76551920 100644 --- a/tests/pytest/test_frozen_lake.py +++ b/tests/pytest/test_frozen_lake.py @@ -5,7 +5,6 @@ similar to the test_frozen_lake_e2e test but integrated with the pytest evaluation system. """ - from typing import Any, Dict, List from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams, MetricResult @@ -18,7 +17,7 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation Convert entries from frozen lake dataset to EvaluationRow objects. """ rows = [] - + for row in data: eval_row = EvaluationRow( messages=[Message(role="system", content=row["system_prompt"])], @@ -27,14 +26,15 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation dataset_info={ "environment_context": row["environment_context"], "user_prompt_template": row["user_prompt_template"], - } - ) + }, + ), ) - + rows.append(eval_row) - + return rows + @evaluation_test( input_dataset=["tests/pytest/data/frozen_lake_dataset.jsonl"], dataset_adapter=frozen_lake_to_evaluation_row, @@ -50,13 +50,13 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow: """ Test frozen lake evaluation using the pytest framework. - + This test evaluates how well the model can navigate the FrozenLake environment by checking if it successfully reaches the goal while avoiding holes. - + Args: row: EvaluationRow object from frozen lake dataset - + Returns: EvaluationRow object with evaluation results """ @@ -71,5 +71,5 @@ def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow: score=score, reason=reason, ) - + return row