-
Notifications
You must be signed in to change notification settings - Fork 16
Use LLM finish reason as the termination reason #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
683dbf1
let policy decide end of loop
mayinghan 3b6d292
add
mayinghan 5e461f6
fix linter
mayinghan 66dc2fc
fix lint
mayinghan d9eea89
fix test
mayinghan 360af4c
update
mayinghan 70dd006
rename RolloutStatus reason to termination_reason
mayinghan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -169,21 +169,14 @@ async def _execute_with_semaphore(idx): | |
| max_tool_calls=getattr(policy, "max_tools_per_turn", None), | ||
| ) | ||
| if trajectory.terminated: | ||
| if trajectory.termination_reason in { | ||
| TerminationReason.CONTROL_PLANE_SIGNAL, | ||
| TerminationReason.USER_STOP, | ||
| }: | ||
| evaluation_rows[idx].rollout_status.status = "finished" | ||
| elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}: | ||
| evaluation_rows[idx].rollout_status.status = "stopped" | ||
| evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get( | ||
| "termination_reason", trajectory.termination_reason | ||
| ) | ||
| else: | ||
| if trajectory.termination_reason == TerminationReason.ERROR: | ||
| evaluation_rows[idx].rollout_status.status = "error" | ||
| evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get( | ||
| evaluation_rows[idx].rollout_status.termination_reason = trajectory.control_plane_summary.get( | ||
| "error_message", None | ||
| ) | ||
| else: | ||
| evaluation_rows[idx].rollout_status.status = "finished" | ||
| evaluation_rows[idx].rollout_status.termination_reason = trajectory.termination_reason | ||
| else: | ||
| evaluation_rows[idx].rollout_status.status = "running" | ||
|
|
||
|
|
@@ -266,7 +259,7 @@ async def _execute_rollout( | |
|
|
||
| # Run rollout loop for this specific environment | ||
| step = 0 | ||
| rollout_end = False | ||
| env_end = False # if the env indicates the rollout reaches the goal | ||
|
|
||
| while step < steps and not trajectory.terminated: | ||
| turn_completed = False | ||
|
|
@@ -297,7 +290,9 @@ async def _execute_rollout( | |
|
|
||
| # In each turn: keep looping until assistant is ready to provide final response | ||
| while not turn_completed and not trajectory.terminated: | ||
| tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) | ||
| tool_calls, usage_stats, finish_reason = await policy( | ||
| tool_schema, rollout_idx, conversation_history | ||
| ) | ||
|
|
||
| # calc llm usage stats happened in this turn if there is aany | ||
| if usage_stats: | ||
|
|
@@ -311,17 +306,17 @@ async def _execute_rollout( | |
| if tool_calls[0].tool_name == "_no_tool_call" and user_simulator: | ||
| turn_completed = True | ||
| break | ||
| # If there's no user simulator, no tool call means policy failed and we should terminate the rollout | ||
| # If there's no user simulator, then it marks the end of the episode as LLM think there is no tool call needed. | ||
| elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]: | ||
| trajectory.terminated = True | ||
| trajectory.termination_reason = TerminationReason.INTERRUPTED | ||
| trajectory.termination_reason = TerminationReason.from_str(finish_reason) | ||
| break | ||
|
|
||
| # Execute each tool call sequentially | ||
| for tool_call in tool_calls: | ||
|
|
||
| # Execute tool call for this environment | ||
| observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call) | ||
| observation, reward, env_end, info = await envs.step(rollout_idx, tool_call) | ||
|
|
||
| tool_response = envs.format_tool_response(observation) | ||
|
|
||
|
|
@@ -331,7 +326,7 @@ async def _execute_rollout( | |
| tool_response, | ||
| conversation_history, | ||
| reward, | ||
| rollout_end, | ||
| env_end, | ||
| info, | ||
| ) | ||
|
|
||
|
|
@@ -354,7 +349,7 @@ async def _execute_rollout( | |
| control_plane_step = { | ||
| "step": step - 1, | ||
| "reward": reward, | ||
| "terminated": rollout_end, | ||
| "terminated": env_end, | ||
| "info": info.get("control_plane", {}), | ||
| "tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"], | ||
| "num_tool_calls": 1, | ||
|
|
@@ -367,11 +362,13 @@ async def _execute_rollout( | |
| if recording_mode: | ||
| policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) | ||
|
|
||
| if rollout_end: | ||
| if env_end: | ||
| # if the env marks the end of the rollout, break the tool call loop | ||
| # but set the termination reason later after the final policy call | ||
| trajectory.terminated = True | ||
| trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL | ||
| break | ||
| elif step >= steps: | ||
|
|
||
| if step >= steps: | ||
| trajectory.terminated = True | ||
| trajectory.termination_reason = TerminationReason.MAX_STEPS | ||
| break | ||
|
|
@@ -392,7 +389,7 @@ async def _execute_rollout( | |
| control_plane_step = { | ||
| "step": step - 1, | ||
| "reward": reward, | ||
| "terminated": rollout_end, | ||
| "terminated": env_end, | ||
| "info": info.get("control_plane", {}), | ||
| "tool_calls": tool_calls_summary, | ||
| "num_tool_calls": len(tool_calls), | ||
|
|
@@ -404,19 +401,16 @@ async def _execute_rollout( | |
| if recording_mode: | ||
| policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history) | ||
|
|
||
| # Use control plane information for termination decision | ||
| if rollout_end: | ||
| trajectory.terminated = True | ||
| trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL | ||
|
|
||
| # tool indicates rollout should be terminated, call policy one last time to get the final response | ||
| _, usage_stats = await policy(tool_schema, rollout_idx, conversation_history) | ||
| # if the env marks end, update control plane summary and do one last policy call, then break the agent loop | ||
| # this is to ensure each turn ends with an assistant message, which will align with the actual agentic llm behavior | ||
| if env_end: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xzrderek but we still save the information in the control plane summary |
||
| _, usage_stats, finish_reason = await policy(tool_schema, rollout_idx, conversation_history) | ||
| if usage_stats: | ||
| trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens | ||
| trajectory.usage["completion_tokens"] += usage_stats.completion_tokens | ||
| trajectory.usage["total_tokens"] += usage_stats.total_tokens | ||
|
|
||
| # Add final control plane summary | ||
| trajectory.terminated = True | ||
| trajectory.termination_reason = TerminationReason.from_str(finish_reason) | ||
| trajectory.control_plane_summary.update( | ||
| { | ||
| "total_reward": trajectory.total_reward, | ||
|
|
@@ -445,7 +439,7 @@ async def _execute_rollout( | |
| ) | ||
|
|
||
| logger.info( | ||
| f"🏁 Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" | ||
| f"🏁 Environmnet indicates rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}" | ||
| ) | ||
| break | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @xzrderek so basically the change is remove this if branch