Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/google/adk/agents/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,33 @@ def should_pause_invocation(self, event: Event) -> bool:

return False

def has_unresolved_long_running_tool_calls(self, events: list[Event]) -> bool:
"""Returns whether any long-running tool call in events is unresolved."""
if not self.is_resumable or not events:
return False

function_response_ids = {
function_response.id
for event in events
for function_response in event.get_function_responses()
if function_response.id and event.author == "user"
}

for event in reversed(events):
if not self.should_pause_invocation(event):
continue

paused_function_call_ids = {
function_call.id
for function_call in event.get_function_calls()
if event.long_running_tool_ids
and function_call.id in event.long_running_tool_ids
}
if paused_function_call_ids - function_response_ids:
return True

return False

# TODO: Move this method from invocation_context to a dedicated module.
def _find_matching_function_call(
self, function_response_event: Event
Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ async def _run_async_impl(

if ctx.is_resumable:
events = ctx._get_events(current_invocation=True, current_branch=True)
if events and any(ctx.should_pause_invocation(e) for e in events[-2:]):
if ctx.has_unresolved_long_running_tool_calls(events):
return
# Only yield an end state if the last event is no longer a long-running
# tool call.
Expand Down
14 changes: 1 addition & 13 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,19 +835,7 @@ async def _run_one_step_async(
# Long running tool calls should have been handled before this point.
# If there are still long running tool calls, it means the agent is paused
# before, and its branch hasn't been resumed yet.
if (
invocation_context.is_resumable
and events
and len(events) > 1
# TODO: here we are using the last 2 events to decide whether to pause
# the invocation. But this is just being optimistic, we should find a
# way to pause when the long running tool call is followed by more than
# one text responses.
and (
invocation_context.should_pause_invocation(events[-1])
or invocation_context.should_pause_invocation(events[-2])
)
):
if invocation_context.has_unresolved_long_running_tool_calls(events):
return

if (
Expand Down
77 changes: 77 additions & 0 deletions tests/unittests/agents/test_invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.adk.sessions.session import Session
from google.genai.types import Content
from google.genai.types import FunctionCall
from google.genai.types import FunctionResponse
from google.genai.types import Part
import pytest

Expand Down Expand Up @@ -210,6 +211,82 @@ def test_should_not_pause_invocation_with_no_function_calls(
nonpausable_event
)

def test_has_unresolved_long_running_tool_calls_with_matching_response(self):
"""Tests that matching function responses resolve the pause."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
function_call = FunctionCall(
id='tool_call_id_1',
name='long_running_function_call',
args={},
)
paused_event = Event(
invocation_id='inv_1',
author='agent',
content=testing_utils.ModelContent([Part(function_call=function_call)]),
long_running_tool_ids={function_call.id},
)
resolved_event = Event(
invocation_id='inv_1',
author='user',
content=Content(
role='user',
parts=[
Part(
function_response=FunctionResponse(
name='long_running_function_call',
response={'result': 'done'},
id=function_call.id,
)
)
],
),
)

assert not invocation_context.has_unresolved_long_running_tool_calls(
[paused_event, resolved_event]
)

def test_has_unresolved_long_running_tool_calls_without_matching_response(
self,
):
"""Tests that unmatched long-running calls still pause the invocation."""
invocation_context = self._create_test_invocation_context(
ResumabilityConfig(is_resumable=True)
)
function_call = FunctionCall(
id='tool_call_id_1',
name='long_running_function_call',
args={},
)
paused_event = Event(
invocation_id='inv_1',
author='agent',
content=testing_utils.ModelContent([Part(function_call=function_call)]),
long_running_tool_ids={function_call.id},
)
unrelated_response_event = Event(
invocation_id='inv_1',
author='user',
content=Content(
role='user',
parts=[
Part(
function_response=FunctionResponse(
name='long_running_function_call',
response={'result': 'done'},
id='different_tool_call_id',
)
)
],
),
)

assert invocation_context.has_unresolved_long_running_tool_calls(
[paused_event, unrelated_response_event]
)

def test_is_resumable_true(self):
"""Tests that is_resumable is True when resumability is enabled."""
invocation_context = self._create_test_invocation_context(
Expand Down