Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/google/adk/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def is_final_response(self) -> bool:
if self.actions.skip_summarization or self.long_running_tool_ids:
return True
return (
not self.get_function_calls()
self.turn_complete is not False
and not self.get_function_calls()
and not self.get_function_responses()
and not self.partial
and not self.has_trailing_code_execution_result()
Expand Down
54 changes: 41 additions & 13 deletions src/google/adk/models/interactions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,33 @@
_NEW_LINE = '\n'


def _merge_text_with_overlap(existing: str, incoming: str) -> str:
"""Merge streamed text fragments while avoiding overlap duplication."""
if not existing:
return incoming
if not incoming:
return existing

max_overlap = min(len(existing), len(incoming))
for size in range(max_overlap, 0, -1):
if existing.endswith(incoming[:size]):
return existing + incoming[size:]
return existing + incoming


def _append_delta_text_part(aggregated_parts: list[types.Part], text: str):
"""Append text to aggregated parts, merging with trailing text if present."""
if not text:
return

if aggregated_parts and aggregated_parts[-1].text is not None:
merged_text = _merge_text_with_overlap(aggregated_parts[-1].text, text)
aggregated_parts[-1] = types.Part.from_text(text=merged_text)
return

aggregated_parts.append(types.Part.from_text(text=text))


def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]:
"""Convert a types.Part to an interaction content dict.

Expand Down Expand Up @@ -487,8 +514,8 @@ def convert_interaction_event_to_llm_response(
if delta_type == 'text':
text = delta.text or ''
if text:
_append_delta_text_part(aggregated_parts, text)
part = types.Part.from_text(text=text)
aggregated_parts.append(part)
return LlmResponse(
content=types.Content(role='model', parts=[part]),
partial=True,
Expand Down Expand Up @@ -539,18 +566,15 @@ def convert_interaction_event_to_llm_response(
)

elif event_type == 'content.stop':
# Content streaming finished, return aggregated content
if aggregated_parts:
return LlmResponse(
content=types.Content(role='model', parts=list(aggregated_parts)),
partial=False,
turn_complete=False,
interaction_id=interaction_id,
)
# Content.stop is a stream boundary marker.
# Final content emission happens at interaction.status_update or stream end
# to avoid duplicate final responses.
return None

elif event_type == 'interaction':
# Final interaction event with complete data
return convert_interaction_to_llm_response(event)
# We intentionally do not emit from this event in streaming mode because
# interaction outputs can duplicate already aggregated content deltas.
return None

elif event_type == 'interaction.status_update':
status = getattr(event, 'status', None)
Expand Down Expand Up @@ -992,6 +1016,7 @@ async def generate_content_via_interactions(
)

aggregated_parts: list[types.Part] = []
has_emitted_turn_complete = False
async for event in responses:
# Log the streaming event
logger.debug(build_interactions_event_log(event))
Expand All @@ -1003,10 +1028,13 @@ async def generate_content_via_interactions(
event, aggregated_parts, current_interaction_id
)
if llm_response:
if llm_response.turn_complete:
has_emitted_turn_complete = True
yield llm_response

# Final aggregated response
if aggregated_parts:
# Final aggregated response fallback if the stream never emitted a
# completion event (e.g., missing interaction.status_update).
if aggregated_parts and not has_emitted_turn_complete:
yield LlmResponse(
content=types.Content(role='model', parts=aggregated_parts),
partial=False,
Expand Down
40 changes: 40 additions & 0 deletions tests/unittests/events_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Event model helpers."""

from google.adk.events.event import Event
from google.genai import types


def test_is_final_response_false_when_turn_incomplete():
"""Event is not final when turn_complete is explicitly False."""
event = Event(
author='agent',
turn_complete=False,
content=types.Content(role='model', parts=[types.Part(text='partial')]),
)

assert not event.is_final_response()


def test_is_final_response_true_when_turn_complete():
"""Event is final for plain text response when turn is complete."""
event = Event(
author='agent',
turn_complete=True,
content=types.Content(role='model', parts=[types.Part(text='done')]),
)

assert event.is_final_response()
101 changes: 101 additions & 0 deletions tests/unittests/models/test_interactions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

"""Tests for interactions_utils.py conversion functions."""

import asyncio
import json
from unittest.mock import AsyncMock
from unittest.mock import MagicMock

from google.adk.models import interactions_utils
Expand Down Expand Up @@ -759,3 +761,102 @@ def test_full_conversation(self):
assert len(result) == 2
assert result[0].parts[0].text == 'Great'
assert result[1].parts[0].text == 'Tell me more'


class TestGenerateContentViaInteractionsStreaming:
"""Tests for streaming generation via interactions API."""

def test_emits_single_final_response_with_status_update(self):
"""Ensures no duplicate final response is emitted in streaming mode."""
delta_1 = MagicMock()
delta_1.event_type = 'content.delta'
delta_1.id = 'interaction_1'
delta_1.delta = MagicMock(type='text', text='Hello ')

delta_2 = MagicMock()
delta_2.event_type = 'content.delta'
delta_2.id = 'interaction_1'
delta_2.delta = MagicMock(type='text', text='world')

status_update = MagicMock()
status_update.event_type = 'interaction.status_update'
status_update.id = 'interaction_1'
status_update.status = 'completed'

async def _stream_events():
for event in [delta_1, delta_2, status_update]:
yield event

api_client = MagicMock()
api_client.aio.interactions.create = AsyncMock(
return_value=_stream_events()
)

llm_request = LlmRequest(
model='gemini-2.5-flash',
contents=[types.Content(role='user', parts=[types.Part(text='hi')])],
)

async def _collect_responses():
return [
response
async for response in
interactions_utils.generate_content_via_interactions(
api_client=api_client, llm_request=llm_request, stream=True
)
]

responses = asyncio.run(_collect_responses())

assert len(responses) == 3
assert responses[0].partial is True
assert responses[1].partial is True
assert responses[2].turn_complete is True
assert responses[2].content.parts[0].text == 'Hello world'

def test_merges_overlapping_text_deltas_in_final_response(self):
"""Ensures overlapping text chunks are merged without duplication."""
delta_1 = MagicMock()
delta_1.event_type = 'content.delta'
delta_1.id = 'interaction_2'
delta_1.delta = MagicMock(type='text', text='Hello wor')

delta_2 = MagicMock()
delta_2.event_type = 'content.delta'
delta_2.id = 'interaction_2'
delta_2.delta = MagicMock(type='text', text='world')

content_stop = MagicMock()
content_stop.event_type = 'content.stop'
content_stop.id = 'interaction_2'

async def _stream_events():
for event in [delta_1, delta_2, content_stop]:
yield event

api_client = MagicMock()
api_client.aio.interactions.create = AsyncMock(
return_value=_stream_events()
)

llm_request = LlmRequest(
model='gemini-2.5-flash',
contents=[types.Content(role='user', parts=[types.Part(text='hi')])],
)

async def _collect_responses():
return [
response
async for response in
interactions_utils.generate_content_via_interactions(
api_client=api_client, llm_request=llm_request, stream=True
)
]

responses = asyncio.run(_collect_responses())

assert len(responses) == 3
final_response = responses[-1]
assert final_response.turn_complete is True
assert final_response.content.parts[0].text == 'Hello world'

Loading