Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 189 additions & 141 deletions google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8759,169 +8759,217 @@ async def generate_content_stream(
# The image shows a flat lay arrangement of freshly baked blueberry
# scones.
"""
if getattr(
self._api_client, 'vertexai', False
) and _extra_utils.has_agent_platform_mcp_servers(config):
raise NotImplementedError(
'MCP servers are not yet supported for streaming in the Agent'
' Platform API.'
)
if not config:
parsed_config = None
elif isinstance(config, dict):
parsed_config = types.GenerateContentConfig(**config)
else:
parsed_config = config.model_copy(deep=True)

# Retrieve and cache any MCP sessions if provided.
incompatible_tools_indexes = (
_extra_utils.find_afc_incompatible_tool_indexes(
config,
is_agent_platform=getattr(self._api_client, 'vertexai', False),
)
)
# Retrieve and cache any MCP sessions if provided.
parsed_config, mcp_to_genai_tool_adapters = (
await _extra_utils.parse_config_for_mcp_sessions(
config,
parsed_config,
is_agent_platform=getattr(self._api_client, 'vertexai', False),
)
)
if _extra_utils.should_disable_afc(parsed_config):
response = await self._generate_content_stream(
model=model, contents=contents, config=parsed_config
)

async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def]
async for chunk in response: # type: ignore[attr-defined]
yield chunk

return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
async def stream_generator():
# Use AsyncExitStack to keep MCP connections alive across the entire stream
async with contextlib.AsyncExitStack() as stack:
current_config = parsed_config

if incompatible_tools_indexes:
original_tools_length = 0
if isinstance(config, types.GenerateContentConfig):
if config.tools:
original_tools_length = len(config.tools)
elif isinstance(config, dict):
tools = config.get('tools', [])
if tools:
original_tools_length = len(tools)
if len(incompatible_tools_indexes) != original_tools_length:
indices_str = ', '.join(map(str, incompatible_tools_indexes))
logger.warning(
'Tools at indices [%s] are not compatible with automatic function '
'calling (AFC). AFC is disabled. If AFC is intended, please '
'include python callables in the tool list, and do not include '
'function declaration and MCP server in the tool list.',
indices_str,
if (
self._api_client.vertexai
and _extra_utils.has_agent_platform_mcp_servers(current_config)
and current_config is not None
):
new_tools: list[Any] = []
if current_config.tools:
for tool in current_config.tools:
if isinstance(tool, types.Tool) and tool.mcp_servers:
if (
tool.function_declarations
or tool.google_search
or tool.retrieval
or tool.google_search_retrieval
or tool.code_execution
):
tool_copy = tool.model_copy(update={'mcp_servers': None})
new_tools.append(tool_copy)

for server in tool.mcp_servers:
if (
getattr(server, 'streamable_http_transport', None)
is not None
):
raise ValueError(
"The 'streamable_http_transport' parameter is only"
' supported in Gemini Developer API mode, not in Gemini'
' Enterprise Agent Platform mode.'
)

if server.name is not None:
session = await stack.enter_async_context(
_mcp_utils._connect_agent_platform_mcp(
self._api_client, server.name
)
)
new_tools.append(session)
else:
raise ValueError(
"Agent Platform MCP servers require a 'name' field."
)
else:
new_tools.append(tool)
current_config.tools = new_tools

final_parsed_config, mcp_to_genai_tool_adapters = (
await _extra_utils.parse_config_for_mcp_sessions(
current_config,
is_agent_platform=getattr(self._api_client, 'vertexai', False),
)
)
response = await self._generate_content_stream(
model=model, contents=contents, config=parsed_config
)

async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def]
async for chunk in response: # type: ignore[attr-defined]
yield chunk

return base_async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
if _extra_utils.should_disable_afc(final_parsed_config):
response = await self._generate_content_stream(
model=model, contents=contents, config=final_parsed_config
)
async for chunk in response:
yield chunk
return

if incompatible_tools_indexes:
original_tools_length = 0
if isinstance(config, types.GenerateContentConfig):
if config.tools:
original_tools_length = len(config.tools)
elif isinstance(config, dict):
tools = config.get('tools', [])
if tools:
original_tools_length = len(tools)
if len(incompatible_tools_indexes) != original_tools_length:
indices_str = ', '.join(map(str, incompatible_tools_indexes))
logger.warning(
'Tools at indices [%s] are not compatible with automatic'
' function calling (AFC). AFC is disabled. If AFC is intended,'
' please include python callables in the tool list, and do not'
' include function declaration and MCP server in the tool'
' list.',
indices_str,
)
response = await self._generate_content_stream(
model=model, contents=contents, config=final_parsed_config
)
async for chunk in response:
yield chunk
return

# With tool compatibility confirmed, validate that the configuration are
# compatible with each other and raise an error if invalid.
_extra_utils.raise_error_for_afc_incompatible_config(parsed_config)
_extra_utils.raise_error_for_afc_incompatible_config(
final_parsed_config
)

async def async_generator(model, contents, config): # type: ignore[no-untyped-def]
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config)
logger.info(
f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.'
)
automatic_function_calling_history: list[types.Content] = []
func_response_parts = None
chunk = None
i = 0
while remaining_remote_calls_afc > 0:
i += 1
response = await self._generate_content_stream(
model=model, contents=contents, config=config
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(
final_parsed_config
)
# TODO: b/453739108 - make AFC logic more robust like the other 3 methods.
if i > 1:
logger.info(f'AFC remote call {i} is done.')
remaining_remote_calls_afc -= 1
if i > 1 and remaining_remote_calls_afc == 0:
logger.info(
'Reached max remote calls for automatic function calling.'
logger.info(
'AFC is enabled with max remote calls:'
f' {remaining_remote_calls_afc}.'
)
automatic_function_calling_history: list[types.Content] = []
func_response_parts = None
chunk = None
i = 0
loop_contents = contents

while remaining_remote_calls_afc > 0:
i += 1
response = await self._generate_content_stream(
model=model, contents=loop_contents, config=final_parsed_config
)
if i > 1:
logger.info(f'AFC remote call {i} is done.')
remaining_remote_calls_afc -= 1
if i > 1 and remaining_remote_calls_afc == 0:
logger.info(
'Reached max remote calls for automatic function calling.'
)

function_map = _extra_utils.get_function_map(
config, mcp_to_genai_tool_adapters, is_caller_method_async=True
)
function_map = _extra_utils.get_function_map(
final_parsed_config,
mcp_to_genai_tool_adapters,
is_caller_method_async=True,
)

if i == 1:
# First request gets a function call.
# Then get function response parts.
# Yield chunks only if there's no function response parts.
async for chunk in response: # type: ignore[attr-defined]
if not function_map:
contents = _extra_utils.append_chunk_contents(contents, chunk)
yield chunk
else:
if (
not chunk.candidates
or not chunk.candidates[0].content
or not chunk.candidates[0].content.parts
):
break
func_response_parts = (
await _extra_utils.get_function_response_parts_async(
chunk, function_map
if i == 1:
async for chunk in response:
if not function_map:
loop_contents = _extra_utils.append_chunk_contents(
loop_contents, chunk
)
yield chunk
else:
if (
not chunk.candidates
or not chunk.candidates[0].content
or not chunk.candidates[0].content.parts
):
break
func_response_parts = (
await _extra_utils.get_function_response_parts_async(
chunk, function_map
)
)
if not func_response_parts:
loop_contents = _extra_utils.append_chunk_contents(
loop_contents, chunk
)
yield chunk
else:
async for chunk in response:
if _extra_utils.should_append_afc_history(final_parsed_config):
chunk.automatic_function_calling_history = (
automatic_function_calling_history
)
loop_contents = _extra_utils.append_chunk_contents(
loop_contents, chunk
)
if not func_response_parts:
contents = _extra_utils.append_chunk_contents(contents, chunk)
yield chunk

else:
# Second request and beyond, yield chunks.
async for chunk in response: # type: ignore[attr-defined]
yield chunk
if (
chunk is None
or not chunk.candidates
or not chunk.candidates[0].content
or not chunk.candidates[0].content.parts
):
break
func_response_parts = (
await _extra_utils.get_function_response_parts_async(
chunk, function_map
)
)

if _extra_utils.should_append_afc_history(config):
chunk.automatic_function_calling_history = (
automatic_function_calling_history
)
contents = _extra_utils.append_chunk_contents(contents, chunk)
yield chunk
if (
chunk is None
or not chunk.candidates
or not chunk.candidates[0].content
or not chunk.candidates[0].content.parts
):
if not function_map or not func_response_parts:
break
func_response_parts = (
await _extra_utils.get_function_response_parts_async(
chunk, function_map
)
)
if not function_map:
break

if not func_response_parts:
break

if chunk is None:
continue
# Append function response parts to contents for the next request.
func_call_content = chunk.candidates[0].content
func_response_content = types.Content(
role='user',
parts=func_response_parts,
)
contents = t.t_contents(contents)
if not automatic_function_calling_history:
automatic_function_calling_history.extend(contents)
if isinstance(contents, list) and func_call_content is not None:
contents.append(func_call_content)
contents.append(func_response_content)
if func_call_content is not None:
automatic_function_calling_history.append(func_call_content)
automatic_function_calling_history.append(func_response_content)
if chunk is None:
continue

return async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
# Append function response parts to contents for the next request.
func_call_content = chunk.candidates[0].content
func_response_content = types.Content(
role='user',
parts=func_response_parts,
)
loop_contents = t.t_contents(loop_contents)
if not automatic_function_calling_history:
automatic_function_calling_history.extend(loop_contents)
if isinstance(loop_contents, list) and func_call_content is not None:
loop_contents.append(func_call_content)
loop_contents.append(func_response_content)
if func_call_content is not None:
automatic_function_calling_history.append(func_call_content)
automatic_function_calling_history.append(func_response_content)

return stream_generator()

async def edit_image(
self,
Expand Down
Loading