@@ -97,10 +97,12 @@ async def execute_rollouts(
9797
9898 async def _execute_with_semaphore (idx ):
9999 async with semaphore :
100- return await self ._execute_rollout (
100+ result = await self ._execute_rollout (
101101 envs , policy , idx , steps , openai_logger , recording_mode , playback_mode , start_time
102102 )
103103
104+ return result
105+
104106 tasks = [_execute_with_semaphore (i ) for i in range (envs .n )]
105107 # exceptions will be try catched inside single _execute_rollout
106108 trajectories = await asyncio .gather (* tasks )
@@ -112,9 +114,6 @@ async def _execute_with_semaphore(idx):
112114
113115 shared_tool_schema = envs .tool_schemas
114116
115- # Clean up
116- await envs .close ()
117-
118117 # Enhanced reporting with control plane info
119118 successful = sum (1 for traj in trajectories if traj .total_reward > 0 )
120119 terminated_by_control_plane = sum (
@@ -175,8 +174,11 @@ async def _execute_with_semaphore(idx):
175174 TerminationReason .USER_STOP ,
176175 }:
177176 evaluation_rows [idx ].rollout_status .status = "finished"
178- elif trajectory .termination_reason == TerminationReason .MAX_STEPS :
177+ elif trajectory .termination_reason in { TerminationReason .MAX_STEPS , TerminationReason . INTERRUPTED } :
179178 evaluation_rows [idx ].rollout_status .status = "stopped"
179+ evaluation_rows [idx ].rollout_status .error_message = trajectory .control_plane_summary .get (
180+ "termination_reason" , trajectory .termination_reason
181+ )
180182 else :
181183 evaluation_rows [idx ].rollout_status .status = "error"
182184 evaluation_rows [idx ].rollout_status .error_message = trajectory .control_plane_summary .get (
@@ -226,6 +228,7 @@ async def _execute_rollout(
226228 "total_tokens" : 0 ,
227229 },
228230 )
231+ failure_reason = None
229232 try :
230233 current_observation , tool_schema = await envs .reset (session )
231234 system_prompt = dataset_row .system_prompt
@@ -311,8 +314,7 @@ async def _execute_rollout(
311314 # If there's no user simulator, no tool call means policy failed and we should terminate the rollout
312315 elif tool_calls [0 ].tool_name in ["_playback_terminate" , "_no_tool_call" ]:
313316 trajectory .terminated = True
314- trajectory .termination_reason = TerminationReason .ERROR
315- trajectory .control_plane_summary .update ({"error_message" : "No expected tool call" })
317+ trajectory .termination_reason = TerminationReason .INTERRUPTED
316318 break
317319
318320 # Execute each tool call sequentially
@@ -466,11 +468,26 @@ async def _execute_rollout(
466468 logger .info (
467469 f"✅ Rollout { rollout_idx } completed: { trajectory .steps } steps, reward: { trajectory .total_reward :.2f} , termination: { trajectory .termination_reason } , in thread { threading .current_thread ().name } "
468470 )
471+
472+ except asyncio .CancelledError :
473+ logger .error (f"🚨 AsyncIO Cancel Error in roll out { rollout_idx } " , exc_info = True )
474+ failure_reason = "asyncio context cancelled"
469475 except Exception as e :
470476 logger .error (f"🚨 Error in rollout { rollout_idx } : { e } " , exc_info = True )
471- trajectory .terminated = True
472- trajectory .termination_reason = TerminationReason .ERROR
473- trajectory .control_plane_summary .update ({"error_message" : str (e )})
477+ failure_reason = str (e )
478+ finally :
479+ if failure_reason :
480+ trajectory .terminated = True
481+ trajectory .termination_reason = TerminationReason .ERROR
482+ trajectory .control_plane_summary .update ({"error_message" : f"{ failure_reason } " })
483+ try :
484+ await envs .connection_manager .reset_session (session )
485+ except :
486+ logger .error (f"Error resetting session { session .session_id } " )
487+ try :
488+ await envs .connection_manager .close_session (session )
489+ except :
490+ logger .error (f"Error closing session { session .session_id } " )
474491 return trajectory
475492
476493 async def _get_control_plane_status (self , session ) -> Optional [Dict [str , Any ]]:
0 commit comments