Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 11 additions & 39 deletions eval_protocol/mcp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mcp.client.streamable_http import streamablehttp_client

from ...types import MCPSession
from mcp.types import Implementation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,19 +51,16 @@ async def initialize_session(self, session: MCPSession) -> None:

exit_stack = AsyncExitStack()

client_info = None
if session.seed is not None or (session.dataset_row and session.dataset_row.environment_context):
from mcp.types import Implementation

client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
if session.seed is not None:
client_info._extra["seed"] = session.seed
if session.dataset_row and session.dataset_row.environment_context:
client_info._extra["config"] = session.dataset_row.environment_context
if session.dataset_row and session.dataset_row.id:
client_info._extra["dataset_row_id"] = session.dataset_row.id
if session.model_id:
client_info._extra["model_id"] = session.model_id
client_info = Implementation(name="reward-kit", version="1.0.0", _extra={})
client_info._extra["session_id"] = session.session_id
if session.seed is not None:
client_info._extra["seed"] = session.seed
if session.dataset_row and session.dataset_row.environment_context:
client_info._extra["config"] = session.dataset_row.environment_context
if session.dataset_row and session.dataset_row.id:
client_info._extra["dataset_row_id"] = session.dataset_row.id
if session.model_id:
client_info._extra["model_id"] = session.model_id

read_stream, write_stream, _ = await exit_stack.enter_async_context(
streamablehttp_client(session.base_url, terminate_on_close=True)
Expand All @@ -77,32 +75,6 @@ async def initialize_session(self, session: MCPSession) -> None:
session._mcp_session = mcp_session
session._exit_stack = exit_stack

# Update session ID to match server's calculation (for control plane sync)
if client_info and hasattr(client_info, "_extra"):
extra_data = client_info._extra
if extra_data and isinstance(extra_data, dict):

seed_value = extra_data.get("seed")
config_value = extra_data.get("config", {})
dataset_row_id_value = extra_data.get("dataset_row_id")
model_id_value = extra_data.get("model_id")

stable_data = {
"seed": seed_value,
"config": config_value,
"dataset_row_id": dataset_row_id_value,
"model_id": model_id_value,
"name": client_info.name,
"version": client_info.version,
}

stable_str = json.dumps(stable_data, sort_keys=True)
server_session_id = hashlib.md5(stable_str.encode()).hexdigest()

# Update the session ID to match what the server generated
session.session_id = server_session_id
logger.info(f"Updated session ID to match server: {server_session_id}")

# PRE-WARM: Discover and cache tools immediately after session initialization
# This prevents concurrent list_tools() calls later
await self._prewarm_tools_cache(session)
Expand Down
469 changes: 249 additions & 220 deletions eval_protocol/mcp/execution/manager.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion eval_protocol/mcp/mcpgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ def _get_session_id(self, ctx: Context) -> str:
print(f"🔍 _get_session_id: extra_data type: {type(extra_data)}")

if extra_data and isinstance(extra_data, dict):
# Create a stable session ID based on seed and other config
# use the client generated session id
if "session_id" in extra_data:
print(f"🔍 _get_session_id: using client generated session_id: {extra_data['session_id']}")
return extra_data["session_id"]

# fallback to create a stable session ID based on seed and other config
seed_value = extra_data.get("seed")
config_value = extra_data.get("config", {})
dataset_row_id_value = extra_data.get("dataset_row_id")
Expand Down
3 changes: 3 additions & 0 deletions eval_protocol/mcp/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st

async def close(self):
"""Closes all MCP sessions."""
print(f"🧹 Resetting {self.n} MCP sessions in MCP server...")
cleanup_tasks = [self.connection_manager.reset_session(session) for session in self.sessions]
await asyncio.gather(*cleanup_tasks)
print(f"🧹 Closing {self.n} MCP sessions...")
tasks = [self.connection_manager.close_session(session) for session in self.sessions]
await asyncio.gather(*tasks)
Expand Down
31 changes: 29 additions & 2 deletions eval_protocol/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,34 @@
from .mcp.session.manager import GeneralMCPVectorEnv
from .models import EvaluationRow
from .types import DatasetRow, MCPSession, MCPToolCall
import asyncio
import hashlib
import json

logger = logging.getLogger(__name__)


def gen_session_id(dataset_row: DatasetRow, model_id: str) -> str:
"""
Generate a session ID for a dataset row
"""
seed_value = dataset_row.seed
config_value = dataset_row.environment_context
dataset_row_id_value = dataset_row.id
model_id_value = model_id

stable_data = {
"seed": seed_value,
"config": config_value,
"dataset_row_id": dataset_row_id_value,
"model_id": model_id_value,
}

stable_str = json.dumps(stable_data, sort_keys=True)

return hashlib.md5(stable_str.encode()).hexdigest()


async def reset_mcp_sessions(envs: GeneralMCPVectorEnv):
"""
Reset mcp server sessions
Expand Down Expand Up @@ -162,9 +186,10 @@ async def make(

dataset_rows.append(dataset_row)

session_id = gen_session_id(dataset_row, model_id)
# Create MCP session
session = MCPSession(
session_id=dataset_row.id,
session_id=session_id,
base_url=base_url,
seed=dataset_row.seed,
model_id=model_id,
Expand Down Expand Up @@ -198,9 +223,11 @@ async def make(
)
dataset_rows.append(dataset_row)

session_id = gen_session_id(dataset_row, model_id)

# Create MCP session
session = MCPSession(
session_id=f"session_{i}",
session_id=session_id,
base_url=base_url,
seed=seeds[i],
model_id=model_id,
Expand Down
20 changes: 20 additions & 0 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,21 @@ class EvalMetadata(BaseModel):
passed: Optional[bool] = Field(None, description="Whether the evaluation passed based on the threshold")


class RolloutStatus(BaseModel):
"""Status of the rollout."""

"""
running: Unfinished rollout which is still in progress.
finished: Rollout finished successfully.
error: Rollout failed.
stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop).
"""
status: Literal["running", "finished", "error", "stopped"] = Field(
"finished", description="Status of the rollout."
)
error_message: Optional[str] = Field(None, description="Error message if the rollout failed.")


class EvaluationRow(BaseModel):
"""
Unified data structure for a single evaluation unit that contains messages,
Expand All @@ -244,6 +259,11 @@ class EvaluationRow(BaseModel):
description="Metadata related to the input (dataset info, model config, session data, etc.).",
)

rollout_status: RolloutStatus = Field(
default_factory=RolloutStatus,
description="The status of the rollout.",
)

# Ground truth reference (moved from EvaluateResult to top level)
ground_truth: Optional[str] = Field(
default=None, description="Optional ground truth reference for this evaluation."
Expand Down
1 change: 1 addition & 0 deletions eval_protocol/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class TerminationReason(str, Enum):
MAX_STEPS = "max_steps"
CONTROL_PLANE_SIGNAL = "control_plane_signal"
USER_STOP = "user_stop"
ERROR = "error"


@dataclass
Expand Down
20 changes: 10 additions & 10 deletions tests/pytest/test_frozen_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
similar to the test_frozen_lake_e2e test but integrated with the pytest evaluation system.
"""


from typing import Any, Dict, List

from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams, MetricResult
Expand All @@ -18,7 +17,7 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
Convert entries from frozen lake dataset to EvaluationRow objects.
"""
rows = []

for row in data:
eval_row = EvaluationRow(
messages=[Message(role="system", content=row["system_prompt"])],
Expand All @@ -27,14 +26,15 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
dataset_info={
"environment_context": row["environment_context"],
"user_prompt_template": row["user_prompt_template"],
}
)
},
),
)

rows.append(eval_row)

return rows


@evaluation_test(
input_dataset=["tests/pytest/data/frozen_lake_dataset.jsonl"],
dataset_adapter=frozen_lake_to_evaluation_row,
Expand All @@ -50,13 +50,13 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation
def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow:
"""
Test frozen lake evaluation using the pytest framework.

This test evaluates how well the model can navigate the FrozenLake environment
by checking if it successfully reaches the goal while avoiding holes.

Args:
row: EvaluationRow object from frozen lake dataset

Returns:
EvaluationRow object with evaluation results
"""
Expand All @@ -71,5 +71,5 @@ def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow:
score=score,
reason=reason,
)

return row
2 changes: 1 addition & 1 deletion tests/test_rollout_control_plane_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def mock_step_side_effect(env_index, tool_call):
assert final_cp_step["step"] == 2, "Should record final step"

# Validate policy interaction
assert policy.step_count == 3, "Policy should have been called 3 times"
assert policy.step_count == 4, "Policy should have been called 3 times"

@pytest.mark.asyncio
async def test_rollout_trajectory_recording_with_control_plane(self):
Expand Down