Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions core/src/main/java/com/google/adk/tools/LoadArtifactsTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import com.google.adk.models.LlmRequest;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.genai.types.Blob;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
Expand All @@ -31,7 +33,9 @@
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Observable;
import io.reactivex.rxjava3.core.Single;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
Expand All @@ -55,6 +59,12 @@
*/
public final class LoadArtifactsTool extends BaseTool {
public static final LoadArtifactsTool INSTANCE = new LoadArtifactsTool();
private static final ImmutableList<String> GEMINI_SUPPORTED_INLINE_MIME_PREFIXES =
ImmutableList.of("image/", "audio/", "video/");
private static final ImmutableSet<String> GEMINI_SUPPORTED_INLINE_MIME_TYPES =
ImmutableSet.of("application/pdf");
private static final ImmutableSet<String> TEXT_LIKE_MIME_TYPES =
ImmutableSet.of("application/csv", "application/json", "application/xml");

public LoadArtifactsTool() {
super("load_artifacts", "Loads the artifacts and adds them to the session.");
Expand Down Expand Up @@ -177,15 +187,75 @@ private Completable loadAndAppendIndividualArtifact(
appendArtifactToLlmRequest(
llmRequestBuilder,
"Artifact " + artifactName + " is:",
artifactName,
actualArtifact)));
}

private void appendArtifactToLlmRequest(
LlmRequest.Builder llmRequestBuilder, String prefix, Part artifact) {
LlmRequest.Builder llmRequestBuilder, String prefix, String artifactName, Part artifact) {
llmRequestBuilder.contents(
ImmutableList.<Content>builder()
.addAll(llmRequestBuilder.build().contents())
.add(Content.fromParts(Part.fromText(prefix), artifact))
.add(Content.fromParts(Part.fromText(prefix), asSafePartForLlm(artifact, artifactName)))
.build());
}

private static String normalizeMimeType(String mimeType) {
if (mimeType == null) {
return "";
}
int separatorIndex = mimeType.indexOf(';');
if (separatorIndex >= 0) {
mimeType = mimeType.substring(0, separatorIndex);
}
return mimeType.trim();
}

private static boolean isInlineMimeTypeSupported(String mimeType) {
String normalized = normalizeMimeType(mimeType);
if (normalized.isEmpty()) {
return false;
}
if (GEMINI_SUPPORTED_INLINE_MIME_TYPES.contains(normalized)) {
return true;
}
return GEMINI_SUPPORTED_INLINE_MIME_PREFIXES.stream().anyMatch(normalized::startsWith);
}

private static Part asSafePartForLlm(Part artifact, String artifactName) {
Optional<Blob> inlineData = artifact.inlineData();
if (inlineData.isEmpty()) {
return artifact;
}

Blob blob = inlineData.get();
if (isInlineMimeTypeSupported(blob.mimeType().orElse(null))) {
return artifact;
}

String mimeType = normalizeMimeType(blob.mimeType().orElse(null));
if (mimeType.isEmpty()) {
mimeType = "application/octet-stream";
}

Optional<byte[]> data = blob.data();
if (data.isEmpty()) {
return Part.fromText(
String.format(
"[Artifact: %s, type: %s. No inline data was provided.]", artifactName, mimeType));
}

if (mimeType.startsWith("text/") || TEXT_LIKE_MIME_TYPES.contains(mimeType)) {
return Part.fromText(new String(data.get(), StandardCharsets.UTF_8));
}

double sizeKb = data.get().length / 1024.0;
return Part.fromText(
String.format(
Locale.US,
"[Binary artifact: %s, type: %s, size: %.1f KB. Content cannot be displayed inline.]",
artifactName,
mimeType,
sizeKb));
}
}
116 changes: 116 additions & 0 deletions core/src/test/java/com/google/adk/tools/LoadArtifactsToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
Expand All @@ -17,13 +18,15 @@
import com.google.adk.sessions.Session;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Blob;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -218,4 +221,117 @@ public void processLlmRequest_artifactsInContext_withOtherFunctionCall_doesNotLo
.loadArtifact(anyString(), anyString(), anyString(), anyString(), anyInt());
assertThat(finalRequest.contents()).containsExactly(functionCallContent);
}

@Test
public void processLlmRequest_unsupportedTextLikeMime_convertsToText() {
String artifactName = "data.csv";
String csvContent = "col1,col2\n1,2\n";
Part artifactPart =
processLoadArtifactRequest(
artifactName,
Part.fromBytes(
csvContent.getBytes(StandardCharsets.UTF_8), "application/csv; charset=utf-8"));

assertThat(artifactPart.inlineData()).isEmpty();
assertThat(artifactPart.text()).hasValue(csvContent);
}

@Test
public void processLlmRequest_supportedMime_keepsInlineData() {
String artifactName = "file.pdf";
byte[] pdfBytes = "%PDF-1.4".getBytes(StandardCharsets.UTF_8);
Part artifactPart =
processLoadArtifactRequest(artifactName, Part.fromBytes(pdfBytes, "application/pdf"));

assertThat(artifactPart.inlineData()).isPresent();
assertThat(artifactPart.inlineData().get().mimeType()).hasValue("application/pdf");
assertThat(artifactPart.inlineData().get().data().get()).isEqualTo(pdfBytes);
}

@Test
public void processLlmRequest_unsupportedBinaryMime_convertsToPlaceholder() {
String artifactName = "slides.pptx";
Part artifactPart =
processLoadArtifactRequest(
artifactName,
Part.fromBytes(
new byte[] {1, 2, 3},
"application/vnd.openxmlformats-officedocument.presentationml.presentation"));

assertThat(artifactPart.inlineData()).isEmpty();
assertThat(artifactPart.text())
.hasValue(
"[Binary artifact: slides.pptx, type:"
+ " application/vnd.openxmlformats-officedocument.presentationml.presentation,"
+ " size: 0.0 KB. Content cannot be displayed inline.]");
}

@Test
public void processLlmRequest_unsupportedMimeWithoutInlineData_convertsToNoDataPlaceholder() {
String artifactName = "empty.bin";
Part artifactPart =
processLoadArtifactRequest(
artifactName,
Part.builder()
.inlineData(Blob.builder().mimeType("application/octet-stream").build())
.build());

assertThat(artifactPart.inlineData()).isEmpty();
assertThat(artifactPart.text())
.hasValue(
"[Artifact: empty.bin, type: application/octet-stream."
+ " No inline data was provided.]");
}

@Test
public void processLlmRequest_emptyMime_defaultsToOctetStream() {
String artifactName = "unknown";
Part artifactPart =
processLoadArtifactRequest(
artifactName,
Part.fromBytes(new byte[] {(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF}, ""));

assertThat(artifactPart.inlineData()).isEmpty();
assertThat(artifactPart.text())
.hasValue(
"[Binary artifact: unknown, type: application/octet-stream,"
+ " size: 0.0 KB. Content cannot be displayed inline.]");
}

private Part processLoadArtifactRequest(String artifactName, Part loadedArtifactPart) {
ImmutableList<String> availableArtifacts = ImmutableList.of(artifactName);
ImmutableList<String> artifactsToLoad = ImmutableList.of(artifactName);

FunctionResponse functionResponse =
FunctionResponse.builder()
.name("load_artifacts")
.response(ImmutableMap.of("artifact_names", artifactsToLoad))
.build();
Content functionCallContent =
Content.builder()
.role("model")
.parts(
ImmutableList.of(
Part.fromFunctionResponse(
functionResponse.name().get(), functionResponse.response().get())))
.build();
llmRequestBuilder.contents(ImmutableList.of(functionCallContent));

ToolContext spiedToolContext = spy(ToolContext.builder(mockInvocationContext).build());
doReturn(Single.just(availableArtifacts)).when(spiedToolContext).listArtifacts();
doReturn(Maybe.just(loadedArtifactPart)).when(spiedToolContext).loadArtifact(artifactName);

loadArtifactsTool.processLlmRequest(llmRequestBuilder, spiedToolContext).blockingAwait();
verify(spiedToolContext).loadArtifact(artifactName);

LlmRequest finalRequest = llmRequestBuilder.build();
assertThat(finalRequest.contents()).hasSize(2);
Content appendedContent = finalRequest.contents().get(1);
assertThat(appendedContent.role()).hasValue("user");
assertThat(appendedContent.parts()).isPresent();
assertThat(appendedContent.parts().get()).hasSize(2);
assertThat(appendedContent.parts().get().get(0).text())
.hasValue("Artifact " + artifactName + " is:");
return appendedContent.parts().get().get(1);
}
}