From e5c90f8d8aa3d5839316d7447ec2b13be337d642 Mon Sep 17 00:00:00 2001 From: Rodrigo Lazo Paz Date: Tue, 31 Mar 2026 00:28:51 -0400 Subject: [PATCH] [AI] Add integration tests for template models Introduced integration tests for `TemplateGenerativeModel` to validate its functionalities, with Chat features being tested in the `ChatTemplateIntegrationTests` file Updated `AIModels` to include `vertexAITemplateModel` and `googleAITemplateModel` for use in these tests. --- .../kotlin/com/google/firebase/ai/AIModels.kt | 16 +++ .../ai/ChatTemplateIntegrationTests.kt | 113 ++++++++++++++++++ .../firebase/ai/TemplateIntegrationTests.kt | 82 +++++++++++++ 3 files changed, 211 insertions(+) create mode 100644 ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ChatTemplateIntegrationTests.kt create mode 100644 ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TemplateIntegrationTests.kt diff --git a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt index b2627e2b8e4..4e6bb018dac 100644 --- a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt +++ b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt @@ -18,7 +18,9 @@ package com.google.firebase.ai import androidx.test.platform.app.InstrumentationRegistry import com.google.firebase.FirebaseApp import com.google.firebase.ai.type.GenerativeBackend +import com.google.firebase.ai.type.PublicPreviewAPI +@OptIn(PublicPreviewAPI::class) class AIModels { companion object { @@ -31,6 +33,8 @@ class AIModels { lateinit var vertexAIFlashLiteModel: GenerativeModel lateinit var googleAIFlashModel: GenerativeModel lateinit var googleAIFlashLiteModel: GenerativeModel + lateinit var vertexAITemplateModel: TemplateGenerativeModel + lateinit var googleAITemplateModel: TemplateGenerativeModel /** Returns a list of general purpose models to test */ fun getModels(): List { @@ -45,6 +49,14 @@ class AIModels { ) } + /** Returns a list of template models to test */ + fun getTemplateModels(): List { + if (app == null) { + setup() + } + return listOf(vertexAITemplateModel, googleAITemplateModel) + } + fun app(): FirebaseApp { if (app == null) { setup() @@ -75,6 +87,10 @@ class AIModels { .generativeModel( modelName = "gemini-2.5-flash-lite", ) + vertexAITemplateModel = + FirebaseAI.getInstance(app!!, GenerativeBackend.vertexAI()).templateGenerativeModel() + googleAITemplateModel = + FirebaseAI.getInstance(app!!, GenerativeBackend.googleAI()).templateGenerativeModel() } } } diff --git a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ChatTemplateIntegrationTests.kt b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ChatTemplateIntegrationTests.kt new file mode 100644 index 00000000000..2971b5fc32c --- /dev/null +++ b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ChatTemplateIntegrationTests.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai + +import com.google.firebase.ai.AIModels.Companion.getTemplateModels +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.content +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContainIgnoringCase +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.junit.Test + +@OptIn(PublicPreviewAPI::class) +class ChatTemplateIntegrationTests { + /* + * Template used in these tests + * + * model: "" + * input: + * schema: + * customerName: string, the name of the customer + * topic: string, problem to solve + * config: + * temperature: 0.1 + * topK: 10 + * topP: 0.8 + * + * ----- + * {{role "system"}} + * You're a customer service agent, but you've been trained to only respond in very short + * sentences. Be succinct, and to the point. No more than 5 words per response + * + * {{role "user"}} + * + * Hello, {{customerName}} + * + * Let's talk about {{topic}} + * {{history}} + */ + private val templateId = "chat-test-template" + + private val customerName = "John Doe" + + private val topic = "Firebase" + private val inputs = mapOf("customerName" to customerName, "topic" to topic) + + @Test + fun testTemplateChat_sendMessage() { + for (model in getTemplateModels()) { + runBlocking { + val chat = model.startChat(templateId, inputs) + val response = chat.sendMessage("which number is higher, one or ten?") + + response.candidates.isNotEmpty() shouldBe true + response.text shouldContainIgnoringCase "ten" + + chat.history.size shouldBe 2 + } + } + } + + @Test + fun testTemplateChat_sendMessageStream() { + for (model in getTemplateModels()) { + runBlocking { + val chat = model.startChat(templateId, inputs) + val responses = chat.sendMessageStream("which number is higher, one or ten?").toList() + responses.isNotEmpty() shouldBe true + responses.joinToString { it.text ?: "" } shouldContainIgnoringCase "ten" + chat.history.size shouldBe 2 + } + } + } + + @Test + fun testTemplateChat_withHistory() { + for (model in getTemplateModels()) { + runBlocking { + val history = + listOf( + content("user") { text("which number is higher, one or ten?") }, + content("model") { text("Ten.") } + ) + val chat = model.startChat(templateId, inputs, history) + chat.history.size shouldBe 2 + val response = + chat.sendMessage( + "Please concatenate them both, first the smaller one, then the bigger one." + ) + + response.candidates.isNotEmpty() shouldBe true + response.text shouldContainIgnoringCase "oneten" + + chat.history.size shouldBe 4 + } + } + } +} diff --git a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TemplateIntegrationTests.kt b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TemplateIntegrationTests.kt new file mode 100644 index 00000000000..6f9b595aeb9 --- /dev/null +++ b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TemplateIntegrationTests.kt @@ -0,0 +1,82 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai + +import com.google.firebase.ai.AIModels.Companion.getTemplateModels +import com.google.firebase.ai.type.PublicPreviewAPI +import io.kotest.matchers.collections.shouldNotBeEmpty +import io.kotest.matchers.string.shouldContainIgnoringCase +import io.kotest.matchers.string.shouldNotBeEmpty +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.junit.Test + +@OptIn(PublicPreviewAPI::class) +class TemplateIntegrationTests { + /* + * Template used in these tests + * + * model: "" + * input: + * schema: + * customerName: string, the name of the customer + * topic: string, problem to solve + * config: + * temperature: 0.1 + * topK: 10 + * topP: 0.8 + * + * ----- + * + * Repeat back with "{{customerName}} - {{topic}}" and just that + */ + private val templateId = "test-template" + + private val customerName = "John Doe" + + private val topic = "Firebase" + private val inputs = mapOf("customerName" to customerName, "topic" to topic) + + @Test + fun testTemplateGenerateContent() { + for (model in getTemplateModels()) { + runBlocking { + val response = model.generateContent(templateId, inputs) + + response.candidates.shouldNotBeEmpty() + response.text shouldContainIgnoringCase customerName + response.text shouldContainIgnoringCase topic + } + } + } + + @Test + fun testTemplateGenerateContentStream() { + for (model in getTemplateModels()) { + runBlocking { + val responses = model.generateContentStream(templateId, inputs).toList() + responses + .joinToString { it.text ?: "" } + .lowercase() + .let { + it shouldContainIgnoringCase customerName + it shouldContainIgnoringCase topic + } + } + } + } +}