From 6d009430f8bed8c7f90fb28240da7046e0a2278e Mon Sep 17 00:00:00 2001 From: Zeel Date: Mon, 30 Mar 2026 16:54:34 -0400 Subject: [PATCH 1/5] fix(flows): resume long-running tools after matching responses --- src/google/adk/agents/invocation_context.py | 28 +++++++ src/google/adk/agents/llm_agent.py | 2 +- .../adk/flows/llm_flows/base_llm_flow.py | 17 +--- src/google/adk/flows/llm_flows/functions.py | 29 +++++++ .../agents/test_invocation_context.py | 77 +++++++++++++++++++ .../flows/llm_flows/test_base_llm_flow.py | 42 ++++++++++ 6 files changed, 181 insertions(+), 14 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index b2032c5325..10eee3e9a0 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -396,6 +396,34 @@ def should_pause_invocation(self, event: Event) -> bool: return False + def has_unresolved_long_running_tool_calls( + self, events: list[Event] + ) -> bool: + """Returns whether any long-running tool call in events is unresolved.""" + if not self.is_resumable or not events: + return False + + function_response_ids = { + function_response.id + for event in events + for function_response in event.get_function_responses() + if function_response.id + } + + for event in reversed(events): + if not self.should_pause_invocation(event): + continue + + paused_function_call_ids = { + function_call.id + for function_call in event.get_function_calls() + if function_call.id in event.long_running_tool_ids + } + if paused_function_call_ids - function_response_ids: + return True + + return False + # TODO: Move this method from invocation_context to a dedicated module. def _find_matching_function_call( self, function_response_event: Event diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 96e5043f72..ab6e4ab523 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -496,7 +496,7 @@ async def _run_async_impl( if ctx.is_resumable: events = ctx._get_events(current_invocation=True, current_branch=True) - if events and any(ctx.should_pause_invocation(e) for e in events[-2:]): + if ctx.has_unresolved_long_running_tool_calls(events): return # Only yield an end state if the last event is no longer a long-running # tool call. diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index bd0037bdcb..1a660771c8 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -99,6 +99,9 @@ def _finalize_model_response_event( if finalized_event.content: function_calls = finalized_event.get_function_calls() if function_calls: + functions.preserve_existing_function_call_ids( + model_response_event, finalized_event + ) functions.populate_client_function_call_id(finalized_event) finalized_event.long_running_tool_ids = ( functions.get_long_running_function_calls( @@ -785,19 +788,7 @@ async def _run_one_step_async( # Long running tool calls should have been handled before this point. # If there are still long running tool calls, it means the agent is paused # before, and its branch hasn't been resumed yet. - if ( - invocation_context.is_resumable - and events - and len(events) > 1 - # TODO: here we are using the last 2 events to decide whether to pause - # the invocation. But this is just being optimistic, we should find a - # way to pause when the long running tool call is followed by more than - # one text responses. - and ( - invocation_context.should_pause_invocation(events[-1]) - or invocation_context.should_pause_invocation(events[-2]) - ) - ): + if invocation_context.has_unresolved_long_running_tool_calls(events): return if ( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 1f85bee3a8..99a9765f73 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -189,6 +189,35 @@ def populate_client_function_call_id(model_response_event: Event) -> None: function_call.id = generate_client_function_call_id() +def preserve_existing_function_call_ids( + previous_event: Event, model_response_event: Event +) -> None: + """Carries forward function call IDs from a previous streaming event. + + Streaming responses may emit partial and final events for the same function + call sequence. The partial event is sent to clients first, while only the + final event is persisted. Preserving IDs across those events keeps + functionResponse routing stable when the client resumes a long-running tool. + + Args: + previous_event: The in-flight model response event from an earlier chunk. + model_response_event: The newly finalized event for the current chunk. + """ + previous_function_calls = previous_event.get_function_calls() + current_function_calls = model_response_event.get_function_calls() + if not previous_function_calls or not current_function_calls: + return + + for previous_function_call, current_function_call in zip( + previous_function_calls, current_function_calls + ): + if current_function_call.id: + continue + if previous_function_call.name != current_function_call.name: + continue + current_function_call.id = previous_function_call.id + + def remove_client_function_call_id(content: Optional[types.Content]) -> None: """Removes ADK-generated function call IDs from content before sending to LLM. diff --git a/tests/unittests/agents/test_invocation_context.py b/tests/unittests/agents/test_invocation_context.py index 87f78b2869..8a0a38a82d 100644 --- a/tests/unittests/agents/test_invocation_context.py +++ b/tests/unittests/agents/test_invocation_context.py @@ -24,6 +24,7 @@ from google.adk.sessions.session import Session from google.genai.types import Content from google.genai.types import FunctionCall +from google.genai.types import FunctionResponse from google.genai.types import Part import pytest @@ -210,6 +211,82 @@ def test_should_not_pause_invocation_with_no_function_calls( nonpausable_event ) + def test_has_unresolved_long_running_tool_calls_with_matching_response(self): + """Tests that matching function responses resolve the pause.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + function_call = FunctionCall( + id='tool_call_id_1', + name='long_running_function_call', + args={}, + ) + paused_event = Event( + invocation_id='inv_1', + author='agent', + content=testing_utils.ModelContent([Part(function_call=function_call)]), + long_running_tool_ids={function_call.id}, + ) + resolved_event = Event( + invocation_id='inv_1', + author='user', + content=Content( + role='user', + parts=[ + Part( + function_response=FunctionResponse( + name='long_running_function_call', + response={'result': 'done'}, + id=function_call.id, + ) + ) + ], + ), + ) + + assert not invocation_context.has_unresolved_long_running_tool_calls( + [paused_event, resolved_event] + ) + + def test_has_unresolved_long_running_tool_calls_without_matching_response( + self, + ): + """Tests that unmatched long-running calls still pause the invocation.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + function_call = FunctionCall( + id='tool_call_id_1', + name='long_running_function_call', + args={}, + ) + paused_event = Event( + invocation_id='inv_1', + author='agent', + content=testing_utils.ModelContent([Part(function_call=function_call)]), + long_running_tool_ids={function_call.id}, + ) + unrelated_response_event = Event( + invocation_id='inv_1', + author='user', + content=Content( + role='user', + parts=[ + Part( + function_response=FunctionResponse( + name='long_running_function_call', + response={'result': 'done'}, + id='different_tool_call_id', + ) + ) + ], + ), + ) + + assert invocation_context.has_unresolved_long_running_tool_calls( + [paused_event, unrelated_response_event] + ) + def test_is_resumable_true(self): """Tests that is_resumable is True when resumability is enabled.""" invocation_context = self._create_test_invocation_context( diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 3dfadbcabf..8fca948550 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -19,6 +19,7 @@ from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event +from google.adk.flows.llm_flows.base_llm_flow import _finalize_model_response_event from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow from google.adk.models.google_llm import Gemini @@ -41,6 +42,47 @@ class BaseLlmFlowForTesting(BaseLlmFlow): pass +def test_finalize_model_response_event_preserves_function_call_ids(): + """Test that streaming finalization keeps function call IDs stable.""" + previous_event = Event( + id=Event.new_id(), + invocation_id='test_invocation', + author='test_agent', + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall( + name='track_execution', + args={'call_id': 'partial'}, + id='adk-existing-id', + ) + ) + ], + ), + partial=True, + ) + llm_response = LlmResponse( + content=types.Content( + role='model', + parts=[ + types.Part.from_function_call( + name='track_execution', args={'call_id': 'final'} + ) + ], + ), + partial=False, + ) + + finalized_event = _finalize_model_response_event( + LlmRequest(), llm_response, previous_event + ) + + function_calls = finalized_event.get_function_calls() + assert len(function_calls) == 1 + assert function_calls[0].id == 'adk-existing-id' + + @pytest.mark.asyncio async def test_preprocess_calls_toolset_process_llm_request(): """Test that _preprocess_async calls process_llm_request on toolsets.""" From d361c12b02383d77ee4bd6a6f414ce24008304ee Mon Sep 17 00:00:00 2001 From: Jordan Date: Wed, 1 Apr 2026 22:12:04 +0000 Subject: [PATCH 2/5] fix: 3 fixes for has_unresolved_long_running_tool_calls (PR #5072) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. pyink formatting: collapse method signature to single line 2. Only count author='user' function_responses as resolutions — agent- generated auto-responses from LongRunningFunctionTool should not resolve the pause, only actual user resume responses should 3. Add null guard on event.long_running_tool_ids to fix mypy type error All 5158 unit tests pass with these changes. --- src/google/adk/agents/invocation_context.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 10eee3e9a0..3287664d29 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -396,9 +396,7 @@ def should_pause_invocation(self, event: Event) -> bool: return False - def has_unresolved_long_running_tool_calls( - self, events: list[Event] - ) -> bool: + def has_unresolved_long_running_tool_calls(self, events: list[Event]) -> bool: """Returns whether any long-running tool call in events is unresolved.""" if not self.is_resumable or not events: return False @@ -407,7 +405,7 @@ def has_unresolved_long_running_tool_calls( function_response.id for event in events for function_response in event.get_function_responses() - if function_response.id + if function_response.id and event.author == 'user' } for event in reversed(events): @@ -417,7 +415,7 @@ def has_unresolved_long_running_tool_calls( paused_function_call_ids = { function_call.id for function_call in event.get_function_calls() - if function_call.id in event.long_running_tool_ids + if event.long_running_tool_ids and function_call.id in event.long_running_tool_ids } if paused_function_call_ids - function_response_ids: return True From 3a891e8e3621cab4352a0e9f47028d8ff9a1cf1c Mon Sep 17 00:00:00 2001 From: Zeel Date: Mon, 13 Apr 2026 20:03:35 -0400 Subject: [PATCH 3/5] chore(flows): rerun PR checks From 31b43e56e16affe05855f6ee6648de37db620cac Mon Sep 17 00:00:00 2001 From: Zeel Date: Tue, 21 Apr 2026 12:22:21 -0400 Subject: [PATCH 4/5] fix(flows): remove duplicate streaming id changes --- src/google/adk/agents/invocation_context.py | 5 ++- .../adk/flows/llm_flows/base_llm_flow.py | 3 -- src/google/adk/flows/llm_flows/functions.py | 29 ------------- .../flows/llm_flows/test_base_llm_flow.py | 43 ------------------- 4 files changed, 3 insertions(+), 77 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 3287664d29..3ad031f2fe 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -405,7 +405,7 @@ def has_unresolved_long_running_tool_calls(self, events: list[Event]) -> bool: function_response.id for event in events for function_response in event.get_function_responses() - if function_response.id and event.author == 'user' + if function_response.id and event.author == "user" } for event in reversed(events): @@ -415,7 +415,8 @@ def has_unresolved_long_running_tool_calls(self, events: list[Event]) -> bool: paused_function_call_ids = { function_call.id for function_call in event.get_function_calls() - if event.long_running_tool_ids and function_call.id in event.long_running_tool_ids + if event.long_running_tool_ids + and function_call.id in event.long_running_tool_ids } if paused_function_call_ids - function_response_ids: return True diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 2170df5077..1235630dc3 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -102,9 +102,6 @@ def _finalize_model_response_event( if finalized_event.content: function_calls = finalized_event.get_function_calls() if function_calls: - functions.preserve_existing_function_call_ids( - model_response_event, finalized_event - ) functions.populate_client_function_call_id(finalized_event) finalized_event.long_running_tool_ids = ( functions.get_long_running_function_calls( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 26e80c4da7..eda8474c01 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -195,35 +195,6 @@ def populate_client_function_call_id(model_response_event: Event) -> None: function_call.id = generate_client_function_call_id() -def preserve_existing_function_call_ids( - previous_event: Event, model_response_event: Event -) -> None: - """Carries forward function call IDs from a previous streaming event. - - Streaming responses may emit partial and final events for the same function - call sequence. The partial event is sent to clients first, while only the - final event is persisted. Preserving IDs across those events keeps - functionResponse routing stable when the client resumes a long-running tool. - - Args: - previous_event: The in-flight model response event from an earlier chunk. - model_response_event: The newly finalized event for the current chunk. - """ - previous_function_calls = previous_event.get_function_calls() - current_function_calls = model_response_event.get_function_calls() - if not previous_function_calls or not current_function_calls: - return - - for previous_function_call, current_function_call in zip( - previous_function_calls, current_function_calls - ): - if current_function_call.id: - continue - if previous_function_call.name != current_function_call.name: - continue - current_function_call.id = previous_function_call.id - - def remove_client_function_call_id(content: Optional[types.Content]) -> None: """Removes ADK-generated function call IDs from content before sending to LLM. diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 90b02b7050..fc64ad3b84 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -19,11 +19,9 @@ from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event -from google.adk.flows.llm_flows.base_llm_flow import _finalize_model_response_event from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow from google.adk.models.google_llm import Gemini -from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin from google.adk.tools.base_toolset import BaseToolset @@ -42,47 +40,6 @@ class BaseLlmFlowForTesting(BaseLlmFlow): pass -def test_finalize_model_response_event_preserves_function_call_ids(): - """Test that streaming finalization keeps function call IDs stable.""" - previous_event = Event( - id=Event.new_id(), - invocation_id='test_invocation', - author='test_agent', - content=types.Content( - role='model', - parts=[ - types.Part( - function_call=types.FunctionCall( - name='track_execution', - args={'call_id': 'partial'}, - id='adk-existing-id', - ) - ) - ], - ), - partial=True, - ) - llm_response = LlmResponse( - content=types.Content( - role='model', - parts=[ - types.Part.from_function_call( - name='track_execution', args={'call_id': 'final'} - ) - ], - ), - partial=False, - ) - - finalized_event = _finalize_model_response_event( - LlmRequest(), llm_response, previous_event - ) - - function_calls = finalized_event.get_function_calls() - assert len(function_calls) == 1 - assert function_calls[0].id == 'adk-existing-id' - - @pytest.mark.asyncio async def test_preprocess_calls_toolset_process_llm_request(): """Test that _preprocess_async calls process_llm_request on toolsets.""" From ffa3801eaa5d8118e2697fcc22fbe00bd660b8c8 Mon Sep 17 00:00:00 2001 From: Zeel Date: Tue, 21 Apr 2026 12:26:08 -0400 Subject: [PATCH 5/5] test(flows): restore llm request import --- tests/unittests/flows/llm_flows/test_base_llm_flow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index fc64ad3b84..793ebb83cd 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -22,6 +22,7 @@ from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow from google.adk.models.google_llm import Gemini +from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin from google.adk.tools.base_toolset import BaseToolset