Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ai-logic/firebase-ai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- [feature] Added the `retrievalConfig` argument to `TemplateToolConfig` (#8107)
- [fixed] Fixed citation indices to be native UTF-16 instead of UTF-8. (#8056)
- [feature] Added automatic function calling support for `LiveGenerativeModel`. (#8223)

# 17.12.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ import com.google.firebase.FirebaseApp
import com.google.firebase.ai.common.APIController
import com.google.firebase.ai.common.AppCheckHeaderProvider
import com.google.firebase.ai.common.JSON
import com.google.firebase.ai.type.AutoFunctionDeclaration
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.FirebaseAutoFunctionException
import com.google.firebase.ai.type.FunctionCallPart
import com.google.firebase.ai.type.FunctionResponsePart
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.LiveClientSetupMessage
import com.google.firebase.ai.type.LiveGenerationConfig
Expand All @@ -40,8 +44,11 @@ import io.ktor.websocket.readBytes
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.jsonObject

/**
* Represents a multimodal model (like Gemini) capable of real-time content generation based on
Expand Down Expand Up @@ -140,7 +147,9 @@ internal constructor(
session = webSession,
blockingDispatcher = blockingDispatcher,
firebaseApp = firebaseApp,
connectionFactory = connectFactory
connectionFactory = connectFactory,
hasFunction = ::hasFunction,
executeFunction = ::executeFunction
)
} catch (e: ClosedReceiveChannelException) {
val reason = webSession?.closeReason?.await()
Expand All @@ -150,6 +159,55 @@ internal constructor(
}
}

internal fun hasFunction(call: FunctionCallPart): Boolean {
return tools
.flatMap { it.autoFunctionDeclarations ?: emptyList() }
.firstOrNull { it.name == call.name && it.functionReference != null } != null
}

@OptIn(InternalSerializationApi::class)
internal suspend fun executeFunction(call: FunctionCallPart): FunctionResponsePart {
if (tools.isEmpty()) {
throw RuntimeException("No registered tools")
}
val tool = tools.flatMap { it.autoFunctionDeclarations ?: emptyList() }
val declaration =
tool.firstOrNull { it.name == call.name }
?: throw RuntimeException("No registered function named ${call.name}")
return executeFunction<Any, Any>(
call,
declaration as AutoFunctionDeclaration<Any, Any>,
JsonObject(call.args).toString()
)
}

@OptIn(InternalSerializationApi::class)
internal suspend fun <I : Any, O : Any> executeFunction(
functionCall: FunctionCallPart,
functionDeclaration: AutoFunctionDeclaration<I, O>,
parameter: String
): FunctionResponsePart {
val inputDeserializer = functionDeclaration.inputSchema.getSerializer()
val input = JSON.decodeFromString(inputDeserializer, parameter)
val functionReference =
functionDeclaration.functionReference
?: throw RuntimeException("Function reference for ${functionDeclaration.name} is missing")
try {
val output = functionReference.invoke(input)
val outputSerializer = functionDeclaration.outputSchema?.getSerializer()
if (outputSerializer != null) {
return FunctionResponsePart.from(
JSON.encodeToJsonElement(outputSerializer, output).jsonObject
)
.normalizeAgainstCall(functionCall)
}
return (output as FunctionResponsePart).normalizeAgainstCall(functionCall)
} catch (e: FirebaseAutoFunctionException) {
return FunctionResponsePart.from(JsonObject(mapOf("error" to JsonPrimitive(e.message))))
.normalizeAgainstCall(functionCall)
}
}

private companion object {
private val TAG = LiveGenerativeModel::class.java.simpleName
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ internal constructor(
private val firebaseApp: FirebaseApp,
private val connectionFactory:
(suspend (SessionResumptionConfig?) -> DefaultClientWebSocketSession)? =
null
null,
private val hasFunction: ((FunctionCallPart) -> Boolean)? = null,
private val executeFunction: (suspend (FunctionCallPart) -> FunctionResponsePart)? = null
) {
/**
* Coroutine scope that we batch data on for network related behavior.
Expand Down Expand Up @@ -580,6 +582,14 @@ internal constructor(
// It's fine to suspend here since you can't have a function call running concurrently
// with an audio response
sendFunctionResponse(it.functionCalls.map(functionCallHandler).toList())
} else if (
hasFunction != null &&
executeFunction != null &&
it.functionCalls.all { f -> hasFunction.invoke(f) }
) {
sendFunctionResponse(
it.functionCalls.map { call -> executeFunction.invoke(call) }.toList()
)
} else {
Log.w(
TAG,
Expand Down
Loading