Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/google/adk/auth/auth_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def run_async(
agent = invocation_context.agent
if not hasattr(agent, 'canonical_tools'):
return
events = invocation_context.session.events
events = invocation_context._get_events(current_branch=True)
if not events:
return

Expand Down
40 changes: 36 additions & 4 deletions tests/unittests/auth/test_auth_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def mock_invocation_context(self, mock_llm_agent, mock_session):
context = Mock(spec=InvocationContext)
context.agent = mock_llm_agent
context.session = mock_session
context._get_events.side_effect = lambda **_: context.session.events
return context

@pytest.fixture
Expand Down Expand Up @@ -165,8 +166,7 @@ async def test_non_llm_agent_returns_early(
):
"""Test that non-LLM agents return early."""
mock_context = Mock(spec=InvocationContext)
mock_context.agent = Mock()
mock_context.agent.__class__.__name__ = 'BaseAgent'
mock_context.agent = object()
mock_context.session = mock_session

result = []
Expand Down Expand Up @@ -273,6 +273,38 @@ async def test_last_event_no_auth_responses_returns_early(

assert result == []

@pytest.mark.asyncio
@patch('google.adk.auth.auth_preprocessor.AuthHandler')
@patch('google.adk.auth.auth_tool.AuthConfig.model_validate')
async def test_ignores_auth_responses_outside_current_branch(
self,
mock_auth_config_validate,
mock_auth_handler_class,
processor,
mock_invocation_context,
mock_llm_request,
mock_user_event_with_auth_response,
):
"""Test auth responses hidden by branch filtering are ignored."""
mock_invocation_context.session.events = [
mock_user_event_with_auth_response
]
mock_invocation_context._get_events.side_effect = None
mock_invocation_context._get_events.return_value = []

result = []
async for event in processor.run_async(
mock_invocation_context, mock_llm_request
):
result.append(event)

mock_invocation_context._get_events.assert_called_once_with(
current_branch=True
)
mock_auth_config_validate.assert_not_called()
mock_auth_handler_class.assert_not_called()
assert result == []

@pytest.mark.asyncio
@patch('google.adk.auth.auth_preprocessor.AuthHandler')
@patch('google.adk.auth.auth_tool.AuthConfig.model_validate')
Expand Down Expand Up @@ -534,9 +566,9 @@ async def test_isinstance_check_for_llm_agent(
"""Test that isinstance check works correctly for LlmAgent."""
# This test ensures the isinstance check work as expected

# Create a mock that fails isinstance check
# Create an object that does not expose canonical_tools.
mock_context = Mock(spec=InvocationContext)
mock_context.agent = Mock() # This will fail isinstance(agent, LlmAgent)
mock_context.agent = object()
mock_context.session = mock_session

result = []
Expand Down