Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions src/google/adk/auth/auth_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@
from .auth_tool import AuthConfig
from .auth_tool import AuthToolArguments

# Prefix used by toolset auth credential IDs.
# Auth requests with this prefix are for toolset authentication (before tool
# listing) and don't require resuming a function call.
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'


async def _store_auth_and_collect_resume_targets(
events: list[Event],
Expand All @@ -50,7 +45,7 @@ async def _store_auth_and_collect_resume_targets(
``AuthToolArguments`` args, merges ``credential_key`` into the
corresponding auth response, stores credentials via ``AuthHandler``,
and returns the set of original function call IDs that should be
re-executed (excluding toolset auth).
re-executed.
Args:
events: Session events to scan.
Expand Down Expand Up @@ -96,8 +91,7 @@ async def _store_auth_and_collect_resume_targets(
state=state
)

# Step 3: Collect original function call IDs to resume, skipping
# toolset auth entries which don't map to a resumable function call.
# Step 3: Collect original function call IDs to resume.
tools_to_resume: set[str] = set()
for fc_id in auth_fc_ids:
requested_auth_config = requested_auth_config_by_id.get(fc_id)
Expand All @@ -115,10 +109,6 @@ async def _store_auth_and_collect_resume_targets(
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
args = AuthToolArguments.model_validate(function_call.args)
if args.function_call_id.startswith(
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
):
continue
tools_to_resume.add(args.function_call_id)

return tools_to_resume
Expand Down
65 changes: 11 additions & 54 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
from ...agents.live_request_queue import LiveRequestQueue
from ...agents.readonly_context import ReadonlyContext
from ...agents.run_config import StreamingMode
from ...auth.auth_handler import AuthHandler
from ...auth.auth_tool import AuthConfig
from ...auth.credential_manager import CredentialManager
from ...events.event import Event
from ...models.base_llm_connection import BaseLlmConnection
Expand All @@ -51,10 +49,6 @@
from ...tools.tool_context import ToolContext
from ...utils.context_utils import Aclosing
from .audio_cache_manager import AudioCacheManager
from .functions import build_auth_request_event

# Prefix used by toolset auth credential IDs
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'

if TYPE_CHECKING:
from ...agents.llm_agent import LlmAgent
Expand Down Expand Up @@ -115,24 +109,24 @@ def _finalize_model_response_event(
async def _resolve_toolset_auth(
invocation_context: InvocationContext,
agent: LlmAgent,
) -> AsyncGenerator[Event, None]:
) -> None:
"""Resolves authentication for toolsets before tool listing.

For each toolset with auth configured via get_auth_config():
- If credential is available, populate auth_config.exchanged_auth_credential
- If credential is not available, yield auth request event and interrupt
- If credential is not available, log and continue — auth will be handled
on demand by ToolAuthHandler when a tool is actually invoked.

This avoids triggering OAuth redirects on every agent invocation,
including messages that don't require any tool calls.

Args:
invocation_context: The invocation context.
agent: The LLM agent.

Yields:
Auth request events if any toolset needs authentication.
"""
if not agent.tools:
return

pending_auth_requests: dict[str, AuthConfig] = {}
callback_context = CallbackContext(invocation_context)

for tool_union in agent.tools:
Expand Down Expand Up @@ -161,30 +155,11 @@ async def _resolve_toolset_auth(
# Populate in-place for toolset to use in get_tools()
auth_config.exchanged_auth_credential = credential
else:
# Need auth - will interrupt
toolset_id = (
f'{TOOLSET_AUTH_CREDENTIAL_ID_PREFIX}{type(tool_union).__name__}'
logger.debug(
'No credential found for toolset %s; deferring auth to tool'
' invocation.',
type(tool_union).__name__,
)
pending_auth_requests[toolset_id] = auth_config

if not pending_auth_requests:
return

# Build auth requests dict with generated auth requests
auth_requests = {
credential_id: AuthHandler(auth_config).generate_auth_request()
for credential_id, auth_config in pending_auth_requests.items()
}

# Yield event with auth requests using the shared helper
yield build_auth_request_event(
invocation_context,
auth_requests,
author=agent.name,
)

# Interrupt invocation
invocation_context.end_invocation = True


async def _handle_before_model_callback(
Expand Down Expand Up @@ -916,14 +891,7 @@ async def _preprocess_async(

# Resolve toolset authentication before tool listing.
# This ensures credentials are ready before get_tools() is called.
async with Aclosing(
self._resolve_toolset_auth(invocation_context, agent)
) as agen:
async for event in agen:
yield event

if invocation_context.end_invocation:
return
await _resolve_toolset_auth(invocation_context, agent)

# Run processors for tools.
await _process_agent_tools(invocation_context, llm_request)
Expand Down Expand Up @@ -1273,17 +1241,6 @@ def _finalize_model_response_event(
llm_request, llm_response, model_response_event
)

async def _resolve_toolset_auth(
self,
invocation_context: InvocationContext,
agent: LlmAgent,
) -> AsyncGenerator[Event, None]:
async with Aclosing(
_resolve_toolset_auth(invocation_context, agent)
) as agen:
async for event in agen:
yield event

async def _handle_before_model_callback(
self,
invocation_context: InvocationContext,
Expand Down
149 changes: 55 additions & 94 deletions tests/unittests/auth/test_toolset_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from typing import Optional
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import patch

Expand All @@ -28,12 +27,8 @@
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_preprocessor import TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
from google.adk.auth.auth_tool import AuthConfig
from google.adk.auth.auth_tool import AuthToolArguments
from google.adk.flows.llm_flows.base_llm_flow import _resolve_toolset_auth
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
from google.adk.flows.llm_flows.base_llm_flow import TOOLSET_AUTH_CREDENTIAL_ID_PREFIX as FLOW_PREFIX
from google.adk.flows.llm_flows.functions import build_auth_request_event
from google.adk.flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
from google.adk.tools.base_tool import BaseTool
Expand Down Expand Up @@ -85,15 +80,6 @@ def create_oauth2_auth_config() -> AuthConfig:
)


class TestToolsetAuthPrefixConstant:
"""Test that prefix constants are consistent."""

def test_prefix_constants_match(self):
"""Ensure auth_preprocessor and base_llm_flow use the same prefix."""
assert TOOLSET_AUTH_CREDENTIAL_ID_PREFIX == FLOW_PREFIX
assert TOOLSET_AUTH_CREDENTIAL_ID_PREFIX == "_adk_toolset_auth_"


class TestResolveToolsetAuth:
"""Tests for _resolve_toolset_auth method in BaseLlmFlow."""

Expand Down Expand Up @@ -121,19 +107,12 @@ def mock_agent(self):
return agent

@pytest.mark.asyncio
async def test_no_tools_returns_no_events(
self, mock_invocation_context, mock_agent
):
"""Test that no events are yielded when agent has no tools."""
async def test_no_tools_completes(self, mock_invocation_context, mock_agent):
"""Test that resolve completes without side effects when agent has no tools."""
mock_agent.tools = []

events = []
async for event in _resolve_toolset_auth(
mock_invocation_context, mock_agent
):
events.append(event)
await _resolve_toolset_auth(mock_invocation_context, mock_agent)

assert len(events) == 0
assert mock_invocation_context.end_invocation is False

@pytest.mark.asyncio
Expand All @@ -144,13 +123,8 @@ async def test_toolset_without_auth_config_skipped(
toolset = MockToolset(auth_config=None)
mock_agent.tools = [toolset]

events = []
async for event in _resolve_toolset_auth(
mock_invocation_context, mock_agent
):
events.append(event)
await _resolve_toolset_auth(mock_invocation_context, mock_agent)

assert len(events) == 0
assert mock_invocation_context.end_invocation is False

@pytest.mark.asyncio
Expand All @@ -162,7 +136,6 @@ async def test_toolset_with_credential_available_populates_config(
toolset = MockToolset(auth_config=auth_config)
mock_agent.tools = [toolset]

# Mock CredentialManager to return a credential
mock_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test-token"),
Expand All @@ -175,23 +148,21 @@ async def test_toolset_with_credential_available_populates_config(
mock_manager.get_auth_credential = AsyncMock(return_value=mock_credential)
MockCredentialManager.return_value = mock_manager

events = []
async for event in _resolve_toolset_auth(
mock_invocation_context, mock_agent
):
events.append(event)
await _resolve_toolset_auth(mock_invocation_context, mock_agent)

# No auth request events - credential was available
assert len(events) == 0
assert mock_invocation_context.end_invocation is False
# Credential should be populated in auth_config
assert auth_config.exchanged_auth_credential == mock_credential

@pytest.mark.asyncio
async def test_toolset_without_credential_yields_auth_event(
async def test_toolset_without_credential_defers_auth(
self, mock_invocation_context, mock_agent
):
"""Test that auth request event is yielded when credential not available."""
"""Test that auth is deferred when credential is not available.

When no credential is found, _resolve_toolset_auth should not interrupt
the invocation. Auth will be handled on demand by ToolAuthHandler when
a tool is actually invoked.
"""
auth_config = create_oauth2_auth_config()
toolset = MockToolset(auth_config=auth_config)
mock_agent.tools = [toolset]
Expand All @@ -203,37 +174,16 @@ async def test_toolset_without_credential_yields_auth_event(
mock_manager.get_auth_credential = AsyncMock(return_value=None)
MockCredentialManager.return_value = mock_manager

events = []
async for event in _resolve_toolset_auth(
mock_invocation_context, mock_agent
):
events.append(event)

# Should yield one auth request event
assert len(events) == 1
assert mock_invocation_context.end_invocation is True
await _resolve_toolset_auth(mock_invocation_context, mock_agent)

# Check event structure
event = events[0]
assert event.invocation_id == "test-invocation-id"
assert event.author == "test-agent"
assert event.content is not None
assert len(event.content.parts) == 1

# Check function call
fc = event.content.parts[0].function_call
assert fc.name == REQUEST_EUC_FUNCTION_CALL_NAME
# The args use camelCase aliases from the pydantic model
assert fc.args["functionCallId"].startswith(
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
)
assert "MockToolset" in fc.args["functionCallId"]
assert mock_invocation_context.end_invocation is False
assert auth_config.exchanged_auth_credential is None

@pytest.mark.asyncio
async def test_multiple_toolsets_needing_auth(
async def test_multiple_toolsets_without_credentials_defers_auth(
self, mock_invocation_context, mock_agent
):
"""Test that multiple toolsets needing auth yield multiple function calls."""
"""Test that multiple toolsets without credentials do not interrupt."""
auth_config1 = create_oauth2_auth_config()
auth_config2 = create_oauth2_auth_config()
toolset1 = MockToolset(auth_config=auth_config1)
Expand All @@ -247,40 +197,51 @@ async def test_multiple_toolsets_needing_auth(
mock_manager.get_auth_credential = AsyncMock(return_value=None)
MockCredentialManager.return_value = mock_manager

events = []
async for event in _resolve_toolset_auth(
mock_invocation_context, mock_agent
):
events.append(event)
await _resolve_toolset_auth(mock_invocation_context, mock_agent)

assert mock_invocation_context.end_invocation is False

# Should yield one event with multiple function calls
# But since both toolsets have same class name, they'll have same ID
# and only one will be in pending_auth_requests (dict overwrites)
assert len(events) == 1
assert mock_invocation_context.end_invocation is True
@pytest.mark.asyncio
async def test_mixed_toolsets_populates_available_credentials(
self, mock_invocation_context, mock_agent
):
"""Test that credentials are populated when available, without interrupt.

When one toolset has credentials and another does not, the available
credential should be populated while the missing one is deferred.
"""
auth_config_with_cred = create_oauth2_auth_config()
auth_config_without_cred = create_oauth2_auth_config()
toolset_with_cred = MockToolset(auth_config=auth_config_with_cred)
toolset_without_cred = MockToolset(auth_config=auth_config_without_cred)
mock_agent.tools = [toolset_with_cred, toolset_without_cred]

class TestAuthPreprocessorToolsetAuthSkip:
"""Tests for auth preprocessor skipping toolset auth."""
mock_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test-token"),
)

def test_toolset_auth_prefix_skipped(self):
"""Test that function calls with toolset auth prefix are skipped."""
from google.adk.auth.auth_preprocessor import TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
call_count = 0

# Verify the prefix is correct
assert TOOLSET_AUTH_CREDENTIAL_ID_PREFIX == "_adk_toolset_auth_"
async def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return mock_credential
return None

# Test that a function_call_id starting with this prefix would be skipped
toolset_function_call_id = f"{TOOLSET_AUTH_CREDENTIAL_ID_PREFIX}McpToolset"
assert toolset_function_call_id.startswith(
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
)
with patch(
"google.adk.flows.llm_flows.base_llm_flow.CredentialManager"
) as MockCredentialManager:
mock_manager = AsyncMock()
mock_manager.get_auth_credential = AsyncMock(side_effect=side_effect)
MockCredentialManager.return_value = mock_manager

# Regular tool auth function_call_id should NOT start with prefix
regular_function_call_id = "call_123"
assert not regular_function_call_id.startswith(
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
)
await _resolve_toolset_auth(mock_invocation_context, mock_agent)

assert mock_invocation_context.end_invocation is False
assert auth_config_with_cred.exchanged_auth_credential == mock_credential
assert auth_config_without_cred.exchanged_auth_credential is None


class TestCallbackContextGetAuthResponse:
Expand Down