From a340b9d20b48dbc81dfe8ea2507c55831cc0dda4 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 19 May 2026 07:59:40 -0700 Subject: [PATCH] feat: remove special handling for Gemini 3 function response ordering PiperOrigin-RevId: 917837583 --- .../adk/flows/llmflows/BaseLlmFlow.java | 3 +- .../google/adk/flows/llmflows/Contents.java | 65 +-- .../java/com/google/adk/models/Gemini.java | 389 +++++++++++----- .../adk/flows/llmflows/BaseLlmFlowTest.java | 27 ++ .../adk/flows/llmflows/ContentsTest.java | 60 ++- .../com/google/adk/models/GeminiTest.java | 424 +++++++++++++++++- .../google/adk/tutorials/CityTimeWeather.java | 2 +- 7 files changed, 790 insertions(+), 180 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fffeab698..67c29ae77 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -685,7 +685,8 @@ private Flowable buildPostprocessingEvents( Event modelResponseEvent = buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse); - if (modelResponseEvent.functionCalls().isEmpty()) { + if (modelResponseEvent.functionCalls().isEmpty() + || modelResponseEvent.partial().orElse(false)) { return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index 876f3a206..36e93014e 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -57,13 +57,6 @@ public Single processRequest( } LlmAgent llmAgent = (LlmAgent) context.agent(); - String modelName; - try { - modelName = llmAgent.resolvedModel().modelName().orElse(""); - } catch (IllegalStateException e) { - modelName = ""; - } - ImmutableList sessionEvents; synchronized (context.session().events()) { sessionEvents = ImmutableList.copyOf(context.session().events()); @@ -75,17 +68,13 @@ public Single processRequest( request.toBuilder() .contents( getCurrentTurnContents( - context.branch().orElse(null), - sessionEvents, - context.agent().name(), - modelName)) + context.branch().orElse(null), sessionEvents, context.agent().name())) .build(), ImmutableList.of())); } ImmutableList contents = - getContents( - context.branch().orElse(null), sessionEvents, context.agent().name(), modelName); + getContents(context.branch().orElse(null), sessionEvents, context.agent().name()); return Single.just( RequestProcessor.RequestProcessingResult.create( @@ -94,19 +83,19 @@ public Single processRequest( /** Gets contents for the current turn only (no conversation history). */ private ImmutableList getCurrentTurnContents( - @Nullable String currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName) { // Find the latest event that starts the current turn and process from there. for (int i = events.size() - 1; i >= 0; i--) { Event event = events.get(i); if (event.author().equals("user") || isOtherAgentReply(agentName, event)) { - return getContents(currentBranch, events.subList(i, events.size()), agentName, modelName); + return getContents(currentBranch, events.subList(i, events.size()), agentName); } } return ImmutableList.of(); } private ImmutableList getContents( - @Nullable String currentBranch, List events, String agentName, String modelName) { + @Nullable String currentBranch, List events, String agentName) { List filteredEvents = new ArrayList<>(); boolean hasCompactEvent = false; @@ -148,7 +137,7 @@ private ImmutableList getContents( } List resultEvents = rearrangeEventsForLatestFunctionResponse(filteredEvents); - resultEvents = rearrangeEventsForAsyncFunctionResponsesInHistory(resultEvents, modelName); + resultEvents = rearrangeEventsForAsyncFunctionResponsesInHistory(resultEvents); return resultEvents.stream() .map(Event::content) @@ -564,8 +553,7 @@ private static List rearrangeEventsForLatestFunctionResponse(List return resultEvents; } - private static List rearrangeEventsForAsyncFunctionResponsesInHistory( - List events, String modelName) { + private static List rearrangeEventsForAsyncFunctionResponsesInHistory(List events) { Map functionCallIdToResponseEventIndex = new HashMap<>(); for (int i = 0; i < events.size(); i++) { final int index = i; @@ -592,11 +580,6 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( List resultEvents = new ArrayList<>(); // Keep track of response events already added to avoid duplicates when merging Set processedResponseIndices = new HashSet<>(); - List responseEventsBuffer = new ArrayList<>(); - - // Gemini 3 requires function calls to be grouped first and only then function responses: - // FC1 FC2 FR1 FR2 - boolean shouldBufferResponseEvents = modelName.contains("gemini-3"); for (int i = 0; i < events.size(); i++) { Event event = events.get(i); @@ -641,47 +624,21 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( for (int index : sortedIndices) { if (processedResponseIndices.add(index)) { // Add index and check if it was newly added - responseEventsBuffer.add(events.get(index)); responseEventsToAdd.add(events.get(index)); } } - if (!shouldBufferResponseEvents) { - if (responseEventsToAdd.size() == 1) { - resultEvents.add(responseEventsToAdd.get(0)); - } else if (responseEventsToAdd.size() > 1) { - resultEvents.add(mergeFunctionResponseEvents(responseEventsToAdd)); - } + if (responseEventsToAdd.size() == 1) { + resultEvents.add(responseEventsToAdd.get(0)); + } else if (responseEventsToAdd.size() > 1) { + resultEvents.add(mergeFunctionResponseEvents(responseEventsToAdd)); } } } else { - // gemini-3 specific part: buffer response events - if (shouldBufferResponseEvents) { - if (!responseEventsBuffer.isEmpty()) { - if (responseEventsBuffer.size() == 1) { - resultEvents.add(responseEventsBuffer.get(0)); - } else { - resultEvents.add(mergeFunctionResponseEvents(responseEventsBuffer)); - } - responseEventsBuffer.clear(); - } - } resultEvents.add(event); } } - // gemini-3 specific part: buffer response events - if (shouldBufferResponseEvents) { - if (!responseEventsBuffer.isEmpty()) { - if (responseEventsBuffer.size() == 1) { - resultEvents.add(responseEventsBuffer.get(0)); - } else { - resultEvents.add(mergeFunctionResponseEvents(responseEventsBuffer)); - } - responseEventsBuffer.clear(); - } - } - return resultEvents; } diff --git a/core/src/main/java/com/google/adk/models/Gemini.java b/core/src/main/java/com/google/adk/models/Gemini.java index 6f145e1de..d4067dff0 100644 --- a/core/src/main/java/com/google/adk/models/Gemini.java +++ b/core/src/main/java/com/google/adk/models/Gemini.java @@ -19,11 +19,11 @@ import static com.google.common.base.StandardSystemProperty.JAVA_VERSION; import com.google.adk.Version; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.Client; import com.google.genai.ResponseStream; -import com.google.genai.types.Candidate; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; import com.google.genai.types.GenerateContentConfig; @@ -35,7 +35,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.Optional; import java.util.concurrent.CompletableFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -226,21 +225,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre () -> processRawResponses( Flowable.fromFuture(streamFuture).flatMapIterable(iterable -> iterable))) - .filter( - llmResponse -> - llmResponse - .content() - .flatMap(Content::parts) - .map( - parts -> - !parts.isEmpty() - && parts.stream() - .anyMatch( - p -> - p.functionCall().isPresent() - || p.functionResponse().isPresent() - || p.text().isPresent())) - .orElse(false)); + .filter(Gemini::shouldEmit); } else { logger.debug("Sending generateContent request to model {}", effectiveModelName); return Flowable.fromFuture( @@ -253,93 +238,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre } static Flowable processRawResponses(Flowable rawResponses) { - final StringBuilder accumulatedText = new StringBuilder(); - final StringBuilder accumulatedThoughtText = new StringBuilder(); - // Array to bypass final local variable reassignment in lambda. - final GenerateContentResponse[] lastRawResponseHolder = {null}; - return rawResponses - .concatMap( - rawResponse -> { - lastRawResponseHolder[0] = rawResponse; - logger.trace("Raw streaming response: {}", rawResponse); - - List responsesToEmit = new ArrayList<>(); - LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse); - Optional part = GeminiUtil.getPart0FromLlmResponse(currentProcessedLlmResponse); - String currentTextChunk = part.flatMap(Part::text).orElse(""); - - if (!currentTextChunk.isBlank()) { - if (part.get().thought().orElse(false)) { - accumulatedThoughtText.append(currentTextChunk); - responsesToEmit.add( - thinkingResponseFromText(currentTextChunk).toBuilder() - .usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null)) - .partial(true) - .build()); - } else { - accumulatedText.append(currentTextChunk); - responsesToEmit.add( - responseFromText(currentTextChunk).toBuilder() - .usageMetadata(currentProcessedLlmResponse.usageMetadata().orElse(null)) - .partial(true) - .build()); - } - } else { - if (accumulatedThoughtText.length() > 0 - && GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) { - LlmResponse aggregatedThoughtResponse = - thinkingResponseFromText(accumulatedThoughtText.toString()); - responsesToEmit.add(aggregatedThoughtResponse); - accumulatedThoughtText.setLength(0); - } - if (accumulatedText.length() > 0 - && GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) { - LlmResponse aggregatedTextResponse = responseFromText(accumulatedText.toString()); - responsesToEmit.add(aggregatedTextResponse); - accumulatedText.setLength(0); - } - responsesToEmit.add(currentProcessedLlmResponse); - } - logger.debug("Responses to emit: {}", responsesToEmit); - return Flowable.fromIterable(responsesToEmit); - }) - .concatWith( - Flowable.defer( - () -> { - GenerateContentResponse finalRawResp = lastRawResponseHolder[0]; - if (finalRawResp == null) { - return Flowable.empty(); - } - boolean isStop = - finalRawResp - .candidates() - .flatMap(candidates -> candidates.stream().findFirst()) - .flatMap(Candidate::finishReason) - .map(finishReason -> finishReason.knownEnum() == FinishReason.Known.STOP) - .orElse(false); - - if (isStop) { - List finalResponses = new ArrayList<>(); - if (accumulatedThoughtText.length() > 0) { - finalResponses.add( - thinkingResponseFromText(accumulatedThoughtText.toString()).toBuilder() - .usageMetadata( - accumulatedText.length() > 0 - ? null - : finalRawResp.usageMetadata().orElse(null)) - .build()); - } - if (accumulatedText.length() > 0) { - finalResponses.add( - responseFromText(accumulatedText.toString()).toBuilder() - .usageMetadata(finalRawResp.usageMetadata().orElse(null)) - .build()); - } - - return Flowable.fromIterable(finalResponses); - } - return Flowable.empty(); - })); + return Flowable.defer(() -> new StreamingResponseAggregator().process(rawResponses)); } private static LlmResponse responseFromText(String accumulatedText) { @@ -358,6 +257,67 @@ private static LlmResponse thinkingResponseFromText(String accumulatedThoughtTex .build(); } + /** + * Returns true if {@code response} should be emitted downstream by the streaming pipeline. + * + *

Drops chunks that carry neither semantic content (i.e. they are an empty-text-only response + * per {@link #isEmptyTextOnlyResponse}) nor any useful metadata (per {@link #hasUsefulMetadata}). + * + *

Package-private for testing. + */ + static boolean shouldEmit(LlmResponse response) { + return !isEmptyTextOnlyResponse(response) || hasUsefulMetadata(response); + } + + /** + * Returns true if {@code response} carries any non-content metadata that should be propagated + * downstream (e.g. {@code usageMetadata}, {@code finishReason}, transcriptions, grounding or + * error info). Inspects only top-level {@link LlmResponse} fields; the response's content/parts + * are intentionally not considered here. + */ + private static boolean hasUsefulMetadata(LlmResponse response) { + return response.usageMetadata().isPresent() + || response.finishReason().isPresent() + || response.errorCode().isPresent() + || response.groundingMetadata().isPresent() + || response.inputTranscription().isPresent() + || response.outputTranscription().isPresent(); + } + + /** + * Returns true if {@code response} consists of exactly one {@link Part} whose only meaningful + * payload is an empty text string (i.e. {@code parts:[{text:""}]}). Such a chunk can be safely + * dropped from the streaming aggregator because it carries no semantic content for the agent + * pipeline. A part is considered to carry semantic content if any of its non-text payloads + * ({@code functionCall}, {@code functionResponse}, {@code inlineData}, {@code executableCode}, + * {@code codeExecutionResult}, {@code fileData}, {@code thoughtSignature}, {@code videoMetadata}, + * {@code toolCall}, {@code toolResponse}) is present. + */ + private static boolean isEmptyTextOnlyResponse(LlmResponse response) { + return response + .content() + .flatMap(Content::parts) + .map( + parts -> { + if (parts.size() != 1) { + return false; + } + Part part = parts.get(0); + return part.text().map(String::isEmpty).orElse(false) + && part.functionCall().isEmpty() + && part.functionResponse().isEmpty() + && part.inlineData().isEmpty() + && part.executableCode().isEmpty() + && part.codeExecutionResult().isEmpty() + && part.fileData().isEmpty() + && part.thoughtSignature().isEmpty() + && part.videoMetadata().isEmpty() + && part.toolCall().isEmpty() + && part.toolResponse().isEmpty(); + }) + .orElse(false); + } + @Override public BaseLlmConnection connect(LlmRequest llmRequest) { if (!apiClient.vertexAI()) { @@ -372,4 +332,225 @@ public BaseLlmConnection connect(LlmRequest llmRequest) { return new GeminiLlmConnection(apiClient, effectiveModelName, liveConnectConfig); } + + private static final class StreamingResponseAggregator { + private final StringBuilder accumulatedText = new StringBuilder(); + private final StringBuilder accumulatedThoughtText = new StringBuilder(); + private final List accumulatedFunctionCalls = new ArrayList<>(); + private GenerateContentResponse lastRawResponse = null; + + /** + * Processes a stream of raw responses, emitting partial and aggregated {@link LlmResponse}s. + */ + private Flowable process(Flowable rawResponses) { + return rawResponses + .concatMap(this::processRawResponse) + .concatWith(Flowable.defer(this::processFinalResponse)); + } + + /** + * Processes a single raw streaming chunk, accumulating parts and emitting intermediate + * responses. + */ + private Flowable processRawResponse(GenerateContentResponse rawResponse) { + lastRawResponse = rawResponse; + logger.trace("Raw streaming response: {}", rawResponse); + + LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse); + List parts = + currentProcessedLlmResponse.content().flatMap(Content::parts).orElse(ImmutableList.of()); + + boolean hasText = accumulateParts(parts); + boolean hasFunctionCall = parts.stream().anyMatch(part -> part.functionCall().isPresent()); + + List responsesToEmit = new ArrayList<>(); + + if (hasText) { + // Text is actively streaming; emit the current partial response (carrying text and any + // function calls). + responsesToEmit.add(currentProcessedLlmResponse.toBuilder().partial(true).build()); + } else { + // Text streaming has paused or ended; flush any previously accumulated text buffers. + flushAccumulatedTextBuffers(currentProcessedLlmResponse, responsesToEmit); + + // Determine how to emit or merge the current non-text chunk. + handleNonTextChunk(currentProcessedLlmResponse, hasFunctionCall, responsesToEmit); + } + + logger.info("Responses to emit: {}", responsesToEmit); + return Flowable.fromIterable(responsesToEmit); + } + + /** + * Accumulates text and function calls from incoming parts. + * + * @return true if any text was present, false otherwise. + */ + private boolean accumulateParts(List parts) { + boolean hasText = false; + for (Part part : parts) { + String text = part.text().orElse(""); + if (!text.isEmpty()) { + hasText = true; + if (part.thought().orElse(false)) { + accumulatedThoughtText.append(text); + } else { + accumulatedText.append(text); + } + } + if (part.functionCall().isPresent()) { + accumulatedFunctionCalls.add(part); + } + } + return hasText; + } + + /** + * Flushes any previously accumulated text or thought text buffers when a non-text chunk + * arrives. + */ + private void flushAccumulatedTextBuffers( + LlmResponse currentResponse, List responsesToEmit) { + if (accumulatedThoughtText.length() > 0 + && GeminiUtil.shouldEmitAccumulatedText(currentResponse)) { + responsesToEmit.add(thinkingResponseFromText(accumulatedThoughtText.toString())); + accumulatedThoughtText.setLength(0); + } + if (accumulatedText.length() > 0 && GeminiUtil.shouldEmitAccumulatedText(currentResponse)) { + responsesToEmit.add(responseFromText(accumulatedText.toString())); + accumulatedText.setLength(0); + } + } + + /** + * Determines how to emit or merge the current non-text chunk (e.g., function calls or + * metadata). + */ + private void handleNonTextChunk( + LlmResponse currentResponse, boolean hasFunctionCall, List responsesToEmit) { + if (hasFunctionCall) { + responsesToEmit.add(currentResponse.toBuilder().partial(true).build()); + } else if (!responsesToEmit.isEmpty()) { + LlmResponse lastResponse = responsesToEmit.get(responsesToEmit.size() - 1); + responsesToEmit.set(responsesToEmit.size() - 1, merge(lastResponse, currentResponse)); + } else if (!accumulatedFunctionCalls.isEmpty()) { + // Suppress the empty STOP chunk because processFinalResponse() will immediately emit + // the final aggregated response carrying the final metadata. + } else { + responsesToEmit.add(currentResponse); + } + } + + /** + * Emits final aggregated, non-partial responses (carrying complete accumulated text or function + * calls) when the stream completes. + */ + private Flowable processFinalResponse() { + if (lastRawResponse == null) { + return Flowable.empty(); + } + LlmResponse currentResponse = LlmResponse.create(lastRawResponse); + boolean isStop = + currentResponse + .finishReason() + .map(reason -> reason.knownEnum() == FinishReason.Known.STOP) + .orElse(false); + + if (!isStop) { + return Flowable.empty(); + } + + List finalResponses = new ArrayList<>(); + if (accumulatedThoughtText.length() > 0) { + finalResponses.add(thinkingResponseFromText(accumulatedThoughtText.toString())); + } + if (accumulatedText.length() > 0) { + finalResponses.add(responseFromText(accumulatedText.toString())); + } + if (!accumulatedFunctionCalls.isEmpty()) { + finalResponses.add( + LlmResponse.builder() + .content(Content.builder().role("model").parts(accumulatedFunctionCalls).build()) + .partial(false) + .build()); + } + + if (!finalResponses.isEmpty()) { + // Merge top-level metadata (finishReason, usageMetadata, etc.) into the LAST response. + LlmResponse lastResponse = finalResponses.get(finalResponses.size() - 1); + finalResponses.set( + finalResponses.size() - 1, mergeMetadataOnly(lastResponse, currentResponse)); + + // Merge thoughtSignature into the THOUGHT response (which is always at index 0 if present), + // or into the last response if no thought response exists. + int thoughtIndex = accumulatedThoughtText.length() > 0 ? 0 : finalResponses.size() - 1; + LlmResponse thoughtTarget = finalResponses.get(thoughtIndex); + finalResponses.set(thoughtIndex, mergeThoughtSignatureOnly(thoughtTarget, currentResponse)); + } + return Flowable.fromIterable(finalResponses); + } + + /** + * Merges top-level metadata and thought signatures from the current response into the last + * emitted response. + */ + private static LlmResponse merge(LlmResponse lastResponse, LlmResponse currentResponse) { + return mergeThoughtSignatureOnly( + mergeMetadataOnly(lastResponse, currentResponse), currentResponse); + } + + /** + * Merges top-level metadata fields (usage, finish reason, grounding, transcriptions) into the + * target response. + */ + private static LlmResponse mergeMetadataOnly( + LlmResponse lastResponse, LlmResponse currentResponse) { + return lastResponse.toBuilder() + .usageMetadata(currentResponse.usageMetadata().orElse(null)) + .finishReason(currentResponse.finishReason().orElse(null)) + .modelVersion(currentResponse.modelVersion().orElse(null)) + .errorCode(currentResponse.errorCode().orElse(null)) + .groundingMetadata(currentResponse.groundingMetadata().orElse(null)) + .inputTranscription(currentResponse.inputTranscription().orElse(null)) + .outputTranscription(currentResponse.outputTranscription().orElse(null)) + .build(); + } + + /** + * Merges thought signatures from the current response into the target thought response part. + */ + private static LlmResponse mergeThoughtSignatureOnly( + LlmResponse lastResponse, LlmResponse currentResponse) { + LlmResponse.Builder mergedBuilder = lastResponse.toBuilder(); + GeminiUtil.getPart0FromLlmResponse(currentResponse) + .flatMap(Part::thoughtSignature) + .ifPresent( + signature -> { + lastResponse + .content() + .filter(content -> content.parts().isPresent()) + .ifPresent( + lastContent -> { + List parts = lastContent.parts().get(); + ImmutableList updatedParts = + parts.isEmpty() + ? ImmutableList.of( + Part.builder() + .thought(true) + .thoughtSignature(signature) + .build()) + : ImmutableList.builder() + .add( + parts.get(0).toBuilder() + .thoughtSignature(signature) + .build()) + .addAll(parts.subList(1, parts.size())) + .build(); + mergedBuilder.content( + lastContent.toBuilder().parts(updatedParts).build()); + }); + }); + return mergedBuilder.build(); + } + } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 2a06c1f0a..70706d7fe 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -216,6 +216,33 @@ public void run_withLongRunningFunctionCall_returnsCorrectEventsWithLongRunningT assertThat(events.get(2).content()).hasValue(secondContent); } + @Test + public void run_withPartialFunctionCall_doesNotExecuteTool() { + Content partialContent = + Content.fromParts(Part.fromFunctionCall("my_function", ImmutableMap.of("arg1", "value1"))); + LlmResponse partialResponse = + LlmResponse.builder().content(partialContent).partial(true).build(); + TestLlm testLlm = createTestLlm(partialResponse); + ImmutableMap testResponse = + ImmutableMap.of("response", "response for my_function"); + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .tools(ImmutableList.of(new TestTool("my_function", testResponse))) + .build()); + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow( + /* requestProcessors= */ ImmutableList.of(), + /* responseProcessors= */ ImmutableList.of(), + /* maxSteps= */ Optional.of(1)); + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + assertThat(events.get(0).partial()).hasValue(true); + assertThat(events.get(0).functionCalls()).hasSize(1); + } + @Test public void run_withRequestProcessor_doesNotModifyRequest() { Content content = Content.fromParts(Part.fromText("LLM response")); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index 1e6267dde..4313b14ca 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -561,16 +561,66 @@ public void rearrangeHistory_gemini3interleavedFCFR_groupsFcAndFr() { List result = runContentsProcessorWithModelName(inputEvents, "gemini-3-flash-exp"); - assertThat(result).hasSize(4); - assertThat(result.get(0)).isEqualTo(u1.content().get()); - assertThat(result.get(1)).isEqualTo(fc1.content().get()); - assertThat(result.get(2)).isEqualTo(fc2.content().get()); - Content mergedContent = result.get(3); + assertThat(result).isEqualTo(eventsToContents(inputEvents)); + } + + @Test + public void rearrangeHistory_sequentialCalls_preservesInterleavedOrder() { + Event u1 = createUserEvent("u1", "Query"); + Event fc1 = createFunctionCallEvent("fc1", "tool1", "call1"); + Event fr1 = createFunctionResponseEvent("fr1", "tool1", "call1"); + Event fc2 = createFunctionCallEvent("fc2", "tool2", "call2"); + Event fr2 = createFunctionResponseEvent("fr2", "tool2", "call2"); + + ImmutableList inputEvents = ImmutableList.of(u1, fc1, fr1, fc2, fr2); + + List result = runContentsProcessor(inputEvents); + + assertThat(result).isEqualTo(eventsToContents(inputEvents)); + } + + @Test + public void rearrangeHistory_parallelCallsSeparateResponseEvents_mergesResponses() { + Event fcEvent = createParallelFunctionCallEvent("fc1", "tool1", "call1", "tool2", "call2"); + Event frEvent1 = createFunctionResponseEvent("fr1", "tool1", "call1"); + Event frEvent2 = createFunctionResponseEvent("fr2", "tool2", "call2"); + ImmutableList inputEvents = + ImmutableList.of(createUserEvent("u1", "Query"), fcEvent, frEvent1, frEvent2); + + List result = runContentsProcessorWithModelName(inputEvents, "gemini-3-flash-exp"); + + assertThat(result).hasSize(3); // u1, fc1, merged_fr + assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); + Content mergedContent = result.get(2); + assertThat(mergedContent.parts().get()).hasSize(2); + assertThat(mergedContent.parts().get().get(0).functionResponse().get().name()) + .hasValue("tool1"); + assertThat(mergedContent.parts().get().get(1).functionResponse().get().name()) + .hasValue("tool2"); + } + + @Test + public void rearrangeHistory_parallelCallsSeparateResponseEventsInHistory_mergesResponses() { + Event fcEvent = createParallelFunctionCallEvent("fc1", "tool1", "call1", "tool2", "call2"); + Event frEvent1 = createFunctionResponseEvent("fr1", "tool1", "call1"); + Event frEvent2 = createFunctionResponseEvent("fr2", "tool2", "call2"); + Event u2 = createUserEvent("u2", "Second Query"); + ImmutableList inputEvents = + ImmutableList.of(createUserEvent("u1", "Query"), fcEvent, frEvent1, frEvent2, u2); + + List result = runContentsProcessor(inputEvents); + + assertThat(result).hasSize(4); // u1, fc1, merged_fr, u2 + assertThat(result.get(0)).isEqualTo(inputEvents.get(0).content().get()); + assertThat(result.get(1)).isEqualTo(inputEvents.get(1).content().get()); + Content mergedContent = result.get(2); assertThat(mergedContent.parts().get()).hasSize(2); assertThat(mergedContent.parts().get().get(0).functionResponse().get().name()) .hasValue("tool1"); assertThat(mergedContent.parts().get().get(1).functionResponse().get().name()) .hasValue("tool2"); + assertThat(result.get(3)).isEqualTo(inputEvents.get(4).content().get()); } @Test diff --git a/core/src/test/java/com/google/adk/models/GeminiTest.java b/core/src/test/java/com/google/adk/models/GeminiTest.java index c230f5f68..6ed1c776a 100644 --- a/core/src/test/java/com/google/adk/models/GeminiTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiTest.java @@ -17,6 +17,7 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Candidate; import com.google.genai.types.Content; @@ -27,6 +28,7 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.functions.Predicate; import io.reactivex.rxjava3.subscribers.TestSubscriber; +import java.nio.charset.StandardCharsets; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -63,6 +65,52 @@ public void processRawResponses_withTextChunks_emitsPartialResponses() { isFunctionCallResponse()); } + @Test + public void processRawResponses_chunkWithBothTextAndFunctionCall_emitsPartialWithBoth() { + GenerateContentResponse chunkWithBoth = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.fromText("Here is the call:"), + Part.fromFunctionCall("my_tool", ImmutableMap.of())) + .build()) + .build()) + .build(); + + Flowable rawResponses = Flowable.just(chunkWithBoth); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, isPartialTextAndFunctionCallResponse("Here is the call:", "my_tool")); + } + + @Test + public void processRawResponses_streamingFunctionCallsAndStop_emitsPartialsThenFinalAggregated() { + Part fc1 = Part.fromFunctionCall("tool1", ImmutableMap.of("arg1", "val1")); + Part fc2 = Part.fromFunctionCall("tool2", ImmutableMap.of("arg2", "val2")); + GenerateContentResponse fc2WithStop = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(fc2).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .build(); + Flowable rawResponses = Flowable.just(toResponse(fc1), fc2WithStop); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("tool1"), + isPartialFunctionCallResponse("tool2"), + isFinalAggregatedFunctionCallResponse("tool1", "tool2")); + } + @Test public void processRawResponses_textAndStopReason_emitsPartialThenFinalText() { Flowable rawResponses = @@ -111,17 +159,14 @@ public void processRawResponses_finishReasonNotStop_doesNotEmitFinalAccumulatedT } @Test - public void processRawResponses_textThenEmpty_emitsPartialTextThenFullTextAndEmpty() { + public void processRawResponses_textThenEmpty_emitsPartialTextThenFullText() { Flowable rawResponses = Flowable.just(toResponseWithText("Thinking..."), GenerateContentResponse.builder().build()); Flowable llmResponses = Gemini.processRawResponses(rawResponses); assertLlmResponses( - llmResponses, - isPartialTextResponse("Thinking..."), - isFinalTextResponse("Thinking..."), - isEmptyResponse()); + llmResponses, isPartialTextResponse("Thinking..."), isFinalTextResponse("Thinking...")); } @Test @@ -157,6 +202,27 @@ public void processRawResponses_textAndStopReason_finalResponseIncludesUsageMeta isFinalTextResponseWithUsageMetadata("Hello world", metadata)); } + @Test + public void + processRawResponses_textThenEmptyStopWithUsageMetadata_finalResponseIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(10, 20, 30); + GenerateContentResponse stopResponse = + GenerateContentResponse.builder() + .candidates( + Candidate.builder().finishReason(new FinishReason(FinishReason.Known.STOP)).build()) + .usageMetadata(metadata) + .build(); + Flowable rawResponses = + Flowable.just(toResponseWithText("Hello"), stopResponse); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponse("Hello"), + isFinalTextResponseWithUsageMetadata("Hello", metadata)); + } + @Test public void processRawResponses_thoughtChunksAndStop_includeUsageMetadata() { GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); @@ -194,8 +260,256 @@ public void processRawResponses_thoughtAndTextWithStop_onlyFinalTextIncludesUsag isFinalTextResponseWithUsageMetadata("Answer", metadata2)); } - // Helper methods for assertions + @Test + public void + processRawResponses_textAndFunctionCallWithStop_onlyFinalFunctionCallIncludesUsageMetadata() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 5, 10); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(10, 20, 30); + Part fcPart = Part.fromFunctionCall("my_tool", ImmutableMap.of()); + GenerateContentResponse stopResponse = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(fcPart).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = + Flowable.just(toResponseWithText("Answer", metadata1), stopResponse); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialTextResponseWithUsageMetadata("Answer", metadata1), + isFinalTextResponseWithNoUsageMetadata("Answer"), + isPartialFunctionCallResponse("my_tool"), + isFinalAggregatedFunctionCallResponseWithUsageMetadata(metadata2, "my_tool")); + } + + @Test + public void + processRawResponses_thoughtThenEmptyWithSignatureAndStop_flushesThoughtWithSignature() { + GenerateContentResponseUsageMetadata metadata1 = createUsageMetadata(5, 10, 15); + GenerateContentResponseUsageMetadata metadata2 = createUsageMetadata(5, 20, 25); + GenerateContentResponse chunk1 = toResponseWithThoughtText("Thinking", metadata1); + GenerateContentResponse chunk2 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .thought(true) + .thoughtSignature("sig".getBytes(StandardCharsets.UTF_8)) + .build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata2) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2); + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isPartialThoughtResponseWithUsageMetadata("Thinking", metadata1), + isFinalThoughtResponseWithUsageMetadataAndSignature("Thinking", metadata2, "sig")); + } + + @Test + public void processRawResponses_emptyPartsThenSignature_doesNotThrowException() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(5, 10, 15); + GenerateContentResponse chunk1 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content(Content.builder().parts(ImmutableList.of()).build()) + .build()) + .build(); + GenerateContentResponse chunk2 = + GenerateContentResponse.builder() + .candidates( + Candidate.builder() + .content( + Content.builder() + .parts( + Part.builder() + .thought(true) + .thoughtSignature("sig".getBytes(StandardCharsets.UTF_8)) + .build()) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build()) + .usageMetadata(metadata) + .build(); + Flowable rawResponses = Flowable.just(chunk1, chunk2); + + Flowable llmResponses = Gemini.processRawResponses(rawResponses); + + assertLlmResponses( + llmResponses, + isEmptyResponse(), + isFinalThoughtResponseWithUsageMetadataAndSignature("", metadata, "sig")); + } + + @Test + public void functionCallThenEmptyTextWithStop_emitsPartialThenFinalAggregatedFunctionCall() { + Flowable rawResponses = + Flowable.just( + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("", FinishReason.Known.STOP)); + + Flowable llmResponses = + Gemini.processRawResponses(rawResponses).filter(Gemini::shouldEmit); + + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("test_function"), + isFinalAggregatedFunctionCallResponse("test_function")); + } + + @Test + public void functionCallThenEmptyTextWithUsageMetadata_emitsFinalAggregatedWithUsageMetadata() { + GenerateContentResponseUsageMetadata metadata = createUsageMetadata(5, 10, 15); + Flowable rawResponses = + Flowable.just( + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("", FinishReason.Known.STOP, metadata)); + + Flowable llmResponses = + Gemini.processRawResponses(rawResponses).filter(Gemini::shouldEmit); + + assertLlmResponses( + llmResponses, + isPartialFunctionCallResponse("test_function"), + isFinalAggregatedFunctionCallResponseWithUsageMetadata(metadata, "test_function")); + } + + @Test + public void functionCallThenEmptyText_doesNotEmitExtraEmptyResponse() { + Flowable rawResponses = + Flowable.just( + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("")); + + Flowable llmResponses = + Gemini.processRawResponses(rawResponses).filter(Gemini::shouldEmit); + + assertLlmResponses(llmResponses, isPartialFunctionCallResponse("test_function")); + } + + @Test + public void textThenFunctionCallThenEmptyTextWithStop_emitsTextThenFunctionCalls() { + Flowable rawResponses = + Flowable.just( + toResponseWithText("Thinking..."), + toResponse(Part.fromFunctionCall("test_function", ImmutableMap.of())), + toResponseWithText("", FinishReason.Known.STOP)); + + Flowable llmResponses = + Gemini.processRawResponses(rawResponses).filter(Gemini::shouldEmit); + + assertLlmResponses( + llmResponses, + isPartialTextResponse("Thinking..."), + isFinalTextResponse("Thinking..."), + isPartialFunctionCallResponse("test_function"), + isFinalAggregatedFunctionCallResponse("test_function")); + } + + // Test cases for the shouldEmit filter applied by generateContent after processRawResponses. + // shouldEmit drops chunks that are empty-text-only unless they carry final metadata (usage + // metadata or finish reason); everything else is forwarded. + // processRawResponses normally already strips empty-text-only chunks, so shouldEmit + // is defense-in-depth, but it must still behave correctly when fed any LlmResponse directly. + + @Test + public void shouldEmit_emptyTextOnlyResponseWithNoMetadata_returnsFalse() { + LlmResponse response = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .build(); + + assertThat(Gemini.shouldEmit(response)).isFalse(); + } + + @Test + public void shouldEmit_emptyTextOnlyResponseWithFinishReason_returnsTrue() { + LlmResponse response = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_emptyTextOnlyResponseWithUsageMetadata_returnsTrue() { + LlmResponse response = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("")).build()) + .usageMetadata(createUsageMetadata(5, 10, 15)) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_nonEmptyTextResponse_returnsTrue() { + LlmResponse response = + LlmResponse.builder() + .content(Content.builder().role("model").parts(Part.fromText("hello")).build()) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_functionCallResponse_returnsTrue() { + LlmResponse response = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromFunctionCall("test_function", ImmutableMap.of())) + .build()) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_contentlessResponse_returnsTrue() { + // A response with no content at all is not an empty-text-only response, so it should pass + // through regardless of metadata. This is the shape emitted by processRawResponses after it + // strips empty-text content while preserving metadata. + LlmResponse response = LlmResponse.builder().build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + @Test + public void shouldEmit_multiPartResponseWithEmptyTextPart_returnsTrue() { + // Only single-part empty-text responses are considered "empty-text-only". A multi-part response + // is treated as carrying semantic content and must always pass through. + LlmResponse response = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(""), Part.fromText("hello")) + .build()) + .build(); + + assertThat(Gemini.shouldEmit(response)).isTrue(); + } + + // Helper methods for assertions private void assertLlmResponses( Flowable llmResponses, Predicate... predicates) { TestSubscriber testSubscriber = llmResponses.test(); @@ -232,6 +546,66 @@ private static Predicate isFunctionCallResponse() { }; } + private static Predicate isPartialFunctionCallResponse(String expectedToolName) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(response.content().get().parts().get()).hasSize(1); + assertThat(response.content().get().parts().get().get(0).functionCall().get().name()) + .hasValue(expectedToolName); + return true; + }; + } + + private static Predicate isPartialTextAndFunctionCallResponse( + String expectedText, String expectedToolName) { + return response -> { + assertThat(response.partial()).hasValue(true); + assertThat(response.content().get().parts().get()).hasSize(2); + assertThat(response.content().get().parts().get().get(0).text()).hasValue(expectedText); + assertThat(response.content().get().parts().get().get(1).functionCall().get().name()) + .hasValue(expectedToolName); + return true; + }; + } + + private static Predicate isFinalAggregatedFunctionCallResponse( + String... expectedToolNames) { + return response -> { + assertThat(response.partial()).hasValue(false); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } + return true; + }; + } + + private static Predicate isFinalAggregatedFunctionCallResponseWithUsageMetadata( + GenerateContentResponseUsageMetadata expectedMetadata, String... expectedToolNames) { + return response -> { + assertThat(response.partial()).hasValue(false); + assertThat(response.content().get().parts().get()).hasSize(expectedToolNames.length); + for (int i = 0; i < expectedToolNames.length; i++) { + assertThat(response.content().get().parts().get().get(i).functionCall().get().name()) + .hasValue(expectedToolNames[i]); + } + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + + private static Predicate isFinalTextResponseWithNoUsageMetadata( + String expectedText) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(response.usageMetadata()).isEmpty(); + return true; + }; + } + private static Predicate isEmptyResponse() { return response -> { assertThat(response.partial()).isEmpty(); @@ -302,8 +676,28 @@ private static Predicate isFinalThoughtResponseWithNoUsageMetadata( }; } - // Helper methods to create responses for testing + private static Predicate isFinalThoughtResponseWithUsageMetadataAndSignature( + String expectedText, + GenerateContentResponseUsageMetadata expectedMetadata, + String expectedSignature) { + return response -> { + assertThat(response.partial()).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::thought).orElse(false)) + .isTrue(); + assertThat( + GeminiUtil.getPart0FromLlmResponse(response) + .flatMap(Part::thoughtSignature) + .orElse(new byte[0])) + .isEqualTo(expectedSignature.getBytes(StandardCharsets.UTF_8)); + + assertThat(response.usageMetadata()).hasValue(expectedMetadata); + return true; + }; + } + // Helper methods to create responses for testing private GenerateContentResponse toResponseWithText(String text) { return toResponse(Part.fromText(text)); } @@ -316,14 +710,6 @@ private GenerateContentResponse toResponseWithText(String text, FinishReason.Kno .build()); } - private GenerateContentResponse toResponse(Part part) { - return toResponse(Candidate.builder().content(Content.builder().parts(part).build()).build()); - } - - private GenerateContentResponse toResponse(Candidate candidate) { - return GenerateContentResponse.builder().candidates(candidate).build(); - } - private GenerateContentResponse toResponseWithText( String text, GenerateContentResponseUsageMetadata usageMetadata) { return GenerateContentResponse.builder() @@ -349,6 +735,14 @@ private GenerateContentResponse toResponseWithText( .build(); } + private GenerateContentResponse toResponse(Part part) { + return toResponse(Candidate.builder().content(Content.builder().parts(part).build()).build()); + } + + private GenerateContentResponse toResponse(Candidate candidate) { + return GenerateContentResponse.builder().candidates(candidate).build(); + } + private GenerateContentResponse toResponseWithThoughtText( String text, GenerateContentResponseUsageMetadata usageMetadata) { Part thoughtPart = Part.fromText(text).toBuilder().thought(true).build(); diff --git a/tutorials/city-time-weather/src/main/java/com/google/adk/tutorials/CityTimeWeather.java b/tutorials/city-time-weather/src/main/java/com/google/adk/tutorials/CityTimeWeather.java index 18c8f8786..85347c964 100644 --- a/tutorials/city-time-weather/src/main/java/com/google/adk/tutorials/CityTimeWeather.java +++ b/tutorials/city-time-weather/src/main/java/com/google/adk/tutorials/CityTimeWeather.java @@ -31,7 +31,7 @@ public class CityTimeWeather { public static final BaseAgent ROOT_AGENT = LlmAgent.builder() .name("multi_tool_agent") - .model("gemini-2.0-flash-lite") + .model("gemini-3.1-flash-lite") .description("Agent to answer questions about the time and weather in a city.") .instruction( "You are a helpful agent who can answer user questions about the time and weather in"