Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 98 additions & 41 deletions core/src/main/java/com/google/adk/models/GeminiLlmConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@
import com.google.genai.types.LiveServerMessage;
import com.google.genai.types.LiveServerToolCall;
import com.google.genai.types.Part;
import com.google.genai.types.UsageMetadata;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.disposables.CompositeDisposable;
import io.reactivex.rxjava3.processors.PublishProcessor;
import java.net.SocketException;
import java.util.List;
Expand Down Expand Up @@ -65,6 +68,7 @@ public final class GeminiLlmConnection implements BaseLlmConnection {
private final CompletableFuture<AsyncSession> sessionFuture;
private final PublishProcessor<LlmResponse> responseProcessor = PublishProcessor.create();
private final Flowable<LlmResponse> responseFlowable = responseProcessor.serialize();
private final CompositeDisposable disposables = new CompositeDisposable();
private final AtomicBoolean closed = new AtomicBoolean(false);

/**
Expand Down Expand Up @@ -120,53 +124,104 @@ private void handleServerMessage(LiveServerMessage message) {

logger.debug("Received server message: {}", message.toJson());

Optional<LlmResponse> llmResponse = convertToServerResponse(message);
llmResponse.ifPresent(responseProcessor::onNext);
Observable<LlmResponse> llmResponse = convertToServerResponse(message);
if (!disposables.add(
llmResponse.subscribe(responseProcessor::onNext, responseProcessor::onError))) {
logger.warn(
"disposables container already disposed, the subscription will be disposed immediately");
}
}

/** Converts a server message into the standardized LlmResponse format. */
static Optional<LlmResponse> convertToServerResponse(LiveServerMessage message) {
static Observable<LlmResponse> convertToServerResponse(LiveServerMessage message) {
return Observable.create(
emitter -> {
// AtomicBoolean is used to modify state from within lambdas, which
// require captured variables to be effectively final.
final AtomicBoolean handled = new AtomicBoolean(false);
message
.serverContent()
.ifPresent(
serverContent -> {
emitter.onNext(createServerContentResponse(serverContent));
handled.set(true);
});
message
.toolCall()
.ifPresent(
toolCall -> {
emitter.onNext(createToolCallResponse(toolCall));
handled.set(true);
});
message
.usageMetadata()
.ifPresent(
usageMetadata -> {
logger.debug("Received usage metadata: {}", usageMetadata);
emitter.onNext(createUsageMetadataResponse(usageMetadata));
handled.set(true);
});
message
.toolCallCancellation()
.ifPresent(
toolCallCancellation -> {
logger.debug("Received tool call cancellation: {}", toolCallCancellation);
// TODO: implement proper CFC and thus tool call cancellation handling.
handled.set(true);
});
message
.setupComplete()
.ifPresent(
setupComplete -> {
logger.debug("Received setup complete.");
handled.set(true);
});

if (!handled.get()) {
logger.warn("Received unknown or empty server message: {}", message.toJson());
emitter.onNext(createUnknownMessageResponse());
}
emitter.onComplete();
});
}

private static LlmResponse createServerContentResponse(LiveServerContent serverContent) {
LlmResponse.Builder builder = LlmResponse.builder();
serverContent.modelTurn().ifPresent(builder::content);
return builder
.partial(serverContent.turnComplete().map(completed -> !completed).orElse(false))
.turnComplete(serverContent.turnComplete().orElse(false))
.interrupted(serverContent.interrupted())
.build();
}

if (message.serverContent().isPresent()) {
LiveServerContent serverContent = message.serverContent().get();
serverContent.modelTurn().ifPresent(builder::content);
builder
.partial(serverContent.turnComplete().map(completed -> !completed).orElse(false))
.turnComplete(serverContent.turnComplete().orElse(false))
.interrupted(serverContent.interrupted());
} else if (message.toolCall().isPresent()) {
LiveServerToolCall toolCall = message.toolCall().get();
toolCall
.functionCalls()
.ifPresent(
calls -> {
for (FunctionCall call : calls) {
builder.content(
Content.builder()
.parts(ImmutableList.of(Part.builder().functionCall(call).build()))
.build());
}
});
builder.partial(false).turnComplete(false);
} else if (message.usageMetadata().isPresent()) {
logger.debug("Received usage metadata: {}", message.usageMetadata().get());
return Optional.empty();
} else if (message.toolCallCancellation().isPresent()) {
logger.debug("Received tool call cancellation: {}", message.toolCallCancellation().get());
// TODO: implement proper CFC and thus tool call cancellation handling.
return Optional.empty();
} else if (message.setupComplete().isPresent()) {
logger.debug("Received setup complete.");
return Optional.empty();
} else {
logger.warn("Received unknown or empty server message: {}", message.toJson());
builder
.errorCode(new FinishReason("Unknown server message."))
.errorMessage("Received unknown server message.");
}
private static LlmResponse createToolCallResponse(LiveServerToolCall toolCall) {
LlmResponse.Builder builder = LlmResponse.builder();
toolCall
.functionCalls()
.ifPresent(
calls -> {
for (FunctionCall call : calls) {
builder.content(
Content.builder()
.parts(ImmutableList.of(Part.builder().functionCall(call).build()))
.build());
}
});
return builder.partial(false).turnComplete(false).build();
}

return Optional.of(builder.build());
private static LlmResponse createUsageMetadataResponse(UsageMetadata usageMetadata) {
return LlmResponse.builder()
.usageMetadata(GeminiUtil.toGenerateContentResponseUsageMetadata(usageMetadata))
.build();
}

private static LlmResponse createUnknownMessageResponse() {
return LlmResponse.builder()
.errorCode(new FinishReason("Unknown server message."))
.errorMessage("Received unknown server message.")
.build();
}

/** Handles errors that occur *during* the initial connection attempt. */
Expand Down Expand Up @@ -281,6 +336,8 @@ private void closeInternal(Throwable throwable) {
} else {
sessionFuture.cancel(false);
}

disposables.dispose();
}
}

Expand Down
20 changes: 20 additions & 0 deletions core/src/main/java/com/google/adk/models/GeminiUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import com.google.genai.types.Blob;
import com.google.genai.types.Content;
import com.google.genai.types.FileData;
import com.google.genai.types.GenerateContentResponseUsageMetadata;
import com.google.genai.types.Part;
import com.google.genai.types.UsageMetadata;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
Expand Down Expand Up @@ -224,4 +226,22 @@ public static List<Content> stripThoughts(List<Content> originalContents) {
})
.collect(toImmutableList());
}

public static GenerateContentResponseUsageMetadata toGenerateContentResponseUsageMetadata(
UsageMetadata usageMetadata) {
GenerateContentResponseUsageMetadata.Builder builder =
GenerateContentResponseUsageMetadata.builder();
usageMetadata.promptTokenCount().ifPresent(builder::promptTokenCount);
usageMetadata.cachedContentTokenCount().ifPresent(builder::cachedContentTokenCount);
usageMetadata.responseTokenCount().ifPresent(builder::candidatesTokenCount);
usageMetadata.toolUsePromptTokenCount().ifPresent(builder::toolUsePromptTokenCount);
usageMetadata.thoughtsTokenCount().ifPresent(builder::thoughtsTokenCount);
usageMetadata.totalTokenCount().ifPresent(builder::totalTokenCount);
usageMetadata.promptTokensDetails().ifPresent(builder::promptTokensDetails);
usageMetadata.cacheTokensDetails().ifPresent(builder::cacheTokensDetails);
usageMetadata.responseTokensDetails().ifPresent(builder::candidatesTokensDetails);
usageMetadata.toolUsePromptTokensDetails().ifPresent(builder::toolUsePromptTokensDetails);
usageMetadata.trafficType().ifPresent(builder::trafficType);
return builder.build();
}
}
Loading