From 683dbf13d5d630dc9933039dee2fedc39cc3d35c Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 11 Aug 2025 20:33:48 +0000 Subject: [PATCH 1/7] let policy decide end of loop --- eval_protocol/mcp/execution/base_policy.py | 28 +++++++++------ eval_protocol/mcp/execution/manager.py | 40 ++++++++-------------- eval_protocol/mcp/execution/policy.py | 3 +- eval_protocol/playback_policy.py | 2 +- eval_protocol/types/types.py | 23 +++++++++++-- eval_protocol/utils/static_policy.py | 4 +-- 6 files changed, 57 insertions(+), 43 deletions(-) diff --git a/eval_protocol/mcp/execution/base_policy.py b/eval_protocol/mcp/execution/base_policy.py index 17f67da3..22bad57e 100644 --- a/eval_protocol/mcp/execution/base_policy.py +++ b/eval_protocol/mcp/execution/base_policy.py @@ -151,7 +151,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List[MCPToolCall], CompletionUsage]: + ) -> Tuple[List[MCPToolCall], CompletionUsage, str]: """ Generate tool calls using conversation history for proper OpenAI trajectories. @@ -161,7 +161,7 @@ async def _generate_live_tool_calls( user_prompt: Current user prompt with observation Returns: - List of MCPToolCall objects + List of MCPToolCall objects, LLM usage stats, and finish reason """ # Convert MCP tools to LLM format llm_tools = self._convert_mcp_tools_to_llm_format(tool_schemas) @@ -190,6 +190,8 @@ async def _generate_live_tool_calls( total_tokens=response["usage"]["total_tokens"], ) + finish_reason = response["choices"][0]["finish_reason"] + # Extract tool call from response message = response["choices"][0]["message"] logger.debug(f"Environment {env_index} - Response message: {message}") @@ -217,15 +219,19 @@ async def _generate_live_tool_calls( if self.max_tools_per_turn: mcp_tool_calls = mcp_tool_calls[: self.max_tools_per_turn] - return mcp_tool_calls, usage_stats + return mcp_tool_calls, usage_stats, finish_reason else: # No tool calls in response - this is normal when episode ends or LLM provides only text logger.debug(f"No tool calls in response for env {env_index}, message content: {message.get('content')}") - return [ - MCPToolCall( - tool_name="_no_tool_call", - arguments={ - "reason": "no_tool_call_generated", - }, - ) - ], usage_stats + return ( + [ + MCPToolCall( + tool_name="_no_tool_call", + arguments={ + "reason": "no_tool_call_generated", + }, + ) + ], + usage_stats, + finish_reason, + ) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index dbf197fc..14639222 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -266,7 +266,7 @@ async def _execute_rollout( # Run rollout loop for this specific environment step = 0 - rollout_end = False + env_end = False # if the env indicates the rollout reaches the goal while step < steps and not trajectory.terminated: turn_completed = False @@ -297,7 +297,9 @@ async def _execute_rollout( # In each turn: keep looping until assistant is ready to provide final response while not turn_completed and not trajectory.terminated: - tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) + tool_calls, usage_stats, finish_reason = await policy( + tool_schema, rollout_idx, conversation_history + ) # calc llm usage stats happened in this turn if there is aany if usage_stats: @@ -311,17 +313,17 @@ async def _execute_rollout( if tool_calls[0].tool_name == "_no_tool_call" and user_simulator: turn_completed = True break - # If there's no user simulator, no tool call means policy failed and we should terminate the rollout + # If there's no user simulator, then it marks the end of the episode as LLM think there is no tool call needed. elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: trajectory.terminated = True - trajectory.termination_reason = TerminationReason.INTERRUPTED + trajectory.termination_reason = TerminationReason.from_str(finish_reason) break # Execute each tool call sequentially for tool_call in tool_calls: # Execute tool call for this environment - observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call) + observation, reward, env_end, info = await envs.step(rollout_idx, tool_call) tool_response = envs.format_tool_response(observation) @@ -331,7 +333,7 @@ async def _execute_rollout( tool_response, conversation_history, reward, - rollout_end, + env_end, info, ) @@ -354,7 +356,7 @@ async def _execute_rollout( control_plane_step = { "step": step - 1, "reward": reward, - "terminated": rollout_end, + "terminated": env_end, "info": info.get("control_plane", {}), "tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"], "num_tool_calls": 1, @@ -367,11 +369,7 @@ async def _execute_rollout( if recording_mode: policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - if rollout_end: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL - break - elif step >= steps: + if step >= steps: trajectory.terminated = True trajectory.termination_reason = TerminationReason.MAX_STEPS break @@ -392,7 +390,7 @@ async def _execute_rollout( control_plane_step = { "step": step - 1, "reward": reward, - "terminated": rollout_end, + "terminated": env_end, "info": info.get("control_plane", {}), "tool_calls": tool_calls_summary, "num_tool_calls": len(tool_calls), @@ -404,18 +402,8 @@ async def _execute_rollout( if recording_mode: policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - # Use control plane information for termination decision - if rollout_end: - trajectory.terminated = True - trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL - - # tool indicates rollout should be terminated, call policy one last time to get the final response - _, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) - if usage_stats: - trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens - trajectory.usage["completion_tokens"] += usage_stats.completion_tokens - trajectory.usage["total_tokens"] += usage_stats.total_tokens - + # update control plane summary if the env marks end + if env_end: # Add final control plane summary trajectory.control_plane_summary.update( { @@ -445,7 +433,7 @@ async def _execute_rollout( ) logger.info( - f"🏁 Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" + f"🏁 Environmnet indicates rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" ) break diff --git a/eval_protocol/mcp/execution/policy.py b/eval_protocol/mcp/execution/policy.py index 06233c4b..f529a21d 100644 --- a/eval_protocol/mcp/execution/policy.py +++ b/eval_protocol/mcp/execution/policy.py @@ -213,7 +213,8 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict: if response.choices[0].message.tool_calls else [] ), - } + }, + "finish_reason": response.choices[0].finish_reason, } ], "usage": { diff --git a/eval_protocol/playback_policy.py b/eval_protocol/playback_policy.py index ed4b5dfd..587553be 100644 --- a/eval_protocol/playback_policy.py +++ b/eval_protocol/playback_policy.py @@ -253,7 +253,7 @@ async def __call__( ] # Return the recorded tool call - return self._extract_tool_call_from_messages(messages, env_index), None + return self._extract_tool_call_from_messages(messages, env_index), None, None else: # Live mode - generate tool call using provided conversation history return await self._generate_live_tool_calls(tool_schemas, env_index, conversation_history) diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index 85bdf5e9..5adf6013 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -11,15 +11,34 @@ class TerminationReason(str, Enum): MAX_STEPS: Trajectory ends because we hit the step limit CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition) USER_STOP: Trajectory ends because the simulated user signals to stop - INTERRUPTED: Trajectory ends unexpectedly, for example, expecting tool call but there is no tool call ERROR: Trajectory ends because of an error + STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop") + LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length") """ MAX_STEPS = "max_steps" CONTROL_PLANE_SIGNAL = "control_plane_signal" USER_STOP = "user_stop" - INTERRUPTED = "interrupted" ERROR = "error" + STOP = "stop" + LENGTH = "length" + + @classmethod + def from_str(cls, value: str) -> "TerminationReason": + if value == "stop": + return cls.STOP + elif value == "length": + return cls.LENGTH + elif value == "max_steps": + return cls.MAX_STEPS + elif value == "control_plane_signal": + return cls.CONTROL_PLANE_SIGNAL + elif value == "user_stop": + return cls.USER_STOP + elif value == "error": + return cls.ERROR + else: + raise ValueError(f"Invalid termination reason: {value}") @dataclass diff --git a/eval_protocol/utils/static_policy.py b/eval_protocol/utils/static_policy.py index cb4ff1c3..6ee8a0d7 100644 --- a/eval_protocol/utils/static_policy.py +++ b/eval_protocol/utils/static_policy.py @@ -106,7 +106,7 @@ async def _generate_live_tool_calls( logger.debug(f"🎮 Env {env_index} step {step_count}: {action}") usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - return [tool_call], usage_stats + return [tool_call], usage_stats, None def add_tool_response( self, @@ -241,7 +241,7 @@ async def _generate_live_tool_calls( logger.debug(f"🎲 Env {env_index}: {action}") usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - return [tool_call], usage_stats + return [tool_call], usage_stats, None def add_tool_response( self, From 3b6d292f586240ea53288e0152f203f90c4b1021 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 11 Aug 2025 13:55:00 -0700 Subject: [PATCH 2/7] add --- eval_protocol/mcp/execution/manager.py | 17 +++++------------ eval_protocol/models.py | 6 +++--- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 14639222..92376314 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -169,21 +169,14 @@ async def _execute_with_semaphore(idx): max_tool_calls=getattr(policy, "max_tools_per_turn", None), ) if trajectory.terminated: - if trajectory.termination_reason in { - TerminationReason.CONTROL_PLANE_SIGNAL, - TerminationReason.USER_STOP, - }: - evaluation_rows[idx].rollout_status.status = "finished" - elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}: - evaluation_rows[idx].rollout_status.status = "stopped" - evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get( - "termination_reason", trajectory.termination_reason - ) - else: + if trajectory.termination_reason == TerminationReason.ERROR: evaluation_rows[idx].rollout_status.status = "error" - evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get( + evaluation_rows[idx].rollout_status.reason= trajectory.control_plane_summary.get( "error_message", None ) + else: + evaluation_rows[idx].rollout_status.status = "finished" + evaluation_rows[idx].rollout_status.reason = trajectory.termination_reason else: evaluation_rows[idx].rollout_status.status = "running" diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 79c4490d..53e7202b 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -270,10 +270,10 @@ class RolloutStatus(BaseModel): error: Rollout failed. stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop). """ - status: Literal["running", "finished", "error", "stopped"] = Field( - "finished", description="Status of the rollout." + status: Literal["running", "finished", "error"] = Field( + "running", description="Status of the rollout." ) - error_message: Optional[str] = Field(None, description="Error message if the rollout failed.") + reason: Optional[str] = Field("", description="reason of the rollout status, mapped to values in TerminationReason") class EvaluationRow(BaseModel): From 5e461f65965ca1589d02f2d3a62246f988926d01 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 11 Aug 2025 13:57:07 -0700 Subject: [PATCH 3/7] fix linter --- eval_protocol/playback_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval_protocol/playback_policy.py b/eval_protocol/playback_policy.py index 587553be..876e419f 100644 --- a/eval_protocol/playback_policy.py +++ b/eval_protocol/playback_policy.py @@ -207,7 +207,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List["MCPToolCall"], CompletionUsage]: + ) -> Tuple[List["MCPToolCall"], CompletionUsage, str]: """ Generate tool calls in live mode. Concrete classes must implement this. From 66dc2fc64f8491ab92d743c7874a23423fa04efb Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 11 Aug 2025 14:24:11 -0700 Subject: [PATCH 4/7] fix lint --- eval_protocol/utils/static_policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eval_protocol/utils/static_policy.py b/eval_protocol/utils/static_policy.py index 6ee8a0d7..c8b31792 100644 --- a/eval_protocol/utils/static_policy.py +++ b/eval_protocol/utils/static_policy.py @@ -73,7 +73,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List[MCPToolCall], CompletionUsage]: + ) -> Tuple[List[MCPToolCall], CompletionUsage, str]: """ Generate tool calls in live mode using the static action sequence. @@ -220,7 +220,7 @@ async def _generate_live_tool_calls( tool_schemas: List[Dict], env_index: int, conversation_history: List[Dict[str, Any]], - ) -> Tuple[List[MCPToolCall], CompletionUsage]: + ) -> Tuple[List[MCPToolCall], CompletionUsage, str]: """ Generate random tool calls in live mode. From d9eea891c7422824c899af8b0ebbbcb5f4687643 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Mon, 11 Aug 2025 14:51:28 -0700 Subject: [PATCH 5/7] fix test --- tests/test_rollout_control_plane_integration.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_rollout_control_plane_integration.py b/tests/test_rollout_control_plane_integration.py index 6745f6cc..dcaac0e9 100644 --- a/tests/test_rollout_control_plane_integration.py +++ b/tests/test_rollout_control_plane_integration.py @@ -50,9 +50,13 @@ async def __call__(self, tool_schema, env_index, conversation_history): tool_calls = [] tool_call = MCPToolCall(tool_name="lake_move", arguments={"action": action}) tool_calls.append(tool_call) + if self.step_count == 3: + self.step_count += 1 + no_tool_call = MCPToolCall(tool_name="_no_tool_call", arguments={}) + return [no_tool_call], None, "stop" self.step_count += 1 - return tool_calls, None + return tool_calls, None, None def add_tool_response( self, @@ -285,11 +289,11 @@ def mock_step_side_effect(env_index, tool_call): final_cp_step = final_msg.control_plane_step assert final_cp_step["terminated"] == True, "Final step should be terminated" assert final_cp_step["reward"] == 1.0, "Final step should have correct reward" - assert final_cp_step["termination_reason"] == "control_plane_signal", "Should terminate via control plane" + assert final_cp_step["termination_reason"] == "stop", "Should terminate via control plane" assert final_cp_step["step"] == 2, "Should record final step" # Validate policy interaction - assert policy.step_count == 4, "Policy should have been called 3 times" + assert policy.step_count == 4, "Policy should have been called 4 times" @pytest.mark.asyncio async def test_rollout_trajectory_recording_with_control_plane(self): From 360af4c9141f7c7c787d11216291b071eb85fdde Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Tue, 12 Aug 2025 00:08:27 -0700 Subject: [PATCH 6/7] update --- eval_protocol/mcp/execution/manager.py | 19 ++++++++++++++++--- eval_protocol/types/types.py | 7 ++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 92376314..3660da12 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -171,7 +171,7 @@ async def _execute_with_semaphore(idx): if trajectory.terminated: if trajectory.termination_reason == TerminationReason.ERROR: evaluation_rows[idx].rollout_status.status = "error" - evaluation_rows[idx].rollout_status.reason= trajectory.control_plane_summary.get( + evaluation_rows[idx].rollout_status.reason = trajectory.control_plane_summary.get( "error_message", None ) else: @@ -362,6 +362,12 @@ async def _execute_rollout( if recording_mode: policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) + if env_end: + # if the env marks the end of the rollout, break the tool call loop + # but set the termination reason later after the final policy call + trajectory.terminated = True + break + if step >= steps: trajectory.terminated = True trajectory.termination_reason = TerminationReason.MAX_STEPS @@ -395,9 +401,16 @@ async def _execute_rollout( if recording_mode: policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) - # update control plane summary if the env marks end + # if the env marks end, update control plane summary and do one last policy call, then break the agent loop + # this is to ensure each turn ends with an assistant message, which will align with the actual agentic llm behavior if env_end: - # Add final control plane summary + _, usage_stats, finish_reason = await policy(tool_schema, rollout_idx, conversation_history) + if usage_stats: + trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens + trajectory.usage["completion_tokens"] += usage_stats.completion_tokens + trajectory.usage["total_tokens"] += usage_stats.total_tokens + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.from_str(finish_reason) trajectory.control_plane_summary.update( { "total_reward": trajectory.total_reward, diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index 5adf6013..a94675c4 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -1,8 +1,9 @@ +from contextlib import AsyncExitStack from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional + from mcp.client.session import ClientSession -from contextlib import AsyncExitStack class TerminationReason(str, Enum): @@ -14,6 +15,7 @@ class TerminationReason(str, Enum): ERROR: Trajectory ends because of an error STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop") LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length") + TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls") """ MAX_STEPS = "max_steps" @@ -22,6 +24,7 @@ class TerminationReason(str, Enum): ERROR = "error" STOP = "stop" LENGTH = "length" + TOOL_CALLS = "tool_calls" @classmethod def from_str(cls, value: str) -> "TerminationReason": @@ -37,6 +40,8 @@ def from_str(cls, value: str) -> "TerminationReason": return cls.USER_STOP elif value == "error": return cls.ERROR + elif value == "tool_calls": + return cls.TOOL_CALLS else: raise ValueError(f"Invalid termination reason: {value}") From 70dd00650c0f0ca44bd1d4cc48ea68c961e56c32 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Tue, 12 Aug 2025 00:12:56 -0700 Subject: [PATCH 7/7] rename RolloutStatus reason to termination_reason --- eval_protocol/mcp/execution/manager.py | 4 ++-- eval_protocol/models.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 3660da12..aec867fc 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -171,12 +171,12 @@ async def _execute_with_semaphore(idx): if trajectory.terminated: if trajectory.termination_reason == TerminationReason.ERROR: evaluation_rows[idx].rollout_status.status = "error" - evaluation_rows[idx].rollout_status.reason = trajectory.control_plane_summary.get( + evaluation_rows[idx].rollout_status.termination_reason = trajectory.control_plane_summary.get( "error_message", None ) else: evaluation_rows[idx].rollout_status.status = "finished" - evaluation_rows[idx].rollout_status.reason = trajectory.termination_reason + evaluation_rows[idx].rollout_status.termination_reason = trajectory.termination_reason else: evaluation_rows[idx].rollout_status.status = "running" diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 53e7202b..77707c23 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -270,10 +270,10 @@ class RolloutStatus(BaseModel): error: Rollout failed. stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop). """ - status: Literal["running", "finished", "error"] = Field( - "running", description="Status of the rollout." + status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.") + termination_reason: Optional[str] = Field( + "", description="reason of the rollout status, mapped to values in TerminationReason" ) - reason: Optional[str] = Field("", description="reason of the rollout status, mapped to values in TerminationReason") class EvaluationRow(BaseModel):