From b4a70d6da62c4421c63c0e1e6b974b61324ddf57 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Mon, 8 Jun 2026 12:59:38 -0700 Subject: [PATCH] feat: Add Agent Platform MCP support to async generate_content_stream PiperOrigin-RevId: 928724878 --- google/genai/models.py | 330 +++++++++++++++++++++++------------------ 1 file changed, 189 insertions(+), 141 deletions(-) diff --git a/google/genai/models.py b/google/genai/models.py index 547132546..daae9b804 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -8759,169 +8759,217 @@ async def generate_content_stream( # The image shows a flat lay arrangement of freshly baked blueberry # scones. """ - if getattr( - self._api_client, 'vertexai', False - ) and _extra_utils.has_agent_platform_mcp_servers(config): - raise NotImplementedError( - 'MCP servers are not yet supported for streaming in the Agent' - ' Platform API.' - ) + if not config: + parsed_config = None + elif isinstance(config, dict): + parsed_config = types.GenerateContentConfig(**config) + else: + parsed_config = config.model_copy(deep=True) - # Retrieve and cache any MCP sessions if provided. incompatible_tools_indexes = ( _extra_utils.find_afc_incompatible_tool_indexes( - config, - is_agent_platform=getattr(self._api_client, 'vertexai', False), - ) - ) - # Retrieve and cache any MCP sessions if provided. - parsed_config, mcp_to_genai_tool_adapters = ( - await _extra_utils.parse_config_for_mcp_sessions( - config, + parsed_config, is_agent_platform=getattr(self._api_client, 'vertexai', False), ) ) - if _extra_utils.should_disable_afc(parsed_config): - response = await self._generate_content_stream( - model=model, contents=contents, config=parsed_config - ) - async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def] - async for chunk in response: # type: ignore[attr-defined] - yield chunk - - return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return] + async def stream_generator(): + # Use AsyncExitStack to keep MCP connections alive across the entire stream + async with contextlib.AsyncExitStack() as stack: + current_config = parsed_config - if incompatible_tools_indexes: - original_tools_length = 0 - if isinstance(config, types.GenerateContentConfig): - if config.tools: - original_tools_length = len(config.tools) - elif isinstance(config, dict): - tools = config.get('tools', []) - if tools: - original_tools_length = len(tools) - if len(incompatible_tools_indexes) != original_tools_length: - indices_str = ', '.join(map(str, incompatible_tools_indexes)) - logger.warning( - 'Tools at indices [%s] are not compatible with automatic function ' - 'calling (AFC). AFC is disabled. If AFC is intended, please ' - 'include python callables in the tool list, and do not include ' - 'function declaration and MCP server in the tool list.', - indices_str, + if ( + self._api_client.vertexai + and _extra_utils.has_agent_platform_mcp_servers(current_config) + and current_config is not None + ): + new_tools: list[Any] = [] + if current_config.tools: + for tool in current_config.tools: + if isinstance(tool, types.Tool) and tool.mcp_servers: + if ( + tool.function_declarations + or tool.google_search + or tool.retrieval + or tool.google_search_retrieval + or tool.code_execution + ): + tool_copy = tool.model_copy(update={'mcp_servers': None}) + new_tools.append(tool_copy) + + for server in tool.mcp_servers: + if ( + getattr(server, 'streamable_http_transport', None) + is not None + ): + raise ValueError( + "The 'streamable_http_transport' parameter is only" + ' supported in Gemini Developer API mode, not in Gemini' + ' Enterprise Agent Platform mode.' + ) + + if server.name is not None: + session = await stack.enter_async_context( + _mcp_utils._connect_agent_platform_mcp( + self._api_client, server.name + ) + ) + new_tools.append(session) + else: + raise ValueError( + "Agent Platform MCP servers require a 'name' field." + ) + else: + new_tools.append(tool) + current_config.tools = new_tools + + final_parsed_config, mcp_to_genai_tool_adapters = ( + await _extra_utils.parse_config_for_mcp_sessions( + current_config, + is_agent_platform=getattr(self._api_client, 'vertexai', False), + ) ) - response = await self._generate_content_stream( - model=model, contents=contents, config=parsed_config - ) - async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def] - async for chunk in response: # type: ignore[attr-defined] - yield chunk - - return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return] + if _extra_utils.should_disable_afc(final_parsed_config): + response = await self._generate_content_stream( + model=model, contents=contents, config=final_parsed_config + ) + async for chunk in response: + yield chunk + return + + if incompatible_tools_indexes: + original_tools_length = 0 + if isinstance(config, types.GenerateContentConfig): + if config.tools: + original_tools_length = len(config.tools) + elif isinstance(config, dict): + tools = config.get('tools', []) + if tools: + original_tools_length = len(tools) + if len(incompatible_tools_indexes) != original_tools_length: + indices_str = ', '.join(map(str, incompatible_tools_indexes)) + logger.warning( + 'Tools at indices [%s] are not compatible with automatic' + ' function calling (AFC). AFC is disabled. If AFC is intended,' + ' please include python callables in the tool list, and do not' + ' include function declaration and MCP server in the tool' + ' list.', + indices_str, + ) + response = await self._generate_content_stream( + model=model, contents=contents, config=final_parsed_config + ) + async for chunk in response: + yield chunk + return - # With tool compatibility confirmed, validate that the configuration are - # compatible with each other and raise an error if invalid. - _extra_utils.raise_error_for_afc_incompatible_config(parsed_config) + _extra_utils.raise_error_for_afc_incompatible_config( + final_parsed_config + ) - async def async_generator(model, contents, config): # type: ignore[no-untyped-def] - remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config) - logger.info( - f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.' - ) - automatic_function_calling_history: list[types.Content] = [] - func_response_parts = None - chunk = None - i = 0 - while remaining_remote_calls_afc > 0: - i += 1 - response = await self._generate_content_stream( - model=model, contents=contents, config=config + remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc( + final_parsed_config ) - # TODO: b/453739108 - make AFC logic more robust like the other 3 methods. - if i > 1: - logger.info(f'AFC remote call {i} is done.') - remaining_remote_calls_afc -= 1 - if i > 1 and remaining_remote_calls_afc == 0: - logger.info( - 'Reached max remote calls for automatic function calling.' + logger.info( + 'AFC is enabled with max remote calls:' + f' {remaining_remote_calls_afc}.' + ) + automatic_function_calling_history: list[types.Content] = [] + func_response_parts = None + chunk = None + i = 0 + loop_contents = contents + + while remaining_remote_calls_afc > 0: + i += 1 + response = await self._generate_content_stream( + model=model, contents=loop_contents, config=final_parsed_config ) + if i > 1: + logger.info(f'AFC remote call {i} is done.') + remaining_remote_calls_afc -= 1 + if i > 1 and remaining_remote_calls_afc == 0: + logger.info( + 'Reached max remote calls for automatic function calling.' + ) - function_map = _extra_utils.get_function_map( - config, mcp_to_genai_tool_adapters, is_caller_method_async=True - ) + function_map = _extra_utils.get_function_map( + final_parsed_config, + mcp_to_genai_tool_adapters, + is_caller_method_async=True, + ) - if i == 1: - # First request gets a function call. - # Then get function response parts. - # Yield chunks only if there's no function response parts. - async for chunk in response: # type: ignore[attr-defined] - if not function_map: - contents = _extra_utils.append_chunk_contents(contents, chunk) - yield chunk - else: - if ( - not chunk.candidates - or not chunk.candidates[0].content - or not chunk.candidates[0].content.parts - ): - break - func_response_parts = ( - await _extra_utils.get_function_response_parts_async( - chunk, function_map + if i == 1: + async for chunk in response: + if not function_map: + loop_contents = _extra_utils.append_chunk_contents( + loop_contents, chunk + ) + yield chunk + else: + if ( + not chunk.candidates + or not chunk.candidates[0].content + or not chunk.candidates[0].content.parts + ): + break + func_response_parts = ( + await _extra_utils.get_function_response_parts_async( + chunk, function_map + ) + ) + if not func_response_parts: + loop_contents = _extra_utils.append_chunk_contents( + loop_contents, chunk ) + yield chunk + else: + async for chunk in response: + if _extra_utils.should_append_afc_history(final_parsed_config): + chunk.automatic_function_calling_history = ( + automatic_function_calling_history + ) + loop_contents = _extra_utils.append_chunk_contents( + loop_contents, chunk ) - if not func_response_parts: - contents = _extra_utils.append_chunk_contents(contents, chunk) - yield chunk - - else: - # Second request and beyond, yield chunks. - async for chunk in response: # type: ignore[attr-defined] + yield chunk + if ( + chunk is None + or not chunk.candidates + or not chunk.candidates[0].content + or not chunk.candidates[0].content.parts + ): + break + func_response_parts = ( + await _extra_utils.get_function_response_parts_async( + chunk, function_map + ) + ) - if _extra_utils.should_append_afc_history(config): - chunk.automatic_function_calling_history = ( - automatic_function_calling_history - ) - contents = _extra_utils.append_chunk_contents(contents, chunk) - yield chunk - if ( - chunk is None - or not chunk.candidates - or not chunk.candidates[0].content - or not chunk.candidates[0].content.parts - ): + if not function_map or not func_response_parts: break - func_response_parts = ( - await _extra_utils.get_function_response_parts_async( - chunk, function_map - ) - ) - if not function_map: - break - - if not func_response_parts: - break - if chunk is None: - continue - # Append function response parts to contents for the next request. - func_call_content = chunk.candidates[0].content - func_response_content = types.Content( - role='user', - parts=func_response_parts, - ) - contents = t.t_contents(contents) - if not automatic_function_calling_history: - automatic_function_calling_history.extend(contents) - if isinstance(contents, list) and func_call_content is not None: - contents.append(func_call_content) - contents.append(func_response_content) - if func_call_content is not None: - automatic_function_calling_history.append(func_call_content) - automatic_function_calling_history.append(func_response_content) + if chunk is None: + continue - return async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return] + # Append function response parts to contents for the next request. + func_call_content = chunk.candidates[0].content + func_response_content = types.Content( + role='user', + parts=func_response_parts, + ) + loop_contents = t.t_contents(loop_contents) + if not automatic_function_calling_history: + automatic_function_calling_history.extend(loop_contents) + if isinstance(loop_contents, list) and func_call_content is not None: + loop_contents.append(func_call_content) + loop_contents.append(func_response_content) + if func_call_content is not None: + automatic_function_calling_history.append(func_call_content) + automatic_function_calling_history.append(func_response_content) + + return stream_generator() async def edit_image( self,