From b6f6b7fda5db7e7083ee093cfd5ade6107a433e3 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Fri, 8 Aug 2025 12:23:28 -0700 Subject: [PATCH 1/3] keep intermediate llm usage stats even for failure trajectories --- eval_protocol/mcp/execution/manager.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index b26f2f2b..feefdbab 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -260,8 +260,6 @@ async def _execute_rollout( {"role": "user", "content": user_prompt}, ] - usage_stats_list: List[CompletionUsage] = [] - logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}") # Run rollout loop for this specific environment @@ -375,7 +373,9 @@ async def _execute_rollout( # calc llm usage stats happened in this turn if there is aany if usage_stats: - usage_stats_list.append(usage_stats) + trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens + trajectory.usage["completion_tokens"] += usage_stats.completion_tokens + trajectory.usage["total_tokens"] += usage_stats.total_tokens # With user simulator, increment step after an entire conversation step if user_simulator is not None: @@ -409,7 +409,9 @@ async def _execute_rollout( # tool indicates rollout should be terminated, call policy one last time to get the final response _, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) if usage_stats: - usage_stats_list.append(usage_stats) + trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens + trajectory.usage["completion_tokens"] += usage_stats.completion_tokens + trajectory.usage["total_tokens"] += usage_stats.total_tokens # Add final control plane summary trajectory.control_plane_summary.update( @@ -460,11 +462,6 @@ async def _execute_rollout( msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason break - for usage_stats in usage_stats_list: - trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens - trajectory.usage["completion_tokens"] += usage_stats.completion_tokens - trajectory.usage["total_tokens"] += usage_stats.total_tokens - logger.info( f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}" ) From d727d0f84483c134c49328873aa913601e01d193 Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Sat, 9 Aug 2025 01:07:10 +0000 Subject: [PATCH 2/3] set termination reason and error message --- eval_protocol/mcp/execution/manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index feefdbab..f9be338c 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -163,7 +163,7 @@ async def _execute_with_semaphore(idx): evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx]) evaluation_rows[idx].tools = shared_tool_schema - evaluation_rows[idx].usage = trajectory.usage + evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage) evaluation_rows[idx].input_metadata.completion_params = CompletionParams( model=policy.model_id, temperature=getattr(policy, "temperature", None), @@ -306,6 +306,8 @@ async def _execute_rollout( # If there's no user simulator, no tool call means policy failed and we should terminate the rollout elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: trajectory.terminated = True + trajectory.termination_reason = TerminationReason.ERROR + trajectory.control_plane_summary.update({"error_message": "No expected tool call"}) break # Execute each tool call sequentially From 3f3ad1d9b186e42bbf9f901b9d0aa775d9c26d9d Mon Sep 17 00:00:00 2001 From: Yinghan Ma Date: Sat, 9 Aug 2025 01:17:18 +0000 Subject: [PATCH 3/3] add --- eval_protocol/mcp/execution/manager.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index f9be338c..749a8d1f 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -297,6 +297,12 @@ async def _execute_rollout( while not turn_completed and not trajectory.terminated: tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) + # calc llm usage stats happened in this turn if there is aany + if usage_stats: + trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens + trajectory.usage["completion_tokens"] += usage_stats.completion_tokens + trajectory.usage["total_tokens"] += usage_stats.total_tokens + # If no tool call is generated, turn is finished if len(tool_calls) == 1: # If there's a user simulator, no tool call means the policy is ready to provide final response on this turn @@ -373,12 +379,6 @@ async def _execute_rollout( if observation is not None: current_observation = observation - # calc llm usage stats happened in this turn if there is aany - if usage_stats: - trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens - trajectory.usage["completion_tokens"] += usage_stats.completion_tokens - trajectory.usage["total_tokens"] += usage_stats.total_tokens - # With user simulator, increment step after an entire conversation step if user_simulator is not None: step += 1