diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index d197dbaa5..3b460b073 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -31,6 +31,7 @@ import com.google.errorprone.annotations.InlineMe; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -369,6 +370,31 @@ public boolean isResumable() { return resumabilityConfig.isResumable(); } + /** Returns ResumabilityConfig for this invocation. */ + public ResumabilityConfig resumabilityConfig() { + return resumabilityConfig; + } + + /** + * Populates agentStates and endOfAgents maps by reading session events for this invocation id. + */ + public void populateAgentStates(List events) { + events.stream() + .filter(event -> invocationId().equals(event.invocationId())) + .forEach( + event -> { + if (event.actions() != null) { + if (event.actions().agentState() != null + && !event.actions().agentState().isEmpty()) { + agentStates.putAll(event.actions().agentState()); + } + if (event.actions().endOfAgent()) { + endOfAgents.put(event.author(), true); + } + } + }); + } + /** Returns the events compaction configuration for the current agent run. */ public Optional eventsCompactionConfig() { return Optional.ofNullable(eventsCompactionConfig); diff --git a/core/src/main/java/com/google/adk/events/Event.java b/core/src/main/java/com/google/adk/events/Event.java index 9e05918be..d968efa53 100644 --- a/core/src/main/java/com/google/adk/events/Event.java +++ b/core/src/main/java/com/google/adk/events/Event.java @@ -294,8 +294,7 @@ public final boolean hasTrailingCodeExecutionResult() { /** Returns true if this is a final response. */ @JsonIgnore public final boolean finalResponse() { - if (actions().skipSummarization().orElse(false) - || (longRunningToolIds().isPresent() && !longRunningToolIds().get().isEmpty())) { + if (actions().skipSummarization().orElse(false)) { return true; } return functionCalls().isEmpty() diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 493fa4b27..6543ec823 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -18,12 +18,15 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.adk.JsonBaseModel; import com.google.adk.agents.BaseAgentState; import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.Part; +import java.util.HashSet; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; @@ -31,11 +34,12 @@ /** Represents the actions attached to an event. */ // TODO - b/414081262 make json wire camelCase @JsonDeserialize(builder = EventActions.Builder.class) -public class EventActions { +public class EventActions extends JsonBaseModel { private Optional skipSummarization; private ConcurrentMap stateDelta; private ConcurrentMap artifactDelta; + private Set deletedArtifactIds; private Optional transferToAgent; private Optional escalate; private ConcurrentMap> requestedAuthConfigs; @@ -51,6 +55,7 @@ public EventActions() { this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); + this.deletedArtifactIds = new HashSet<>(); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); @@ -66,6 +71,7 @@ private EventActions(Builder builder) { this.skipSummarization = builder.skipSummarization; this.stateDelta = builder.stateDelta; this.artifactDelta = builder.artifactDelta; + this.deletedArtifactIds = builder.deletedArtifactIds; this.transferToAgent = builder.transferToAgent; this.escalate = builder.escalate; this.requestedAuthConfigs = builder.requestedAuthConfigs; @@ -122,6 +128,16 @@ public void setArtifactDelta(ConcurrentMap artifactDelta) { this.artifactDelta = artifactDelta; } + @JsonProperty("deletedArtifactIds") + @JsonInclude(JsonInclude.Include.NON_EMPTY) + public Set deletedArtifactIds() { + return deletedArtifactIds; + } + + public void setDeletedArtifactIds(Set deletedArtifactIds) { + this.deletedArtifactIds = deletedArtifactIds; + } + @JsonProperty("transferToAgent") public Optional transferToAgent() { return transferToAgent; @@ -238,6 +254,7 @@ public boolean equals(Object o) { return Objects.equals(skipSummarization, that.skipSummarization) && Objects.equals(stateDelta, that.stateDelta) && Objects.equals(artifactDelta, that.artifactDelta) + && Objects.equals(deletedArtifactIds, that.deletedArtifactIds) && Objects.equals(transferToAgent, that.transferToAgent) && Objects.equals(escalate, that.escalate) && Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs) @@ -255,6 +272,7 @@ public int hashCode() { skipSummarization, stateDelta, artifactDelta, + deletedArtifactIds, transferToAgent, escalate, requestedAuthConfigs, @@ -271,6 +289,7 @@ public static class Builder { private Optional skipSummarization; private ConcurrentMap stateDelta; private ConcurrentMap artifactDelta; + private Set deletedArtifactIds; private Optional transferToAgent; private Optional escalate; private ConcurrentMap> requestedAuthConfigs; @@ -285,6 +304,7 @@ public Builder() { this.skipSummarization = Optional.empty(); this.stateDelta = new ConcurrentHashMap<>(); this.artifactDelta = new ConcurrentHashMap<>(); + this.deletedArtifactIds = new HashSet<>(); this.transferToAgent = Optional.empty(); this.escalate = Optional.empty(); this.requestedAuthConfigs = new ConcurrentHashMap<>(); @@ -299,6 +319,7 @@ private Builder(EventActions eventActions) { this.skipSummarization = eventActions.skipSummarization(); this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta()); this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta()); + this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds()); this.transferToAgent = eventActions.transferToAgent(); this.escalate = eventActions.escalate(); this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs()); @@ -332,6 +353,13 @@ public Builder artifactDelta(ConcurrentMap value) { return this; } + @CanIgnoreReturnValue + @JsonProperty("deletedArtifactIds") + public Builder deletedArtifactIds(Set value) { + this.deletedArtifactIds = value; + return this; + } + @CanIgnoreReturnValue @JsonProperty("transferToAgent") public Builder transferToAgent(String agentId) { @@ -401,6 +429,7 @@ public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); this.stateDelta.putAll(other.stateDelta()); this.artifactDelta.putAll(other.artifactDelta()); + this.deletedArtifactIds.addAll(other.deletedArtifactIds()); other.transferToAgent().ifPresent(this::transferToAgent); other.escalate().ifPresent(this::escalate); this.requestedAuthConfigs.putAll(other.requestedAuthConfigs()); diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index 61135c78e..c1cb30180 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -22,6 +22,7 @@ import com.google.adk.apps.ResumabilityConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; +import com.google.adk.events.EventActions; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.PluginManager; @@ -150,7 +151,7 @@ public void testBuildWithLiveRequestQueue() { } @Test - public void testCopyOf() { + public void testToBuilder() { InvocationContext originalContext = InvocationContext.builder() .sessionService(mockSessionService) @@ -933,4 +934,56 @@ public void testDeprecatedConstructor_11params() { assertThat(context.runConfig()).isEqualTo(runConfig); assertThat(context.endInvocation()).isTrue(); } + + @Test + public void populateAgentStates_populatesAgentStatesAndEndOfAgents() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .agent(mockAgent) + .session(session) + .invocationId(testInvocationId) + .build(); + + BaseAgentState agent1State = mock(BaseAgentState.class); + ConcurrentHashMap agent1StateMap = new ConcurrentHashMap<>(); + agent1StateMap.put("agent1", agent1State); + Event event1 = + Event.builder() + .invocationId(testInvocationId) + .author("agent1") + .actions(EventActions.builder().agentState(agent1StateMap).endOfAgent(true).build()) + .build(); + Event event2 = + Event.builder() + .invocationId("other-invocation-id") + .author("agent2") + .actions(EventActions.builder().endOfAgent(true).build()) + .build(); + Event event3 = + Event.builder() + .invocationId(testInvocationId) + .author("agent3") + .actions(EventActions.builder().endOfAgent(false).build()) + .build(); + BaseAgentState agent4State = mock(BaseAgentState.class); + ConcurrentHashMap agent4StateMap = new ConcurrentHashMap<>(); + agent4StateMap.put("agent4", agent4State); + Event event4 = + Event.builder() + .invocationId(testInvocationId) + .author("agent4") + .actions(EventActions.builder().agentState(agent4StateMap).endOfAgent(false).build()) + .build(); + Event event5 = Event.builder().invocationId(testInvocationId).author("agent5").build(); + + context.populateAgentStates(ImmutableList.of(event1, event2, event3, event4, event5)); + + assertThat(context.agentStates()).hasSize(2); + assertThat(context.agentStates()).containsEntry("agent1", agent1State); + assertThat(context.agentStates()).containsEntry("agent4", agent4State); + assertThat(context.endOfAgents()).hasSize(1); + assertThat(context.endOfAgents()).containsEntry("agent1", true); + } } diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 18870ad44..9ea88b40a 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -20,6 +20,7 @@ import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; import java.util.concurrent.ConcurrentHashMap; @@ -44,7 +45,11 @@ public final class EventActionsTest { @Test public void toBuilder_createsBuilderWithSameValues() { EventActions eventActionsWithSkipSummarization = - EventActions.builder().skipSummarization(true).compaction(COMPACTION).build(); + EventActions.builder() + .skipSummarization(true) + .compaction(COMPACTION) + .deletedArtifactIds(ImmutableSet.of("d1")) + .build(); EventActions eventActionsAfterRebuild = eventActionsWithSkipSummarization.toBuilder().build(); @@ -59,6 +64,7 @@ public void merge_mergesAllFields() { .skipSummarization(true) .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1"))) .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", PART))) + .deletedArtifactIds(ImmutableSet.of("deleted1")) .requestedAuthConfigs( new ConcurrentHashMap<>( ImmutableMap.of("config1", new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))))) @@ -70,6 +76,7 @@ public void merge_mergesAllFields() { EventActions.builder() .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key2", "value2"))) .artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", PART))) + .deletedArtifactIds(ImmutableSet.of("deleted2")) .transferToAgent("agentId") .escalate(true) .requestedAuthConfigs( @@ -85,6 +92,7 @@ public void merge_mergesAllFields() { assertThat(merged.skipSummarization()).hasValue(true); assertThat(merged.stateDelta()).containsExactly("key1", "value1", "key2", "value2"); assertThat(merged.artifactDelta()).containsExactly("artifact1", PART, "artifact2", PART); + assertThat(merged.deletedArtifactIds()).containsExactly("deleted1", "deleted2"); assertThat(merged.transferToAgent()).hasValue("agentId"); assertThat(merged.escalate()).hasValue(true); assertThat(merged.requestedAuthConfigs()) @@ -107,4 +115,19 @@ public void removeStateByKey_marksKeyAsRemoved() { assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED); } + + @Test + public void jsonSerialization_works() throws Exception { + EventActions eventActions = + EventActions.builder() + .deletedArtifactIds(ImmutableSet.of("d1", "d2")) + .stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("k", "v"))) + .build(); + + String json = eventActions.toJson(); + EventActions deserialized = EventActions.fromJsonString(json, EventActions.class); + + assertThat(deserialized).isEqualTo(eventActions); + assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2"); + } } diff --git a/core/src/test/java/com/google/adk/events/EventTest.java b/core/src/test/java/com/google/adk/events/EventTest.java index f443abee5..d6de97f7f 100644 --- a/core/src/test/java/com/google/adk/events/EventTest.java +++ b/core/src/test/java/com/google/adk/events/EventTest.java @@ -191,4 +191,67 @@ public void event_json_serialization_works() throws Exception { Event deserializedEvent = Event.fromJson(json); assertThat(deserializedEvent).isEqualTo(EVENT); } + + @Test + public void finalResponse_returnsTrueIfNoToolCalls() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromText("hello"))) + .build(); + assertThat(event.finalResponse()).isTrue(); + } + + @Test + public void finalResponse_returnsFalseIfToolCalls() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v")))) + .build(); + assertThat(event.finalResponse()).isFalse(); + } + + @Test + public void finalResponse_isTrueForEventWithTextContent() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromText("hello"))) + .longRunningToolIds(ImmutableSet.of("tool1")) + .build(); + assertThat(event.finalResponse()).isTrue(); + } + + @Test + public void finalResponse_isFalseForEventWithToolCallAndLongRunningToolId() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v")))) + .longRunningToolIds(ImmutableSet.of("tool1")) + .build(); + assertThat(event.finalResponse()).isFalse(); + } + + @Test + public void finalResponse_returnsTrueIfSkipSummarization() { + Event event = + Event.builder() + .id("e1") + .invocationId("i1") + .author("agent") + .content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v")))) + .actions(EventActions.builder().skipSummarization(true).build()) + .build(); + assertThat(event.finalResponse()).isTrue(); + } }