diff --git a/ai-logic/firebase-ai/CHANGELOG.md b/ai-logic/firebase-ai/CHANGELOG.md index efd9beeb4ca..ad7be4a8513 100644 --- a/ai-logic/firebase-ai/CHANGELOG.md +++ b/ai-logic/firebase-ai/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased +- [feature] Added support for Chat interactions using server prompt templates (#7986) - [fixed] Fixed an issue causing network timeouts to throw the incorrect exception type, instead of `RequestTimeoutException` (#7966) - [fixed] Fixed missing `toString()` implemenation for `InferenceSource` (#7970) diff --git a/ai-logic/firebase-ai/api.txt b/ai-logic/firebase-ai/api.txt index 60d5e75fcc0..188a9178b75 100644 --- a/ai-logic/firebase-ai/api.txt +++ b/ai-logic/firebase-ai/api.txt @@ -137,9 +137,19 @@ package com.google.firebase.ai { public static final class OnDeviceConfig.Companion { } + @com.google.firebase.ai.type.PublicPreviewAPI public final class TemplateChat { + method public java.util.List getHistory(); + method public suspend Object? sendMessage(com.google.firebase.ai.type.Content prompt, kotlin.coroutines.Continuation); + method public suspend Object? sendMessage(String prompt, kotlin.coroutines.Continuation); + method public kotlinx.coroutines.flow.Flow sendMessageStream(com.google.firebase.ai.type.Content prompt); + method public kotlinx.coroutines.flow.Flow sendMessageStream(String prompt); + property public final java.util.List history; + } + @com.google.firebase.ai.type.PublicPreviewAPI public final class TemplateGenerativeModel { method public suspend Object? generateContent(String templateId, java.util.Map inputs, kotlin.coroutines.Continuation); method public kotlinx.coroutines.flow.Flow generateContentStream(String templateId, java.util.Map inputs); + method @com.google.firebase.ai.type.PublicPreviewAPI public com.google.firebase.ai.TemplateChat startChat(String templateId, java.util.Map inputs, java.util.List history = emptyList()); } @com.google.firebase.ai.type.PublicPreviewAPI public final class TemplateImagenModel { diff --git a/ai-logic/firebase-ai/gradle.properties b/ai-logic/firebase-ai/gradle.properties index 47a250ed3b9..9c7e2f88876 100644 --- a/ai-logic/firebase-ai/gradle.properties +++ b/ai-logic/firebase-ai/gradle.properties @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -version=17.10.2 +version=17.11.0 latestReleasedVersion=17.10.1 diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateChat.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateChat.kt new file mode 100644 index 00000000000..a0af12cbf58 --- /dev/null +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateChat.kt @@ -0,0 +1,122 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai + +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.GenerateContentResponse +import com.google.firebase.ai.type.InvalidStateException +import com.google.firebase.ai.type.Part +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.content +import java.util.concurrent.Semaphore +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.onCompletion +import kotlinx.coroutines.flow.onEach + +/** Representation of a multi-turn interaction with a server template model. */ +@PublicPreviewAPI +public class TemplateChat +internal constructor( + private val model: TemplateGenerativeModel, + private val templateId: String, + private val inputs: Map, + public val history: MutableList = ArrayList() +) { + private var lock = Semaphore(1) + + /** + * Sends a message using the provided [prompt]; automatically providing the existing [history] as + * context. + * + * @param prompt The input that, together with the history, will be given to the model as the + * prompt. + */ + public suspend fun sendMessage(prompt: Content): GenerateContentResponse { + prompt.assertComesFromUser() + attemptLock() + try { + return model.generateContentWithHistory(templateId, inputs, history + prompt).also { resp -> + history.add(prompt) + history.add(resp.candidates.first().content) + } + } finally { + lock.release() + } + } + + /** + * Sends a message using the provided text [prompt]; automatically providing the existing + * [history] as context. + */ + public suspend fun sendMessage(prompt: String): GenerateContentResponse { + val content = content { text(prompt) } + return sendMessage(content) + } + + /** + * Sends a message using the provided [prompt]; automatically providing the existing [history] as + * context. Returns a flow. + */ + public fun sendMessageStream(prompt: Content): Flow { + prompt.assertComesFromUser() + attemptLock() + + val fullPrompt = history + prompt + val flow = model.generateContentWithHistoryStream(templateId, inputs, fullPrompt) + val tempHistory = mutableListOf() + val responseParts = mutableListOf() + + return flow + .onEach { response -> + response.candidates.first().content.parts.let { responseParts.addAll(it) } + } + .onCompletion { + lock.release() + if (it == null) { + tempHistory.add(prompt) + tempHistory.add( + content("model") { responseParts.forEach { part -> this.parts.add(part) } } + ) + history.addAll(tempHistory) + } + } + } + + /** + * Sends a message using the provided text [prompt]; automatically providing the existing + * [history] as context. Returns a flow. + */ + public fun sendMessageStream(prompt: String): Flow { + val content = content { text(prompt) } + return sendMessageStream(content) + } + + private fun Content.assertComesFromUser() { + if (role !in listOf("user", "function")) { + throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.") + } + } + + private fun attemptLock() { + if (!lock.tryAcquire()) { + throw InvalidStateException( + "This chat instance currently has an ongoing request, please wait for it to complete " + + "before sending more messages" + ) + } + } +} diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt index 8672813fff3..81107d41ee4 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt @@ -86,10 +86,27 @@ internal constructor( public suspend fun generateContent( templateId: String, inputs: Map, + ): GenerateContentResponse = generateContentWithHistory(templateId, inputs, null) + + /** + * Generates content from a prompt template and inputs. + * + * @param templateId The ID of the prompt template to use. + * @param inputs A map of variables to substitute into the template. + * @param history Prior history in the conversation. + * @return The content generated by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + @PublicPreviewAPI + internal suspend fun generateContentWithHistory( + templateId: String, + inputs: Map, + history: List? ): GenerateContentResponse = try { controller - .templateGenerateContent("$templateUri$templateId", constructRequest(inputs)) + .templateGenerateContent("$templateUri$templateId", constructRequest(inputs, history)) .toPublic() .validate() } catch (e: Throwable) { @@ -108,12 +125,44 @@ internal constructor( public fun generateContentStream( templateId: String, inputs: Map + ): Flow = generateContentWithHistoryStream(templateId, inputs, null) + + /** + * Generates content as a stream from a prompt template, inputs, and history. + * + * @param templateId The ID of the prompt template to use. + * @param inputs A map of variables to substitute into the template. + * @param history Prior history in the conversation. + * @return A [Flow] which will emit responses as they are returned by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + @PublicPreviewAPI + internal fun generateContentWithHistoryStream( + templateId: String, + inputs: Map, + history: List? ): Flow = controller - .templateGenerateContentStream("$templateUri$templateId", constructRequest(inputs)) + .templateGenerateContentStream("$templateUri$templateId", constructRequest(inputs, history)) .catch { throw FirebaseAIException.from(it) } .map { it.toPublic().validate() } + /** + * Creates a [TemplateChat] instance using this model with the optionally provided history. + * + * @param templateId The ID of the prompt template to use. + * @param inputs A map of variables to substitute into the template for the session. + * @param history Prior history in the conversation. + * @return The initialized [TemplateChat] instance. + */ + @PublicPreviewAPI + public fun startChat( + templateId: String, + inputs: Map, + history: List = emptyList() + ): TemplateChat = TemplateChat(this, templateId, inputs, history.toMutableList()) + internal fun constructRequest( inputs: Map, history: List? = null diff --git a/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/TemplateChatTests.kt b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/TemplateChatTests.kt new file mode 100644 index 00000000000..ec5352e955c --- /dev/null +++ b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/TemplateChatTests.kt @@ -0,0 +1,108 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai + +import com.google.firebase.ai.type.Candidate +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.FinishReason +import com.google.firebase.ai.type.GenerateContentResponse +import com.google.firebase.ai.type.Part +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.TextPart +import com.google.firebase.ai.type.content +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldBeInstanceOf +import io.mockk.coEvery +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner + +@OptIn(PublicPreviewAPI::class) +@RunWith(RobolectricTestRunner::class) +class TemplateChatTests { + private val model = mockk() + private val templateId = "test-template" + private val inputs = mapOf("key" to "value") + + private lateinit var chat: TemplateChat + + @Before + fun setup() { + chat = TemplateChat(model, templateId, inputs) + } + + @Test + fun `sendMessage(Content) adds prompt and response to history`() = runTest { + val prompt = content("user") { text("hello") } + val responseContent = content("model") { text("hi") } + val response = createResponse(responseContent) + + coEvery { model.generateContentWithHistory(templateId, inputs, any()) } returns response + + chat.sendMessage(prompt) + + chat.history shouldHaveSize 2 + chat.history[0] shouldBeEquivalentTo prompt + chat.history[1] shouldBeEquivalentTo responseContent + } + + @Test + fun `sendMessageStream(Content) adds prompt and aggregated responses to history`() = runTest { + val prompt = content("user") { text("hello") } + val response1 = createResponse(content("model") { text("hi ") }) + val response2 = createResponse(content("model") { text("there") }) + + every { model.generateContentWithHistoryStream(templateId, inputs, any()) } returns + flowOf(response1, response2) + + val flow = chat.sendMessageStream(prompt) + flow.toList() + + chat.history shouldHaveSize 2 + chat.history[0] shouldBeEquivalentTo prompt + chat.history[1].parts shouldHaveSize 2 + chat.history[1].parts[0].shouldBeInstanceOf().text shouldBe "hi " + chat.history[1].parts[1].shouldBeInstanceOf().text shouldBe "there" + } + + private fun createResponse(content: Content): GenerateContentResponse { + return GenerateContentResponse.Internal( + listOf(Candidate.Internal(content.toInternal(), finishReason = FinishReason.Internal.STOP)) + ) + .toPublic() + } + + private infix fun Content.shouldBeEquivalentTo(other: Content) { + this.role shouldBe other.role + this.parts shouldHaveSize other.parts.size + this.parts.zip(other.parts).forEach { (a, b) -> a.shouldBeEquivalentTo(b) } + } + + private fun Part.shouldBeEquivalentTo(other: Part) { + this::class shouldBe other::class + if (this is TextPart && other is TextPart) { + this.text shouldBe other.text + } + } +} diff --git a/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/TemplateGenerativeModelTests.kt b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/TemplateGenerativeModelTests.kt new file mode 100644 index 00000000000..1fc90120d4b --- /dev/null +++ b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/TemplateGenerativeModelTests.kt @@ -0,0 +1,106 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai + +import com.google.firebase.ai.common.APIController +import com.google.firebase.ai.type.Candidate +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.FinishReason +import com.google.firebase.ai.type.GenerateContentResponse +import com.google.firebase.ai.type.Part +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.TextPart +import com.google.firebase.ai.type.content +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldBeInstanceOf +import io.mockk.coEvery +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner + +@OptIn(PublicPreviewAPI::class, kotlinx.serialization.ExperimentalSerializationApi::class) +@RunWith(RobolectricTestRunner::class) +class TemplateGenerativeModelTests { + private val controller = mockk() + private val templateUri = "https://example.com/templates/" + private lateinit var model: TemplateGenerativeModel + private val templateId = "test-template" + private val inputs = mapOf("key" to "value") + + @Before + fun setup() { + model = TemplateGenerativeModel(templateUri, controller) + } + + @Test + fun `generateContentWithHistory calls controller correctly`() = runTest { + val history = listOf(content("user") { text("hello") }) + val responseContent = content("model") { text("hi") } + val responseInternal = createResponseInternal(responseContent) + + coEvery { controller.templateGenerateContent(any(), any()) } returns responseInternal + + val actualResponse = model.generateContentWithHistory(templateId, inputs, history) + + actualResponse.candidates.first().content shouldBeEquivalentTo responseContent + } + + @Test + fun `generateContentWithHistoryStream calls controller correctly`() = runTest { + val history = listOf(content("user") { text("hello") }) + val response1 = createResponseInternal(content("model") { text("hi ") }) + val response2 = createResponseInternal(content("model") { text("there") }) + + every { controller.templateGenerateContentStream(any(), any()) } returns + flowOf(response1, response2) + + val flow = model.generateContentWithHistoryStream(templateId, inputs, history) + val responses = flow.toList() + + responses shouldHaveSize 2 + responses[0].candidates.first().content.parts[0].shouldBeInstanceOf().text shouldBe + "hi " + responses[1].candidates.first().content.parts[0].shouldBeInstanceOf().text shouldBe + "there" + } + + private fun createResponseInternal(content: Content): GenerateContentResponse.Internal { + return GenerateContentResponse.Internal( + listOf(Candidate.Internal(content.toInternal(), finishReason = FinishReason.Internal.STOP)) + ) + } + + private infix fun Content.shouldBeEquivalentTo(other: Content) { + this.role shouldBe other.role + this.parts shouldHaveSize other.parts.size + this.parts.zip(other.parts).forEach { (a, b) -> a.shouldBeEquivalentTo(b) } + } + + private fun Part.shouldBeEquivalentTo(other: Part) { + this::class shouldBe other::class + if (this is TextPart && other is TextPart) { + this.text shouldBe other.text + } + } +}