diff --git a/eval_protocol/mcp/execution/base_policy.py b/eval_protocol/mcp/execution/base_policy.py index 17f67da3..22bad57e 100644 --- a/eval_protocol/mcp/execution/base_policy.py +++ b/eval_protocol/mcp/execution/base_policy.py @@ -151,7 +151,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List[MCPToolCall], CompletionUsage]: + ) -> Tuple[List[MCPToolCall], CompletionUsage, str]: """ Generate tool calls using conversation history for proper OpenAI trajectories. @@ -161,7 +161,7 @@ async def _generate_live_tool_calls( user_prompt: Current user prompt with observation Returns: - List of MCPToolCall objects + List of MCPToolCall objects, LLM usage stats, and finish reason """ # Convert MCP tools to LLM format llm_tools = self._convert_mcp_tools_to_llm_format(tool_schemas) @@ -190,6 +190,8 @@ async def _generate_live_tool_calls( total_tokens=response["usage"]["total_tokens"], ) + finish_reason = response["choices"][0]["finish_reason"] + # Extract tool call from response message = response["choices"][0]["message"] logger.debug(f"Environment {env_index} - Response message: {message}") @@ -217,15 +219,19 @@ async def _generate_live_tool_calls( if self.max_tools_per_turn: mcp_tool_calls = mcp_tool_calls[: self.max_tools_per_turn] - return mcp_tool_calls, usage_stats + return mcp_tool_calls, usage_stats, finish_reason else: # No tool calls in response - this is normal when episode ends or LLM provides only text logger.debug(f"No tool calls in response for env {env_index}, message content: {message.get('content')}") - return [ - MCPToolCall( - tool_name="_no_tool_call", - arguments={ - "reason": "no_tool_call_generated", - }, - ) - ], usage_stats + return ( + [ + MCPToolCall( + tool_name="_no_tool_call", + arguments={ + "reason": "no_tool_call_generated", + }, + ) + ], + usage_stats, + finish_reason, + ) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index dbf197fc..aec867fc 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -169,21 +169,14 @@ async def _execute_with_semaphore(idx): max_tool_calls=getattr(policy, "max_tools_per_turn", None), ) if trajectory.terminated: - if trajectory.termination_reason in { - TerminationReason.CONTROL_PLANE_SIGNAL, - TerminationReason.USER_STOP, - }: - evaluation_rows[idx].rollout_status.status = "finished" - elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}: - evaluation_rows[idx].rollout_status.status = "stopped" - evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get( - "termination_reason", trajectory.termination_reason - ) - else: + if trajectory.termination_reason == TerminationReason.ERROR: evaluation_rows[idx].rollout_status.status = "error" - evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get( + evaluation_rows[idx].rollout_status.termination_reason = trajectory.control_plane_summary.get( "error_message", None ) + else: + evaluation_rows[idx].rollout_status.status = "finished" + evaluation_rows[idx].rollout_status.termination_reason = trajectory.termination_reason else: evaluation_rows[idx].rollout_status.status = "running" @@ -266,7 +259,7 @@ async def _execute_rollout( # Run rollout loop for this specific environment step = 0 - rollout_end = False + env_end = False # if the env indicates the rollout reaches the goal while step < steps and not trajectory.terminated: turn_completed = False @@ -297,7 +290,9 @@ async def _execute_rollout( # In each turn: keep looping until assistant is ready to provide final response while not turn_completed and not trajectory.terminated: - tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) + tool_calls, usage_stats, finish_reason = await policy( + tool_schema, rollout_idx, conversation_history + ) # calc llm usage stats happened in this turn if there is aany if usage_stats: @@ -311,17 +306,17 @@ async def _execute_rollout( if tool_calls[0].tool_name == "_no_tool_call" and user_simulator: turn_completed = True break - # If there's no user simulator, no tool call means policy failed and we should terminate the rollout + # If there's no user simulator, then it marks the end of the episode as LLM think there is no tool call needed. elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: trajectory.terminated = True - trajectory.termination_reason = TerminationReason.INTERRUPTED + trajectory.termination_reason = TerminationReason.from_str(finish_reason) break # Execute each tool call sequentially for tool_call in tool_calls: # Execute tool call for this environment - observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call) + observation, reward, env_end, info = await envs.step(rollout_idx, tool_call) tool_response = envs.format_tool_response(observation) @@ -331,7 +326,7 @@ async def _execute_rollout( tool_response, conversation_history, reward, - rollout_end, + env_end, info, ) @@ -354,7 +349,7 @@ async def _execute_rollout( control_plane_step = { "step": step - 1, "reward": reward, - "terminated": rollout_end, + "terminated": env_end, "info": info.get("control_plane", {}), "tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"], "num_tool_calls": 1, @@ -367,11 +362,13 @@ async def _execute_rollout( if recording_mode: policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - if rollout_end: + if env_end: + # if the env marks the end of the rollout, break the tool call loop + # but set the termination reason later after the final policy call trajectory.terminated = True - trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL break - elif step >= steps: + + if step >= steps: trajectory.terminated = True trajectory.termination_reason = TerminationReason.MAX_STEPS break @@ -392,7 +389,7 @@ async def _execute_rollout( control_plane_step = { "step": step - 1, "reward": reward, - "terminated": rollout_end, + "terminated": env_end, "info": info.get("control_plane", {}), "tool_calls": tool_calls_summary, "num_tool_calls": len(tool_calls), @@ -404,19 +401,16 @@ async def _execute_rollout( if recording_mode: policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - # Use control plane information for termination decision - if rollout_end: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL - - # tool indicates rollout should be terminated, call policy one last time to get the final response - _, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) + # if the env marks end, update control plane summary and do one last policy call, then break the agent loop + # this is to ensure each turn ends with an assistant message, which will align with the actual agentic llm behavior + if env_end: + _, usage_stats, finish_reason = await policy(tool_schema, rollout_idx, conversation_history) if usage_stats: trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens trajectory.usage["completion_tokens"] += usage_stats.completion_tokens trajectory.usage["total_tokens"] += usage_stats.total_tokens - - # Add final control plane summary + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.from_str(finish_reason) trajectory.control_plane_summary.update( { "total_reward": trajectory.total_reward, @@ -445,7 +439,7 @@ async def _execute_rollout( ) logger.info( - f"🏁 Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" + f"🏁 Environmnet indicates rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" ) break diff --git a/eval_protocol/mcp/execution/policy.py b/eval_protocol/mcp/execution/policy.py index 06233c4b..f529a21d 100644 --- a/eval_protocol/mcp/execution/policy.py +++ b/eval_protocol/mcp/execution/policy.py @@ -213,7 +213,8 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict: if response.choices[0].message.tool_calls else [] ), - } + }, + "finish_reason": response.choices[0].finish_reason, } ], "usage": { diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 79c4490d..77707c23 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -270,10 +270,10 @@ class RolloutStatus(BaseModel): error: Rollout failed. stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop). """ - status: Literal["running", "finished", "error", "stopped"] = Field( - "finished", description="Status of the rollout." + status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.") + termination_reason: Optional[str] = Field( + "", description="reason of the rollout status, mapped to values in TerminationReason" ) - error_message: Optional[str] = Field(None, description="Error message if the rollout failed.") class EvaluationRow(BaseModel): diff --git a/eval_protocol/playback_policy.py b/eval_protocol/playback_policy.py index ed4b5dfd..876e419f 100644 --- a/eval_protocol/playback_policy.py +++ b/eval_protocol/playback_policy.py @@ -207,7 +207,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List["MCPToolCall"], CompletionUsage]: + ) -> Tuple[List["MCPToolCall"], CompletionUsage, str]: """ Generate tool calls in live mode. Concrete classes must implement this. @@ -253,7 +253,7 @@ async def __call__( ] # Return the recorded tool call - return self._extract_tool_call_from_messages(messages, env_index), None + return self._extract_tool_call_from_messages(messages, env_index), None, None else: # Live mode - generate tool call using provided conversation history return await self._generate_live_tool_calls(tool_schemas, env_index, conversation_history) diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index 85bdf5e9..a94675c4 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -1,8 +1,9 @@ +from contextlib import AsyncExitStack from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional + from mcp.client.session import ClientSession -from contextlib import AsyncExitStack class TerminationReason(str, Enum): @@ -11,15 +12,38 @@ class TerminationReason(str, Enum): MAX_STEPS: Trajectory ends because we hit the step limit CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition) USER_STOP: Trajectory ends because the simulated user signals to stop - INTERRUPTED: Trajectory ends unexpectedly, for example, expecting tool call but there is no tool call ERROR: Trajectory ends because of an error + STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop") + LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length") + TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls") """ MAX_STEPS = "max_steps" CONTROL_PLANE_SIGNAL = "control_plane_signal" USER_STOP = "user_stop" - INTERRUPTED = "interrupted" ERROR = "error" + STOP = "stop" + LENGTH = "length" + TOOL_CALLS = "tool_calls" + + @classmethod + def from_str(cls, value: str) -> "TerminationReason": + if value == "stop": + return cls.STOP + elif value == "length": + return cls.LENGTH + elif value == "max_steps": + return cls.MAX_STEPS + elif value == "control_plane_signal": + return cls.CONTROL_PLANE_SIGNAL + elif value == "user_stop": + return cls.USER_STOP + elif value == "error": + return cls.ERROR + elif value == "tool_calls": + return cls.TOOL_CALLS + else: + raise ValueError(f"Invalid termination reason: {value}") @dataclass diff --git a/eval_protocol/utils/static_policy.py b/eval_protocol/utils/static_policy.py index cb4ff1c3..c8b31792 100644 --- a/eval_protocol/utils/static_policy.py +++ b/eval_protocol/utils/static_policy.py @@ -73,7 +73,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List[MCPToolCall], CompletionUsage]: + ) -> Tuple[List[MCPToolCall], CompletionUsage, str]: """ Generate tool calls in live mode using the static action sequence. @@ -106,7 +106,7 @@ async def _generate_live_tool_calls( logger.debug(f"🎮 Env {env_index} step {step_count}: {action}") usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - return [tool_call], usage_stats + return [tool_call], usage_stats, None def add_tool_response( self, @@ -220,7 +220,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List[MCPToolCall], CompletionUsage]: + ) -> Tuple[List[MCPToolCall], CompletionUsage, str]: """ Generate random tool calls in live mode. @@ -241,7 +241,7 @@ async def _generate_live_tool_calls( logger.debug(f"🎲 Env {env_index}: {action}") usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - return [tool_call], usage_stats + return [tool_call], usage_stats, None def add_tool_response( self, diff --git a/tests/test_rollout_control_plane_integration.py b/tests/test_rollout_control_plane_integration.py index 6745f6cc..dcaac0e9 100644 --- a/tests/test_rollout_control_plane_integration.py +++ b/tests/test_rollout_control_plane_integration.py @@ -50,9 +50,13 @@ async def __call__(self, tool_schema, env_index, conversation_history): tool_calls = [] tool_call = MCPToolCall(tool_name="lake_move", arguments={"action": action}) tool_calls.append(tool_call) + if self.step_count == 3: + self.step_count += 1 + no_tool_call = MCPToolCall(tool_name="_no_tool_call", arguments={}) + return [no_tool_call], None, "stop" self.step_count += 1 - return tool_calls, None + return tool_calls, None, None def add_tool_response( self, @@ -285,11 +289,11 @@ def mock_step_side_effect(env_index, tool_call): final_cp_step = final_msg.control_plane_step assert final_cp_step["terminated"] == True, "Final step should be terminated" assert final_cp_step["reward"] == 1.0, "Final step should have correct reward" - assert final_cp_step["termination_reason"] == "control_plane_signal", "Should terminate via control plane" + assert final_cp_step["termination_reason"] == "stop", "Should terminate via control plane" assert final_cp_step["step"] == 2, "Should record final step" # Validate policy interaction - assert policy.step_count == 4, "Policy should have been called 3 times" + assert policy.step_count == 4, "Policy should have been called 4 times" @pytest.mark.asyncio async def test_rollout_trajectory_recording_with_control_plane(self):