From 777a2f61a4a8c4054535d129257181d18610dec7 Mon Sep 17 00:00:00 2001 From: Rosario Fernandes Date: Thu, 28 May 2026 11:29:51 +0100 Subject: [PATCH 1/2] feat(ai-logic): add automatic function calling to LiveGenerativeModel --- .../google/firebase/ai/LiveGenerativeModel.kt | 60 ++++++++++++++++++- .../google/firebase/ai/type/LiveSession.kt | 12 +++- 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt index 4b248d01f40..5fa381dd16a 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt @@ -20,7 +20,11 @@ import com.google.firebase.FirebaseApp import com.google.firebase.ai.common.APIController import com.google.firebase.ai.common.AppCheckHeaderProvider import com.google.firebase.ai.common.JSON +import com.google.firebase.ai.type.AutoFunctionDeclaration import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.FirebaseAutoFunctionException +import com.google.firebase.ai.type.FunctionCallPart +import com.google.firebase.ai.type.FunctionResponsePart import com.google.firebase.ai.type.GenerativeBackend import com.google.firebase.ai.type.LiveClientSetupMessage import com.google.firebase.ai.type.LiveGenerationConfig @@ -40,8 +44,11 @@ import io.ktor.websocket.readBytes import kotlin.coroutines.CoroutineContext import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.encodeToString import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.jsonObject /** * Represents a multimodal model (like Gemini) capable of real-time content generation based on @@ -140,7 +147,9 @@ internal constructor( session = webSession, blockingDispatcher = blockingDispatcher, firebaseApp = firebaseApp, - connectionFactory = connectFactory + connectionFactory = connectFactory, + hasFunction = ::hasFunction, + executeFunction = ::executeFunction ) } catch (e: ClosedReceiveChannelException) { val reason = webSession?.closeReason?.await() @@ -150,6 +159,55 @@ internal constructor( } } + internal fun hasFunction(call: FunctionCallPart): Boolean { + return tools + .flatMap { it.autoFunctionDeclarations ?: emptyList() } + .firstOrNull { it.name == call.name && it.functionReference != null } != null + } + + @OptIn(InternalSerializationApi::class) + internal suspend fun executeFunction(call: FunctionCallPart): FunctionResponsePart { + if (tools.isEmpty()) { + throw RuntimeException("No registered tools") + } + val tool = tools.flatMap { it.autoFunctionDeclarations ?: emptyList() } + val declaration = + tool.firstOrNull { it.name == call.name } + ?: throw RuntimeException("No registered function named ${call.name}") + return executeFunction( + call, + declaration as AutoFunctionDeclaration, + JsonObject(call.args).toString() + ) + } + + @OptIn(InternalSerializationApi::class) + internal suspend fun executeFunction( + functionCall: FunctionCallPart, + functionDeclaration: AutoFunctionDeclaration, + parameter: String + ): FunctionResponsePart { + val inputDeserializer = functionDeclaration.inputSchema.getSerializer() + val input = JSON.decodeFromString(inputDeserializer, parameter) + val functionReference = + functionDeclaration.functionReference + ?: throw RuntimeException("Function reference for ${functionDeclaration.name} is missing") + try { + val output = functionReference.invoke(input) + val outputSerializer = functionDeclaration.outputSchema?.getSerializer() + if (outputSerializer != null) { + return FunctionResponsePart.from( + JSON.encodeToJsonElement(outputSerializer, output).jsonObject + ) + .normalizeAgainstCall(functionCall) + } + return (output as FunctionResponsePart).normalizeAgainstCall(functionCall) + } catch (e: FirebaseAutoFunctionException) { + return FunctionResponsePart.from(JsonObject(mapOf("error" to JsonPrimitive(e.message)))) + .normalizeAgainstCall(functionCall) + } + } + private companion object { private val TAG = LiveGenerativeModel::class.java.simpleName } diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt index 7aac4b73da0..23a549607a1 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt @@ -75,7 +75,9 @@ internal constructor( private val firebaseApp: FirebaseApp, private val connectionFactory: (suspend (SessionResumptionConfig?) -> DefaultClientWebSocketSession)? = - null + null, + private val hasFunction: ((FunctionCallPart) -> Boolean)? = null, + private val executeFunction: (suspend (FunctionCallPart) -> FunctionResponsePart)? = null ) { /** * Coroutine scope that we batch data on for network related behavior. @@ -580,6 +582,14 @@ internal constructor( // It's fine to suspend here since you can't have a function call running concurrently // with an audio response sendFunctionResponse(it.functionCalls.map(functionCallHandler).toList()) + } else if ( + hasFunction != null && + executeFunction != null && + it.functionCalls.all { f -> hasFunction.invoke(f) } + ) { + sendFunctionResponse( + it.functionCalls.map { call -> executeFunction.invoke(call) }.toList() + ) } else { Log.w( TAG, From e5259b298267d7fb5bb77a0133f11f8ee30bc3a4 Mon Sep 17 00:00:00 2001 From: Rosario Fernandes Date: Thu, 28 May 2026 12:03:45 +0100 Subject: [PATCH 2/2] add changelog entry --- ai-logic/firebase-ai/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/ai-logic/firebase-ai/CHANGELOG.md b/ai-logic/firebase-ai/CHANGELOG.md index 14d8c925401..c35533c0830 100644 --- a/ai-logic/firebase-ai/CHANGELOG.md +++ b/ai-logic/firebase-ai/CHANGELOG.md @@ -2,6 +2,7 @@ - [feature] Added the `retrievalConfig` argument to `TemplateToolConfig` (#8107) - [fixed] Fixed citation indices to be native UTF-16 instead of UTF-8. (#8056) +- [feature] Added automatic function calling support for `LiveGenerativeModel`. (#8223) # 17.12.0