Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions core/src/main/java/com/google/adk/agents/InvocationContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.errorprone.annotations.InlineMe;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionCall;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -369,6 +370,31 @@ public boolean isResumable() {
return resumabilityConfig.isResumable();
}

/** Returns ResumabilityConfig for this invocation. */
public ResumabilityConfig resumabilityConfig() {
return resumabilityConfig;
}

/**
* Populates agentStates and endOfAgents maps by reading session events for this invocation id.
*/
public void populateAgentStates(List<Event> events) {
events.stream()
.filter(event -> invocationId().equals(event.invocationId()))
.forEach(
event -> {
if (event.actions() != null) {
if (event.actions().agentState() != null
&& !event.actions().agentState().isEmpty()) {
agentStates.putAll(event.actions().agentState());
}
if (event.actions().endOfAgent()) {
endOfAgents.put(event.author(), true);
}
}
});
}

/** Returns the events compaction configuration for the current agent run. */
public Optional<EventsCompactionConfig> eventsCompactionConfig() {
return Optional.ofNullable(eventsCompactionConfig);
Expand Down
3 changes: 1 addition & 2 deletions core/src/main/java/com/google/adk/events/Event.java
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,7 @@ public final boolean hasTrailingCodeExecutionResult() {
/** Returns true if this is a final response. */
@JsonIgnore
public final boolean finalResponse() {
if (actions().skipSummarization().orElse(false)
|| (longRunningToolIds().isPresent() && !longRunningToolIds().get().isEmpty())) {
if (actions().skipSummarization().orElse(false)) {
return true;
}
return functionCalls().isEmpty()
Expand Down
31 changes: 30 additions & 1 deletion core/src/main/java/com/google/adk/events/EventActions.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,28 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.google.adk.JsonBaseModel;
import com.google.adk.agents.BaseAgentState;
import com.google.adk.sessions.State;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.genai.types.Part;
import java.util.HashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import javax.annotation.Nullable;

/** Represents the actions attached to an event. */
// TODO - b/414081262 make json wire camelCase
@JsonDeserialize(builder = EventActions.Builder.class)
public class EventActions {
public class EventActions extends JsonBaseModel {

private Optional<Boolean> skipSummarization;
private ConcurrentMap<String, Object> stateDelta;
private ConcurrentMap<String, Part> artifactDelta;
private Set<String> deletedArtifactIds;
private Optional<String> transferToAgent;
private Optional<Boolean> escalate;
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
Expand All @@ -51,6 +55,7 @@ public EventActions() {
this.skipSummarization = Optional.empty();
this.stateDelta = new ConcurrentHashMap<>();
this.artifactDelta = new ConcurrentHashMap<>();
this.deletedArtifactIds = new HashSet<>();
this.transferToAgent = Optional.empty();
this.escalate = Optional.empty();
this.requestedAuthConfigs = new ConcurrentHashMap<>();
Expand All @@ -66,6 +71,7 @@ private EventActions(Builder builder) {
this.skipSummarization = builder.skipSummarization;
this.stateDelta = builder.stateDelta;
this.artifactDelta = builder.artifactDelta;
this.deletedArtifactIds = builder.deletedArtifactIds;
this.transferToAgent = builder.transferToAgent;
this.escalate = builder.escalate;
this.requestedAuthConfigs = builder.requestedAuthConfigs;
Expand Down Expand Up @@ -122,6 +128,16 @@ public void setArtifactDelta(ConcurrentMap<String, Part> artifactDelta) {
this.artifactDelta = artifactDelta;
}

@JsonProperty("deletedArtifactIds")
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public Set<String> deletedArtifactIds() {
return deletedArtifactIds;
}

public void setDeletedArtifactIds(Set<String> deletedArtifactIds) {
this.deletedArtifactIds = deletedArtifactIds;
}

@JsonProperty("transferToAgent")
public Optional<String> transferToAgent() {
return transferToAgent;
Expand Down Expand Up @@ -238,6 +254,7 @@ public boolean equals(Object o) {
return Objects.equals(skipSummarization, that.skipSummarization)
&& Objects.equals(stateDelta, that.stateDelta)
&& Objects.equals(artifactDelta, that.artifactDelta)
&& Objects.equals(deletedArtifactIds, that.deletedArtifactIds)
&& Objects.equals(transferToAgent, that.transferToAgent)
&& Objects.equals(escalate, that.escalate)
&& Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs)
Expand All @@ -255,6 +272,7 @@ public int hashCode() {
skipSummarization,
stateDelta,
artifactDelta,
deletedArtifactIds,
transferToAgent,
escalate,
requestedAuthConfigs,
Expand All @@ -271,6 +289,7 @@ public static class Builder {
private Optional<Boolean> skipSummarization;
private ConcurrentMap<String, Object> stateDelta;
private ConcurrentMap<String, Part> artifactDelta;
private Set<String> deletedArtifactIds;
private Optional<String> transferToAgent;
private Optional<Boolean> escalate;
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
Expand All @@ -285,6 +304,7 @@ public Builder() {
this.skipSummarization = Optional.empty();
this.stateDelta = new ConcurrentHashMap<>();
this.artifactDelta = new ConcurrentHashMap<>();
this.deletedArtifactIds = new HashSet<>();
this.transferToAgent = Optional.empty();
this.escalate = Optional.empty();
this.requestedAuthConfigs = new ConcurrentHashMap<>();
Expand All @@ -299,6 +319,7 @@ private Builder(EventActions eventActions) {
this.skipSummarization = eventActions.skipSummarization();
this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta());
this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta());
this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds());
this.transferToAgent = eventActions.transferToAgent();
this.escalate = eventActions.escalate();
this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs());
Expand Down Expand Up @@ -332,6 +353,13 @@ public Builder artifactDelta(ConcurrentMap<String, Part> value) {
return this;
}

@CanIgnoreReturnValue
@JsonProperty("deletedArtifactIds")
public Builder deletedArtifactIds(Set<String> value) {
this.deletedArtifactIds = value;
return this;
}

@CanIgnoreReturnValue
@JsonProperty("transferToAgent")
public Builder transferToAgent(String agentId) {
Expand Down Expand Up @@ -401,6 +429,7 @@ public Builder merge(EventActions other) {
other.skipSummarization().ifPresent(this::skipSummarization);
this.stateDelta.putAll(other.stateDelta());
this.artifactDelta.putAll(other.artifactDelta());
this.deletedArtifactIds.addAll(other.deletedArtifactIds());
other.transferToAgent().ifPresent(this::transferToAgent);
other.escalate().ifPresent(this::escalate);
this.requestedAuthConfigs.putAll(other.requestedAuthConfigs());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.adk.apps.ResumabilityConfig;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
import com.google.adk.memory.BaseMemoryService;
import com.google.adk.models.LlmCallsLimitExceededException;
import com.google.adk.plugins.PluginManager;
Expand Down Expand Up @@ -150,7 +151,7 @@ public void testBuildWithLiveRequestQueue() {
}

@Test
public void testCopyOf() {
public void testToBuilder() {
InvocationContext originalContext =
InvocationContext.builder()
.sessionService(mockSessionService)
Expand Down Expand Up @@ -933,4 +934,56 @@ public void testDeprecatedConstructor_11params() {
assertThat(context.runConfig()).isEqualTo(runConfig);
assertThat(context.endInvocation()).isTrue();
}

@Test
public void populateAgentStates_populatesAgentStatesAndEndOfAgents() {
InvocationContext context =
InvocationContext.builder()
.sessionService(mockSessionService)
.artifactService(mockArtifactService)
.agent(mockAgent)
.session(session)
.invocationId(testInvocationId)
.build();

BaseAgentState agent1State = mock(BaseAgentState.class);
ConcurrentHashMap<String, BaseAgentState> agent1StateMap = new ConcurrentHashMap<>();
agent1StateMap.put("agent1", agent1State);
Event event1 =
Event.builder()
.invocationId(testInvocationId)
.author("agent1")
.actions(EventActions.builder().agentState(agent1StateMap).endOfAgent(true).build())
.build();
Event event2 =
Event.builder()
.invocationId("other-invocation-id")
.author("agent2")
.actions(EventActions.builder().endOfAgent(true).build())
.build();
Event event3 =
Event.builder()
.invocationId(testInvocationId)
.author("agent3")
.actions(EventActions.builder().endOfAgent(false).build())
.build();
BaseAgentState agent4State = mock(BaseAgentState.class);
ConcurrentHashMap<String, BaseAgentState> agent4StateMap = new ConcurrentHashMap<>();
agent4StateMap.put("agent4", agent4State);
Event event4 =
Event.builder()
.invocationId(testInvocationId)
.author("agent4")
.actions(EventActions.builder().agentState(agent4StateMap).endOfAgent(false).build())
.build();
Event event5 = Event.builder().invocationId(testInvocationId).author("agent5").build();

context.populateAgentStates(ImmutableList.of(event1, event2, event3, event4, event5));

assertThat(context.agentStates()).hasSize(2);
assertThat(context.agentStates()).containsEntry("agent1", agent1State);
assertThat(context.agentStates()).containsEntry("agent4", agent4State);
assertThat(context.endOfAgents()).hasSize(1);
assertThat(context.endOfAgents()).containsEntry("agent1", true);
}
}
25 changes: 24 additions & 1 deletion core/src/test/java/com/google/adk/events/EventActionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.adk.sessions.State;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.genai.types.Content;
import com.google.genai.types.Part;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -44,7 +45,11 @@ public final class EventActionsTest {
@Test
public void toBuilder_createsBuilderWithSameValues() {
EventActions eventActionsWithSkipSummarization =
EventActions.builder().skipSummarization(true).compaction(COMPACTION).build();
EventActions.builder()
.skipSummarization(true)
.compaction(COMPACTION)
.deletedArtifactIds(ImmutableSet.of("d1"))
.build();

EventActions eventActionsAfterRebuild = eventActionsWithSkipSummarization.toBuilder().build();

Expand All @@ -59,6 +64,7 @@ public void merge_mergesAllFields() {
.skipSummarization(true)
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1")))
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", PART)))
.deletedArtifactIds(ImmutableSet.of("deleted1"))
.requestedAuthConfigs(
new ConcurrentHashMap<>(
ImmutableMap.of("config1", new ConcurrentHashMap<>(ImmutableMap.of("k", "v")))))
Expand All @@ -70,6 +76,7 @@ public void merge_mergesAllFields() {
EventActions.builder()
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key2", "value2")))
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", PART)))
.deletedArtifactIds(ImmutableSet.of("deleted2"))
.transferToAgent("agentId")
.escalate(true)
.requestedAuthConfigs(
Expand All @@ -85,6 +92,7 @@ public void merge_mergesAllFields() {
assertThat(merged.skipSummarization()).hasValue(true);
assertThat(merged.stateDelta()).containsExactly("key1", "value1", "key2", "value2");
assertThat(merged.artifactDelta()).containsExactly("artifact1", PART, "artifact2", PART);
assertThat(merged.deletedArtifactIds()).containsExactly("deleted1", "deleted2");
assertThat(merged.transferToAgent()).hasValue("agentId");
assertThat(merged.escalate()).hasValue(true);
assertThat(merged.requestedAuthConfigs())
Expand All @@ -107,4 +115,19 @@ public void removeStateByKey_marksKeyAsRemoved() {

assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED);
}

@Test
public void jsonSerialization_works() throws Exception {
EventActions eventActions =
EventActions.builder()
.deletedArtifactIds(ImmutableSet.of("d1", "d2"))
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("k", "v")))
.build();

String json = eventActions.toJson();
EventActions deserialized = EventActions.fromJsonString(json, EventActions.class);

assertThat(deserialized).isEqualTo(eventActions);
assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2");
}
}
63 changes: 63 additions & 0 deletions core/src/test/java/com/google/adk/events/EventTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,67 @@ public void event_json_serialization_works() throws Exception {
Event deserializedEvent = Event.fromJson(json);
assertThat(deserializedEvent).isEqualTo(EVENT);
}

@Test
public void finalResponse_returnsTrueIfNoToolCalls() {
Event event =
Event.builder()
.id("e1")
.invocationId("i1")
.author("agent")
.content(Content.fromParts(Part.fromText("hello")))
.build();
assertThat(event.finalResponse()).isTrue();
}

@Test
public void finalResponse_returnsFalseIfToolCalls() {
Event event =
Event.builder()
.id("e1")
.invocationId("i1")
.author("agent")
.content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v"))))
.build();
assertThat(event.finalResponse()).isFalse();
}

@Test
public void finalResponse_isTrueForEventWithTextContent() {
Event event =
Event.builder()
.id("e1")
.invocationId("i1")
.author("agent")
.content(Content.fromParts(Part.fromText("hello")))
.longRunningToolIds(ImmutableSet.of("tool1"))
.build();
assertThat(event.finalResponse()).isTrue();
}

@Test
public void finalResponse_isFalseForEventWithToolCallAndLongRunningToolId() {
Event event =
Event.builder()
.id("e1")
.invocationId("i1")
.author("agent")
.content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v"))))
.longRunningToolIds(ImmutableSet.of("tool1"))
.build();
assertThat(event.finalResponse()).isFalse();
}

@Test
public void finalResponse_returnsTrueIfSkipSummarization() {
Event event =
Event.builder()
.id("e1")
.invocationId("i1")
.author("agent")
.content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v"))))
.actions(EventActions.builder().skipSummarization(true).build())
.build();
assertThat(event.finalResponse()).isTrue();
}
}