diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index ed9b2106..d197dbaa 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -56,6 +56,7 @@ public class InvocationContext { private final ResumabilityConfig resumabilityConfig; @Nullable private final EventsCompactionConfig eventsCompactionConfig; private final InvocationCostManager invocationCostManager; + private final Map callbackContextData; private Optional branch; private BaseAgent agent; @@ -80,6 +81,7 @@ protected InvocationContext(Builder builder) { this.resumabilityConfig = builder.resumabilityConfig; this.eventsCompactionConfig = builder.eventsCompactionConfig; this.invocationCostManager = builder.invocationCostManager; + this.callbackContextData = builder.callbackContextData; } /** @@ -306,6 +308,14 @@ public RunConfig runConfig() { return runConfig; } + /** + * Returns a map for storing temporary context data that can be shared between different parts of + * the invocation (e.g., before/on/after model callbacks). + */ + public Map callbackContextData() { + return callbackContextData; + } + /** Returns agent-specific state saved within this invocation. */ public Map agentStates() { return agentStates; @@ -437,6 +447,7 @@ private Builder(InvocationContext context) { this.resumabilityConfig = context.resumabilityConfig; this.eventsCompactionConfig = context.eventsCompactionConfig; this.invocationCostManager = context.invocationCostManager; + this.callbackContextData = context.callbackContextData; } private BaseSessionService sessionService; @@ -457,6 +468,7 @@ private Builder(InvocationContext context) { private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); @Nullable private EventsCompactionConfig eventsCompactionConfig; private InvocationCostManager invocationCostManager = new InvocationCostManager(); + private Map callbackContextData = new ConcurrentHashMap<>(); /** * Sets the session service for managing session state. @@ -692,6 +704,18 @@ public Builder eventsCompactionConfig(@Nullable EventsCompactionConfig eventsCom return this; } + /** + * Sets the callback context data for the invocation. + * + * @param callbackContextData the callback context data. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder callbackContextData(Map callbackContextData) { + this.callbackContextData = callbackContextData; + return this; + } + /** * Builds the {@link InvocationContext} instance. * @@ -728,7 +752,8 @@ public boolean equals(Object o) { && Objects.equals(endOfAgents, that.endOfAgents) && Objects.equals(resumabilityConfig, that.resumabilityConfig) && Objects.equals(eventsCompactionConfig, that.eventsCompactionConfig) - && Objects.equals(invocationCostManager, that.invocationCostManager); + && Objects.equals(invocationCostManager, that.invocationCostManager) + && Objects.equals(callbackContextData, that.callbackContextData); } @Override @@ -751,6 +776,7 @@ public int hashCode() { endOfAgents, resumabilityConfig, eventsCompactionConfig, - invocationCostManager); + invocationCostManager, + callbackContextData); } } diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index 64d2f5bf..61135c78 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -23,10 +23,13 @@ import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; import com.google.adk.memory.BaseMemoryService; +import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; +import com.google.adk.summarizer.EventsCompactionConfig; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; @@ -34,6 +37,8 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -178,6 +183,25 @@ public void testCopyOf() { assertThat(copiedContext.endInvocation()).isEqualTo(originalContext.endInvocation()); assertThat(copiedContext.activeStreamingTools()) .isEqualTo(originalContext.activeStreamingTools()); + assertThat(copiedContext.callbackContextData()) + .isSameInstanceAs(originalContext.callbackContextData()); + } + + @Test + public void testBuildWithCallbackContextData() { + Map data = new ConcurrentHashMap<>(); + data.put("key", "value"); + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .callbackContextData(data) + .build(); + + assertThat(context.callbackContextData()).isEqualTo(data); + assertThat(context.callbackContextData()).isSameInstanceAs(data); } @Test @@ -404,6 +428,22 @@ public void testEquals_differentValues() { assertThat(context.equals(contextWithDiffAgent)).isFalse(); assertThat(context.equals(contextWithUserContentEmpty)).isFalse(); assertThat(context.equals(contextWithLiveQueuePresent)).isFalse(); + + InvocationContext contextWithDiffCallbackContextData = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(userContent) + .runConfig(runConfig) + .endInvocation(false) + .callbackContextData(ImmutableMap.of("key", "value")) + .build(); + assertThat(context.equals(contextWithDiffCallbackContextData)).isFalse(); } @Test @@ -453,6 +493,22 @@ public void testHashCode_differentValues() { assertThat(context).isNotEqualTo(contextWithDiffSessionService); assertThat(context).isNotEqualTo(contextWithDiffInvocationId); + + InvocationContext contextWithDiffCallbackContextData = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(userContent) + .runConfig(runConfig) + .endInvocation(false) + .callbackContextData(ImmutableMap.of("key", "value")) + .build(); + assertThat(context.hashCode()).isNotEqualTo(contextWithDiffCallbackContextData.hashCode()); } @Test @@ -604,4 +660,277 @@ public void shouldPauseInvocation_whenResumableAndMatchingFunctionCallId_isTrue( .build(); assertThat(context.shouldPauseInvocation(event)).isTrue(); } + + @Test + public void incrementLlmCallsCount_whenLimitNotExceeded_doesNotThrow() throws Exception { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .runConfig(RunConfig.builder().setMaxLlmCalls(2).build()) + .build(); + + context.incrementLlmCallsCount(); + context.incrementLlmCallsCount(); + // No exception thrown + } + + @Test + public void incrementLlmCallsCount_whenLimitExceeded_throwsException() throws Exception { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .runConfig(RunConfig.builder().setMaxLlmCalls(1).build()) + .build(); + + context.incrementLlmCallsCount(); + LlmCallsLimitExceededException thrown = + Assert.assertThrows( + LlmCallsLimitExceededException.class, () -> context.incrementLlmCallsCount()); + assertThat(thrown).hasMessageThat().contains("limit of 1 exceeded"); + } + + @Test + public void incrementLlmCallsCount_whenNoLimit_doesNotThrow() throws Exception { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .runConfig(RunConfig.builder().setMaxLlmCalls(0).build()) + .build(); + + for (int i = 0; i < 100; i++) { + context.incrementLlmCallsCount(); + } + } + + @Test + public void testSessionGetters() { + Session sessionWithDetails = + Session.builder("test-id").appName("test-app").userId("test-user").build(); + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(sessionWithDetails) + .build(); + + assertThat(context.appName()).isEqualTo("test-app"); + assertThat(context.userId()).isEqualTo("test-user"); + } + + @Test + public void testAgentStatesAndEndOfAgents() { + BaseAgentState mockState = mock(BaseAgentState.class); + ImmutableMap states = ImmutableMap.of("agent1", mockState); + ImmutableMap endOfAgents = ImmutableMap.of("agent1", true); + + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .agentStates(states) + .endOfAgents(endOfAgents) + .build(); + + assertThat(context.agentStates()).isEqualTo(states); + assertThat(context.endOfAgents()).isEqualTo(endOfAgents); + } + + @Test + public void testSetEndInvocation() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .build(); + + assertThat(context.endInvocation()).isFalse(); + context.setEndInvocation(true); + assertThat(context.endInvocation()).isTrue(); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testBranch() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .branch("test-branch") + .build(); + + assertThat(context.branch()).hasValue("test-branch"); + + context.branch("new-branch"); + assertThat(context.branch()).hasValue("new-branch"); + + context.branch(null); + assertThat(context.branch()).isEmpty(); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testDeprecatedCreateMethods() { + InvocationContext context1 = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.ofNullable(userContent)) + .runConfig(runConfig) + .build(); + + assertThat(context1.sessionService()).isEqualTo(mockSessionService); + assertThat(context1.artifactService()).isEqualTo(mockArtifactService); + assertThat(context1.invocationId()).isEqualTo(testInvocationId); + assertThat(context1.agent()).isEqualTo(mockAgent); + assertThat(context1.session()).isEqualTo(session); + assertThat(context1.userContent()).hasValue(userContent); + assertThat(context1.runConfig()).isEqualTo(runConfig); + + InvocationContext context2 = + InvocationContext.create( + mockSessionService, + mockArtifactService, + mockAgent, + session, + liveRequestQueue, + runConfig); + + assertThat(context2.sessionService()).isEqualTo(mockSessionService); + assertThat(context2.artifactService()).isEqualTo(mockArtifactService); + assertThat(context2.agent()).isEqualTo(mockAgent); + assertThat(context2.session()).isEqualTo(session); + assertThat(context2.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context2.runConfig()).isEqualTo(runConfig); + } + + @Test + public void testActiveStreamingTools() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .build(); + + assertThat(context.activeStreamingTools()).isEmpty(); + ActiveStreamingTool tool = new ActiveStreamingTool(new LiveRequestQueue()); + context.activeStreamingTools().put("tool1", tool); + assertThat(context.activeStreamingTools()).containsEntry("tool1", tool); + } + + @Test + public void testEventsCompactionConfig() { + EventsCompactionConfig config = new EventsCompactionConfig(5, 2); + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .eventsCompactionConfig(config) + .build(); + + assertThat(context.eventsCompactionConfig()).hasValue(config); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testBuilderOptionalParameters() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .liveRequestQueue(Optional.of(liveRequestQueue)) + .branch(Optional.of("test-branch")) + .userContent(Optional.of(userContent)) + .build(); + + assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context.branch()).hasValue("test-branch"); + assertThat(context.userContent()).hasValue(userContent); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testDeprecatedConstructor() { + InvocationContext context = + new InvocationContext( + mockSessionService, + mockArtifactService, + mockMemoryService, + pluginManager, + Optional.of(liveRequestQueue), + Optional.of("test-branch"), + testInvocationId, + mockAgent, + session, + Optional.of(userContent), + runConfig, + true); + + assertThat(context.sessionService()).isEqualTo(mockSessionService); + assertThat(context.artifactService()).isEqualTo(mockArtifactService); + assertThat(context.memoryService()).isEqualTo(mockMemoryService); + assertThat(context.pluginManager()).isEqualTo(pluginManager); + assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context.branch()).hasValue("test-branch"); + assertThat(context.invocationId()).isEqualTo(testInvocationId); + assertThat(context.agent()).isEqualTo(mockAgent); + assertThat(context.session()).isEqualTo(session); + assertThat(context.userContent()).hasValue(userContent); + assertThat(context.runConfig()).isEqualTo(runConfig); + assertThat(context.endInvocation()).isTrue(); + } + + @Test + @SuppressWarnings("deprecation") // Testing deprecated methods. + public void testDeprecatedConstructor_11params() { + InvocationContext context = + new InvocationContext( + mockSessionService, + mockArtifactService, + mockMemoryService, + Optional.of(liveRequestQueue), + Optional.of("test-branch"), + testInvocationId, + mockAgent, + session, + Optional.of(userContent), + runConfig, + true); + + assertThat(context.sessionService()).isEqualTo(mockSessionService); + assertThat(context.artifactService()).isEqualTo(mockArtifactService); + assertThat(context.memoryService()).isEqualTo(mockMemoryService); + assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); + assertThat(context.branch()).hasValue("test-branch"); + assertThat(context.invocationId()).isEqualTo(testInvocationId); + assertThat(context.agent()).isEqualTo(mockAgent); + assertThat(context.session()).isEqualTo(session); + assertThat(context.userContent()).hasValue(userContent); + assertThat(context.runConfig()).isEqualTo(runConfig); + assertThat(context.endInvocation()).isTrue(); + } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 5f4932a8..657d1c67 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -25,6 +25,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.truth.Truth.assertThat; +import com.google.adk.agents.Callbacks; import com.google.adk.agents.InvocationContext; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; @@ -42,6 +43,7 @@ import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; import java.util.Map; @@ -414,6 +416,82 @@ public void run_requestProcessorsAreCalledExactlyOnce() { assertThat(processor2CallCount.get()).isEqualTo(1); } + @Test + public void run_sharingcallbackContextDataBetweenCallbacks() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + ctx.invocationContext().callbackContextData().put("key", "value_from_before"); + return Maybe.empty(); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + String value = (String) ctx.invocationContext().callbackContextData().get("key"); + LlmResponse modifiedResp = + resp.toBuilder().content(Content.fromParts(Part.fromText("Saw: " + value))).build(); + return Maybe.just(modifiedResp); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .build()); + + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).stringifyContent()).isEqualTo("Saw: value_from_before"); + } + + @Test + public void run_sharingcallbackContextDataAcrossContextCopies() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + Callbacks.BeforeModelCallback beforeCallback = + (ctx, req) -> { + ctx.invocationContext().callbackContextData().put("key", "value_from_before"); + return Maybe.empty(); + }; + + Callbacks.AfterModelCallback afterCallback = + (ctx, resp) -> { + String value = (String) ctx.invocationContext().callbackContextData().get("key"); + LlmResponse modifiedResp = + resp.toBuilder().content(Content.fromParts(Part.fromText("Saw: " + value))).build(); + return Maybe.just(modifiedResp); + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .beforeModelCallback(beforeCallback) + .afterModelCallback(afterCallback) + .build()); + + BaseLlmFlow baseLlmFlow = + new BaseLlmFlow(ImmutableList.of(), ImmutableList.of()) { + @Override + public Flowable run(InvocationContext context) { + // Force a context copy + InvocationContext copiedContext = context.toBuilder().build(); + return super.run(copiedContext); + } + }; + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).stringifyContent()).isEqualTo("Saw: value_from_before"); + } + private static BaseLlmFlow createBaseLlmFlowWithoutProcessors() { return createBaseLlmFlow(ImmutableList.of(), ImmutableList.of()); }