diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index b26f2f2b..749a8d1f 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -163,7 +163,7 @@ async def _execute_with_semaphore(idx): evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx]) evaluation_rows[idx].tools = shared_tool_schema - evaluation_rows[idx].usage = trajectory.usage + evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage) evaluation_rows[idx].input_metadata.completion_params = CompletionParams( model=policy.model_id, temperature=getattr(policy, "temperature", None), @@ -260,8 +260,6 @@ async def _execute_rollout( {"role": "user", "content": user_prompt}, ] - usage_stats_list: List[CompletionUsage] = [] - logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}") # Run rollout loop for this specific environment @@ -299,6 +297,12 @@ async def _execute_rollout( while not turn_completed and not trajectory.terminated: tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) + # calc llm usage stats happened in this turn if there is aany + if usage_stats: + trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens + trajectory.usage["completion_tokens"] += usage_stats.completion_tokens + trajectory.usage["total_tokens"] += usage_stats.total_tokens + # If no tool call is generated, turn is finished if len(tool_calls) == 1: # If there's a user simulator, no tool call means the policy is ready to provide final response on this turn @@ -308,6 +312,8 @@ async def _execute_rollout( # If there's no user simulator, no tool call means policy failed and we should terminate the rollout elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: trajectory.terminated = True + trajectory.termination_reason = TerminationReason.ERROR + trajectory.control_plane_summary.update({"error_message": "No expected tool call"}) break # Execute each tool call sequentially @@ -373,10 +379,6 @@ async def _execute_rollout( if observation is not None: current_observation = observation - # calc llm usage stats happened in this turn if there is aany - if usage_stats: - usage_stats_list.append(usage_stats) - # With user simulator, increment step after an entire conversation step if user_simulator is not None: step += 1 @@ -409,7 +411,9 @@ async def _execute_rollout( # tool indicates rollout should be terminated, call policy one last time to get the final response _, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) if usage_stats: - usage_stats_list.append(usage_stats) + trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens + trajectory.usage["completion_tokens"] += usage_stats.completion_tokens + trajectory.usage["total_tokens"] += usage_stats.total_tokens # Add final control plane summary trajectory.control_plane_summary.update( @@ -460,11 +464,6 @@ async def _execute_rollout( msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason break - for usage_stats in usage_stats_list: - trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens - trajectory.usage["completion_tokens"] += usage_stats.completion_tokens - trajectory.usage["total_tokens"] += usage_stats.total_tokens - logger.info( f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" )