Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/google/adk/agents/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
from typing import Any
from typing import Callable
from typing import cast
from typing import Optional

Expand Down Expand Up @@ -51,6 +52,16 @@ class LlmCallsLimitExceededError(Exception):
"""Error thrown when the number of LLM calls exceed the limit."""


ToolProgressHandler = Callable[[str, str | None, Any], Any]
"""Callback for UI-only tool progress updates.

Args:
tool_name: The name of the tool reporting progress.
function_call_id: The function call id if available.
data: The tool-defined progress payload.
"""


class RealtimeCacheEntry(BaseModel):
"""Store audio data chunks for caching before flushing."""

Expand Down Expand Up @@ -207,6 +218,11 @@ class InvocationContext(BaseModel):
live_request_queue: Optional[LiveRequestQueue] = None
"""The queue to receive live requests."""

tool_progress_handler: Optional[ToolProgressHandler] = Field(
default=None, exclude=True
)
"""Runtime callback for tool progress updates that should not reach the LLM."""

active_streaming_tools: Optional[dict[str, ActiveStreamingTool]] = None
"""The running streaming tools of this invocation."""

Expand Down
4 changes: 4 additions & 0 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ def __init__(
self.session_service = session_service
self.memory_service = memory_service
self.credential_service = credential_service
self.on_tool_progress: Optional[Callable[[str, str | None, Any], Any]] = (
None
)
self.plugin_manager = PluginManager(
plugins=app.plugins, close_timeout=plugin_close_timeout
)
Expand Down Expand Up @@ -2068,6 +2071,7 @@ def _new_invocation_context(
session=session,
user_content=new_message,
live_request_queue=live_request_queue,
tool_progress_handler=self.on_tool_progress,
run_config=run_config,
resumability_config=self.resumability_config,
)
Expand Down
27 changes: 26 additions & 1 deletion src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def __init__(
self.func = func
# Detect context parameter by type annotation, fallback to 'tool_context' name
self._context_param_name = find_context_parameter(func) or 'tool_context'
self._ignore_params = [self._context_param_name, 'input_stream']
self._ignore_params = [
self._context_param_name,
'input_stream',
'progress_callback',
]
self._require_confirmation = require_confirmation

@override
Expand Down Expand Up @@ -221,6 +225,10 @@ async def run_async(
valid_params = {param for param in signature.parameters}
if self._context_param_name in valid_params:
args_to_call[self._context_param_name] = tool_context
if 'progress_callback' in valid_params:
args_to_call['progress_callback'] = self._make_progress_callback(
tool_context
)

# Filter args_to_call to only include valid parameters for the function
args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params}
Expand Down Expand Up @@ -297,6 +305,19 @@ async def _invoke_callable(
else:
return target(**args_to_call)

def _make_progress_callback(self, tool_context: ToolContext) -> Callable:
"""Returns a tool-bound progress callback for UI-only status updates."""

async def progress_callback(data: Any) -> None:
handler = tool_context._invocation_context.tool_progress_handler
if handler is None:
return
result = handler(self.name, tool_context.function_call_id, data)
if inspect.isawaitable(result):
await result

return progress_callback

# TODO(hangfei): fix call live for function stream.
async def _call_live(
self,
Expand All @@ -319,6 +340,10 @@ async def _call_live(
].stream
if self._context_param_name in signature.parameters:
args_to_call[self._context_param_name] = tool_context
if 'progress_callback' in signature.parameters:
args_to_call['progress_callback'] = self._make_progress_callback(
tool_context
)

# TODO: support tool confirmation for live mode.
async with Aclosing(self.func(**args_to_call)) as agen:
Expand Down
116 changes: 111 additions & 5 deletions tests/unittests/tools/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def my_tool(query: str, ctx: Context) -> str:

tool = FunctionTool(my_tool)
assert tool._context_param_name == "ctx"
assert tool._ignore_params == ["ctx", "input_stream"]
assert tool._ignore_params == ["ctx", "input_stream", "progress_callback"]


def test_context_param_detection_with_tool_context_type():
Expand All @@ -466,7 +466,11 @@ def my_tool(query: str, tool_context: ToolContext) -> str:

tool = FunctionTool(my_tool)
assert tool._context_param_name == "tool_context"
assert tool._ignore_params == ["tool_context", "input_stream"]
assert tool._ignore_params == [
"tool_context",
"input_stream",
"progress_callback",
]


def test_context_param_detection_with_custom_name():
Expand All @@ -477,7 +481,11 @@ def my_tool(query: str, my_custom_context: Context) -> str:

tool = FunctionTool(my_tool)
assert tool._context_param_name == "my_custom_context"
assert tool._ignore_params == ["my_custom_context", "input_stream"]
assert tool._ignore_params == [
"my_custom_context",
"input_stream",
"progress_callback",
]


def test_context_param_detection_fallback_to_name():
Expand All @@ -488,7 +496,11 @@ def my_tool(query: str, tool_context) -> str:

tool = FunctionTool(my_tool)
assert tool._context_param_name == "tool_context"
assert tool._ignore_params == ["tool_context", "input_stream"]
assert tool._ignore_params == [
"tool_context",
"input_stream",
"progress_callback",
]


def test_context_param_detection_no_context():
Expand All @@ -499,7 +511,11 @@ def my_tool(query: str, count: int) -> str:

tool = FunctionTool(my_tool)
assert tool._context_param_name == "tool_context"
assert tool._ignore_params == ["tool_context", "input_stream"]
assert tool._ignore_params == [
"tool_context",
"input_stream",
"progress_callback",
]


@pytest.mark.asyncio
Expand All @@ -518,6 +534,96 @@ def my_tool(query: str, ctx: Context) -> dict:
assert result == {"query": "test", "has_context": True}


@pytest.mark.asyncio
async def test_run_async_injects_progress_callback(mock_tool_context):
"""Test that run_async injects a UI-only progress callback when declared."""
progress_events = []

async def progress_handler(tool_name, function_call_id, data):
progress_events.append((tool_name, function_call_id, data))

async def my_tool(query: str, progress_callback) -> dict:
await progress_callback({"step": 1, "message": "working"})
return {"query": query}

mock_tool_context.function_call_id = "call-123"
mock_tool_context._invocation_context.tool_progress_handler = progress_handler

tool = FunctionTool(my_tool)
result = await tool.run_async(
args={"query": "test"},
tool_context=mock_tool_context,
)

assert result == {"query": "test"}
assert progress_events == [
("my_tool", "call-123", {"step": 1, "message": "working"})
]


@pytest.mark.asyncio
async def test_run_async_progress_callback_no_handler_is_noop(
mock_tool_context,
):
"""Test that an injected progress callback is a no-op without a handler."""

async def my_tool(progress_callback) -> dict:
await progress_callback({"step": 1})
return {"ok": True}

mock_tool_context._invocation_context.tool_progress_handler = None

tool = FunctionTool(my_tool)
result = await tool.run_async(args={}, tool_context=mock_tool_context)

assert result == {"ok": True}


def test_progress_callback_is_hidden_from_declaration():
"""Test that progress_callback is not exposed in the model-facing schema."""

def my_tool(query: str, progress_callback) -> str:
"""Search with UI-only progress."""
return query

declaration = FunctionTool(my_tool)._get_declaration()

assert declaration.parameters_json_schema is not None
properties = declaration.parameters_json_schema["properties"]
assert "query" in properties
assert "progress_callback" not in properties


@pytest.mark.asyncio
async def test_call_live_injects_progress_callback(mock_tool_context):
"""Test that live streaming tools receive the progress callback."""
progress_events = []

def progress_handler(tool_name, function_call_id, data):
progress_events.append((tool_name, function_call_id, data))

async def my_tool(progress_callback):
await progress_callback({"step": "start"})
yield {"status": "done"}

mock_tool_context.function_call_id = "live-call-123"
mock_tool_context._invocation_context.tool_progress_handler = progress_handler
mock_tool_context._invocation_context.active_streaming_tools = {}

tool = FunctionTool(my_tool)
results = [
item
async for item in tool._call_live(
args={},
tool_context=mock_tool_context,
invocation_context=mock_tool_context._invocation_context,
)
]

assert results == [{"status": "done"}]
assert progress_events == [("my_tool", "live-call-123", {"step": "start"})]


@pytest.mark.asyncio
async def test_run_async_with_context_type_annotation(mock_tool_context):
"""Test that run_async works with Context type annotation."""
Expand Down