Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 11 additions & 39 deletions eval_protocol/mcp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mcp.client.streamable_http import streamablehttp_client

from ...types import MCPSession
from mcp.types import Implementation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,19 +51,16 @@ async def initialize_session(self, session: MCPSession) -> None:

exit_stack = AsyncExitStack()

client_info = None
if session.seed is not None or (session.dataset_row and session.dataset_row.environment_context):
from mcp.types import Implementation

client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
if session.seed is not None:
client_info._extra["seed"] = session.seed
if session.dataset_row and session.dataset_row.environment_context:
client_info._extra["config"] = session.dataset_row.environment_context
if session.dataset_row and session.dataset_row.id:
client_info._extra["dataset_row_id"] = session.dataset_row.id
if session.model_id:
client_info._extra["model_id"] = session.model_id
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
client_info._extra["session_id"] = session.session_id
if session.seed is not None:
client_info._extra["seed"] = session.seed
if session.dataset_row and session.dataset_row.environment_context:
client_info._extra["config"] = session.dataset_row.environment_context
if session.dataset_row and session.dataset_row.id:
client_info._extra["dataset_row_id"] = session.dataset_row.id
if session.model_id:
client_info._extra["model_id"] = session.model_id

read_stream, write_stream, _ = await exit_stack.enter_async_context(
streamablehttp_client(session.base_url, terminate_on_close=True)
Expand All @@ -77,32 +75,6 @@ async def initialize_session(self, session: MCPSession) -> None:
session._mcp_session = mcp_session
session._exit_stack = exit_stack

# Update session ID to match server's calculation (for control plane sync)
if client_info and hasattr(client_info, "_extra"):
extra_data = client_info._extra
if extra_data and isinstance(extra_data, dict):

seed_value = extra_data.get("seed")
config_value = extra_data.get("config", {})
dataset_row_id_value = extra_data.get("dataset_row_id")
model_id_value = extra_data.get("model_id")

stable_data = {
"seed": seed_value,
"config": config_value,
"dataset_row_id": dataset_row_id_value,
"model_id": model_id_value,
"name": client_info.name,
"version": client_info.version,
}

stable_str = json.dumps(stable_data, sort_keys=True)
server_session_id = hashlib.md5(stable_str.encode()).hexdigest()

# Update the session ID to match what the server generated
session.session_id = server_session_id
logger.info(f"Updated session ID to match server: {server_session_id}")

# PRE-WARM: Discover and cache tools immediately after session initialization
# This prevents concurrent list_tools() calls later
await self._prewarm_tools_cache(session)
Expand Down
10 changes: 9 additions & 1 deletion eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def _execute_with_semaphore(idx):
)

tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
trajectories = await asyncio.gather(*tasks)
trajectories = await asyncio.gather(*tasks, return_exceptions=True)

# Calculate durations
total_duration = time.time() - start_time
Expand Down Expand Up @@ -141,6 +141,14 @@ async def _execute_with_semaphore(idx):
evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in trajectories]

for idx, trajectory in enumerate(trajectories):
evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
if isinstance(trajectory, Exception):
logger.error(f"🚨 Error in rollout id {envs.dataset_rows[idx].id}: {trajectory}")
evaluation_rows[idx].input_metadata.session_data["error"] = True
evaluation_rows[idx].input_metadata.session_data["error_message"] = str(trajectory)
continue

# Handle multimodal content by extracting text from complex content structures
messages = []
for msg in trajectory.conversation_history:
Expand Down
7 changes: 6 additions & 1 deletion eval_protocol/mcp/mcpgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ def _get_session_id(self, ctx: Context) -> str:
print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")

if extra_data and isinstance(extra_data, dict):
# Create a stable session ID based on seed and other config
# use the client generated session id
if "session_id" in extra_data:
print(f"🔍 _get_session_id: using client generated session_id: {extra_data['session_id']}")
return extra_data["session_id"]

# fallback to create a stable session ID based on seed and other config
seed_value = extra_data.get("seed")
config_value = extra_data.get("config", {})
dataset_row_id_value = extra_data.get("dataset_row_id")
Expand Down
3 changes: 3 additions & 0 deletions eval_protocol/mcp/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st

async def close(self):
"""Closes all MCP sessions."""
print(f"🧹 Resetting {self.n} MCP sessions in MCP server...")
cleanup_tasks = [self.connection_manager.reset_session(session) for session in self.sessions]
await asyncio.gather(*cleanup_tasks)
print(f"🧹 Closing {self.n} MCP sessions...")
tasks = [self.connection_manager.close_session(session) for session in self.sessions]
await asyncio.gather(*tasks)
Expand Down
31 changes: 29 additions & 2 deletions eval_protocol/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,34 @@
from .mcp.session.manager import GeneralMCPVectorEnv
from .models import EvaluationRow
from .types import DatasetRow, MCPSession, MCPToolCall
import asyncio
import hashlib
import json

logger = logging.getLogger(__name__)


def gen_session_id(dataset_row: DatasetRow, model_id: str) -> str:
"""
Generate a session ID for a dataset row
"""
seed_value = dataset_row.seed
config_value = dataset_row.environment_context
dataset_row_id_value = dataset_row.id
model_id_value = model_id

stable_data = {
"seed": seed_value,
"config": config_value,
"dataset_row_id": dataset_row_id_value,
"model_id": model_id_value,
}

stable_str = json.dumps(stable_data, sort_keys=True)

return hashlib.md5(stable_str.encode()).hexdigest()


async def reset_mcp_sessions(envs: GeneralMCPVectorEnv):
"""
Reset mcp server sessions
Expand Down Expand Up @@ -162,9 +186,10 @@ async def make(

dataset_rows.append(dataset_row)

session_id = gen_session_id(dataset_row, model_id)
# Create MCP session
session = MCPSession(
session_id=dataset_row.id,
session_id=session_id,
base_url=base_url,
seed=dataset_row.seed,
model_id=model_id,
Expand Down Expand Up @@ -198,9 +223,11 @@ async def make(
)
dataset_rows.append(dataset_row)

session_id = gen_session_id(dataset_row, model_id)

# Create MCP session
session = MCPSession(
session_id=f"session_{i}",
session_id=session_id,
base_url=base_url,
seed=seeds[i],
model_id=model_id,
Expand Down
20 changes: 10 additions & 10 deletions tests/pytest/test_frozen_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
similar to the test_frozen_lake_e2e test but integrated with the pytest evaluation system.
"""


from typing import Any, Dict, List

from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams, MetricResult
Expand All @@ -18,7 +17,7 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
Convert entries from frozen lake dataset to EvaluationRow objects.
"""
rows = []

for row in data:
eval_row = EvaluationRow(
messages=[Message(role="system", content=row["system_prompt"])],
Expand All @@ -27,14 +26,15 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
dataset_info={
"environment_context": row["environment_context"],
"user_prompt_template": row["user_prompt_template"],
}
)
},
),
)

rows.append(eval_row)

return rows


@evaluation_test(
input_dataset=["tests/pytest/data/frozen_lake_dataset.jsonl"],
dataset_adapter=frozen_lake_to_evaluation_row,
Expand All @@ -50,13 +50,13 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow:
"""
Test frozen lake evaluation using the pytest framework.

This test evaluates how well the model can navigate the FrozenLake environment
by checking if it successfully reaches the goal while avoiding holes.

Args:
row: EvaluationRow object from frozen lake dataset

Returns:
EvaluationRow object with evaluation results
"""
Expand All @@ -71,5 +71,5 @@ def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow:
score=score,
reason=reason,
)

return row
Loading