From cc38838d9a6f0ca50affb1e3038bbdfcd207f1a3 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 7 Aug 2025 14:11:18 -0700 Subject: [PATCH 1/8] gen session_id from client --- eval_protocol/mcp/client/connection.py | 50 ++++++-------------------- eval_protocol/mcp/mcpgym.py | 7 +++- eval_protocol/mcp/session/manager.py | 3 ++ eval_protocol/mcp_env.py | 31 ++++++++++++++-- tests/pytest/test_frozen_lake.py | 20 +++++------ 5 files changed, 59 insertions(+), 52 deletions(-) diff --git a/eval_protocol/mcp/client/connection.py b/eval_protocol/mcp/client/connection.py index 64f80352..97943a23 100644 --- a/eval_protocol/mcp/client/connection.py +++ b/eval_protocol/mcp/client/connection.py @@ -16,6 +16,7 @@ from mcp.client.streamable_http import streamablehttp_client from ...types import MCPSession +from mcp.types import Implementation logger = logging.getLogger(__name__) @@ -50,19 +51,16 @@ async def initialize_session(self, session: MCPSession) -> None: exit_stack = AsyncExitStack() - client_info = None - if session.seed is not None or (session.dataset_row and session.dataset_row.environment_context): - from mcp.types import Implementation - - client_info = Implementation(name="reward-kit", version="1.0.0", _extra={}) - if session.seed is not None: - client_info._extra["seed"] = session.seed - if session.dataset_row and session.dataset_row.environment_context: - client_info._extra["config"] = session.dataset_row.environment_context - if session.dataset_row and session.dataset_row.id: - client_info._extra["dataset_row_id"] = session.dataset_row.id - if session.model_id: - client_info._extra["model_id"] = session.model_id + client_info = Implementation(name="reward-kit", version="1.0.0", _extra={}) + client_info._extra["session_id"] = session.session_id + if session.seed is not None: + client_info._extra["seed"] = session.seed + if session.dataset_row and session.dataset_row.environment_context: + client_info._extra["config"] = session.dataset_row.environment_context + if session.dataset_row and session.dataset_row.id: + client_info._extra["dataset_row_id"] = session.dataset_row.id + if session.model_id: + client_info._extra["model_id"] = session.model_id read_stream, write_stream, _ = await exit_stack.enter_async_context( streamablehttp_client(session.base_url, terminate_on_close=True) @@ -77,32 +75,6 @@ async def initialize_session(self, session: MCPSession) -> None: session._mcp_session = mcp_session session._exit_stack = exit_stack - # Update session ID to match server's calculation (for control plane sync) - if client_info and hasattr(client_info, "_extra"): - extra_data = client_info._extra - if extra_data and isinstance(extra_data, dict): - - seed_value = extra_data.get("seed") - config_value = extra_data.get("config", {}) - dataset_row_id_value = extra_data.get("dataset_row_id") - model_id_value = extra_data.get("model_id") - - stable_data = { - "seed": seed_value, - "config": config_value, - "dataset_row_id": dataset_row_id_value, - "model_id": model_id_value, - "name": client_info.name, - "version": client_info.version, - } - - stable_str = json.dumps(stable_data, sort_keys=True) - server_session_id = hashlib.md5(stable_str.encode()).hexdigest() - - # Update the session ID to match what the server generated - session.session_id = server_session_id - logger.info(f"Updated session ID to match server: {server_session_id}") - # PRE-WARM: Discover and cache tools immediately after session initialization # This prevents concurrent list_tools() calls later await self._prewarm_tools_cache(session) diff --git a/eval_protocol/mcp/mcpgym.py b/eval_protocol/mcp/mcpgym.py index 2bfa7d7b..f1df1328 100644 --- a/eval_protocol/mcp/mcpgym.py +++ b/eval_protocol/mcp/mcpgym.py @@ -146,7 +146,12 @@ def _get_session_id(self, ctx: Context) -> str: print(f"๐Ÿ” _get_session_id: extra_data type: {type(extra_data)}") if extra_data and isinstance(extra_data, dict): - # Create a stable session ID based on seed and other config + # use the client generated session id + if "session_id" in extra_data: + print(f"๐Ÿ” _get_session_id: using client generated session_id: {extra_data['session_id']}") + return extra_data["session_id"] + + # fallback to create a stable session ID based on seed and other config seed_value = extra_data.get("seed") config_value = extra_data.get("config", {}) dataset_row_id_value = extra_data.get("dataset_row_id") diff --git a/eval_protocol/mcp/session/manager.py b/eval_protocol/mcp/session/manager.py index be413b46..5bd36e5a 100644 --- a/eval_protocol/mcp/session/manager.py +++ b/eval_protocol/mcp/session/manager.py @@ -219,6 +219,9 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st async def close(self): """Closes all MCP sessions.""" + print(f"๐Ÿงน Resetting {self.n} MCP sessions in MCP server...") + cleanup_tasks = [self.connection_manager.reset_session(session) for session in self.sessions] + await asyncio.gather(*cleanup_tasks) print(f"๐Ÿงน Closing {self.n} MCP sessions...") tasks = [self.connection_manager.close_session(session) for session in self.sessions] await asyncio.gather(*tasks) diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index 1d330994..05ea8414 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -53,10 +53,34 @@ from .mcp.session.manager import GeneralMCPVectorEnv from .models import EvaluationRow from .types import DatasetRow, MCPSession, MCPToolCall +import asyncio +import hashlib +import json logger = logging.getLogger(__name__) +def gen_session_id(dataset_row: DatasetRow, model_id: str) -> str: + """ + Generate a session ID for a dataset row + """ + seed_value = dataset_row.seed + config_value = dataset_row.environment_context + dataset_row_id_value = dataset_row.id + model_id_value = model_id + + stable_data = { + "seed": seed_value, + "config": config_value, + "dataset_row_id": dataset_row_id_value, + "model_id": model_id_value, + } + + stable_str = json.dumps(stable_data, sort_keys=True) + + return hashlib.md5(stable_str.encode()).hexdigest() + + async def reset_mcp_sessions(envs: GeneralMCPVectorEnv): """ Reset mcp server sessions @@ -162,9 +186,10 @@ async def make( dataset_rows.append(dataset_row) + session_id = gen_session_id(dataset_row, model_id) # Create MCP session session = MCPSession( - session_id=dataset_row.id, + session_id=session_id, base_url=base_url, seed=dataset_row.seed, model_id=model_id, @@ -198,9 +223,11 @@ async def make( ) dataset_rows.append(dataset_row) + session_id = gen_session_id(dataset_row, model_id) + # Create MCP session session = MCPSession( - session_id=f"session_{i}", + session_id=session_id, base_url=base_url, seed=seeds[i], model_id=model_id, diff --git a/tests/pytest/test_frozen_lake.py b/tests/pytest/test_frozen_lake.py index 69f0c400..76551920 100644 --- a/tests/pytest/test_frozen_lake.py +++ b/tests/pytest/test_frozen_lake.py @@ -5,7 +5,6 @@ similar to the test_frozen_lake_e2e test but integrated with the pytest evaluation system. """ - from typing import Any, Dict, List from eval_protocol.models import EvaluateResult, EvaluationRow, Message, InputMetadata, CompletionParams, MetricResult @@ -18,7 +17,7 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation Convert entries from frozen lake dataset to EvaluationRow objects. """ rows = [] - + for row in data: eval_row = EvaluationRow( messages=[Message(role="system", content=row["system_prompt"])], @@ -27,14 +26,15 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation dataset_info={ "environment_context": row["environment_context"], "user_prompt_template": row["user_prompt_template"], - } - ) + }, + ), ) - + rows.append(eval_row) - + return rows + @evaluation_test( input_dataset=["tests/pytest/data/frozen_lake_dataset.jsonl"], dataset_adapter=frozen_lake_to_evaluation_row, @@ -50,13 +50,13 @@ def frozen_lake_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluation def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow: """ Test frozen lake evaluation using the pytest framework. - + This test evaluates how well the model can navigate the FrozenLake environment by checking if it successfully reaches the goal while avoiding holes. - + Args: row: EvaluationRow object from frozen lake dataset - + Returns: EvaluationRow object with evaluation results """ @@ -71,5 +71,5 @@ def test_frozen_lake_evaluation(row: EvaluationRow) -> EvaluationRow: score=score, reason=reason, ) - + return row From b5d242c1a95821eef81970423bf35b8210f1d20e Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 7 Aug 2025 22:12:43 -0700 Subject: [PATCH 2/8] catch error catch error add --- eval_protocol/mcp/execution/manager.py | 449 +++++++++++++------------ eval_protocol/types/types.py | 1 + 2 files changed, 230 insertions(+), 220 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index ab461f8a..f9ec6c20 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -103,6 +103,7 @@ async def _execute_with_semaphore(idx): ) tasks = [_execute_with_semaphore(i) for i in range(envs.n)] + # exceptions should be try catched inside single _execute_rollout trajectories = await asyncio.gather(*tasks) # Calculate durations @@ -159,6 +160,8 @@ async def _execute_with_semaphore(idx): messages.append(Message.model_validate(msg_dict)) evaluation_rows[idx].messages = messages + evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id + evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx]) evaluation_rows[idx].tools = shared_tool_schema evaluation_rows[idx].usage = trajectory.usage evaluation_rows[idx].input_metadata.completion_params = CompletionParams( @@ -209,241 +212,247 @@ async def _execute_rollout( "total_tokens": 0, }, ) + try: + current_observation, tool_schema = await envs.reset(session) + system_prompt = dataset_row.system_prompt + + # Record initial observation + trajectory.observations.append(current_observation) + + # Create user simulator for this rollout if configured in dataset + user_simulator = None + user_simulator_state = None + + # If user simulation is enabled, initial message is from the simulated user + if dataset_row.user_simulation and dataset_row.user_simulation.get("enabled", False): + user_simulator = UserSimulator( + instructions=dataset_row.user_simulation.get("system_prompt"), + llm=dataset_row.user_simulation.get("llm", "gpt-4.1"), + llm_args=dataset_row.user_simulation.get("llm_args", {"temperature": 0.0}), + ) - current_observation, tool_schema = await envs.reset(session) - system_prompt = dataset_row.system_prompt + # Get initial messages in tau2-bench format for user simulator + user_simulator_state = user_simulator.get_init_state() + user_message, user_simulator_state = user_simulator.generate_next_message( + AssistantMessage(role="assistant", content="Hi! How can I help you today?"), + user_simulator_state, + ) + current_observation = user_message.content if user_message.content else "" + + user_prompt = envs.format_user_prompt(rollout_idx, current_observation) + conversation_history = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + usage_stats_list: List[CompletionUsage] = [] + + logger.info(f"๐ŸŽฏ Starting rollout {rollout_idx} in thread {threading.current_thread().name}") + + # Run rollout loop for this specific environment + step = 0 + rollout_end = False + + while step < steps and not trajectory.terminated: + turn_completed = False + info = {} + reward = 0.0 + observation = current_observation + tool_calls = [] + + if user_simulator and user_simulator_state: + # Get user simulator messages and find the last assistant message + user_simulator_messages = self._get_user_simulator_messages(conversation_history) + + # Last message was agent, simulated user response + if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage): + # Generate user response using the simulator + user_message, user_simulator_state = user_simulator.generate_next_message( + user_simulator_messages[-1], user_simulator_state + ) + user_content = user_message.content if user_message.content else "" - # Record initial observation - trajectory.observations.append(current_observation) + user_prompt = envs.format_user_prompt(rollout_idx, user_content) + conversation_history.append({"role": "user", "content": user_prompt}) - # Create user simulator for this rollout if configured in dataset - user_simulator = None - user_simulator_state = None + # Check if user simulator signaled termination + if UserSimulator.is_stop(user_message): + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.USER_STOP - # If user simulation is enabled, initial message is from the simulated user - if dataset_row.user_simulation and dataset_row.user_simulation.get("enabled", False): - user_simulator = UserSimulator( - instructions=dataset_row.user_simulation.get("system_prompt"), - llm=dataset_row.user_simulation.get("llm", "gpt-4.1"), - llm_args=dataset_row.user_simulation.get("llm_args", {"temperature": 0.0}), - ) + # In each turn: keep looping until assistant is ready to provide final response + while not turn_completed and not trajectory.terminated: + tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) - # Get initial messages in tau2-bench format for user simulator - user_simulator_state = user_simulator.get_init_state() - user_message, user_simulator_state = user_simulator.generate_next_message( - AssistantMessage(role="assistant", content="Hi! How can I help you today?"), - user_simulator_state, - ) - current_observation = user_message.content if user_message.content else "" - - user_prompt = envs.format_user_prompt(rollout_idx, current_observation) - conversation_history = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - usage_stats_list: List[CompletionUsage] = [] - - logger.info(f"๐ŸŽฏ Starting rollout {rollout_idx} in thread {threading.current_thread().name}") - - # Run rollout loop for this specific environment - step = 0 - rollout_end = False - - while step < steps and not trajectory.terminated: - turn_completed = False - info = {} - reward = 0.0 - observation = current_observation - tool_calls = [] - - if user_simulator and user_simulator_state: - # Get user simulator messages and find the last assistant message - user_simulator_messages = self._get_user_simulator_messages(conversation_history) - - # Last message was agent, simulated user response - if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage): - # Generate user response using the simulator - user_message, user_simulator_state = user_simulator.generate_next_message( - user_simulator_messages[-1], user_simulator_state - ) - user_content = user_message.content if user_message.content else "" - - user_prompt = envs.format_user_prompt(rollout_idx, user_content) - conversation_history.append({"role": "user", "content": user_prompt}) - - # Check if user simulator signaled termination - if UserSimulator.is_stop(user_message): - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.USER_STOP - - # In each turn: keep looping until assistant is ready to provide final response - while not turn_completed and not trajectory.terminated: - tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) - - # If no tool call is generated, turn is finished - if len(tool_calls) == 1: - # If there's a user simulator, no tool call means the policy is ready to provide final response on this turn - if tool_calls[0].tool_name == "_no_tool_call" and user_simulator: - turn_completed = True - break - # If there's no user simulator, no tool call means policy failed and we should terminate the rollout - elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: - trajectory.terminated = True - break - - # Execute each tool call sequentially - for tool_call in tool_calls: - - # Execute tool call for this environment - observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call) - - tool_response = envs.format_tool_response(observation) - - policy.add_tool_response( - rollout_idx, - tool_call, - tool_response, - conversation_history, - reward, - rollout_end, - info, - ) + # If no tool call is generated, turn is finished + if len(tool_calls) == 1: + # If there's a user simulator, no tool call means the policy is ready to provide final response on this turn + if tool_calls[0].tool_name == "_no_tool_call" and user_simulator: + turn_completed = True + break + # If there's no user simulator, no tool call means policy failed and we should terminate the rollout + elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: + trajectory.terminated = True + break - # Update trajectory with both data and control plane information - trajectory.observations.append(observation) - - # Record action (tool call) - action_str = f"{tool_call.tool_name}({tool_call.arguments})" - trajectory.actions.append(action_str) - - # Record control plane (reward/termination) - trajectory.rewards.append(reward) - trajectory.total_reward += reward - - # Non-user simulator step counter: each tool call is a step - if user_simulator is None: - step += 1 - trajectory.steps = step - - control_plane_step = { - "step": step - 1, - "reward": reward, - "terminated": rollout_end, - "info": info.get("control_plane", {}), - "tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"], - "num_tool_calls": 1, - } - conversation_history[-1]["control_plane_step"] = control_plane_step - trajectory.control_plane_steps.append(control_plane_step) - - # Log conversation state for playback if in recording mode - if recording_mode: - policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - - if rollout_end: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL - break - elif step >= steps: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.MAX_STEPS - break - - # Update current observation for potential next turn - if observation is not None: - current_observation = observation - - # calc llm usage stats happened in this turn if there is aany - if usage_stats: - usage_stats_list.append(usage_stats) - - # With user simulator, increment step after an entire conversation step - if user_simulator is not None: - step += 1 - trajectory.steps = step - - # Enhanced trajectory recording with control plane info - # Create summary of all tool calls executed in this step - tool_calls_summary = [f"{tc.tool_name}({tc.arguments})" for tc in tool_calls] - - control_plane_step = { - "step": step - 1, - "reward": reward, - "terminated": rollout_end, - "info": info.get("control_plane", {}), - "tool_calls": tool_calls_summary, - "num_tool_calls": len(tool_calls), - } - conversation_history[-1]["control_plane_step"] = control_plane_step - trajectory.control_plane_steps.append(control_plane_step) - - # Log conversation state for playback if in recording mode - if recording_mode: - policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - - # Use control plane information for termination decision - if rollout_end: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL - - # Add final control plane summary - trajectory.control_plane_summary.update( - { - "total_reward": trajectory.total_reward, - "termination_reason": trajectory.termination_reason, - "final_step": step - 1, - "control_plane_source": info.get("control_plane", {}), - } - ) + # Execute each tool call sequentially + for tool_call in tool_calls: - # Log final OpenAI conversation for terminated trajectories only - if openai_logger: - if conversation_history and len(conversation_history) > 0: - openai_logger( - { - "messages": conversation_history, - "metadata": { - "session_id": session.session_id, - "seed": session.seed, - "total_steps": trajectory.steps, - "total_reward": trajectory.total_reward, - "terminated": True, - "success": reward > 0, - "control_plane_summary": trajectory.control_plane_summary, - }, - } + # Execute tool call for this environment + observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call) + + tool_response = envs.format_tool_response(observation) + + policy.add_tool_response( + rollout_idx, + tool_call, + tool_response, + conversation_history, + reward, + rollout_end, + info, ) - logger.info( - f"๐Ÿ Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" - ) - break + # Update trajectory with both data and control plane information + trajectory.observations.append(observation) + + # Record action (tool call) + action_str = f"{tool_call.tool_name}({tool_call.arguments})" + trajectory.actions.append(action_str) + + # Record control plane (reward/termination) + trajectory.rewards.append(reward) + trajectory.total_reward += reward + + # Non-user simulator step counter: each tool call is a step + if user_simulator is None: + step += 1 + trajectory.steps = step + + control_plane_step = { + "step": step - 1, + "reward": reward, + "terminated": rollout_end, + "info": info.get("control_plane", {}), + "tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"], + "num_tool_calls": 1, + } + conversation_history[-1]["control_plane_step"] = control_plane_step + trajectory.control_plane_steps.append(control_plane_step) - # Progress logging - if step % 10 == 0: - logger.debug(f"Rollout {rollout_idx} step {step}, reward: {trajectory.total_reward:.2f}") + # Log conversation state for playback if in recording mode + if recording_mode: + policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - # Set termination reason if not already set (e.g., due to step limit) - if not trajectory.termination_reason and step >= steps: - trajectory.termination_reason = TerminationReason.MAX_STEPS + if rollout_end: + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL + break + elif step >= steps: + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.MAX_STEPS + break - trajectory.conversation_history = conversation_history + # Update current observation for potential next turn + if observation is not None: + current_observation = observation + + # calc llm usage stats happened in this turn if there is aany + if usage_stats: + usage_stats_list.append(usage_stats) + + # With user simulator, increment step after an entire conversation step + if user_simulator is not None: + step += 1 + trajectory.steps = step + + # Enhanced trajectory recording with control plane info + # Create summary of all tool calls executed in this step + tool_calls_summary = [f"{tc.tool_name}({tc.arguments})" for tc in tool_calls] + + control_plane_step = { + "step": step - 1, + "reward": reward, + "terminated": rollout_end, + "info": info.get("control_plane", {}), + "tool_calls": tool_calls_summary, + "num_tool_calls": len(tool_calls), + } + conversation_history[-1]["control_plane_step"] = control_plane_step + trajectory.control_plane_steps.append(control_plane_step) + + # Log conversation state for playback if in recording mode + if recording_mode: + policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) + + # Use control plane information for termination decision + if rollout_end: + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL + + # Add final control plane summary + trajectory.control_plane_summary.update( + { + "total_reward": trajectory.total_reward, + "termination_reason": trajectory.termination_reason, + "final_step": step - 1, + "control_plane_source": info.get("control_plane", {}), + } + ) - # Add termination_reason to the final control_plane_step - for msg in reversed(trajectory.conversation_history): - if msg.get("control_plane_step"): - msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason - break + # Log final OpenAI conversation for terminated trajectories only + if openai_logger: + if conversation_history and len(conversation_history) > 0: + openai_logger( + { + "messages": conversation_history, + "metadata": { + "session_id": session.session_id, + "seed": session.seed, + "total_steps": trajectory.steps, + "total_reward": trajectory.total_reward, + "terminated": True, + "success": reward > 0, + "control_plane_summary": trajectory.control_plane_summary, + }, + } + ) + + logger.info( + f"๐Ÿ Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" + ) + break - for usage_stats in usage_stats_list: - trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens - trajectory.usage["completion_tokens"] += usage_stats.completion_tokens - trajectory.usage["total_tokens"] += usage_stats.total_tokens + # Progress logging + if step % 10 == 0: + logger.debug(f"Rollout {rollout_idx} step {step}, reward: {trajectory.total_reward:.2f}") - logger.info( - f"โœ… Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" - ) + # Set termination reason if not already set (e.g., due to step limit) + if not trajectory.termination_reason and step >= steps: + trajectory.termination_reason = TerminationReason.MAX_STEPS + + trajectory.conversation_history = conversation_history + + # Add termination_reason to the final control_plane_step + for msg in reversed(trajectory.conversation_history): + if msg.get("control_plane_step"): + msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason + break + + for usage_stats in usage_stats_list: + trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens + trajectory.usage["completion_tokens"] += usage_stats.completion_tokens + trajectory.usage["total_tokens"] += usage_stats.total_tokens + + logger.info( + f"โœ… Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" + ) + except Exception as e: + logger.error(f"๐Ÿšจ Error in rollout {rollout_idx}: {e}", exc_info=True) + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.ERROR + trajectory.input_metadata.session_data["error"] = True + trajectory.input_metadata.session_data["error_message"] = str(e) return trajectory async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]: diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index 7c0184f0..b9d4a19e 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -16,6 +16,7 @@ class TerminationReason(str, Enum): MAX_STEPS = "max_steps" CONTROL_PLANE_SIGNAL = "control_plane_signal" USER_STOP = "user_stop" + ERROR = "error" @dataclass From b486ebc1e00c5a5bc4eaea20a6ada7ac79fd036b Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 7 Aug 2025 22:12:43 -0700 Subject: [PATCH 3/8] catch error catch error add --- eval_protocol/mcp/execution/manager.py | 449 +++++++++++++------------ eval_protocol/types/types.py | 1 + 2 files changed, 230 insertions(+), 220 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index ab461f8a..f9ec6c20 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -103,6 +103,7 @@ async def _execute_with_semaphore(idx): ) tasks = [_execute_with_semaphore(i) for i in range(envs.n)] + # exceptions should be try catched inside single _execute_rollout trajectories = await asyncio.gather(*tasks) # Calculate durations @@ -159,6 +160,8 @@ async def _execute_with_semaphore(idx): messages.append(Message.model_validate(msg_dict)) evaluation_rows[idx].messages = messages + evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id + evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx]) evaluation_rows[idx].tools = shared_tool_schema evaluation_rows[idx].usage = trajectory.usage evaluation_rows[idx].input_metadata.completion_params = CompletionParams( @@ -209,241 +212,247 @@ async def _execute_rollout( "total_tokens": 0, }, ) + try: + current_observation, tool_schema = await envs.reset(session) + system_prompt = dataset_row.system_prompt + + # Record initial observation + trajectory.observations.append(current_observation) + + # Create user simulator for this rollout if configured in dataset + user_simulator = None + user_simulator_state = None + + # If user simulation is enabled, initial message is from the simulated user + if dataset_row.user_simulation and dataset_row.user_simulation.get("enabled", False): + user_simulator = UserSimulator( + instructions=dataset_row.user_simulation.get("system_prompt"), + llm=dataset_row.user_simulation.get("llm", "gpt-4.1"), + llm_args=dataset_row.user_simulation.get("llm_args", {"temperature": 0.0}), + ) - current_observation, tool_schema = await envs.reset(session) - system_prompt = dataset_row.system_prompt + # Get initial messages in tau2-bench format for user simulator + user_simulator_state = user_simulator.get_init_state() + user_message, user_simulator_state = user_simulator.generate_next_message( + AssistantMessage(role="assistant", content="Hi! How can I help you today?"), + user_simulator_state, + ) + current_observation = user_message.content if user_message.content else "" + + user_prompt = envs.format_user_prompt(rollout_idx, current_observation) + conversation_history = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + usage_stats_list: List[CompletionUsage] = [] + + logger.info(f"๐ŸŽฏ Starting rollout {rollout_idx} in thread {threading.current_thread().name}") + + # Run rollout loop for this specific environment + step = 0 + rollout_end = False + + while step < steps and not trajectory.terminated: + turn_completed = False + info = {} + reward = 0.0 + observation = current_observation + tool_calls = [] + + if user_simulator and user_simulator_state: + # Get user simulator messages and find the last assistant message + user_simulator_messages = self._get_user_simulator_messages(conversation_history) + + # Last message was agent, simulated user response + if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage): + # Generate user response using the simulator + user_message, user_simulator_state = user_simulator.generate_next_message( + user_simulator_messages[-1], user_simulator_state + ) + user_content = user_message.content if user_message.content else "" - # Record initial observation - trajectory.observations.append(current_observation) + user_prompt = envs.format_user_prompt(rollout_idx, user_content) + conversation_history.append({"role": "user", "content": user_prompt}) - # Create user simulator for this rollout if configured in dataset - user_simulator = None - user_simulator_state = None + # Check if user simulator signaled termination + if UserSimulator.is_stop(user_message): + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.USER_STOP - # If user simulation is enabled, initial message is from the simulated user - if dataset_row.user_simulation and dataset_row.user_simulation.get("enabled", False): - user_simulator = UserSimulator( - instructions=dataset_row.user_simulation.get("system_prompt"), - llm=dataset_row.user_simulation.get("llm", "gpt-4.1"), - llm_args=dataset_row.user_simulation.get("llm_args", {"temperature": 0.0}), - ) + # In each turn: keep looping until assistant is ready to provide final response + while not turn_completed and not trajectory.terminated: + tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) - # Get initial messages in tau2-bench format for user simulator - user_simulator_state = user_simulator.get_init_state() - user_message, user_simulator_state = user_simulator.generate_next_message( - AssistantMessage(role="assistant", content="Hi! How can I help you today?"), - user_simulator_state, - ) - current_observation = user_message.content if user_message.content else "" - - user_prompt = envs.format_user_prompt(rollout_idx, current_observation) - conversation_history = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - usage_stats_list: List[CompletionUsage] = [] - - logger.info(f"๐ŸŽฏ Starting rollout {rollout_idx} in thread {threading.current_thread().name}") - - # Run rollout loop for this specific environment - step = 0 - rollout_end = False - - while step < steps and not trajectory.terminated: - turn_completed = False - info = {} - reward = 0.0 - observation = current_observation - tool_calls = [] - - if user_simulator and user_simulator_state: - # Get user simulator messages and find the last assistant message - user_simulator_messages = self._get_user_simulator_messages(conversation_history) - - # Last message was agent, simulated user response - if user_simulator_messages and isinstance(user_simulator_messages[-1], AssistantMessage): - # Generate user response using the simulator - user_message, user_simulator_state = user_simulator.generate_next_message( - user_simulator_messages[-1], user_simulator_state - ) - user_content = user_message.content if user_message.content else "" - - user_prompt = envs.format_user_prompt(rollout_idx, user_content) - conversation_history.append({"role": "user", "content": user_prompt}) - - # Check if user simulator signaled termination - if UserSimulator.is_stop(user_message): - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.USER_STOP - - # In each turn: keep looping until assistant is ready to provide final response - while not turn_completed and not trajectory.terminated: - tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) - - # If no tool call is generated, turn is finished - if len(tool_calls) == 1: - # If there's a user simulator, no tool call means the policy is ready to provide final response on this turn - if tool_calls[0].tool_name == "_no_tool_call" and user_simulator: - turn_completed = True - break - # If there's no user simulator, no tool call means policy failed and we should terminate the rollout - elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: - trajectory.terminated = True - break - - # Execute each tool call sequentially - for tool_call in tool_calls: - - # Execute tool call for this environment - observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call) - - tool_response = envs.format_tool_response(observation) - - policy.add_tool_response( - rollout_idx, - tool_call, - tool_response, - conversation_history, - reward, - rollout_end, - info, - ) + # If no tool call is generated, turn is finished + if len(tool_calls) == 1: + # If there's a user simulator, no tool call means the policy is ready to provide final response on this turn + if tool_calls[0].tool_name == "_no_tool_call" and user_simulator: + turn_completed = True + break + # If there's no user simulator, no tool call means policy failed and we should terminate the rollout + elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: + trajectory.terminated = True + break - # Update trajectory with both data and control plane information - trajectory.observations.append(observation) - - # Record action (tool call) - action_str = f"{tool_call.tool_name}({tool_call.arguments})" - trajectory.actions.append(action_str) - - # Record control plane (reward/termination) - trajectory.rewards.append(reward) - trajectory.total_reward += reward - - # Non-user simulator step counter: each tool call is a step - if user_simulator is None: - step += 1 - trajectory.steps = step - - control_plane_step = { - "step": step - 1, - "reward": reward, - "terminated": rollout_end, - "info": info.get("control_plane", {}), - "tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"], - "num_tool_calls": 1, - } - conversation_history[-1]["control_plane_step"] = control_plane_step - trajectory.control_plane_steps.append(control_plane_step) - - # Log conversation state for playback if in recording mode - if recording_mode: - policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - - if rollout_end: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL - break - elif step >= steps: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.MAX_STEPS - break - - # Update current observation for potential next turn - if observation is not None: - current_observation = observation - - # calc llm usage stats happened in this turn if there is aany - if usage_stats: - usage_stats_list.append(usage_stats) - - # With user simulator, increment step after an entire conversation step - if user_simulator is not None: - step += 1 - trajectory.steps = step - - # Enhanced trajectory recording with control plane info - # Create summary of all tool calls executed in this step - tool_calls_summary = [f"{tc.tool_name}({tc.arguments})" for tc in tool_calls] - - control_plane_step = { - "step": step - 1, - "reward": reward, - "terminated": rollout_end, - "info": info.get("control_plane", {}), - "tool_calls": tool_calls_summary, - "num_tool_calls": len(tool_calls), - } - conversation_history[-1]["control_plane_step"] = control_plane_step - trajectory.control_plane_steps.append(control_plane_step) - - # Log conversation state for playback if in recording mode - if recording_mode: - policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - - # Use control plane information for termination decision - if rollout_end: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL - - # Add final control plane summary - trajectory.control_plane_summary.update( - { - "total_reward": trajectory.total_reward, - "termination_reason": trajectory.termination_reason, - "final_step": step - 1, - "control_plane_source": info.get("control_plane", {}), - } - ) + # Execute each tool call sequentially + for tool_call in tool_calls: - # Log final OpenAI conversation for terminated trajectories only - if openai_logger: - if conversation_history and len(conversation_history) > 0: - openai_logger( - { - "messages": conversation_history, - "metadata": { - "session_id": session.session_id, - "seed": session.seed, - "total_steps": trajectory.steps, - "total_reward": trajectory.total_reward, - "terminated": True, - "success": reward > 0, - "control_plane_summary": trajectory.control_plane_summary, - }, - } + # Execute tool call for this environment + observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call) + + tool_response = envs.format_tool_response(observation) + + policy.add_tool_response( + rollout_idx, + tool_call, + tool_response, + conversation_history, + reward, + rollout_end, + info, ) - logger.info( - f"๐Ÿ Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" - ) - break + # Update trajectory with both data and control plane information + trajectory.observations.append(observation) + + # Record action (tool call) + action_str = f"{tool_call.tool_name}({tool_call.arguments})" + trajectory.actions.append(action_str) + + # Record control plane (reward/termination) + trajectory.rewards.append(reward) + trajectory.total_reward += reward + + # Non-user simulator step counter: each tool call is a step + if user_simulator is None: + step += 1 + trajectory.steps = step + + control_plane_step = { + "step": step - 1, + "reward": reward, + "terminated": rollout_end, + "info": info.get("control_plane", {}), + "tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"], + "num_tool_calls": 1, + } + conversation_history[-1]["control_plane_step"] = control_plane_step + trajectory.control_plane_steps.append(control_plane_step) - # Progress logging - if step % 10 == 0: - logger.debug(f"Rollout {rollout_idx} step {step}, reward: {trajectory.total_reward:.2f}") + # Log conversation state for playback if in recording mode + if recording_mode: + policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - # Set termination reason if not already set (e.g., due to step limit) - if not trajectory.termination_reason and step >= steps: - trajectory.termination_reason = TerminationReason.MAX_STEPS + if rollout_end: + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL + break + elif step >= steps: + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.MAX_STEPS + break - trajectory.conversation_history = conversation_history + # Update current observation for potential next turn + if observation is not None: + current_observation = observation + + # calc llm usage stats happened in this turn if there is aany + if usage_stats: + usage_stats_list.append(usage_stats) + + # With user simulator, increment step after an entire conversation step + if user_simulator is not None: + step += 1 + trajectory.steps = step + + # Enhanced trajectory recording with control plane info + # Create summary of all tool calls executed in this step + tool_calls_summary = [f"{tc.tool_name}({tc.arguments})" for tc in tool_calls] + + control_plane_step = { + "step": step - 1, + "reward": reward, + "terminated": rollout_end, + "info": info.get("control_plane", {}), + "tool_calls": tool_calls_summary, + "num_tool_calls": len(tool_calls), + } + conversation_history[-1]["control_plane_step"] = control_plane_step + trajectory.control_plane_steps.append(control_plane_step) + + # Log conversation state for playback if in recording mode + if recording_mode: + policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) + + # Use control plane information for termination decision + if rollout_end: + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL + + # Add final control plane summary + trajectory.control_plane_summary.update( + { + "total_reward": trajectory.total_reward, + "termination_reason": trajectory.termination_reason, + "final_step": step - 1, + "control_plane_source": info.get("control_plane", {}), + } + ) - # Add termination_reason to the final control_plane_step - for msg in reversed(trajectory.conversation_history): - if msg.get("control_plane_step"): - msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason - break + # Log final OpenAI conversation for terminated trajectories only + if openai_logger: + if conversation_history and len(conversation_history) > 0: + openai_logger( + { + "messages": conversation_history, + "metadata": { + "session_id": session.session_id, + "seed": session.seed, + "total_steps": trajectory.steps, + "total_reward": trajectory.total_reward, + "terminated": True, + "success": reward > 0, + "control_plane_summary": trajectory.control_plane_summary, + }, + } + ) + + logger.info( + f"๐Ÿ Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" + ) + break - for usage_stats in usage_stats_list: - trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens - trajectory.usage["completion_tokens"] += usage_stats.completion_tokens - trajectory.usage["total_tokens"] += usage_stats.total_tokens + # Progress logging + if step % 10 == 0: + logger.debug(f"Rollout {rollout_idx} step {step}, reward: {trajectory.total_reward:.2f}") - logger.info( - f"โœ… Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" - ) + # Set termination reason if not already set (e.g., due to step limit) + if not trajectory.termination_reason and step >= steps: + trajectory.termination_reason = TerminationReason.MAX_STEPS + + trajectory.conversation_history = conversation_history + + # Add termination_reason to the final control_plane_step + for msg in reversed(trajectory.conversation_history): + if msg.get("control_plane_step"): + msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason + break + + for usage_stats in usage_stats_list: + trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens + trajectory.usage["completion_tokens"] += usage_stats.completion_tokens + trajectory.usage["total_tokens"] += usage_stats.total_tokens + + logger.info( + f"โœ… Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" + ) + except Exception as e: + logger.error(f"๐Ÿšจ Error in rollout {rollout_idx}: {e}", exc_info=True) + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.ERROR + trajectory.input_metadata.session_data["error"] = True + trajectory.input_metadata.session_data["error_message"] = str(e) return trajectory async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]: diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index 7c0184f0..b9d4a19e 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -16,6 +16,7 @@ class TerminationReason(str, Enum): MAX_STEPS = "max_steps" CONTROL_PLANE_SIGNAL = "control_plane_signal" USER_STOP = "user_stop" + ERROR = "error" @dataclass From a80c6e9587b86c95ef0ba8d393775f5b2b51692e Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 7 Aug 2025 23:46:25 -0700 Subject: [PATCH 4/8] add final assistant response --- eval_protocol/mcp/execution/manager.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index f9ec6c20..b698cd41 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -337,6 +337,7 @@ async def _execute_rollout( "tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"], "num_tool_calls": 1, } + print(f"๐Ÿ” control_plane_step: {control_plane_step}") conversation_history[-1]["control_plane_step"] = control_plane_step trajectory.control_plane_steps.append(control_plane_step) @@ -344,6 +345,7 @@ async def _execute_rollout( if recording_mode: policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) + # tool indicates rollout should be terminated, call policy one last time to get the final response if rollout_end: trajectory.terminated = True trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL @@ -390,6 +392,9 @@ async def _execute_rollout( trajectory.terminated = True trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL + _, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) + usage_stats_list.append(usage_stats) + # Add final control plane summary trajectory.control_plane_summary.update( { From e69766cb9d64682d6e18aa0341d6071266078c98 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 7 Aug 2025 23:53:27 -0700 Subject: [PATCH 5/8] update --- eval_protocol/mcp/execution/manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index c609a790..008fffba 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -397,7 +397,8 @@ async def _execute_rollout( trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL _, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) - usage_stats_list.append(usage_stats) + if usage_stats: + usage_stats_list.append(usage_stats) # Add final control plane summary trajectory.control_plane_summary.update( From d274bceccb7ce4686f0d473acfca5e358d32caff Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Thu, 7 Aug 2025 23:54:58 -0700 Subject: [PATCH 6/8] add --- eval_protocol/mcp/execution/manager.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 008fffba..0c596dd4 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -346,7 +346,6 @@ async def _execute_rollout( if recording_mode: policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - # tool indicates rollout should be terminated, call policy one last time to get the final response if rollout_end: trajectory.terminated = True trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL @@ -387,15 +386,13 @@ async def _execute_rollout( # Log conversation state for playback if in recording mode if recording_mode: policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - # Log conversation state for playback if in recording mode - if recording_mode: - policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) # Use control plane information for termination decision if rollout_end: trajectory.terminated = True trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL + # tool indicates rollout should be terminated, call policy one last time to get the final response _, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) if usage_stats: usage_stats_list.append(usage_stats) From 86294ed77d9dea20380205a7c294820508441b7b Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Fri, 8 Aug 2025 00:39:29 -0700 Subject: [PATCH 7/8] record rollout status --- eval_protocol/mcp/execution/manager.py | 21 +++++++++++++++++---- eval_protocol/models.py | 20 ++++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 0c596dd4..b26f2f2b 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -103,8 +103,7 @@ async def _execute_with_semaphore(idx): ) tasks = [_execute_with_semaphore(i) for i in range(envs.n)] - # exceptions should be try catched inside single _execute_rollout - # exceptions should be try catched inside single _execute_rollout + # exceptions will be try catched inside single _execute_rollout trajectories = await asyncio.gather(*tasks) # Calculate durations @@ -171,6 +170,21 @@ async def _execute_with_semaphore(idx): max_tokens=getattr(policy, "max_tokens", None), max_tool_calls=getattr(policy, "max_tools_per_turn", None), ) + if trajectory.terminated: + if trajectory.termination_reason in { + TerminationReason.CONTROL_PLANE_SIGNAL, + TerminationReason.USER_STOP, + }: + evaluation_rows[idx].rollout_status.status = "finished" + elif trajectory.termination_reason == TerminationReason.MAX_STEPS: + evaluation_rows[idx].rollout_status.status = "stopped" + else: + evaluation_rows[idx].rollout_status.status = "error" + evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get( + "error_message", None + ) + else: + evaluation_rows[idx].rollout_status.status = "running" return evaluation_rows @@ -458,8 +472,7 @@ async def _execute_rollout( logger.error(f"๐Ÿšจ Error in rollout {rollout_idx}: {e}", exc_info=True) trajectory.terminated = True trajectory.termination_reason = TerminationReason.ERROR - trajectory.input_metadata.session_data["error"] = True - trajectory.input_metadata.session_data["error_message"] = str(e) + trajectory.control_plane_summary.update({"error_message": str(e)}) return trajectory async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]: diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 60f75975..583985b4 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -220,6 +220,21 @@ class EvalMetadata(BaseModel): passed: Optional[bool] = Field(None, description="Whether the evaluation passed based on the threshold") +class RolloutStatus(BaseModel): + """Status of the rollout.""" + + """ + running: Unfinished rollout which is still in progress. + finished: Rollout finished successfully. + error: Rollout failed. + stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop). + """ + status: Literal["running", "finished", "error", "stopped"] = Field( + "finished", description="Status of the rollout." + ) + error_message: Optional[str] = Field(None, description="Error message if the rollout failed.") + + class EvaluationRow(BaseModel): """ Unified data structure for a single evaluation unit that contains messages, @@ -244,6 +259,11 @@ class EvaluationRow(BaseModel): description="Metadata related to the input (dataset info, model config, session data, etc.).", ) + rollout_status: RolloutStatus = Field( + default_factory=RolloutStatus, + description="The status of the rollout.", + ) + # Ground truth reference (moved from EvaluateResult to top level) ground_truth: Optional[str] = Field( default=None, description="Optional ground truth reference for this evaluation." From 98b7bc9bb3bee0deefa7c98421224b2b8b548b17 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Fri, 8 Aug 2025 00:42:26 -0700 Subject: [PATCH 8/8] fix ut --- tests/test_rollout_control_plane_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rollout_control_plane_integration.py b/tests/test_rollout_control_plane_integration.py index 9be29b81..6745f6cc 100644 --- a/tests/test_rollout_control_plane_integration.py +++ b/tests/test_rollout_control_plane_integration.py @@ -289,7 +289,7 @@ def mock_step_side_effect(env_index, tool_call): assert final_cp_step["step"] == 2, "Should record final step" # Validate policy interaction - assert policy.step_count == 3, "Policy should have been called 3 times" + assert policy.step_count == 4, "Policy should have been called 3 times" @pytest.mark.asyncio async def test_rollout_trajectory_recording_with_control_plane(self):