From 6745bb7b53828d5651027fbbc252dedd07031117 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 5 Feb 2026 10:27:46 -0800 Subject: [PATCH] fix: emit multiple LlmResponses in GeminiLlmConnection A single LiveServerMessage is now converted to a series of LlmResponse messages each corresponding to a different part of the LiveServerMessage, notably the UsageMetadata field is now converted to a GenerateResponseUsageMetadata and emitted downstream. PiperOrigin-RevId: 866010045 --- .../adk/models/GeminiLlmConnection.java | 139 ++++++++++++----- .../com/google/adk/models/GeminiUtil.java | 20 +++ .../adk/models/GeminiLlmConnectionTest.java | 144 ++++++++++++++++-- 3 files changed, 248 insertions(+), 55 deletions(-) diff --git a/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java b/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java index e8ae485d7..2e1229d0b 100644 --- a/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java +++ b/core/src/main/java/com/google/adk/models/GeminiLlmConnection.java @@ -34,8 +34,11 @@ import com.google.genai.types.LiveServerMessage; import com.google.genai.types.LiveServerToolCall; import com.google.genai.types.Part; +import com.google.genai.types.UsageMetadata; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.processors.PublishProcessor; import java.net.SocketException; import java.util.List; @@ -65,6 +68,7 @@ public final class GeminiLlmConnection implements BaseLlmConnection { private final CompletableFuture sessionFuture; private final PublishProcessor responseProcessor = PublishProcessor.create(); private final Flowable responseFlowable = responseProcessor.serialize(); + private final CompositeDisposable disposables = new CompositeDisposable(); private final AtomicBoolean closed = new AtomicBoolean(false); /** @@ -120,53 +124,104 @@ private void handleServerMessage(LiveServerMessage message) { logger.debug("Received server message: {}", message.toJson()); - Optional llmResponse = convertToServerResponse(message); - llmResponse.ifPresent(responseProcessor::onNext); + Observable llmResponse = convertToServerResponse(message); + if (!disposables.add( + llmResponse.subscribe(responseProcessor::onNext, responseProcessor::onError))) { + logger.warn( + "disposables container already disposed, the subscription will be disposed immediately"); + } } /** Converts a server message into the standardized LlmResponse format. */ - static Optional convertToServerResponse(LiveServerMessage message) { + static Observable convertToServerResponse(LiveServerMessage message) { + return Observable.create( + emitter -> { + // AtomicBoolean is used to modify state from within lambdas, which + // require captured variables to be effectively final. + final AtomicBoolean handled = new AtomicBoolean(false); + message + .serverContent() + .ifPresent( + serverContent -> { + emitter.onNext(createServerContentResponse(serverContent)); + handled.set(true); + }); + message + .toolCall() + .ifPresent( + toolCall -> { + emitter.onNext(createToolCallResponse(toolCall)); + handled.set(true); + }); + message + .usageMetadata() + .ifPresent( + usageMetadata -> { + logger.debug("Received usage metadata: {}", usageMetadata); + emitter.onNext(createUsageMetadataResponse(usageMetadata)); + handled.set(true); + }); + message + .toolCallCancellation() + .ifPresent( + toolCallCancellation -> { + logger.debug("Received tool call cancellation: {}", toolCallCancellation); + // TODO: implement proper CFC and thus tool call cancellation handling. + handled.set(true); + }); + message + .setupComplete() + .ifPresent( + setupComplete -> { + logger.debug("Received setup complete."); + handled.set(true); + }); + + if (!handled.get()) { + logger.warn("Received unknown or empty server message: {}", message.toJson()); + emitter.onNext(createUnknownMessageResponse()); + } + emitter.onComplete(); + }); + } + + private static LlmResponse createServerContentResponse(LiveServerContent serverContent) { LlmResponse.Builder builder = LlmResponse.builder(); + serverContent.modelTurn().ifPresent(builder::content); + return builder + .partial(serverContent.turnComplete().map(completed -> !completed).orElse(false)) + .turnComplete(serverContent.turnComplete().orElse(false)) + .interrupted(serverContent.interrupted()) + .build(); + } - if (message.serverContent().isPresent()) { - LiveServerContent serverContent = message.serverContent().get(); - serverContent.modelTurn().ifPresent(builder::content); - builder - .partial(serverContent.turnComplete().map(completed -> !completed).orElse(false)) - .turnComplete(serverContent.turnComplete().orElse(false)) - .interrupted(serverContent.interrupted()); - } else if (message.toolCall().isPresent()) { - LiveServerToolCall toolCall = message.toolCall().get(); - toolCall - .functionCalls() - .ifPresent( - calls -> { - for (FunctionCall call : calls) { - builder.content( - Content.builder() - .parts(ImmutableList.of(Part.builder().functionCall(call).build())) - .build()); - } - }); - builder.partial(false).turnComplete(false); - } else if (message.usageMetadata().isPresent()) { - logger.debug("Received usage metadata: {}", message.usageMetadata().get()); - return Optional.empty(); - } else if (message.toolCallCancellation().isPresent()) { - logger.debug("Received tool call cancellation: {}", message.toolCallCancellation().get()); - // TODO: implement proper CFC and thus tool call cancellation handling. - return Optional.empty(); - } else if (message.setupComplete().isPresent()) { - logger.debug("Received setup complete."); - return Optional.empty(); - } else { - logger.warn("Received unknown or empty server message: {}", message.toJson()); - builder - .errorCode(new FinishReason("Unknown server message.")) - .errorMessage("Received unknown server message."); - } + private static LlmResponse createToolCallResponse(LiveServerToolCall toolCall) { + LlmResponse.Builder builder = LlmResponse.builder(); + toolCall + .functionCalls() + .ifPresent( + calls -> { + for (FunctionCall call : calls) { + builder.content( + Content.builder() + .parts(ImmutableList.of(Part.builder().functionCall(call).build())) + .build()); + } + }); + return builder.partial(false).turnComplete(false).build(); + } - return Optional.of(builder.build()); + private static LlmResponse createUsageMetadataResponse(UsageMetadata usageMetadata) { + return LlmResponse.builder() + .usageMetadata(GeminiUtil.toGenerateContentResponseUsageMetadata(usageMetadata)) + .build(); + } + + private static LlmResponse createUnknownMessageResponse() { + return LlmResponse.builder() + .errorCode(new FinishReason("Unknown server message.")) + .errorMessage("Received unknown server message.") + .build(); } /** Handles errors that occur *during* the initial connection attempt. */ @@ -281,6 +336,8 @@ private void closeInternal(Throwable throwable) { } else { sessionFuture.cancel(false); } + + disposables.dispose(); } } diff --git a/core/src/main/java/com/google/adk/models/GeminiUtil.java b/core/src/main/java/com/google/adk/models/GeminiUtil.java index 2b95c0ab2..319226d69 100644 --- a/core/src/main/java/com/google/adk/models/GeminiUtil.java +++ b/core/src/main/java/com/google/adk/models/GeminiUtil.java @@ -24,7 +24,9 @@ import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FileData; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; +import com.google.genai.types.UsageMetadata; import java.util.List; import java.util.Optional; import java.util.stream.Stream; @@ -224,4 +226,22 @@ public static List stripThoughts(List originalContents) { }) .collect(toImmutableList()); } + + public static GenerateContentResponseUsageMetadata toGenerateContentResponseUsageMetadata( + UsageMetadata usageMetadata) { + GenerateContentResponseUsageMetadata.Builder builder = + GenerateContentResponseUsageMetadata.builder(); + usageMetadata.promptTokenCount().ifPresent(builder::promptTokenCount); + usageMetadata.cachedContentTokenCount().ifPresent(builder::cachedContentTokenCount); + usageMetadata.responseTokenCount().ifPresent(builder::candidatesTokenCount); + usageMetadata.toolUsePromptTokenCount().ifPresent(builder::toolUsePromptTokenCount); + usageMetadata.thoughtsTokenCount().ifPresent(builder::thoughtsTokenCount); + usageMetadata.totalTokenCount().ifPresent(builder::totalTokenCount); + usageMetadata.promptTokensDetails().ifPresent(builder::promptTokensDetails); + usageMetadata.cacheTokensDetails().ifPresent(builder::cacheTokensDetails); + usageMetadata.responseTokensDetails().ifPresent(builder::candidatesTokensDetails); + usageMetadata.toolUsePromptTokensDetails().ifPresent(builder::toolUsePromptTokensDetails); + usageMetadata.trafficType().ifPresent(builder::trafficType); + return builder.build(); + } } diff --git a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java index 5d70cc449..a3ac09fe5 100644 --- a/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiLlmConnectionTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.LiveServerContent; import com.google.genai.types.LiveServerMessage; import com.google.genai.types.LiveServerSetupComplete; @@ -28,6 +29,8 @@ import com.google.genai.types.LiveServerToolCallCancellation; import com.google.genai.types.Part; import com.google.genai.types.UsageMetadata; +import io.reactivex.rxjava3.observers.TestObserver; +import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -45,8 +48,13 @@ public void convertToServerResponse_withInterruptedTrue_mapsInterruptedField() { .build(); LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); + TestObserver testObserver = new TestObserver<>(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.content()).isPresent(); assertThat(response.content().get().text()).isEqualTo("Model response"); @@ -66,8 +74,13 @@ public void convertToServerResponse_withInterruptedFalse_mapsInterruptedField() LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.interrupted()).hasValue(false); assertThat(response.turnComplete()).hasValue(false); } @@ -82,8 +95,13 @@ public void convertToServerResponse_withoutInterruptedField_mapsEmptyOptional() LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.interrupted()).isEmpty(); assertThat(response.turnComplete()).hasValue(true); } @@ -98,8 +116,13 @@ public void convertToServerResponse_withTurnCompleteTrue_mapsPartialFalse() { LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.partial()).hasValue(false); assertThat(response.turnComplete()).hasValue(true); } @@ -114,8 +137,13 @@ public void convertToServerResponse_withTurnCompleteFalse_mapsPartialTrue() { LiveServerMessage message = LiveServerMessage.builder().serverContent(serverContent).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.partial()).hasValue(true); assertThat(response.turnComplete()).hasValue(false); } @@ -128,8 +156,13 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() { LiveServerMessage message = LiveServerMessage.builder().toolCall(toolCall).build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.content()).isPresent(); assertThat(response.content().get().parts()).isPresent(); assertThat(response.content().get().parts().get()).hasSize(1); @@ -139,40 +172,123 @@ public void convertToServerResponse_withToolCall_mapsContentWithFunctionCall() { } @Test - public void convertToServerResponse_withUsageMetadata_returnsEmptyOptional() { + public void convertToServerResponse_withUsageMetadata_mapsGenerateResponseUsageMetadata() { LiveServerMessage message = - LiveServerMessage.builder().usageMetadata(UsageMetadata.builder().build()).build(); + LiveServerMessage.builder() + .usageMetadata( + UsageMetadata.builder() + .promptTokenCount(10) + .responseTokenCount(20) + .totalTokenCount(30) + .build()) + .build(); - assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); + assertThat(response.usageMetadata()).isPresent(); + GenerateContentResponseUsageMetadata expectedUsageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + assertThat(response.usageMetadata()).hasValue(expectedUsageMetadata); } @Test - public void convertToServerResponse_withToolCallCancellation_returnsEmptyOptional() { + public void convertToServerResponse_withToolCallCancellation_returnsNoValues() { LiveServerMessage message = LiveServerMessage.builder() .toolCallCancellation(LiveServerToolCallCancellation.builder().build()) .build(); - assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertNoValues(); + testObserver.assertComplete(); } @Test - public void convertToServerResponse_withSetupComplete_returnsEmptyOptional() { + public void convertToServerResponse_withSetupComplete_returnsNoValues() { LiveServerMessage message = LiveServerMessage.builder() .setupComplete(LiveServerSetupComplete.builder().build()) .build(); - assertThat(GeminiLlmConnection.convertToServerResponse(message)).isEmpty(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertNoValues(); + testObserver.assertComplete(); } @Test public void convertToServerResponse_withUnknownMessage_returnsErrorResponse() { LiveServerMessage message = LiveServerMessage.builder().build(); - LlmResponse response = GeminiLlmConnection.convertToServerResponse(message).get(); + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + testObserver.assertValueCount(1); + testObserver.assertComplete(); + LlmResponse response = testObserver.values().get(0); assertThat(response.errorCode()).isPresent(); assertThat(response.errorMessage()).hasValue("Received unknown server message."); } + + @Test + public void convertToServerResponse_withContentAndUsageMetadata_emitsMultiple() { + LiveServerContent serverContent = + LiveServerContent.builder() + .modelTurn(Content.fromParts(Part.fromText("Model response"))) + .turnComplete(true) + .build(); + + UsageMetadata usageMetadata = + UsageMetadata.builder() + .promptTokenCount(10) + .responseTokenCount(20) + .totalTokenCount(30) + .build(); + + LiveServerMessage message = + LiveServerMessage.builder() + .serverContent(serverContent) + .usageMetadata(usageMetadata) + .build(); + + TestObserver testObserver = new TestObserver<>(); + + GeminiLlmConnection.convertToServerResponse(message).subscribe(testObserver); + + testObserver.assertValueCount(2); + testObserver.assertComplete(); + + List responses = testObserver.values(); + + // Check for ServerContent response + LlmResponse contentResponse = responses.get(0); + assertThat(contentResponse.content()).isPresent(); + assertThat(contentResponse.content().get().text()).isEqualTo("Model response"); + assertThat(contentResponse.usageMetadata()).isEmpty(); + + // Check for UsageMetadata response + LlmResponse usageResponse = responses.get(1); + assertThat(usageResponse.content()).isEmpty(); + assertThat(usageResponse.usageMetadata()).isPresent(); + GenerateContentResponseUsageMetadata expectedUsageMetadata = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build(); + assertThat(usageResponse.usageMetadata()).hasValue(expectedUsageMetadata); + } }