From cb4b2b7b169f7e555b4f5787b5d92a3ccde681da Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Mon, 8 Jun 2026 12:13:45 +0800 Subject: [PATCH] fix: merge streamed model history turns --- google/genai/chats.py | 19 ++- google/genai/tests/chats/test_get_history.py | 131 ++++++++++++++++++- 2 files changed, 144 insertions(+), 6 deletions(-) diff --git a/google/genai/chats.py b/google/genai/chats.py index 3d5e181e9..722461ffe 100644 --- a/google/genai/chats.py +++ b/google/genai/chats.py @@ -96,12 +96,27 @@ def _extract_curated_history( is_valid = False i += 1 if is_valid: - curated_history.extend(current_output) + curated_history.extend(_merge_model_outputs(current_output)) elif curated_history: curated_history.pop() return curated_history +def _merge_model_outputs(contents: list[Content]) -> list[Content]: + """Merge adjacent model chunks into one turn for request history.""" + merged: list[Content] = [] + for content in contents: + if ( + content.role == "model" + and merged + and merged[-1].role == "model" + ): + merged[-1].parts = (merged[-1].parts or []) + (content.parts or []) + else: + merged.append(content.model_copy()) + return merged + + class _BaseChat: """Base chat session.""" @@ -166,7 +181,7 @@ def record_history( self._comprehensive_history.extend(output_contents) if is_valid: self._curated_history.extend(input_contents) - self._curated_history.extend(output_contents) + self._curated_history.extend(_merge_model_outputs(output_contents)) def get_history(self, curated: bool = False) -> list[Content]: """Returns the chat history. diff --git a/google/genai/tests/chats/test_get_history.py b/google/genai/tests/chats/test_get_history.py index e29862a77..c9b2acef1 100644 --- a/google/genai/tests/chats/test_get_history.py +++ b/google/genai/tests/chats/test_get_history.py @@ -156,6 +156,43 @@ def mock_generate_content_stream_afc_history(): yield mock_generate_content +@pytest.fixture +def mock_generate_content_stream_with_split_model_output(): + with mock.patch.object( + models.Models, 'generate_content_stream' + ) as mock_generate_content: + mock_generate_content.return_value = [ + types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role='model', + parts=[types.Part.from_text(text='reasoning')], + ), + ) + ] + ), + types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall( + name='foo', args={'bar': 'baz'} + ) + ) + ], + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ), + ] + yield mock_generate_content + + def test_history_start_with_valid_model_content(): history = [ types.Content( @@ -251,7 +288,15 @@ def test_history_with_consecutive_valid_model_outputs(): chat = chats_module.create(model='gemini-2.5-flash', history=history) assert chat.get_history() == history - assert chat.get_history(curated=True) == history + assert chat.get_history(curated=True) == [ + types.Content( + role='model', + parts=[ + types.Part.from_text(text='model output 1'), + types.Part.from_text(text='model output 2'), + ], + ), + ] def test_history_with_valid_and_invalid_model_output(): @@ -343,7 +388,19 @@ def test_sync_chat_create(): chat = chats_module.create(model='gemini-2.5-flash', history=history) assert chat.get_history() == history - assert chat.get_history(curated=True) == history + assert chat.get_history(curated=True) == [ + types.Content( + role='user', parts=[types.Part.from_text(text='user input turn 1')] + ), + types.Content( + role='model', + parts=[ + types.Part.from_text(text='model output turn 1'), + types.Part.from_text(text='model output turn 1'), + types.Part.from_text(text='user input turn 2'), + ], + ), + ] def test_async_chat_create(): @@ -374,7 +431,20 @@ def test_async_chat_create(): chat = chats_module.create(model='gemini-2.5-flash', history=history) assert chat.get_history() == history - assert chat.get_history(curated=True) == history + assert chat.get_history(curated=True) == [ + types.Content( + role='user', parts=[types.Part.from_text(text='user input turn 1')] + ), + types.Content( + role='model', + parts=[ + types.Part.from_text(text='model output turn 1'), + types.Part.from_text(text='model output turn 1'), + types.Part.from_text(text='user input turn 2'), + types.Part.from_text(text='model output turn 2'), + ], + ), + ] def test_sync_chat_create_with_history_dict(): @@ -470,7 +540,15 @@ def test_history_with_invalid_turns(): comprehensive_history.append(invalid_output) curated_history = [] curated_history.append(valid_input) - curated_history.extend(valid_output) + curated_history.append( + types.Content( + role='model', + parts=[ + valid_output[0].parts[0], + valid_output[1].parts[0], + ], + ) + ) models_module = models.Models(mock_api_client) chats_module = chats.Chats(modules=models_module) @@ -596,3 +674,48 @@ def test_chat_stream_with_afc_history(mock_generate_content_stream_afc_history): ] assert chat.get_history() == expected_history assert chat.get_history(curated=True) == expected_history + + +def test_chat_stream_merges_model_chunks_for_curated_history( + mock_generate_content_stream_with_split_model_output, +): + models_module = models.Models(mock_api_client) + chats_module = chats.Chats(modules=models_module) + chat = chats_module.create(model='gemini-2.5-flash') + + for _ in chat.send_message_stream('Hello'): + pass + + expected_comprehensive_history = [ + types.UserContent(parts=[types.Part.from_text(text='Hello')]), + types.Content( + role='model', + parts=[types.Part.from_text(text='reasoning')], + ), + types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall( + name='foo', args={'bar': 'baz'} + ) + ) + ], + ), + ] + expected_curated_history = [ + types.UserContent(parts=[types.Part.from_text(text='Hello')]), + types.Content( + role='model', + parts=[ + types.Part.from_text(text='reasoning'), + types.Part( + function_call=types.FunctionCall( + name='foo', args={'bar': 'baz'} + ) + ), + ], + ), + ] + assert chat.get_history() == expected_comprehensive_history + assert chat.get_history(curated=True) == expected_curated_history