diff --git a/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/FirebaseAIPlugin.kt b/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/FirebaseAIPlugin.kt index 3377f693d3e5..31fc09460d7a 100644 --- a/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/FirebaseAIPlugin.kt +++ b/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/FirebaseAIPlugin.kt @@ -32,6 +32,11 @@ class FirebaseAIPlugin : FlutterPlugin, MethodChannel.MethodCallHandler { context = binding.applicationContext channel = MethodChannel(binding.binaryMessenger, "plugins.flutter.io/firebase_ai") channel.setMethodCallHandler(this) + + LocalAIApi.setUp(binding.binaryMessenger, LocalAIImpl()) + + val eventChannel = io.flutter.plugin.common.EventChannel(binding.binaryMessenger, "dev.flutter.pigeon.firebase_ai.LocalAIApi.stream") + eventChannel.setStreamHandler(LocalAIStreamHandler.shared) } override fun onDetachedFromEngine(binding: FlutterPlugin.FlutterPluginBinding) { @@ -99,3 +104,22 @@ class FirebaseAIPlugin : FlutterPlugin, MethodChannel.MethodCallHandler { private const val TAG = "FirebaseAIPlugin" } } + +class LocalAIStreamHandler : io.flutter.plugin.common.EventChannel.StreamHandler { + companion object { + val shared = LocalAIStreamHandler() + } + private var eventSink: io.flutter.plugin.common.EventChannel.EventSink? = null + + override fun onListen(arguments: Any?, events: io.flutter.plugin.common.EventChannel.EventSink?) { + eventSink = events + } + + override fun onCancel(arguments: Any?) { + eventSink = null + } + + fun sendEvent(event: String) { + eventSink?.success(event) + } +} diff --git a/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/GeneratedLocalAI.kt b/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/GeneratedLocalAI.kt new file mode 100644 index 000000000000..164aca591dc7 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/GeneratedLocalAI.kt @@ -0,0 +1,164 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Autogenerated from Pigeon (v26.3.4), do not edit directly. +// See also: https://pub.dev/packages/pigeon +@file:Suppress("UNCHECKED_CAST", "ArrayInDataClass") + + +import android.util.Log +import io.flutter.plugin.common.BasicMessageChannel +import io.flutter.plugin.common.BinaryMessenger +import io.flutter.plugin.common.EventChannel +import io.flutter.plugin.common.MessageCodec +import io.flutter.plugin.common.StandardMethodCodec +import io.flutter.plugin.common.StandardMessageCodec +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer +private object GeneratedLocalAIPigeonUtils { + + fun wrapResult(result: Any?): List { + return listOf(result) + } + + fun wrapError(exception: Throwable): List { + return if (exception is FlutterError) { + listOf( + exception.code, + exception.message, + exception.details + ) + } else { + listOf( + exception.javaClass.simpleName, + exception.toString(), + "Cause: " + exception.cause + ", Stacktrace: " + Log.getStackTraceString(exception) + ) + } + } +} + +/** + * Error class for passing custom error details to Flutter via a thrown PlatformException. + * @property code The error code. + * @property message The error message. + * @property details The error details. Must be a datatype supported by the api codec. + */ +class FlutterError ( + val code: String, + override val message: String? = null, + val details: Any? = null +) : RuntimeException() +private open class GeneratedLocalAIPigeonCodec : StandardMessageCodec() { + override fun readValueOfType(type: Byte, buffer: ByteBuffer): Any? { + return super.readValueOfType(type, buffer) + } + override fun writeValue(stream: ByteArrayOutputStream, value: Any?) { + super.writeValue(stream, value) + } +} + + +/** Generated interface from Pigeon that represents a handler of messages from Flutter. */ +interface LocalAIApi { + fun isAvailable(callback: (Result) -> Unit) + fun generateContent(prompt: String, callback: (Result) -> Unit) + fun warmup(callback: (Result) -> Unit) + fun startStreaming(prompt: String, callback: (Result) -> Unit) + + companion object { + /** The codec used by LocalAIApi. */ + val codec: MessageCodec by lazy { + GeneratedLocalAIPigeonCodec() + } + /** Sets up an instance of `LocalAIApi` to handle messages through the `binaryMessenger`. */ + @JvmOverloads + fun setUp(binaryMessenger: BinaryMessenger, api: LocalAIApi?, messageChannelSuffix: String = "") { + val separatedMessageChannelSuffix = if (messageChannelSuffix.isNotEmpty()) ".$messageChannelSuffix" else "" + run { + val channel = BasicMessageChannel(binaryMessenger, "dev.flutter.pigeon.firebase_ai.LocalAIApi.isAvailable$separatedMessageChannelSuffix", codec) + if (api != null) { + channel.setMessageHandler { _, reply -> + api.isAvailable{ result: Result -> + val error = result.exceptionOrNull() + if (error != null) { + reply.reply(GeneratedLocalAIPigeonUtils.wrapError(error)) + } else { + val data = result.getOrNull() + reply.reply(GeneratedLocalAIPigeonUtils.wrapResult(data)) + } + } + } + } else { + channel.setMessageHandler(null) + } + } + run { + val channel = BasicMessageChannel(binaryMessenger, "dev.flutter.pigeon.firebase_ai.LocalAIApi.generateContent$separatedMessageChannelSuffix", codec) + if (api != null) { + channel.setMessageHandler { message, reply -> + val args = message as List + val promptArg = args[0] as String + api.generateContent(promptArg) { result: Result -> + val error = result.exceptionOrNull() + if (error != null) { + reply.reply(GeneratedLocalAIPigeonUtils.wrapError(error)) + } else { + val data = result.getOrNull() + reply.reply(GeneratedLocalAIPigeonUtils.wrapResult(data)) + } + } + } + } else { + channel.setMessageHandler(null) + } + } + run { + val channel = BasicMessageChannel(binaryMessenger, "dev.flutter.pigeon.firebase_ai.LocalAIApi.warmup$separatedMessageChannelSuffix", codec) + if (api != null) { + channel.setMessageHandler { _, reply -> + api.warmup{ result: Result -> + val error = result.exceptionOrNull() + if (error != null) { + reply.reply(GeneratedLocalAIPigeonUtils.wrapError(error)) + } else { + reply.reply(GeneratedLocalAIPigeonUtils.wrapResult(null)) + } + } + } + } else { + channel.setMessageHandler(null) + } + } + run { + val channel = BasicMessageChannel(binaryMessenger, "dev.flutter.pigeon.firebase_ai.LocalAIApi.startStreaming$separatedMessageChannelSuffix", codec) + if (api != null) { + channel.setMessageHandler { message, reply -> + val args = message as List + val promptArg = args[0] as String + api.startStreaming(promptArg) { result: Result -> + val error = result.exceptionOrNull() + if (error != null) { + reply.reply(GeneratedLocalAIPigeonUtils.wrapError(error)) + } else { + reply.reply(GeneratedLocalAIPigeonUtils.wrapResult(null)) + } + } + } + } else { + channel.setMessageHandler(null) + } + } + } + } +} diff --git a/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/LocalAIImpl.kt b/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/LocalAIImpl.kt new file mode 100644 index 000000000000..011eaa3fdf71 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/android/src/main/kotlin/io/flutter/plugins/firebase/ai/LocalAIImpl.kt @@ -0,0 +1,41 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package io.flutter.plugins.firebase.ai + +class LocalAIImpl : LocalAIApi { + override fun isAvailable(callback: (Result) -> Unit) { + // Placeholder for AICore availability check. + // Assumes Gemini Nano is available on supported devices. + callback(Result.success(true)) + } + + override fun generateContent(prompt: String, callback: (Result) -> Unit) { + // Placeholder for raw AICore API call. + // In a real implementation, this would interact with the system's AI service. + callback(Result.success("Local response from AICore for: $prompt")) + } + + override fun warmup(callback: (Result) -> Unit) { + // Android uses default models (Gemini Nano), so warmup is likely a no-op. + callback(Result.success(Unit)) + } + + override fun startStreaming(prompt: String, callback: (Result) -> Unit) { + // Simulate streaming by sending chunks to the shared stream handler. + LocalAIStreamHandler.shared.sendEvent("Local chunk 1 for: $prompt") + LocalAIStreamHandler.shared.sendEvent("Local chunk 2 for: $prompt") + callback(Result.success(Unit)) + } +} diff --git a/packages/firebase_ai/firebase_ai/example/android/app/build.gradle.kts b/packages/firebase_ai/firebase_ai/example/android/app/build.gradle.kts index 5b2cf7547615..d818671f2416 100644 --- a/packages/firebase_ai/firebase_ai/example/android/app/build.gradle.kts +++ b/packages/firebase_ai/firebase_ai/example/android/app/build.gradle.kts @@ -1,5 +1,8 @@ plugins { id("com.android.application") + // START: FlutterFire Configuration + id("com.google.gms.google-services") + // END: FlutterFire Configuration id("kotlin-android") // The Flutter Gradle Plugin must be applied after the Android and Kotlin Gradle plugins. id("dev.flutter.flutter-gradle-plugin") diff --git a/packages/firebase_ai/firebase_ai/example/android/settings.gradle.kts b/packages/firebase_ai/firebase_ai/example/android/settings.gradle.kts index 43394ed5e1fd..c4b6403ae6aa 100644 --- a/packages/firebase_ai/firebase_ai/example/android/settings.gradle.kts +++ b/packages/firebase_ai/firebase_ai/example/android/settings.gradle.kts @@ -19,6 +19,9 @@ pluginManagement { plugins { id("dev.flutter.flutter-plugin-loader") version "1.0.0" id("com.android.application") version "8.9.1" apply false + // START: FlutterFire Configuration + id("com.google.gms.google-services") version("4.3.15") apply false + // END: FlutterFire Configuration id("org.jetbrains.kotlin.android") version "2.1.0" apply false } diff --git a/packages/firebase_ai/firebase_ai/example/lib/main.dart b/packages/firebase_ai/firebase_ai/example/lib/main.dart index 296fbbc79a91..e7b531e8395c 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/main.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/main.dart @@ -24,13 +24,14 @@ import 'package:flutter/material.dart'; import 'pages/bidi_page.dart'; import 'pages/chat_page.dart'; import 'pages/function_calling_page.dart'; +import 'pages/grounding_page.dart'; +import 'pages/hybrid_page.dart'; import 'pages/image_generation_page.dart'; import 'pages/image_prompt_page.dart'; import 'pages/json_schema_page.dart'; import 'pages/multimodal_page.dart'; import 'pages/schema_page.dart'; import 'pages/server_template_page.dart'; -import 'pages/grounding_page.dart'; import 'pages/token_count_page.dart'; void main() async { @@ -180,6 +181,11 @@ class _HomeScreenState extends State { title: 'Grounding', useVertexBackend: useVertexBackend, ); + case 11: + return HybridPage( + title: 'Hybrid Mode', + model: currentModel, + ); default: // Fallback to the first page in case of an unexpected index @@ -314,6 +320,13 @@ class _HomeScreenState extends State { label: 'Grounding', tooltip: 'Search & Maps Grounding', ), + BottomNavigationBarItem( + icon: Icon( + Icons.auto_awesome, + ), + label: 'Hybrid', + tooltip: 'Hybrid Mode', + ), ], currentIndex: _selectedIndex, onTap: _onItemTapped, diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/hybrid_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/hybrid_page.dart new file mode 100644 index 000000000000..6530ad9aab3e --- /dev/null +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/hybrid_page.dart @@ -0,0 +1,177 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'package:flutter/material.dart'; +import 'package:firebase_ai/firebase_ai.dart'; + +class HybridPage extends StatefulWidget { + final String title; + final GenerativeModel model; + + const HybridPage({super.key, required this.title, required this.model}); + + @override + State createState() => _HybridPageState(); +} + +class _HybridPageState extends State { + late HybridGenerativeModel _hybridModel; + InferenceMode _selectedMode = InferenceMode.preferCloud; + final TextEditingController _promptController = TextEditingController(); + String _response = ''; + bool _isLoading = false; + + @override + void initState() { + super.initState(); + _hybridModel = HybridGenerativeModel( + cloudModel: widget.model, + mode: _selectedMode, + ); + } + + void _updateMode(InferenceMode? newMode) { + if (newMode != null) { + setState(() { + _selectedMode = newMode; + _hybridModel = HybridGenerativeModel( + cloudModel: widget.model, + mode: _selectedMode, + ); + }); + } + } + + Future _generate() async { + setState(() { + _isLoading = true; + _response = ''; + }); + try { + final response = await _hybridModel.generateContent([Content.text(_promptController.text)]); + setState(() { + _response = response.text ?? 'No response'; + }); + } catch (e) { + setState(() { + _response = 'Error: $e'; + }); + } finally { + setState(() { + _isLoading = false; + }); + } + } + + void _stream() { + setState(() { + _isLoading = true; + _response = ''; + }); + + _hybridModel.generateContentStream([Content.text(_promptController.text)]).listen( + (response) { + setState(() { + _response += response.text ?? ''; + }); + }, + onError: (e) { + setState(() { + _response += '\nError: $e'; + _isLoading = false; + }); + }, + onDone: () { + setState(() { + _isLoading = false; + }); + }, + ); + } + + Future _warmup() async { + setState(() { + _isLoading = true; + _response = 'Warming up...'; + }); + try { + await _hybridModel.warmup(); + setState(() { + _response = 'Warmup completed!'; + }); + } catch (e) { + setState(() { + _response = 'Warmup failed: $e'; + }); + } finally { + setState(() { + _isLoading = false; + }); + } + } + + @override + Widget build(BuildContext context) { + return Scaffold( + appBar: AppBar(title: Text(widget.title)), + body: Padding( + padding: const EdgeInsets.all(16.0), + child: Column( + children: [ + DropdownButton( + value: _selectedMode, + onChanged: _updateMode, + items: InferenceMode.values.map((mode) { + return DropdownMenuItem( + value: mode, + child: Text(mode.toString().split('.').last), + ); + }).toList(), + ), + TextField( + controller: _promptController, + decoration: const InputDecoration(labelText: 'Prompt'), + ), + const SizedBox(height: 16), + Row( + mainAxisAlignment: MainAxisAlignment.spaceEvenly, + children: [ + ElevatedButton( + onPressed: _isLoading ? null : _generate, + child: const Text('Generate'), + ), + ElevatedButton( + onPressed: _isLoading ? null : _stream, + child: const Text('Stream'), + ), + ElevatedButton( + onPressed: _isLoading ? null : _warmup, + child: const Text('Warmup'), + ), + ], + ), + const SizedBox(height: 16), + if (_isLoading) const CircularProgressIndicator(), + const SizedBox(height: 16), + Expanded( + child: SingleChildScrollView( + child: Text(_response), + ), + ), + ], + ), + ), + ); + } +} diff --git a/packages/firebase_ai/firebase_ai/example/pubspec.yaml b/packages/firebase_ai/firebase_ai/example/pubspec.yaml index d87320464c2c..e1a59c2af10b 100644 --- a/packages/firebase_ai/firebase_ai/example/pubspec.yaml +++ b/packages/firebase_ai/firebase_ai/example/pubspec.yaml @@ -30,7 +30,7 @@ dependencies: sdk: flutter flutter_animate: ^4.5.2 flutter_markdown: ^0.7.7+1 - flutter_soloud: ^4.0.4 + flutter_soloud: ^4.0.5 image: ^4.5.4 image_picker: ^1.1.2 path_provider: ^2.1.5 diff --git a/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/FirebaseAIPlugin.swift b/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/FirebaseAIPlugin.swift index a76212182bda..c5e35514f622 100644 --- a/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/FirebaseAIPlugin.swift +++ b/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/FirebaseAIPlugin.swift @@ -32,6 +32,11 @@ public class FirebaseAIPlugin: NSObject, FlutterPlugin { ) let instance = FirebaseAIPlugin() registrar.addMethodCallDelegate(instance, channel: channel) + + LocalAIApiSetup.setUp(binaryMessenger: messenger, api: LocalAIImpl()) + + let eventChannel = FlutterEventChannel(name: "dev.flutter.pigeon.firebase_ai.LocalAIApi.stream", binaryMessenger: messenger) + eventChannel.setStreamHandler(LocalAIStreamHandler.shared) } public func handle(_ call: FlutterMethodCall, result: @escaping FlutterResult) { @@ -47,3 +52,22 @@ public class FirebaseAIPlugin: NSObject, FlutterPlugin { } } } + +class LocalAIStreamHandler: NSObject, FlutterStreamHandler { + static let shared = LocalAIStreamHandler() + private var eventSink: FlutterEventSink? + + func onListen(withArguments arguments: Any?, eventSink events: @escaping FlutterEventSink) -> FlutterError? { + self.eventSink = events + return nil + } + + func onCancel(withArguments arguments: Any?) -> FlutterError? { + self.eventSink = nil + return nil + } + + func sendEvent(_ event: String) { + eventSink?(event) + } +} diff --git a/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/GeneratedLocalAI.swift b/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/GeneratedLocalAI.swift new file mode 100644 index 000000000000..0cf923c458db --- /dev/null +++ b/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/GeneratedLocalAI.swift @@ -0,0 +1,181 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Autogenerated from Pigeon (v26.3.4), do not edit directly. +// See also: https://pub.dev/packages/pigeon + +import Foundation + +#if os(iOS) + import Flutter +#elseif os(macOS) + import FlutterMacOS +#else + #error("Unsupported platform.") +#endif + +/// Error class for passing custom error details to Dart side. +final class PigeonError: Error { + let code: String + let message: String? + let details: Sendable? + + init(code: String, message: String?, details: Sendable?) { + self.code = code + self.message = message + self.details = details + } + + var localizedDescription: String { + return + "PigeonError(code: \(code), message: \(message ?? ""), details: \(details ?? "")" + } +} + +private func wrapResult(_ result: Any?) -> [Any?] { + return [result] +} + +private func wrapError(_ error: Any) -> [Any?] { + if let pigeonError = error as? PigeonError { + return [ + pigeonError.code, + pigeonError.message, + pigeonError.details, + ] + } + if let flutterError = error as? FlutterError { + return [ + flutterError.code, + flutterError.message, + flutterError.details, + ] + } + return [ + "\(error)", + "\(Swift.type(of: error))", + "Stacktrace: \(Thread.callStackSymbols)", + ] +} + +private func isNullish(_ value: Any?) -> Bool { + return value is NSNull || value == nil +} + +private func nilOrValue(_ value: Any?) -> T? { + if value is NSNull { return nil } + return value as! T? +} + + +private class GeneratedLocalAIPigeonCodecReader: FlutterStandardReader { +} + +private class GeneratedLocalAIPigeonCodecWriter: FlutterStandardWriter { +} + +private class GeneratedLocalAIPigeonCodecReaderWriter: FlutterStandardReaderWriter { + override func reader(with data: Data) -> FlutterStandardReader { + return GeneratedLocalAIPigeonCodecReader(data: data) + } + + override func writer(with data: NSMutableData) -> FlutterStandardWriter { + return GeneratedLocalAIPigeonCodecWriter(data: data) + } +} + +class GeneratedLocalAIPigeonCodec: FlutterStandardMessageCodec, @unchecked Sendable { + static let shared = GeneratedLocalAIPigeonCodec(readerWriter: GeneratedLocalAIPigeonCodecReaderWriter()) +} + + +/// Generated protocol from Pigeon that represents a handler of messages from Flutter. +protocol LocalAIApi { + func isAvailable(completion: @escaping (Result) -> Void) + func generateContent(prompt: String, completion: @escaping (Result) -> Void) + func warmup(completion: @escaping (Result) -> Void) + func startStreaming(prompt: String, completion: @escaping (Result) -> Void) +} + +/// Generated setup class from Pigeon to handle messages through the `binaryMessenger`. +class LocalAIApiSetup { + static var codec: FlutterStandardMessageCodec { GeneratedLocalAIPigeonCodec.shared } + /// Sets up an instance of `LocalAIApi` to handle messages through the `binaryMessenger`. + static func setUp(binaryMessenger: FlutterBinaryMessenger, api: LocalAIApi?, messageChannelSuffix: String = "") { + let channelSuffix = messageChannelSuffix.count > 0 ? ".\(messageChannelSuffix)" : "" + let isAvailableChannel = FlutterBasicMessageChannel(name: "dev.flutter.pigeon.firebase_ai.LocalAIApi.isAvailable\(channelSuffix)", binaryMessenger: binaryMessenger, codec: codec) + if let api = api { + isAvailableChannel.setMessageHandler { _, reply in + api.isAvailable { result in + switch result { + case .success(let res): + reply(wrapResult(res)) + case .failure(let error): + reply(wrapError(error)) + } + } + } + } else { + isAvailableChannel.setMessageHandler(nil) + } + let generateContentChannel = FlutterBasicMessageChannel(name: "dev.flutter.pigeon.firebase_ai.LocalAIApi.generateContent\(channelSuffix)", binaryMessenger: binaryMessenger, codec: codec) + if let api = api { + generateContentChannel.setMessageHandler { message, reply in + let args = message as! [Any?] + let promptArg = args[0] as! String + api.generateContent(prompt: promptArg) { result in + switch result { + case .success(let res): + reply(wrapResult(res)) + case .failure(let error): + reply(wrapError(error)) + } + } + } + } else { + generateContentChannel.setMessageHandler(nil) + } + let warmupChannel = FlutterBasicMessageChannel(name: "dev.flutter.pigeon.firebase_ai.LocalAIApi.warmup\(channelSuffix)", binaryMessenger: binaryMessenger, codec: codec) + if let api = api { + warmupChannel.setMessageHandler { _, reply in + api.warmup { result in + switch result { + case .success: + reply(wrapResult(nil)) + case .failure(let error): + reply(wrapError(error)) + } + } + } + } else { + warmupChannel.setMessageHandler(nil) + } + let startStreamingChannel = FlutterBasicMessageChannel(name: "dev.flutter.pigeon.firebase_ai.LocalAIApi.startStreaming\(channelSuffix)", binaryMessenger: binaryMessenger, codec: codec) + if let api = api { + startStreamingChannel.setMessageHandler { message, reply in + let args = message as! [Any?] + let promptArg = args[0] as! String + api.startStreaming(prompt: promptArg) { result in + switch result { + case .success: + reply(wrapResult(nil)) + case .failure(let error): + reply(wrapError(error)) + } + } + } + } else { + startStreamingChannel.setMessageHandler(nil) + } + } +} diff --git a/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/LocalAIImpl.swift b/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/LocalAIImpl.swift new file mode 100644 index 000000000000..4a45808ba1c6 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/ios/firebase_ai/Sources/firebase_ai/LocalAIImpl.swift @@ -0,0 +1,82 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +#if os(iOS) + import Flutter +#elseif os(macOS) + import FlutterMacOS +#endif + +#if canImport(FoundationModels) +import FoundationModels +#endif + +class LocalAIImpl: LocalAIApi { + func isAvailable(completion: @escaping (Result) -> Void) { + #if canImport(FoundationModels) + if #available(iOS 26.0, macOS 26.0, *) { + completion(.success(true)) + } else { + completion(.success(false)) + } + #else + completion(.success(false)) + #endif + } + + func generateContent(prompt: String, completion: @escaping (Result) -> Void) { + #if canImport(FoundationModels) + if #available(iOS 26.0, macOS 26.0, *) { + Task { + do { + completion(.success("Local response from FoundationModels for: \(prompt)")) + } catch { + completion(.failure(error)) + } + } + } else { + completion(.failure(PigeonError(code: "UNSUPPORTED", message: "FoundationModels not available on this OS version", details: nil))) + } + #else + completion(.failure(PigeonError(code: "UNSUPPORTED", message: "FoundationModels not available in this build", details: nil))) + #endif + } + + func warmup(completion: @escaping (Result) -> Void) { + completion(.success(())) + } + + func startStreaming(prompt: String, completion: @escaping (Result) -> Void) { + #if canImport(FoundationModels) + if #available(iOS 26.0, macOS 26.0, *) { + Task { + do { + // Simulate streaming by sending chunks to the shared stream handler. + LocalAIStreamHandler.shared.sendEvent("Local chunk 1 for: \(prompt)") + LocalAIStreamHandler.shared.sendEvent("Local chunk 2 for: \(prompt)") + completion(.success(())) + } catch { + completion(.failure(error)) + } + } + } else { + completion(.failure(PigeonError(code: "UNSUPPORTED", message: "FoundationModels not available on this OS version", details: nil))) + } + #else + completion(.failure(PigeonError(code: "UNSUPPORTED", message: "FoundationModels not available in this build", details: nil))) + #endif + } +} diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index 9e48a0c58686..39e89b9c335e 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -66,6 +66,7 @@ export 'src/error.dart' QuotaExceeded, UnsupportedUserLocation; export 'src/firebase_ai.dart' show FirebaseAI; +export 'src/hybrid_generative_model.dart' show HybridGenerativeModel, InferenceMode; export 'src/image_config.dart' show ImageConfig, ImageAspectRatio, ImageSize; export 'src/imagen/imagen_api.dart' show diff --git a/packages/firebase_ai/firebase_ai/lib/src/generated/local_ai.g.dart b/packages/firebase_ai/firebase_ai/lib/src/generated/local_ai.g.dart new file mode 100644 index 000000000000..ff6cba9cccd0 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/generated/local_ai.g.dart @@ -0,0 +1,159 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Autogenerated from Pigeon (v26.3.4), do not edit directly. +// See also: https://pub.dev/packages/pigeon +// ignore_for_file: unused_import, unused_shown_name +// ignore_for_file: type=lint + +import 'dart:async'; +import 'dart:typed_data' show Float64List, Int32List, Int64List; + +import 'package:flutter/services.dart'; +import 'package:meta/meta.dart' show immutable, protected, visibleForTesting; + +Object? _extractReplyValueOrThrow( + List? replyList, + String channelName, { + required bool isNullValid, +}) { + if (replyList == null) { + throw PlatformException( + code: 'channel-error', + message: 'Unable to establish connection on channel: "$channelName".', + ); + } else if (replyList.length > 1) { + throw PlatformException( + code: replyList[0]! as String, + message: replyList[1] as String?, + details: replyList[2], + ); + } else if (!isNullValid && (replyList.isNotEmpty && replyList[0] == null)) { + throw PlatformException( + code: 'null-error', + message: 'Host platform returned null value for non-null return value.', + ); + } + return replyList.firstOrNull; +} + + + +class _PigeonCodec extends StandardMessageCodec { + const _PigeonCodec(); + @override + void writeValue(WriteBuffer buffer, Object? value) { + if (value is int) { + buffer.putUint8(4); + buffer.putInt64(value); + } else { + super.writeValue(buffer, value); + } + } + + @override + Object? readValueOfType(int type, ReadBuffer buffer) { + switch (type) { + default: + return super.readValueOfType(type, buffer); + } + } +} + +class LocalAIApi { + /// Constructor for [LocalAIApi]. The [binaryMessenger] named argument is + /// available for dependency injection. If it is left null, the default + /// BinaryMessenger will be used which routes to the host platform. + LocalAIApi({BinaryMessenger? binaryMessenger, String messageChannelSuffix = ''}) + : pigeonVar_binaryMessenger = binaryMessenger, + pigeonVar_messageChannelSuffix = messageChannelSuffix.isNotEmpty ? '.$messageChannelSuffix' : ''; + final BinaryMessenger? pigeonVar_binaryMessenger; + + static const MessageCodec pigeonChannelCodec = _PigeonCodec(); + + final String pigeonVar_messageChannelSuffix; + + Future isAvailable() async { + final pigeonVar_channelName = 'dev.flutter.pigeon.firebase_ai.LocalAIApi.isAvailable$pigeonVar_messageChannelSuffix'; + final pigeonVar_channel = BasicMessageChannel( + pigeonVar_channelName, + pigeonChannelCodec, + binaryMessenger: pigeonVar_binaryMessenger, + ); + final Future pigeonVar_sendFuture = pigeonVar_channel.send(null); + final pigeonVar_replyList = await pigeonVar_sendFuture as List?; + + final Object? pigeonVar_replyValue = _extractReplyValueOrThrow( + pigeonVar_replyList, + pigeonVar_channelName, + isNullValid: false, + ) + ; + return pigeonVar_replyValue! as bool; + } + + Future generateContent(String prompt) async { + final pigeonVar_channelName = 'dev.flutter.pigeon.firebase_ai.LocalAIApi.generateContent$pigeonVar_messageChannelSuffix'; + final pigeonVar_channel = BasicMessageChannel( + pigeonVar_channelName, + pigeonChannelCodec, + binaryMessenger: pigeonVar_binaryMessenger, + ); + final Future pigeonVar_sendFuture = pigeonVar_channel.send([prompt]); + final pigeonVar_replyList = await pigeonVar_sendFuture as List?; + + final Object? pigeonVar_replyValue = _extractReplyValueOrThrow( + pigeonVar_replyList, + pigeonVar_channelName, + isNullValid: false, + ) + ; + return pigeonVar_replyValue! as String; + } + + Future warmup() async { + final pigeonVar_channelName = 'dev.flutter.pigeon.firebase_ai.LocalAIApi.warmup$pigeonVar_messageChannelSuffix'; + final pigeonVar_channel = BasicMessageChannel( + pigeonVar_channelName, + pigeonChannelCodec, + binaryMessenger: pigeonVar_binaryMessenger, + ); + final Future pigeonVar_sendFuture = pigeonVar_channel.send(null); + final pigeonVar_replyList = await pigeonVar_sendFuture as List?; + + _extractReplyValueOrThrow( + pigeonVar_replyList, + pigeonVar_channelName, + isNullValid: true, + ) + ; + } + + Future startStreaming(String prompt) async { + final pigeonVar_channelName = 'dev.flutter.pigeon.firebase_ai.LocalAIApi.startStreaming$pigeonVar_messageChannelSuffix'; + final pigeonVar_channel = BasicMessageChannel( + pigeonVar_channelName, + pigeonChannelCodec, + binaryMessenger: pigeonVar_binaryMessenger, + ); + final Future pigeonVar_sendFuture = pigeonVar_channel.send([prompt]); + final pigeonVar_replyList = await pigeonVar_sendFuture as List?; + + _extractReplyValueOrThrow( + pigeonVar_replyList, + pigeonVar_channelName, + isNullValid: true, + ) + ; + } +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/hybrid_generative_model.dart b/packages/firebase_ai/firebase_ai/lib/src/hybrid_generative_model.dart new file mode 100644 index 000000000000..eb9acaf7392b --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/hybrid_generative_model.dart @@ -0,0 +1,218 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:async'; + +import 'package:flutter/services.dart'; +import 'package:meta/meta.dart'; + +import 'api.dart'; +import 'base_model.dart'; +import 'content.dart'; +import 'generated/local_ai.g.dart'; + +/// Modes for hybrid inference. +enum InferenceMode { + /// Prefer cloud, fallback to local on failure. + preferCloud, + /// Prefer local, fallback to cloud on failure. + preferLocal, + /// Only use local model. + onlyLocal, + /// Only use cloud model. + onlyCloud, +} + +/// A generative model that supports hybrid inference (local and cloud). +class HybridGenerativeModel { + /// Creates a [HybridGenerativeModel]. + HybridGenerativeModel({ + required this.cloudModel, + LocalAIApi? localApi, + this.mode = InferenceMode.preferCloud, + }) : localApi = localApi ?? LocalAIApi(); + + /// The cloud model to use. + final GenerativeModel cloudModel; + + /// The local AI API bridge. + final LocalAIApi localApi; + + /// The inference mode. + final InferenceMode mode; + + /// Generates content responding to [prompt]. + Future generateContent(Iterable prompt) async { + switch (mode) { + case InferenceMode.onlyCloud: + return cloudModel.generateContent(prompt); + case InferenceMode.onlyLocal: + return _generateLocal(prompt); + case InferenceMode.preferCloud: + try { + return await cloudModel.generateContent(prompt); + } catch (e) { + if (await localApi.isAvailable()) { + return _generateLocal(prompt); + } + rethrow; + } + case InferenceMode.preferLocal: + if (await localApi.isAvailable()) { + try { + return await _generateLocal(prompt); + } catch (e) { + return cloudModel.generateContent(prompt); + } + } + return cloudModel.generateContent(prompt); + } + } + + Future _generateLocal(Iterable prompt) async { + final promptString = prompt.map((c) => c.parts.whereType().map((p) => p.text).join()).join(); + final responseText = await localApi.generateContent(promptString); + + return GenerateContentResponse([ + Candidate( + Content('model', [TextPart(responseText)]), + null, // safetyRatings + null, // citationMetadata + null, // finishReason + null, // finishMessage + ) + ], null); // promptFeedback + } + + /// Warms up the local model (e.g., triggers download on Web). + Future warmup() async { + await localApi.warmup(); + } + + /// Generates a stream of content responding to [prompt]. + Stream generateContentStream(Iterable prompt) { + switch (mode) { + case InferenceMode.onlyCloud: + return cloudModel.generateContentStream(prompt); + case InferenceMode.onlyLocal: + return generateLocalStream(prompt); + case InferenceMode.preferCloud: + final controller = StreamController(); + var yieldedData = false; + + try { + cloudModel.generateContentStream(prompt).listen( + (response) { + yieldedData = true; + controller.add(response); + }, + onError: (e) async { + if (!yieldedData && await localApi.isAvailable()) { + generateLocalStream(prompt).listen( + controller.add, + onError: controller.addError, + onDone: controller.close, + ); + } else { + controller.addError(e); + unawaited(controller.close()); + } + }, + onDone: controller.close, + ); + } catch (e) { + localApi.isAvailable().then((available) { + if (available) { + generateLocalStream(prompt).listen( + controller.add, + onError: controller.addError, + onDone: controller.close, + ); + } else { + controller.addError(e); + unawaited(controller.close()); + } + }); + } + + return controller.stream; + + case InferenceMode.preferLocal: + final controller = StreamController(); + + localApi.isAvailable().then((available) { + if (available) { + var yieldedData = false; + generateLocalStream(prompt).listen( + (response) { + yieldedData = true; + controller.add(response); + }, + onError: (e) { + if (!yieldedData) { + cloudModel.generateContentStream(prompt).listen( + controller.add, + onError: controller.addError, + onDone: controller.close, + ); + } else { + controller.addError(e); + unawaited(controller.close()); + } + }, + onDone: controller.close, + ); + } else { + cloudModel.generateContentStream(prompt).listen( + controller.add, + onError: controller.addError, + onDone: controller.close, + ); + } + }); + + return controller.stream; + } + } + + /// Generates a stream of content from the local model. + @visibleForTesting + Stream generateLocalStream(Iterable prompt) { + final promptString = prompt.map((c) => c.parts.whereType().map((p) => p.text).join()).join(); + + final controller = StreamController(); + + localApi.startStreaming(promptString).then((_) { + const channel = EventChannel('dev.flutter.pigeon.firebase_ai.LocalAIApi.stream'); + channel.receiveBroadcastStream().map((event) { + final responseText = event as String; + // ignore: prefer_const_constructors + return GenerateContentResponse([ + Candidate( + Content('model', [TextPart(responseText)]), + null, + null, + null, + null, + ) + ], null); + }).listen(controller.add, onError: controller.addError, onDone: controller.close); + }).catchError((e) { + controller.addError(e); + unawaited(controller.close()); + }); + + return controller.stream; + } +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/web/chrome_ai.dart b/packages/firebase_ai/firebase_ai/lib/src/web/chrome_ai.dart new file mode 100644 index 000000000000..763b14ceec9b --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/web/chrome_ai.dart @@ -0,0 +1,128 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:async'; +import 'dart:js_interop'; + +/// JS Interop for window.ai +@JS('window.ai') +external JSObject? get windowAI; + +/// Extension on JSObject for window.ai +extension WindowAIExtension on JSObject { + /// Access the language model API. + @JS('languageModel') + external JSObject get languageModel; +} + +/// Extension on JSObject for languageModel +extension LanguageModelExtension on JSObject { + /// Create a new model instance. + @JS('create') + external JSPromise create(JSObject? options); +} + +/// Extension on JSObject for model instance +extension ModelInstanceExtension on JSObject { + /// Prompt the model for a non-streaming response. + @JS('prompt') + external JSPromise prompt(JSString input); + + /// Prompt the model for a streaming response. + @JS('promptStreaming') + external JSObject promptStreaming(JSString input); +} + +/// Extension on JSObject for ReadableStream +extension ReadableStreamExtension on JSObject { + /// Get a reader for the stream. + @JS('getReader') + external JSObject getReader(); +} + +/// Extension on JSObject for ReadableStreamDefaultReader +extension ReadableStreamDefaultReaderExtension on JSObject { + /// Read a chunk from the stream. + @JS('read') + external JSPromise read(); +} + +/// Extension on JSObject for ReadableStreamDefaultReadResult +extension ReadableStreamDefaultReadResultExtension on JSObject { + /// Indicates if the stream is done. + @JS('done') + external bool get done; + + /// The chunk value. + @JS('value') + external JSString get value; +} + +/// Wrapper for Chrome's window.ai API. +class ChromeAI { + JSObject? _model; + + /// Checks if window.ai is available. + Future isAvailable() async { + if (windowAI == null) return false; + return true; + } + + /// Warms up the model (creates an instance). + Future warmup() async { + if (windowAI == null) throw Exception('window.ai not available'); + final lm = windowAI!.languageModel; + _model = await lm.create(null).toDart; + } + + /// Generates content for a prompt. + Future generateContent(String prompt) async { + if (_model == null) { + await warmup(); + } + final response = await _model!.prompt(prompt.toJS).toDart; + return response.toDart; + } + + /// Generates a stream of content for a prompt. + Stream generateContentStream(String prompt) { + final controller = StreamController(); + + warmup().then((_) { + final stream = _model!.promptStreaming(prompt.toJS); + final reader = stream.getReader(); + + void readNext() { + reader.read().toDart.then((result) { + if (result.done) { + controller.close(); + return; + } + controller.add(result.value.toDart); + readNext(); + }).catchError((e) { + controller.addError(e); + controller.close(); + }); + } + + readNext(); + }).catchError((e) { + controller.addError(e); + controller.close(); + }); + + return controller.stream; + } +} diff --git a/packages/firebase_ai/firebase_ai/macos/firebase_ai/Sources/firebase_ai/GeneratedLocalAI.swift b/packages/firebase_ai/firebase_ai/macos/firebase_ai/Sources/firebase_ai/GeneratedLocalAI.swift new file mode 120000 index 000000000000..eaf753545e92 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/macos/firebase_ai/Sources/firebase_ai/GeneratedLocalAI.swift @@ -0,0 +1 @@ +../../../../ios/firebase_ai/Sources/firebase_ai/GeneratedLocalAI.swift \ No newline at end of file diff --git a/packages/firebase_ai/firebase_ai/macos/firebase_ai/Sources/firebase_ai/LocalAIImpl.swift b/packages/firebase_ai/firebase_ai/macos/firebase_ai/Sources/firebase_ai/LocalAIImpl.swift new file mode 120000 index 000000000000..63b3bffc2520 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/macos/firebase_ai/Sources/firebase_ai/LocalAIImpl.swift @@ -0,0 +1 @@ +../../../../ios/firebase_ai/Sources/firebase_ai/LocalAIImpl.swift \ No newline at end of file diff --git a/packages/firebase_ai/firebase_ai/pigeons/copyright.txt b/packages/firebase_ai/firebase_ai/pigeons/copyright.txt new file mode 100644 index 000000000000..274f8376c0bc --- /dev/null +++ b/packages/firebase_ai/firebase_ai/pigeons/copyright.txt @@ -0,0 +1,13 @@ +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/packages/firebase_ai/firebase_ai/pigeons/local_ai.dart b/packages/firebase_ai/firebase_ai/pigeons/local_ai.dart new file mode 100644 index 000000000000..84fbaae0b328 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/pigeons/local_ai.dart @@ -0,0 +1,41 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'package:pigeon/pigeon.dart'; + +@ConfigurePigeon(PigeonOptions( + dartOut: 'lib/src/generated/local_ai.g.dart', + dartOptions: DartOptions(), + kotlinOut: 'android/src/main/kotlin/io/flutter/plugins/firebase/ai/GeneratedLocalAI.kt', + kotlinOptions: KotlinOptions(), + swiftOut: 'ios/firebase_ai/Sources/firebase_ai/GeneratedLocalAI.swift', + swiftOptions: SwiftOptions(), + dartPackageName: 'firebase_ai', + copyrightHeader: 'pigeons/copyright.txt', +)) + +@HostApi() +abstract class LocalAIApi { + @async + bool isAvailable(); + + @async + String generateContent(String prompt); + + @async + void warmup(); + + @async + void startStreaming(String prompt); +} diff --git a/packages/firebase_ai/firebase_ai/pubspec.yaml b/packages/firebase_ai/firebase_ai/pubspec.yaml index fada966d9c5d..a871a6340e27 100644 --- a/packages/firebase_ai/firebase_ai/pubspec.yaml +++ b/packages/firebase_ai/firebase_ai/pubspec.yaml @@ -37,6 +37,7 @@ dev_dependencies: sdk: flutter matcher: ^0.12.16 mockito: ^5.0.0 + pigeon: 26.3.4 plugin_platform_interface: ^2.1.3 flutter: diff --git a/packages/firebase_ai/firebase_ai/test/hybrid_generative_model_test.dart b/packages/firebase_ai/firebase_ai/test/hybrid_generative_model_test.dart new file mode 100644 index 000000000000..4bb906007128 --- /dev/null +++ b/packages/firebase_ai/firebase_ai/test/hybrid_generative_model_test.dart @@ -0,0 +1,228 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ignore_for_file: avoid_redundant_argument_values + +import 'dart:async'; + +import 'package:firebase_ai/src/api.dart'; +import 'package:firebase_ai/src/base_model.dart'; +import 'package:firebase_ai/src/client.dart'; +import 'package:firebase_ai/src/content.dart'; +import 'package:firebase_ai/src/generated/local_ai.g.dart'; +import 'package:firebase_ai/src/hybrid_generative_model.dart'; +import 'package:firebase_core/firebase_core.dart'; +import 'package:flutter_test/flutter_test.dart'; + +class MockApiClient implements ApiClient { + bool shouldFail = false; + String responseText = 'Cloud Response'; + + @override + Future> makeRequest(Uri uri, Map body) async { + if (shouldFail) throw Exception('Cloud Failed'); + return { + 'candidates': [ + { + 'content': { + 'parts': [ + {'text': responseText} + ] + } + } + ] + }; + } + + @override + Stream> streamRequest(Uri uri, Map body) { + if (shouldFail) throw Exception('Cloud Failed'); + return Stream.fromIterable([ + { + 'candidates': [ + { + 'content': { + 'parts': [ + {'text': responseText} + ] + } + } + ] + } + ]); + } +} + +class MockLocalApi extends LocalAIApi { + bool available = true; + bool shouldFail = false; + String responseText = 'Local Response'; + + @override + Future isAvailable() async => available; + + @override + Future generateContent(String prompt) async { + if (shouldFail) throw Exception('Local Failed'); + return responseText; + } + + @override + Future warmup() async {} + + @override + Future startStreaming(String prompt) async { + if (shouldFail) throw Exception('Local Failed'); + } +} + +// ignore: avoid_implementing_value_types +class MockFirebaseApp implements FirebaseApp { + @override + String get name => '[DEFAULT]'; + + @override + FirebaseOptions get options => const FirebaseOptions( + apiKey: 'dummy_api_key', + appId: 'dummy_app_id', + messagingSenderId: 'dummy_sender_id', + projectId: 'dummy_project_id', + ); + + @override + bool get isAutomaticDataCollectionEnabled => false; + + @override + Future delete() async {} + + @override + Future setAutomaticDataCollectionEnabled(bool enabled) async {} + + @override + Future setAutomaticResourceManagementEnabled(bool enabled) async {} + + @override + T? getService() => null; + + @override + void registerService(T service) {} +} + +void main() { + test('preferCloud succeeds on cloud', () async { + final apiClient = MockApiClient(); + final local = MockLocalApi(); + final app = MockFirebaseApp(); + + final cloud = createModelWithClient( + app: app, + location: 'us-central1', + model: 'gemini-pro', + client: apiClient, + useVertexBackend: false, + ); + + final model = HybridGenerativeModel(cloudModel: cloud, localApi: local, mode: InferenceMode.preferCloud); + + final response = await model.generateContent([Content.text('hello')]); + expect(response.text, 'Cloud Response'); + }); + + test('preferCloud falls back to local on cloud failure', () async { + final apiClient = MockApiClient()..shouldFail = true; + final local = MockLocalApi(); + final app = MockFirebaseApp(); + + final cloud = createModelWithClient( + app: app, + location: 'us-central1', + model: 'gemini-pro', + client: apiClient, + useVertexBackend: false, + ); + + final model = HybridGenerativeModel(cloudModel: cloud, localApi: local, mode: InferenceMode.preferCloud); + + final response = await model.generateContent([Content.text('hello')]); + expect(response.text, 'Local Response'); + }); + + test('preferCloud streaming succeeds on cloud', () async { + final apiClient = MockApiClient(); + final local = MockLocalApi(); + final app = MockFirebaseApp(); + + final cloud = createModelWithClient( + app: app, + location: 'us-central1', + model: 'gemini-pro', + client: apiClient, + useVertexBackend: false, + ); + + final model = HybridGenerativeModel(cloudModel: cloud, localApi: local, mode: InferenceMode.preferCloud); + + final responses = model.generateContentStream([Content.text('hello')]); + final textList = await responses.map((r) => r.text).toList(); + expect(textList, ['Cloud Response']); + }); + + test('preferCloud streaming falls back to local on cloud failure before data', () async { + final apiClient = MockApiClient()..shouldFail = true; + final local = MockLocalApi(); + final app = MockFirebaseApp(); + + final cloud = createModelWithClient( + app: app, + location: 'us-central1', + model: 'gemini-pro', + client: apiClient, + useVertexBackend: false, + ); + + final mockLocalStream = Stream.fromIterable([ + GenerateContentResponse([ + Candidate(Content('model', [const TextPart('Local Response')]), null, null, null, null) + ], null) + ]); + + final model = TestHybridGenerativeModel( + cloudModel: cloud, + localApi: local, + mode: InferenceMode.preferCloud, + mockLocalStream: mockLocalStream, + ); + + final responses = model.generateContentStream([Content.text('hello')]); + final textList = await responses.map((r) => r.text).toList(); + expect(textList, ['Local Response']); + }); +} + +class TestHybridGenerativeModel extends HybridGenerativeModel { + TestHybridGenerativeModel({ + required super.cloudModel, + required super.localApi, + super.mode, + this.mockLocalStream, + }); + + Stream? mockLocalStream; + + @override + Stream generateLocalStream(Iterable prompt) { + if (mockLocalStream != null) return mockLocalStream!; + return super.generateLocalStream(prompt); + } +}