diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 8c1568ccc7..8229a58a52 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -143,6 +143,46 @@ "before a response was recorded)." ) +_LITELLM_THOUGHT_SIGNATURE_SEPARATOR = "__thought__" + + +def _decode_litellm_tool_call_id( + tool_call_id: str, +) -> tuple[str, Optional[bytes]]: + """Extracts thought_signature bytes from a LiteLLM tool call id.""" + if not tool_call_id: + return tool_call_id, None + + base_id, separator, encoded_signature = tool_call_id.partition( + _LITELLM_THOUGHT_SIGNATURE_SEPARATOR + ) + if not separator or not encoded_signature: + return base_id, None + + try: + return base_id, base64.b64decode(encoded_signature) + except (ValueError, TypeError) as err: + logger.warning( + "Failed to decode thought_signature from tool call id %r: %s", + tool_call_id, + err, + ) + return base_id, None + + +def _encode_litellm_tool_call_id( + tool_call_id: Optional[str], thought_signature: Optional[bytes] +) -> Optional[str]: + """Embeds thought_signature bytes in a LiteLLM-compatible tool call id.""" + if not tool_call_id or not thought_signature: + return tool_call_id + + encoded_signature = base64.b64encode(thought_signature).decode("utf-8") + return ( + f"{tool_call_id}{_LITELLM_THOUGHT_SIGNATURE_SEPARATOR}{encoded_signature}" + ) + + _LITELLM_IMPORTED = False _LITELLM_GLOBAL_SYMBOLS = ( "ChatCompletionAssistantMessage", @@ -673,7 +713,10 @@ async def _content_to_message_param( tool_calls.append( ChatCompletionAssistantToolCall( type="function", - id=part.function_call.id, + id=_encode_litellm_tool_call_id( + part.function_call.id, + part.thought_signature, + ), function=Function( name=part.function_call.name, arguments=_safe_json_serialize(part.function_call.args), @@ -1490,7 +1533,12 @@ def _message_to_generate_content_response( name=tool_call.function.name, args=json.loads(tool_call.function.arguments or "{}"), ) - part.function_call.id = tool_call.id + tool_call_id, thought_signature = _decode_litellm_tool_call_id( + tool_call.id + ) + part.function_call.id = tool_call_id + if thought_signature: + part.thought_signature = thought_signature parts.append(part) return LlmResponse( diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 2bd5f7d226..2e66b6d2da 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License +import base64 import contextlib import json import logging @@ -2218,6 +2219,58 @@ def test_message_to_generate_content_response_tool_call(): assert response.content.parts[0].function_call.id == "test_tool_call_id" +def test_message_to_generate_content_response_tool_call_with_thought_signature(): + signature = b"gemini_signature" + encoded_signature = base64.b64encode(signature).decode("utf-8") + message = ChatCompletionAssistantMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + type="function", + id=f"test_tool_call_id__thought__{encoded_signature}", + function=Function( + name="test_function", + arguments='{"test_arg": "test_value"}', + ), + ) + ], + ) + + response = _message_to_generate_content_response(message) + assert response.content.role == "model" + assert response.content.parts[0].function_call.name == "test_function" + assert response.content.parts[0].function_call.args == { + "test_arg": "test_value" + } + assert response.content.parts[0].function_call.id == "test_tool_call_id" + assert response.content.parts[0].thought_signature == signature + + +@pytest.mark.asyncio +async def test_content_to_message_param_embeds_thought_signature_in_tool_call(): + part = types.Part.from_function_call( + name="test_function", + args={"test_arg": "test_value"}, + ) + part.function_call.id = "test_tool_call_id" + part.thought_signature = b"gemini_signature" + content = types.Content(role="model", parts=[part]) + + message = await _content_to_message_param(content) + + tool_calls = message["tool_calls"] + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0][ + "id" + ] == "test_tool_call_id__thought__" + base64.b64encode( + b"gemini_signature" + ).decode( + "utf-8" + ) + + def test_message_to_generate_content_response_inline_tool_call_text(): message = ChatCompletionAssistantMessage( role="assistant",