Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions google/genai/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,27 @@ def _extract_curated_history(
is_valid = False
i += 1
if is_valid:
curated_history.extend(current_output)
curated_history.extend(_merge_model_outputs(current_output))
elif curated_history:
curated_history.pop()
return curated_history


def _merge_model_outputs(contents: list[Content]) -> list[Content]:
"""Merge adjacent model chunks into one turn for request history."""
merged: list[Content] = []
for content in contents:
if (
content.role == "model"
and merged
and merged[-1].role == "model"
):
merged[-1].parts = (merged[-1].parts or []) + (content.parts or [])
else:
merged.append(content.model_copy())
return merged


class _BaseChat:
"""Base chat session."""

Expand Down Expand Up @@ -166,7 +181,7 @@ def record_history(
self._comprehensive_history.extend(output_contents)
if is_valid:
self._curated_history.extend(input_contents)
self._curated_history.extend(output_contents)
self._curated_history.extend(_merge_model_outputs(output_contents))

def get_history(self, curated: bool = False) -> list[Content]:
"""Returns the chat history.
Expand Down
131 changes: 127 additions & 4 deletions google/genai/tests/chats/test_get_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,43 @@ def mock_generate_content_stream_afc_history():
yield mock_generate_content


@pytest.fixture
def mock_generate_content_stream_with_split_model_output():
with mock.patch.object(
models.Models, 'generate_content_stream'
) as mock_generate_content:
mock_generate_content.return_value = [
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
role='model',
parts=[types.Part.from_text(text='reasoning')],
),
)
]
),
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
role='model',
parts=[
types.Part(
function_call=types.FunctionCall(
name='foo', args={'bar': 'baz'}
)
)
],
),
finish_reason=types.FinishReason.STOP,
)
]
),
]
yield mock_generate_content


def test_history_start_with_valid_model_content():
history = [
types.Content(
Expand Down Expand Up @@ -251,7 +288,15 @@ def test_history_with_consecutive_valid_model_outputs():
chat = chats_module.create(model='gemini-2.5-flash', history=history)

assert chat.get_history() == history
assert chat.get_history(curated=True) == history
assert chat.get_history(curated=True) == [
types.Content(
role='model',
parts=[
types.Part.from_text(text='model output 1'),
types.Part.from_text(text='model output 2'),
],
),
]


def test_history_with_valid_and_invalid_model_output():
Expand Down Expand Up @@ -343,7 +388,19 @@ def test_sync_chat_create():
chat = chats_module.create(model='gemini-2.5-flash', history=history)

assert chat.get_history() == history
assert chat.get_history(curated=True) == history
assert chat.get_history(curated=True) == [
types.Content(
role='user', parts=[types.Part.from_text(text='user input turn 1')]
),
types.Content(
role='model',
parts=[
types.Part.from_text(text='model output turn 1'),
types.Part.from_text(text='model output turn 1'),
types.Part.from_text(text='user input turn 2'),
],
),
]


def test_async_chat_create():
Expand Down Expand Up @@ -374,7 +431,20 @@ def test_async_chat_create():
chat = chats_module.create(model='gemini-2.5-flash', history=history)

assert chat.get_history() == history
assert chat.get_history(curated=True) == history
assert chat.get_history(curated=True) == [
types.Content(
role='user', parts=[types.Part.from_text(text='user input turn 1')]
),
types.Content(
role='model',
parts=[
types.Part.from_text(text='model output turn 1'),
types.Part.from_text(text='model output turn 1'),
types.Part.from_text(text='user input turn 2'),
types.Part.from_text(text='model output turn 2'),
],
),
]


def test_sync_chat_create_with_history_dict():
Expand Down Expand Up @@ -470,7 +540,15 @@ def test_history_with_invalid_turns():
comprehensive_history.append(invalid_output)
curated_history = []
curated_history.append(valid_input)
curated_history.extend(valid_output)
curated_history.append(
types.Content(
role='model',
parts=[
valid_output[0].parts[0],
valid_output[1].parts[0],
],
)
)

models_module = models.Models(mock_api_client)
chats_module = chats.Chats(modules=models_module)
Expand Down Expand Up @@ -596,3 +674,48 @@ def test_chat_stream_with_afc_history(mock_generate_content_stream_afc_history):
]
assert chat.get_history() == expected_history
assert chat.get_history(curated=True) == expected_history


def test_chat_stream_merges_model_chunks_for_curated_history(
mock_generate_content_stream_with_split_model_output,
):
models_module = models.Models(mock_api_client)
chats_module = chats.Chats(modules=models_module)
chat = chats_module.create(model='gemini-2.5-flash')

for _ in chat.send_message_stream('Hello'):
pass

expected_comprehensive_history = [
types.UserContent(parts=[types.Part.from_text(text='Hello')]),
types.Content(
role='model',
parts=[types.Part.from_text(text='reasoning')],
),
types.Content(
role='model',
parts=[
types.Part(
function_call=types.FunctionCall(
name='foo', args={'bar': 'baz'}
)
)
],
),
]
expected_curated_history = [
types.UserContent(parts=[types.Part.from_text(text='Hello')]),
types.Content(
role='model',
parts=[
types.Part.from_text(text='reasoning'),
types.Part(
function_call=types.FunctionCall(
name='foo', args={'bar': 'baz'}
)
),
],
),
]
assert chat.get_history() == expected_comprehensive_history
assert chat.get_history(curated=True) == expected_curated_history
Loading