From f286222db99b6151c4e28c3a534cc048cec9430e Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Sat, 6 Jun 2026 01:18:25 +0800 Subject: [PATCH] Expose Gemini Live API client override --- src/google/adk/models/google_llm.py | 39 +++++++++++- tests/unittests/models/test_google_llm.py | 78 ++++++++++++++++++----- 2 files changed, 99 insertions(+), 18 deletions(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index d5923ffd25..35153ee8c5 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -110,6 +110,19 @@ def api_client(self) -> Client: Use ``@property`` instead of ``@cached_property`` if you hit asyncio lock contention in multithreaded code. + + Customizing the Live API Client: + The Live API path uses its own client. To set Live API-only options, + subclass ``Gemini`` and override the ``live_api_client`` property:: + + from functools import cached_property + from google.adk.models import Gemini + from google.genai import Client + + class RegionalLiveGemini(Gemini): + @cached_property + def live_api_client(self) -> Client: + return Client(vertexai=True, location="europe-central2") """ model: str = 'gemini-2.5-flash' @@ -376,8 +389,7 @@ def _live_api_version(self) -> str: # use v1alpha for using API KEY from Google AI Studio return 'v1alpha' - @cached_property - def _live_api_client(self) -> Client: + def _build_live_api_client(self) -> Client: from google.genai import Client base_url, _ = self._base_url_and_api_version @@ -394,6 +406,27 @@ def _live_api_client(self) -> Client: return Client(**kwargs) + def _uses_legacy_live_api_client_override(self) -> bool: + for cls in type(self).__mro__: + if '_live_api_client' in cls.__dict__: + return cls is not Gemini + return False + + @cached_property + def live_api_client(self) -> Client: + """Provides the Live API client. + + Subclasses can override this property to customize the ``Client`` used by + Live API connections. + """ + if self._uses_legacy_live_api_client_override(): + return self._live_api_client + return self._build_live_api_client() + + @cached_property + def _live_api_client(self) -> Client: + return self.live_api_client + @contextlib.asynccontextmanager async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: """Connects to the Gemini model and returns an llm connection. @@ -455,7 +488,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: llm_request.live_connect_config.tools = llm_request.config.tools logger.debug('Connecting to live with llm_request:%s', llm_request) logger.debug('Live connect config: %s', llm_request.live_connect_config) - async with self._live_api_client.aio.live.connect( + async with self.live_api_client.aio.live.connect( model=llm_request.model, config=llm_request.live_connect_config ) as live_session: yield GeminiLlmConnection( diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 7f7ce39895..d970fd5993 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -174,10 +174,10 @@ def test_gemini_live_api_client_creation_with_projects_prefix(): model="projects/test-project/locations/test-location/publishers/google/models/gemini-2.5-pro" ) with mock.patch("google.genai.Client", autospec=True) as mock_client: - _ = model._live_api_client + _ = model.live_api_client assert mock_client.call_count == 2 - # Second call is for _live_api_client + # Second call is for live_api_client. _, kwargs = mock_client.call_args_list[1] assert kwargs["vertexai"] is True @@ -732,25 +732,40 @@ def test_live_api_version_gemini_api(gemini_llm): assert gemini_llm._live_api_version == "v1alpha" -def test_live_api_client_uses_api_version_from_google_base_url(): +@pytest.mark.parametrize( + "base_url, expected_base_url", + [ + ( + "https://generativelanguage.googleapis.com/v1alpha", + "https://generativelanguage.googleapis.com/", + ), + ( + "https://generativelanguage.mtls.googleapis.com/v1alpha", + "https://generativelanguage.mtls.googleapis.com/", + ), + ], +) +def test_live_api_client_uses_api_version_from_google_base_url( + base_url, expected_base_url +): gemini_llm = Gemini( model="gemini-2.5-flash", - base_url="https://generativelanguage.googleapis.com/v1alpha", + base_url=base_url, ) - client = gemini_llm._live_api_client + client = gemini_llm.live_api_client http_options = client._api_client._http_options - assert http_options.base_url == "https://generativelanguage.googleapis.com/" + assert http_options.base_url == expected_base_url assert http_options.api_version == "v1alpha" def test_live_api_client_properties(gemini_llm): - """Test that _live_api_client is properly configured with tracking headers and API version.""" + """Test that live_api_client is properly configured with tracking headers and API version.""" with mock.patch.object( gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI ): - client = gemini_llm._live_api_client + client = gemini_llm.live_api_client # Verify that the client has the correct headers and API version http_options = client._api_client._http_options @@ -763,6 +778,39 @@ def test_live_api_client_properties(gemini_llm): assert value in http_options.headers[key] +def test_live_api_client_private_alias(gemini_llm): + assert gemini_llm._live_api_client is gemini_llm.live_api_client + + +def test_live_api_client_public_override(): + custom_client = mock.MagicMock() + + class CustomGemini(Gemini): + + @property + def live_api_client(self): + return custom_client + + gemini_llm = CustomGemini(model="gemini-2.5-flash") + + assert gemini_llm.live_api_client is custom_client + assert gemini_llm._live_api_client is custom_client + + +def test_live_api_client_legacy_private_override(): + custom_client = mock.MagicMock() + + class CustomGemini(Gemini): + + @property + def _live_api_client(self): + return custom_client + + gemini_llm = CustomGemini(model="gemini-2.5-flash") + + assert gemini_llm.live_api_client is custom_client + + @pytest.mark.asyncio async def test_connect_with_custom_headers(gemini_llm, llm_request): """Test that connect method updates tracking headers and API version when custom headers are provided.""" @@ -774,8 +822,8 @@ async def test_connect_with_custom_headers(gemini_llm, llm_request): mock_live_session = mock.AsyncMock() - # Mock the _live_api_client to return a mock client - with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + # Mock the live_api_client to return a mock client + with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client: # Create a mock context manager class MockLiveConnect: @@ -817,7 +865,7 @@ async def test_connect_without_custom_headers(gemini_llm, llm_request): mock_live_session = mock.AsyncMock() - with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client: class MockLiveConnect: @@ -2099,7 +2147,7 @@ async def test_connect_uses_gemini_speech_config_when_request_is_none( mock_live_session = mock.AsyncMock() - with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client: class MockLiveConnect: @@ -2147,7 +2195,7 @@ async def test_connect_uses_request_speech_config_when_gemini_is_none( mock_live_session = mock.AsyncMock() - with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client: class MockLiveConnect: @@ -2201,7 +2249,7 @@ async def test_connect_request_gemini_config_overrides_speech_config( mock_live_session = mock.AsyncMock() - with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client: class MockLiveConnect: @@ -2242,7 +2290,7 @@ async def test_connect_speech_config_remains_none_when_both_are_none( mock_live_session = mock.AsyncMock() - with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client: class MockLiveConnect: