Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ async def _execute_with_semaphore(idx):
evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
evaluation_rows[idx].tools = shared_tool_schema
evaluation_rows[idx].usage = trajectory.usage
evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage)
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
model=policy.model_id,
temperature=getattr(policy, "temperature", None),
Expand Down Expand Up @@ -260,8 +260,6 @@ async def _execute_rollout(
{"role": "user", "content": user_prompt},
]

usage_stats_list: List[CompletionUsage] = []

logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}")

# Run rollout loop for this specific environment
Expand Down Expand Up @@ -299,6 +297,12 @@ async def _execute_rollout(
while not turn_completed and not trajectory.terminated:
tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)

# calc llm usage stats happened in this turn if there is aany
if usage_stats:
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
trajectory.usage["total_tokens"] += usage_stats.total_tokens

# If no tool call is generated, turn is finished
if len(tool_calls) == 1:
# If there's a user simulator, no tool call means the policy is ready to provide final response on this turn
Expand All @@ -308,6 +312,8 @@ async def _execute_rollout(
# If there's no user simulator, no tool call means policy failed and we should terminate the rollout
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
trajectory.terminated = True
trajectory.termination_reason = TerminationReason.ERROR
trajectory.control_plane_summary.update({"error_message": "No expected tool call"})
break

# Execute each tool call sequentially
Expand Down Expand Up @@ -373,10 +379,6 @@ async def _execute_rollout(
if observation is not None:
current_observation = observation

# calc llm usage stats happened in this turn if there is aany
if usage_stats:
usage_stats_list.append(usage_stats)

# With user simulator, increment step after an entire conversation step
if user_simulator is not None:
step += 1
Expand Down Expand Up @@ -409,7 +411,9 @@ async def _execute_rollout(
# tool indicates rollout should be terminated, call policy one last time to get the final response
_, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
if usage_stats:
usage_stats_list.append(usage_stats)
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
trajectory.usage["total_tokens"] += usage_stats.total_tokens

# Add final control plane summary
trajectory.control_plane_summary.update(
Expand Down Expand Up @@ -460,11 +464,6 @@ async def _execute_rollout(
msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason
break

for usage_stats in usage_stats_list:
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
trajectory.usage["total_tokens"] += usage_stats.total_tokens

logger.info(
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
)
Expand Down
Loading