diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index b2032c5325..3ad031f2fe 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -396,6 +396,33 @@ def should_pause_invocation(self, event: Event) -> bool: return False + def has_unresolved_long_running_tool_calls(self, events: list[Event]) -> bool: + """Returns whether any long-running tool call in events is unresolved.""" + if not self.is_resumable or not events: + return False + + function_response_ids = { + function_response.id + for event in events + for function_response in event.get_function_responses() + if function_response.id and event.author == "user" + } + + for event in reversed(events): + if not self.should_pause_invocation(event): + continue + + paused_function_call_ids = { + function_call.id + for function_call in event.get_function_calls() + if event.long_running_tool_ids + and function_call.id in event.long_running_tool_ids + } + if paused_function_call_ids - function_response_ids: + return True + + return False + # TODO: Move this method from invocation_context to a dedicated module. def _find_matching_function_call( self, function_response_event: Event diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 96e5043f72..ab6e4ab523 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -496,7 +496,7 @@ async def _run_async_impl( if ctx.is_resumable: events = ctx._get_events(current_invocation=True, current_branch=True) - if events and any(ctx.should_pause_invocation(e) for e in events[-2:]): + if ctx.has_unresolved_long_running_tool_calls(events): return # Only yield an end state if the last event is no longer a long-running # tool call. diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 4c253014a9..1235630dc3 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -835,19 +835,7 @@ async def _run_one_step_async( # Long running tool calls should have been handled before this point. # If there are still long running tool calls, it means the agent is paused # before, and its branch hasn't been resumed yet. - if ( - invocation_context.is_resumable - and events - and len(events) > 1 - # TODO: here we are using the last 2 events to decide whether to pause - # the invocation. But this is just being optimistic, we should find a - # way to pause when the long running tool call is followed by more than - # one text responses. - and ( - invocation_context.should_pause_invocation(events[-1]) - or invocation_context.should_pause_invocation(events[-2]) - ) - ): + if invocation_context.has_unresolved_long_running_tool_calls(events): return if ( diff --git a/tests/unittests/agents/test_invocation_context.py b/tests/unittests/agents/test_invocation_context.py index 87f78b2869..8a0a38a82d 100644 --- a/tests/unittests/agents/test_invocation_context.py +++ b/tests/unittests/agents/test_invocation_context.py @@ -24,6 +24,7 @@ from google.adk.sessions.session import Session from google.genai.types import Content from google.genai.types import FunctionCall +from google.genai.types import FunctionResponse from google.genai.types import Part import pytest @@ -210,6 +211,82 @@ def test_should_not_pause_invocation_with_no_function_calls( nonpausable_event ) + def test_has_unresolved_long_running_tool_calls_with_matching_response(self): + """Tests that matching function responses resolve the pause.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + function_call = FunctionCall( + id='tool_call_id_1', + name='long_running_function_call', + args={}, + ) + paused_event = Event( + invocation_id='inv_1', + author='agent', + content=testing_utils.ModelContent([Part(function_call=function_call)]), + long_running_tool_ids={function_call.id}, + ) + resolved_event = Event( + invocation_id='inv_1', + author='user', + content=Content( + role='user', + parts=[ + Part( + function_response=FunctionResponse( + name='long_running_function_call', + response={'result': 'done'}, + id=function_call.id, + ) + ) + ], + ), + ) + + assert not invocation_context.has_unresolved_long_running_tool_calls( + [paused_event, resolved_event] + ) + + def test_has_unresolved_long_running_tool_calls_without_matching_response( + self, + ): + """Tests that unmatched long-running calls still pause the invocation.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + function_call = FunctionCall( + id='tool_call_id_1', + name='long_running_function_call', + args={}, + ) + paused_event = Event( + invocation_id='inv_1', + author='agent', + content=testing_utils.ModelContent([Part(function_call=function_call)]), + long_running_tool_ids={function_call.id}, + ) + unrelated_response_event = Event( + invocation_id='inv_1', + author='user', + content=Content( + role='user', + parts=[ + Part( + function_response=FunctionResponse( + name='long_running_function_call', + response={'result': 'done'}, + id='different_tool_call_id', + ) + ) + ], + ), + ) + + assert invocation_context.has_unresolved_long_running_tool_calls( + [paused_event, unrelated_response_event] + ) + def test_is_resumable_true(self): """Tests that is_resumable is True when resumability is enabled.""" invocation_context = self._create_test_invocation_context(