Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions eval_protocol/mcp/execution/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def _generate_live_tool_calls(
tool_schemas: List[Dict],
env_index: int,
conversation_history: List[Dict[str, Any]],
) -> Tuple[List[MCPToolCall], CompletionUsage]:
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
"""
Generate tool calls using conversation history for proper OpenAI trajectories.

Expand All @@ -161,7 +161,7 @@ async def _generate_live_tool_calls(
user_prompt: Current user prompt with observation

Returns:
List of MCPToolCall objects
List of MCPToolCall objects, LLM usage stats, and finish reason
"""
# Convert MCP tools to LLM format
llm_tools = self._convert_mcp_tools_to_llm_format(tool_schemas)
Expand Down Expand Up @@ -190,6 +190,8 @@ async def _generate_live_tool_calls(
total_tokens=response["usage"]["total_tokens"],
)

finish_reason = response["choices"][0]["finish_reason"]

# Extract tool call from response
message = response["choices"][0]["message"]
logger.debug(f"Environment {env_index} - Response message: {message}")
Expand Down Expand Up @@ -217,15 +219,19 @@ async def _generate_live_tool_calls(
if self.max_tools_per_turn:
mcp_tool_calls = mcp_tool_calls[: self.max_tools_per_turn]

return mcp_tool_calls, usage_stats
return mcp_tool_calls, usage_stats, finish_reason
else:
# No tool calls in response - this is normal when episode ends or LLM provides only text
logger.debug(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
return [
MCPToolCall(
tool_name="_no_tool_call",
arguments={
"reason": "no_tool_call_generated",
},
)
], usage_stats
return (
[
MCPToolCall(
tool_name="_no_tool_call",
arguments={
"reason": "no_tool_call_generated",
},
)
],
usage_stats,
finish_reason,
)
60 changes: 27 additions & 33 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,14 @@ async def _execute_with_semaphore(idx):
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
)
if trajectory.terminated:
if trajectory.termination_reason in {
TerminationReason.CONTROL_PLANE_SIGNAL,
TerminationReason.USER_STOP,
}:
evaluation_rows[idx].rollout_status.status = "finished"
elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}:
evaluation_rows[idx].rollout_status.status = "stopped"
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
"termination_reason", trajectory.termination_reason
)
else:
if trajectory.termination_reason == TerminationReason.ERROR:
evaluation_rows[idx].rollout_status.status = "error"
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
evaluation_rows[idx].rollout_status.termination_reason = trajectory.control_plane_summary.get(
"error_message", None
)
else:
evaluation_rows[idx].rollout_status.status = "finished"
evaluation_rows[idx].rollout_status.termination_reason = trajectory.termination_reason
else:
evaluation_rows[idx].rollout_status.status = "running"

Expand Down Expand Up @@ -266,7 +259,7 @@ async def _execute_rollout(

# Run rollout loop for this specific environment
step = 0
rollout_end = False
env_end = False # if the env indicates the rollout reaches the goal

while step < steps and not trajectory.terminated:
turn_completed = False
Expand Down Expand Up @@ -297,7 +290,9 @@ async def _execute_rollout(

# In each turn: keep looping until assistant is ready to provide final response
while not turn_completed and not trajectory.terminated:
tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
tool_calls, usage_stats, finish_reason = await policy(
tool_schema, rollout_idx, conversation_history
)

# calc llm usage stats happened in this turn if there is aany
if usage_stats:
Expand All @@ -311,17 +306,17 @@ async def _execute_rollout(
if tool_calls[0].tool_name == "_no_tool_call" and user_simulator:
turn_completed = True
break
# If there's no user simulator, no tool call means policy failed and we should terminate the rollout
# If there's no user simulator, then it marks the end of the episode as LLM think there is no tool call needed.
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
trajectory.terminated = True
trajectory.termination_reason = TerminationReason.INTERRUPTED
trajectory.termination_reason = TerminationReason.from_str(finish_reason)
break

# Execute each tool call sequentially
for tool_call in tool_calls:

# Execute tool call for this environment
observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call)
observation, reward, env_end, info = await envs.step(rollout_idx, tool_call)

tool_response = envs.format_tool_response(observation)

Expand All @@ -331,7 +326,7 @@ async def _execute_rollout(
tool_response,
conversation_history,
reward,
rollout_end,
env_end,
info,
)

Expand All @@ -354,7 +349,7 @@ async def _execute_rollout(
control_plane_step = {
"step": step - 1,
"reward": reward,
"terminated": rollout_end,
"terminated": env_end,
"info": info.get("control_plane", {}),
"tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"],
"num_tool_calls": 1,
Expand All @@ -367,11 +362,13 @@ async def _execute_rollout(
if recording_mode:
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)

if rollout_end:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @xzrderek so basically the change is remove this if branch

if env_end:
# if the env marks the end of the rollout, break the tool call loop
# but set the termination reason later after the final policy call
trajectory.terminated = True
trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL
break
elif step >= steps:

if step >= steps:
trajectory.terminated = True
trajectory.termination_reason = TerminationReason.MAX_STEPS
break
Expand All @@ -392,7 +389,7 @@ async def _execute_rollout(
control_plane_step = {
"step": step - 1,
"reward": reward,
"terminated": rollout_end,
"terminated": env_end,
"info": info.get("control_plane", {}),
"tool_calls": tool_calls_summary,
"num_tool_calls": len(tool_calls),
Expand All @@ -404,19 +401,16 @@ async def _execute_rollout(
if recording_mode:
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)

# Use control plane information for termination decision
if rollout_end:
trajectory.terminated = True
trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL

# tool indicates rollout should be terminated, call policy one last time to get the final response
_, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
# if the env marks end, update control plane summary and do one last policy call, then break the agent loop
# this is to ensure each turn ends with an assistant message, which will align with the actual agentic llm behavior
if env_end:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xzrderek but we still save the information in the control plane summary

_, usage_stats, finish_reason = await policy(tool_schema, rollout_idx, conversation_history)
if usage_stats:
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
trajectory.usage["total_tokens"] += usage_stats.total_tokens

# Add final control plane summary
trajectory.terminated = True
trajectory.termination_reason = TerminationReason.from_str(finish_reason)
trajectory.control_plane_summary.update(
{
"total_reward": trajectory.total_reward,
Expand Down Expand Up @@ -445,7 +439,7 @@ async def _execute_rollout(
)

logger.info(
f"🏁 Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}"
f"🏁 Environmnet indicates rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}"
)
break

Expand Down
3 changes: 2 additions & 1 deletion eval_protocol/mcp/execution/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
if response.choices[0].message.tool_calls
else []
),
}
},
"finish_reason": response.choices[0].finish_reason,
}
],
"usage": {
Expand Down
6 changes: 3 additions & 3 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,10 @@ class RolloutStatus(BaseModel):
error: Rollout failed.
stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop).
"""
status: Literal["running", "finished", "error", "stopped"] = Field(
"finished", description="Status of the rollout."
status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.")
termination_reason: Optional[str] = Field(
"", description="reason of the rollout status, mapped to values in TerminationReason"
)
error_message: Optional[str] = Field(None, description="Error message if the rollout failed.")


class EvaluationRow(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions eval_protocol/playback_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def _generate_live_tool_calls(
tool_schemas: List[Dict],
env_index: int,
conversation_history: List[Dict[str, Any]],
) -> Tuple[List["MCPToolCall"], CompletionUsage]:
) -> Tuple[List["MCPToolCall"], CompletionUsage, str]:
"""
Generate tool calls in live mode. Concrete classes must implement this.

Expand Down Expand Up @@ -253,7 +253,7 @@ async def __call__(
]

# Return the recorded tool call
return self._extract_tool_call_from_messages(messages, env_index), None
return self._extract_tool_call_from_messages(messages, env_index), None, None
else:
# Live mode - generate tool call using provided conversation history
return await self._generate_live_tool_calls(tool_schemas, env_index, conversation_history)
Expand Down
30 changes: 27 additions & 3 deletions eval_protocol/types/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from contextlib import AsyncExitStack
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional

from mcp.client.session import ClientSession
from contextlib import AsyncExitStack


class TerminationReason(str, Enum):
Expand All @@ -11,15 +12,38 @@ class TerminationReason(str, Enum):
MAX_STEPS: Trajectory ends because we hit the step limit
CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition)
USER_STOP: Trajectory ends because the simulated user signals to stop
INTERRUPTED: Trajectory ends unexpectedly, for example, expecting tool call but there is no tool call
ERROR: Trajectory ends because of an error
STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop")
LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length")
TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls")
"""

MAX_STEPS = "max_steps"
CONTROL_PLANE_SIGNAL = "control_plane_signal"
USER_STOP = "user_stop"
INTERRUPTED = "interrupted"
ERROR = "error"
STOP = "stop"
LENGTH = "length"
TOOL_CALLS = "tool_calls"

@classmethod
def from_str(cls, value: str) -> "TerminationReason":
if value == "stop":
return cls.STOP
elif value == "length":
return cls.LENGTH
elif value == "max_steps":
return cls.MAX_STEPS
elif value == "control_plane_signal":
return cls.CONTROL_PLANE_SIGNAL
elif value == "user_stop":
return cls.USER_STOP
elif value == "error":
return cls.ERROR
elif value == "tool_calls":
return cls.TOOL_CALLS
else:
raise ValueError(f"Invalid termination reason: {value}")


@dataclass
Expand Down
8 changes: 4 additions & 4 deletions eval_protocol/utils/static_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def _generate_live_tool_calls(
tool_schemas: List[Dict],
env_index: int,
conversation_history: List[Dict[str, Any]],
) -> Tuple[List[MCPToolCall], CompletionUsage]:
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
"""
Generate tool calls in live mode using the static action sequence.

Expand Down Expand Up @@ -106,7 +106,7 @@ async def _generate_live_tool_calls(
logger.debug(f"🎮 Env {env_index} step {step_count}: {action}")

usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
return [tool_call], usage_stats
return [tool_call], usage_stats, None

def add_tool_response(
self,
Expand Down Expand Up @@ -220,7 +220,7 @@ async def _generate_live_tool_calls(
tool_schemas: List[Dict],
env_index: int,
conversation_history: List[Dict[str, Any]],
) -> Tuple[List[MCPToolCall], CompletionUsage]:
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
"""
Generate random tool calls in live mode.

Expand All @@ -241,7 +241,7 @@ async def _generate_live_tool_calls(
logger.debug(f"🎲 Env {env_index}: {action}")

usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
return [tool_call], usage_stats
return [tool_call], usage_stats, None

def add_tool_response(
self,
Expand Down
10 changes: 7 additions & 3 deletions tests/test_rollout_control_plane_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ async def __call__(self, tool_schema, env_index, conversation_history):
tool_calls = []
tool_call = MCPToolCall(tool_name="lake_move", arguments={"action": action})
tool_calls.append(tool_call)
if self.step_count == 3:
self.step_count += 1
no_tool_call = MCPToolCall(tool_name="_no_tool_call", arguments={})
return [no_tool_call], None, "stop"

self.step_count += 1
return tool_calls, None
return tool_calls, None, None

def add_tool_response(
self,
Expand Down Expand Up @@ -285,11 +289,11 @@ def mock_step_side_effect(env_index, tool_call):
final_cp_step = final_msg.control_plane_step
assert final_cp_step["terminated"] == True, "Final step should be terminated"
assert final_cp_step["reward"] == 1.0, "Final step should have correct reward"
assert final_cp_step["termination_reason"] == "control_plane_signal", "Should terminate via control plane"
assert final_cp_step["termination_reason"] == "stop", "Should terminate via control plane"
assert final_cp_step["step"] == 2, "Should record final step"

# Validate policy interaction
assert policy.step_count == 4, "Policy should have been called 3 times"
assert policy.step_count == 4, "Policy should have been called 4 times"

@pytest.mark.asyncio
async def test_rollout_trajectory_recording_with_control_plane(self):
Expand Down