From a31c1d2a5fcb5df132dd07c70ba6991261307ff1 Mon Sep 17 00:00:00 2001 From: Nicolas Mota Date: Fri, 5 Jun 2026 19:23:39 -0300 Subject: [PATCH] feat(run_config): add model_input_context for transient context in LLM requests This update introduces a new attribute, model_input_context, to the RunConfig class, allowing callers to provide transient context for each invocation without altering the conversation history. Additionally, the LLM request processing has been updated to incorporate this context appropriately. Unit tests have been added to verify the correct behavior of this feature. --- src/google/adk/agents/run_config.py | 8 + src/google/adk/flows/llm_flows/contents.py | 30 ++++ .../agents/test_llm_agent_include_contents.py | 148 ++++++++++++++++++ tests/unittests/agents/test_run_config.py | 8 + 4 files changed, 194 insertions(+) diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index e059cd957d..7e4120b43d 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -344,6 +344,14 @@ class RunConfig(BaseModel): ) """ + model_input_context: list[types.Content] | None = None + """Transient context to include in the model input for this invocation. + + The Runner does not persist these contents to the session. They are only + added to the LLM request assembled for the current invocation, which lets + callers provide per-turn context without changing the conversation history. + """ + @model_validator(mode='before') @classmethod def check_for_deprecated_save_live_audio(cls, data: Any) -> Any: diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index feeb8ef972..56e9f5aba9 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -85,6 +85,16 @@ async def run_async( preserve_function_call_ids=preserve_function_call_ids, ) + if ( + invocation_context.run_config + and invocation_context.run_config.model_input_context + ): + _add_model_input_context_to_user_content( + invocation_context, + llm_request, + copy.deepcopy(invocation_context.run_config.model_input_context), + ) + # Add instruction-related contents to proper position in conversation await _add_instructions_to_user_content( invocation_context, llm_request, instruction_related_contents @@ -845,6 +855,26 @@ def _content_contains_function_response(content: types.Content) -> bool: return False +def _add_model_input_context_to_user_content( + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_input_context: list[types.Content], +) -> None: + """Insert transient model input context before the invocation user content.""" + if not model_input_context: + return + + insert_index = 0 + user_content = invocation_context.user_content + if user_content: + for i in range(len(llm_request.contents) - 1, -1, -1): + if llm_request.contents[i] == user_content: + insert_index = i + break + + llm_request.contents[insert_index:insert_index] = model_input_context + + async def _add_instructions_to_user_content( invocation_context: InvocationContext, llm_request: LlmRequest, diff --git a/tests/unittests/agents/test_llm_agent_include_contents.py b/tests/unittests/agents/test_llm_agent_include_contents.py index a196f93553..c93701b743 100644 --- a/tests/unittests/agents/test_llm_agent_include_contents.py +++ b/tests/unittests/agents/test_llm_agent_include_contents.py @@ -15,6 +15,7 @@ """Unit tests for LlmAgent include_contents field behavior.""" from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig from google.adk.agents.sequential_agent import SequentialAgent from google.genai import types import pytest @@ -189,6 +190,153 @@ def simple_tool(message: str) -> dict: assert len(mock_model.requests[0].config.tools) > 0 +def test_model_input_context_is_sent_to_model_without_persisting_to_session(): + mock_model = testing_utils.MockModel.create(responses=["Answer"]) + agent = LlmAgent(name="test_agent", model=mock_model) + runner = testing_utils.InMemoryRunner(agent) + session = runner.session + + list( + runner.runner.run( + user_id=session.user_id, + session_id=session.id, + new_message=testing_utils.get_user_content("Question"), + run_config=RunConfig( + model_input_context=[ + types.UserContent("Relevant context for this turn") + ] + ), + ) + ) + + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "Relevant context for this turn"), + ("user", "Question"), + ] + assert testing_utils.simplify_events(runner.session.events) == [ + ("user", "Question"), + ("test_agent", "Answer"), + ] + + +def test_model_input_context_stays_before_user_message_after_tool_call(): + def simple_tool(message: str) -> dict: + return {"result": f"Tool processed: {message}"} + + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call( + name="simple_tool", args={"message": "payload"} + ), + "Answer", + ] + ) + agent = LlmAgent(name="test_agent", model=mock_model, tools=[simple_tool]) + runner = testing_utils.InMemoryRunner(agent) + session = runner.session + + list( + runner.runner.run( + user_id=session.user_id, + session_id=session.id, + new_message=testing_utils.get_user_content("Question"), + run_config=RunConfig( + model_input_context=[ + types.UserContent("Relevant context for this turn") + ] + ), + ) + ) + + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "Relevant context for this turn"), + ("user", "Question"), + ] + assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [ + ("user", "Relevant context for this turn"), + ("user", "Question"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "payload"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", + response={"result": "Tool processed: payload"}, + ), + ), + ] + assert testing_utils.simplify_events(runner.session.events) == [ + ("user", "Question"), + ( + "test_agent", + types.Part.from_function_call( + name="simple_tool", args={"message": "payload"} + ), + ), + ( + "test_agent", + types.Part.from_function_response( + name="simple_tool", + response={"result": "Tool processed: payload"}, + ), + ), + ("test_agent", "Answer"), + ] + + +def test_model_input_context_with_include_contents_none_sub_agent(): + agent1_model = testing_utils.MockModel.create( + responses=["Agent1 response: XYZ"] + ) + agent1 = LlmAgent(name="agent1", model=agent1_model) + + agent2_model = testing_utils.MockModel.create( + responses=["Agent2 final response"] + ) + agent2 = LlmAgent( + name="agent2", + model=agent2_model, + include_contents="none", + ) + sequential_agent = SequentialAgent( + name="sequential_test_agent", sub_agents=[agent1, agent2] + ) + runner = testing_utils.InMemoryRunner(sequential_agent) + session = runner.session + + list( + runner.runner.run( + user_id=session.user_id, + session_id=session.id, + new_message=testing_utils.get_user_content("Original user request"), + run_config=RunConfig( + model_input_context=[ + types.UserContent("Relevant context for this turn") + ] + ), + ) + ) + + assert testing_utils.simplify_contents(agent1_model.requests[0].contents) == [ + ("user", "Relevant context for this turn"), + ("user", "Original user request"), + ] + assert testing_utils.simplify_contents(agent2_model.requests[0].contents) == [ + ("user", "Relevant context for this turn"), + ( + "user", + [ + types.Part(text="For context:"), + types.Part(text="[agent1] said: Agent1 response: XYZ"), + ], + ), + ] + + @pytest.mark.asyncio async def test_include_contents_none_sequential_agents(): """Test include_contents='none' with sequential agents.""" diff --git a/tests/unittests/agents/test_run_config.py b/tests/unittests/agents/test_run_config.py index cbb82af019..c08a1a52c3 100644 --- a/tests/unittests/agents/test_run_config.py +++ b/tests/unittests/agents/test_run_config.py @@ -97,3 +97,11 @@ def test_avatar_config_with_name(): assert run_config.avatar_config == avatar_config assert run_config.avatar_config.avatar_name == "test_avatar" assert run_config.avatar_config.customized_avatar is None + + +def test_model_input_context_accepts_transient_contents(): + context_content = types.UserContent("Relevant context for this turn") + + run_config = RunConfig(model_input_context=[context_content]) + + assert run_config.model_input_context == [context_content]