diff --git a/docs/ai/usage/index.md b/docs/ai/usage/index.md
new file mode 100644
index 0000000000..d80d7bb5ba
--- /dev/null
+++ b/docs/ai/usage/index.md
@@ -0,0 +1,440 @@
+---
+title: AI Logic
+description: Installation and getting started with Firebase AI Logic.
+icon: //static.invertase.io/assets/social/firebase-logo.png
+next: /analytics/usage
+previous: /remote-config/usage
+---
+
+# Installation
+
+This module requires that the `@react-native-firebase/app` module is already setup and installed. To install the "app" module, view the
+[Getting Started](/) documentation.
+
+```bash
+# Install & setup the app module
+yarn add @react-native-firebase/app
+
+# Install the ai module
+yarn add @react-native-firebase/ai
+```
+
+# What does it do
+
+Firebase AI Logic gives you access to the latest generative AI models from Google.
+
+If you need to call the Gemini API directly from your mobile or web app — rather than server-side — you can use the Firebase AI Logic client SDKs. These client SDKs are built specifically for use with mobile and web apps, offering security options against unauthorized clients as well as integrations with other Firebase services.
+
+# Usage
+
+## Generate text from text-only input
+
+You can call the Gemini API with input that includes only text. For these calls, you need to use a model that supports text-only prompts (like Gemini 1.5 Pro).
+
+Use `generateContent()` which waits for the entire response before returning.
+
+```js
+import React from 'react';
+import { AppRegistry, Button, Text, View } from 'react-native';
+import { getApp } from '@react-native-firebase/app';
+import { getAI, getGenerativeModel } from '@react-native-firebase/ai';
+
+function App() {
+ return (
+
+ {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+
+ const result = await model.generateContent('What is 2 + 2?');
+
+ console.log(result.response.text());
+ }}
+ />
+
+ );
+}
+```
+
+Use `generateContentStream()` if you wish to stream the response.
+
+```js
+import React from 'react';
+import { AppRegistry, Button, Text, View } from 'react-native';
+import { getApp } from '@react-native-firebase/app';
+import { getAI, getGenerativeModel } from '@react-native-firebase/ai';
+
+function App() {
+ return (
+
+ {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+
+ const result = await model.generateContentStream('Write a short poem');
+
+ let text = '';
+ for await (const chunk of result.stream) {
+ const chunkText = chunk.text();
+ text += chunkText;
+ }
+
+ console.log(text);
+ }}
+ />
+
+ );
+}
+```
+
+## Generate text from multi-modal input
+
+You can pass in different input types to generate text responses. **important** - React Native does not have native support for `Blob` and `Buffer` types which might be used to facilitate different modal inputs. You may have to use third party libraries for this functionality.
+
+```js
+import React from 'react';
+import { AppRegistry, Button, Text, View } from 'react-native';
+import { getApp } from '@react-native-firebase/app';
+import { getAI, getGenerativeModel } from '@react-native-firebase/ai';
+
+function App() {
+ return (
+
+ {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+ const prompt = 'What can you see?';
+ const base64Emoji =
+ 'iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=';
+
+ const response = await model.generateContentStream([
+ prompt,
+ { inlineData: { mimeType: 'image/png', data: base64Emoji } },
+ ]);
+
+ let text = '';
+ for await (const chunk of response.stream) {
+ text += chunk.text();
+ }
+
+ console.log(text);
+ }}
+ />
+
+ );
+}
+```
+
+## Generate structured output (e.g. JSON)
+
+The Firebase AI Logic SDK returns responses as unstructured text by default. However, some use cases require structured text, like JSON. For example, you might be using the response for other downstream tasks that require an established data schema.
+
+```js
+import React from 'react';
+import { AppRegistry, Button, Text, View } from 'react-native';
+import { getApp } from '@react-native-firebase/app';
+import { getAI, getGenerativeModel } from '@react-native-firebase/ai';
+
+function App() {
+ return (
+
+ {
+ const app = getApp();
+ const ai = getAI(app);
+ const jsonSchema = Schema.object({
+ properties: {
+ characters: Schema.array({
+ items: Schema.object({
+ properties: {
+ name: Schema.string(),
+ accessory: Schema.string(),
+ age: Schema.number(),
+ species: Schema.string(),
+ },
+ optionalProperties: ['accessory'],
+ }),
+ }),
+ },
+ });
+ const model = getGenerativeModel(ai, {
+ model: 'gemini-1.5-flash',
+ generationConfig: {
+ responseMimeType: 'application/json',
+ responseSchema: jsonSchema,
+ },
+ });
+
+ let prompt = "For use in a children's card game, generate 10 animal-based characters.";
+
+ let result = await model.generateContent(prompt);
+ console.log(result.response.text());
+ }}
+ />
+
+ );
+}
+```
+
+## Multi-turn conversations
+
+You can build freeform conversations across multiple turns. The Firebase AI Logic SDK simplifies the process by managing the state of the conversation, so unlike with `generateContentStream()` or `generateContent()`, you don't have to store the conversation history yourself.
+
+```js
+import React from 'react';
+import { AppRegistry, Button, Text, View } from 'react-native';
+import { getApp } from '@react-native-firebase/app';
+import { getAI, getGenerativeModel } from '@react-native-firebase/ai';
+
+function App() {
+ return (
+
+ {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+
+ const chat = model.startChat({
+ history: [
+ {
+ role: 'user',
+ parts: [{ text: 'Hello, I have 2 dogs in my house.' }],
+ },
+ {
+ role: 'model',
+ parts: [{ text: 'Great to meet you. What would you like to know?' }],
+ },
+ ],
+ generationConfig: {
+ maxOutputTokens: 100,
+ },
+ });
+
+ const msg = 'How many paws are in my house?';
+ const result = await chat.sendMessageStream(msg);
+
+ let text = '';
+ for await (const chunk of result.stream) {
+ const chunkText = chunk.text();
+ text += chunkText;
+ }
+ console.log(text);
+
+ // When you want to see the history of the chat
+ const history = await chat.getHistory();
+ console.log(history);
+ }}
+ />
+
+ );
+}
+```
+
+## Function calling
+
+Generative models are powerful at solving many types of problems. However, they are constrained by limitations like:
+
+- They are frozen after training, leading to stale knowledge.
+- They can't query or modify external data.
+
+Function calling can help you overcome some of these limitations. Function calling is sometimes referred to as tool use because it allows a model to use external tools such as APIs and functions to generate its final response.
+
+```js
+import React from 'react';
+import { AppRegistry, Button, Text, View } from 'react-native';
+import { getApp } from '@react-native-firebase/app';
+import { getAI, getGenerativeModel } from '@react-native-firebase/ai';
+
+function App() {
+ return (
+
+ {
+ // This function calls a hypothetical external API that returns
+ // a collection of weather information for a given location on a given date.
+ // `location` is an object of the form { city: string, state: string }
+ async function fetchWeather({ location, date }) {
+ // For demo purposes, this hypothetical response is hardcoded here in the expected format.
+ return {
+ temperature: 38,
+ chancePrecipitation: '56%',
+ cloudConditions: 'partlyCloudy',
+ };
+ }
+ const fetchWeatherTool = {
+ functionDeclarations: [
+ {
+ name: 'fetchWeather',
+ description: 'Get the weather conditions for a specific city on a specific date',
+ parameters: Schema.object({
+ properties: {
+ location: Schema.object({
+ description:
+ 'The name of the city and its state for which to get ' +
+ 'the weather. Only cities in the USA are supported.',
+ properties: {
+ city: Schema.string({
+ description: 'The city of the location.',
+ }),
+ state: Schema.string({
+ description: 'The US state of the location.',
+ }),
+ },
+ }),
+ date: Schema.string({
+ description:
+ 'The date for which to get the weather. Date must be in the' +
+ ' format: YYYY-MM-DD.',
+ }),
+ },
+ }),
+ },
+ ],
+ };
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, {
+ model: 'gemini-1.5-flash',
+ tools: fetchWeatherTool,
+ });
+
+ const chat = model.startChat();
+ const prompt = 'What was the weather in Boston on October 17, 2024?';
+
+ // Send the user's question (the prompt) to the model using multi-turn chat.
+ let result = await chat.sendMessage(prompt);
+ const functionCalls = result.response.functionCalls();
+ let functionCall;
+ let functionResult;
+ // When the model responds with one or more function calls, invoke the function(s).
+ if (functionCalls.length > 0) {
+ for (const call of functionCalls) {
+ if (call.name === 'fetchWeather') {
+ // Forward the structured input data prepared by the model
+ // to the hypothetical external API.
+ functionResult = await fetchWeather(call.args);
+ functionCall = call;
+ }
+ }
+ }
+ result = await chat.sendMessage([
+ {
+ functionResponse: {
+ name: functionCall.name, // "fetchWeather"
+ response: functionResult,
+ },
+ },
+ ]);
+ console.log(result.response.text());
+ }}
+ />
+
+ );
+}
+```
+
+## Count tokens & billable characters
+
+Generative AI models break down data into units called tokens for processing. Each Gemini model has a [maximum number of tokens](https://firebase.google.com/docs/ai-logic/models) that it can handle in a prompt and response.
+
+The below shows you how to get an estimate of token count and the number of billable characters for a request.
+
+```js
+import React from 'react';
+import { AppRegistry, Button, Text, View } from 'react-native';
+import { getApp } from '@react-native-firebase/app';
+import { getAI, getGenerativeModel } from '@react-native-firebase/ai';
+
+function App() {
+ return (
+
+ {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+ // Count tokens & billable character for text input
+ const { totalTokens, totalBillableCharacters } = await model.countTokens(
+ 'Write a story about a magic backpack.',
+ );
+ console.log(
+ `Total tokens: ${totalTokens}, total billable characters: ${totalBillableCharacters}`,
+ );
+
+ // Count tokens & billable character for multi-modal input
+ const prompt = "What's in this picture?";
+ const imageAsBase64 = '...base64 string image';
+ const imagePart = { inlineData: { mimeType: 'image/jpeg', data: imageAsBase64 } };
+
+ const { totalTokens, totalBillableCharacters } = await model.countTokens([
+ prompt,
+ imagePart,
+ ]);
+ console.log(
+ `Total tokens: ${totalTokens}, total billable characters: ${totalBillableCharacters}`,
+ );
+ }}
+ />
+
+ );
+}
+```
+
+## Getting ready for production
+
+For mobile and web apps, you need to protect the Gemini API and your project resources (like tuned models) from abuse by unauthorized clients. You can use Firebase App Check to verify that all API calls are from your actual app. See [Firebase docs for further information](https://firebase.google.com/docs/ai-logic/app-check).
+
+- Ensure you have setup [App Check for React Native Firebase](/app-check/usage/index)
+- Pass in an instance of App Check to Firebase AI Logic which, under the hood, will call `appCheck.getToken()` and use it as part of Firebase AI Logic API requests to the server.
+
+```js
+import React from 'react';
+import { AppRegistry, Button, Text, View } from 'react-native';
+import { getApp } from '@react-native-firebase/app';
+import auth from '@react-native-firebase/auth';
+import appCheck from '@react-native-firebase/app-check';
+import { getAI, getGenerativeModel, GoogleAIBackend } from '@react-native-firebase/ai';
+
+function App() {
+ return (
+
+ {
+ const app = getApp();
+ // Can also pass an instance of auth which will pass in an auth token if a user is signed-in
+ const authInstance = auth(app);
+ const appCheckInstance = appCheck(app);
+ // Configure appCheck instance as per docs....
+ const options = {
+ appCheck: appCheckInstance,
+ auth: authInstance,
+ backend: new GoogleAIBackend(),
+ };
+
+ const ai = getAI(app, options);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+
+ const result = await model.generateContent('What is 2 + 2?');
+
+ console.log('result', result.response.text());
+ }}
+ />
+
+ );
+}
+```
diff --git a/eslint.config.mjs b/eslint.config.mjs
index e0df72a64f..bbd12abd45 100644
--- a/eslint.config.mjs
+++ b/eslint.config.mjs
@@ -37,8 +37,9 @@ export default [
'**/app.playground.js',
'**/type-test.ts',
'packages/**/modular/dist/**/*',
- 'packages/vertexai/__tests__/test-utils',
+ 'packages/ai/__tests__/test-utils',
'packages/vertexai/dist',
+ 'packages/ai/dist',
],
},
...compat
diff --git a/package.json b/package.json
index 00756f8d9f..2237071261 100644
--- a/package.json
+++ b/package.json
@@ -18,7 +18,7 @@
"lint:spellcheck": "spellchecker --quiet --files=\"docs/**/*.md\" --dictionaries=\"./.spellcheck.dict.txt\" --reports=\"spelling.json\" --plugins spell indefinite-article repeated-words syntax-mentions syntax-urls frontmatter",
"tsc:compile": "tsc --project .",
"lint:all": "yarn lint && yarn lint:markdown && yarn lint:spellcheck && yarn tsc:compile",
- "tests:vertex:mocks": "./scripts/vertex_mock_responses.sh && yarn ts-node ./packages/vertexai/__tests__/test-utils/convert-mocks.ts",
+ "tests:ai:mocks": "./scripts/ai_mock_responses.sh && yarn ts-node ./packages/ai/__tests__/test-utils/convert-mocks.ts",
"tests:jest": "jest",
"tests:jest-watch": "jest --watch",
"tests:jest-coverage": "jest --coverage",
diff --git a/packages/ai/CHANGELOG.md b/packages/ai/CHANGELOG.md
new file mode 100644
index 0000000000..b47a3c0062
--- /dev/null
+++ b/packages/ai/CHANGELOG.md
@@ -0,0 +1,12 @@
+# Change Log
+
+All notable changes to this project will be documented in this file.
+See [Conventional Commits](https://conventionalcommits.org) for commit guidelines.
+
+## Feature
+
+Initial release of the Firebase AI Logic SDK (`FirebaseAI`). This SDK *replaces* the previous Vertex AI in Firebase SDK (`FirebaseVertexAI`) to accommodate the evolving set of supported features and services.
+The new Firebase AI Logic SDK provides **preview** support for the Gemini Developer API, including its free tier offering.
+Using the Firebase AI Logic SDK with the Vertex AI Gemini API is still generally available (GA).
+
+To start using the new SDK, import the `@react-native-firebase/ai` package and use the modular method `getAI()` to initialize. See details in the [migration guide](https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk).
\ No newline at end of file
diff --git a/packages/ai/LICENSE b/packages/ai/LICENSE
new file mode 100644
index 0000000000..ef3ed44f06
--- /dev/null
+++ b/packages/ai/LICENSE
@@ -0,0 +1,32 @@
+Apache-2.0 License
+------------------
+
+Copyright (c) 2016-present Invertase Limited & Contributors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this library except in compliance with the License.
+
+You may obtain a copy of the Apache-2.0 License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+
+Creative Commons Attribution 3.0 License
+----------------------------------------
+
+Copyright (c) 2016-present Invertase Limited & Contributors
+
+Documentation and other instructional materials provided for this project
+(including on a separate documentation repository or it's documentation website) are
+licensed under the Creative Commons Attribution 3.0 License. Code samples/blocks
+contained therein are licensed under the Apache License, Version 2.0 (the "License"), as above.
+
+You may obtain a copy of the Creative Commons Attribution 3.0 License at
+
+ https://creativecommons.org/licenses/by/3.0/
diff --git a/packages/ai/README.md b/packages/ai/README.md
new file mode 100644
index 0000000000..78adab6a98
--- /dev/null
+++ b/packages/ai/README.md
@@ -0,0 +1,55 @@
+
+
+
+
+
React Native Firebase - AI Logic
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+---
+
+Firebase AI Logic gives you access to the latest generative AI models from Google.
+
+If you need to call the Gemini API directly from your mobile or web app — rather than server-side — you can use the Firebase AI Logic client SDKs. These client SDKs are built specifically for use with mobile and web apps, offering security options against unauthorized clients as well as integrations with other Firebase services.
+
+[> Learn More](https://firebase.google.com/docs/ai-logic/)
+
+## Installation
+
+Requires `@react-native-firebase/app` to be installed.
+
+```bash
+yarn add @react-native-firebase/ai
+```
+
+## Documentation
+
+- [Quick Start](https://rnfirebase.io/ai/usage)
+
+## License
+
+- See [LICENSE](/LICENSE)
+
+---
+
+
+
+
+ Built and maintained with 💛 by Invertase .
+
+
+
+---
diff --git a/packages/ai/__tests__/api.test.ts b/packages/ai/__tests__/api.test.ts
new file mode 100644
index 0000000000..1847089736
--- /dev/null
+++ b/packages/ai/__tests__/api.test.ts
@@ -0,0 +1,112 @@
+/**
+ * @license
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import { describe, expect, it } from '@jest/globals';
+import { type ReactNativeFirebase } from '../../app/lib';
+
+import { ModelParams, AIErrorCode } from '../lib/types';
+import { AIError } from '../lib/errors';
+import { getGenerativeModel } from '../lib/index';
+
+import { AI } from '../lib/public-types';
+import { GenerativeModel } from '../lib/models/generative-model';
+
+import { AI_TYPE } from '../lib/constants';
+import { VertexAIBackend } from '../lib/backend';
+
+const fakeAI: AI = {
+ app: {
+ name: 'DEFAULT',
+ options: {
+ apiKey: 'key',
+ appId: 'appId',
+ projectId: 'my-project',
+ },
+ } as ReactNativeFirebase.FirebaseApp,
+ backend: new VertexAIBackend('us-central1'),
+ location: 'us-central1',
+};
+
+describe('getGenerativeModel()', () => {
+ it('should throw an error if no model is provided', () => {
+ try {
+ getGenerativeModel(fakeAI, {} as ModelParams);
+ } catch (e) {
+ expect((e as AIError).code).toContain(AIErrorCode.NO_MODEL);
+ expect((e as AIError).message).toContain(
+ `AI: Must provide a model name. Example: ` +
+ `getGenerativeModel({ model: 'my-model-name' }) (${AI_TYPE}/${AIErrorCode.NO_MODEL})`,
+ );
+ }
+ });
+
+ it('getGenerativeModel throws if no apiKey is provided', () => {
+ const fakeVertexNoApiKey = {
+ ...fakeAI,
+ app: { options: { projectId: 'my-project', appId: 'my-appid' } },
+ } as AI;
+ try {
+ getGenerativeModel(fakeVertexNoApiKey, { model: 'my-model' });
+ } catch (e) {
+ expect((e as AIError).code).toContain(AIErrorCode.NO_API_KEY);
+ expect((e as AIError).message).toBe(
+ `AI: The "apiKey" field is empty in the local ` +
+ `Firebase config. Firebase AI requires this field to` +
+ ` contain a valid API key. (${AI_TYPE}/${AIErrorCode.NO_API_KEY})`,
+ );
+ }
+ });
+
+ it('should throw an error if no projectId is provided', () => {
+ const fakeVertexNoProject = {
+ ...fakeAI,
+ app: { options: { apiKey: 'my-key' } },
+ } as AI;
+ try {
+ getGenerativeModel(fakeVertexNoProject, { model: 'my-model' });
+ } catch (e) {
+ expect((e as AIError).code).toContain(AIErrorCode.NO_PROJECT_ID);
+ expect((e as AIError).message).toBe(
+ `AI: The "projectId" field is empty in the local` +
+ ` Firebase config. Firebase AI requires this field ` +
+ `to contain a valid project ID. (${AI_TYPE}/${AIErrorCode.NO_PROJECT_ID})`,
+ );
+ }
+ });
+
+ it('should throw an error if no appId is provided', () => {
+ const fakeVertexNoProject = {
+ ...fakeAI,
+ app: { options: { apiKey: 'my-key', projectId: 'my-projectid' } },
+ } as AI;
+ try {
+ getGenerativeModel(fakeVertexNoProject, { model: 'my-model' });
+ } catch (e) {
+ expect((e as AIError).code).toContain(AIErrorCode.NO_APP_ID);
+ expect((e as AIError).message).toBe(
+ `AI: The "appId" field is empty in the local` +
+ ` Firebase config. Firebase AI requires this field ` +
+ `to contain a valid app ID. (${AI_TYPE}/${AIErrorCode.NO_APP_ID})`,
+ );
+ }
+ });
+
+ it('should return an instance of GenerativeModel', () => {
+ const genModel = getGenerativeModel(fakeAI, { model: 'my-model' });
+ expect(genModel).toBeInstanceOf(GenerativeModel);
+ expect(genModel.model).toBe('publishers/google/models/my-model');
+ });
+});
diff --git a/packages/ai/__tests__/backend.test.ts b/packages/ai/__tests__/backend.test.ts
new file mode 100644
index 0000000000..88712f91d2
--- /dev/null
+++ b/packages/ai/__tests__/backend.test.ts
@@ -0,0 +1,55 @@
+/**
+ * @license
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import { describe, it, expect } from '@jest/globals';
+import { GoogleAIBackend, VertexAIBackend } from '../lib/backend';
+import { BackendType } from '../lib/public-types';
+import { DEFAULT_LOCATION } from '../lib/constants';
+
+describe('Backend', () => {
+ describe('GoogleAIBackend', () => {
+ it('should set backendType to GOOGLE_AI', () => {
+ const backend = new GoogleAIBackend();
+ expect(backend.backendType).toBe(BackendType.GOOGLE_AI);
+ });
+ });
+
+ describe('VertexAIBackend', () => {
+ it('should set backendType to VERTEX_AI', () => {
+ const backend = new VertexAIBackend();
+ expect(backend.backendType).toBe(BackendType.VERTEX_AI);
+ expect(backend.location).toBe(DEFAULT_LOCATION);
+ });
+
+ it('should set a custom location', () => {
+ const backend = new VertexAIBackend('test-location');
+ expect(backend.backendType).toBe(BackendType.VERTEX_AI);
+ expect(backend.location).toBe('test-location');
+ });
+
+ it('should use a default location if location is empty string', () => {
+ const backend = new VertexAIBackend('');
+ expect(backend.backendType).toBe(BackendType.VERTEX_AI);
+ expect(backend.location).toBe(DEFAULT_LOCATION);
+ });
+
+ it('uses default location if location is null', () => {
+ const backend = new VertexAIBackend(null as any);
+ expect(backend.backendType).toBe(BackendType.VERTEX_AI);
+ expect(backend.location).toBe(DEFAULT_LOCATION);
+ });
+ });
+});
diff --git a/packages/vertexai/__tests__/chat-session-helpers.test.ts b/packages/ai/__tests__/chat-session-helpers.test.ts
similarity index 100%
rename from packages/vertexai/__tests__/chat-session-helpers.test.ts
rename to packages/ai/__tests__/chat-session-helpers.test.ts
diff --git a/packages/vertexai/__tests__/chat-session.test.ts b/packages/ai/__tests__/chat-session.test.ts
similarity index 97%
rename from packages/vertexai/__tests__/chat-session.test.ts
rename to packages/ai/__tests__/chat-session.test.ts
index cd96aa32e6..10b025c62a 100644
--- a/packages/vertexai/__tests__/chat-session.test.ts
+++ b/packages/ai/__tests__/chat-session.test.ts
@@ -21,11 +21,14 @@ import { GenerateContentStreamResult } from '../lib/types';
import { ChatSession } from '../lib/methods/chat-session';
import { ApiSettings } from '../lib/types/internal';
import { RequestOptions } from '../lib/types/requests';
+import { VertexAIBackend } from '../lib/backend';
const fakeApiSettings: ApiSettings = {
apiKey: 'key',
project: 'my-project',
+ appId: 'my-appid',
location: 'us-central1',
+ backend: new VertexAIBackend(),
};
const requestOptions: RequestOptions = {
diff --git a/packages/ai/__tests__/count-tokens.test.ts b/packages/ai/__tests__/count-tokens.test.ts
new file mode 100644
index 0000000000..7b30261283
--- /dev/null
+++ b/packages/ai/__tests__/count-tokens.test.ts
@@ -0,0 +1,167 @@
+/**
+ * @license
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import { describe, expect, it, afterEach, jest, beforeEach } from '@jest/globals';
+import { BackendName, getMockResponse } from './test-utils/mock-response';
+import * as request from '../lib/requests/request';
+import { countTokens } from '../lib/methods/count-tokens';
+import { CountTokensRequest, RequestOptions } from '../lib/types';
+import { ApiSettings } from '../lib/types/internal';
+import { Task } from '../lib/requests/request';
+import { GoogleAIBackend } from '../lib/backend';
+import { SpiedFunction } from 'jest-mock';
+import { mapCountTokensRequest } from '../lib/googleai-mappers';
+
+const fakeApiSettings: ApiSettings = {
+ apiKey: 'key',
+ project: 'my-project',
+ location: 'us-central1',
+ appId: '',
+ backend: new GoogleAIBackend(),
+};
+
+const fakeGoogleAIApiSettings: ApiSettings = {
+ apiKey: 'key',
+ project: 'my-project',
+ appId: 'my-appid',
+ location: '',
+ backend: new GoogleAIBackend(),
+};
+
+const fakeRequestParams: CountTokensRequest = {
+ contents: [{ parts: [{ text: 'hello' }], role: 'user' }],
+};
+
+describe('countTokens()', () => {
+ afterEach(() => {
+ jest.restoreAllMocks();
+ });
+
+ it('total tokens', async () => {
+ const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-success-total-tokens.json');
+ const makeRequestStub = jest
+ .spyOn(request, 'makeRequest')
+ .mockResolvedValue(mockResponse as Response);
+ const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams);
+ expect(result.totalTokens).toBe(6);
+ expect(result.totalBillableCharacters).toBe(16);
+ expect(makeRequestStub).toHaveBeenCalledWith(
+ 'model',
+ Task.COUNT_TOKENS,
+ fakeApiSettings,
+ false,
+ expect.stringContaining('contents'),
+ undefined,
+ );
+ });
+
+ it('total tokens with modality details', async () => {
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-detailed-token-response.json',
+ );
+ const makeRequestStub = jest
+ .spyOn(request, 'makeRequest')
+ .mockResolvedValue(mockResponse as Response);
+ const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams);
+
+ expect(result.totalTokens).toBe(1837);
+ expect(result.totalBillableCharacters).toBe(117);
+ expect(result.promptTokensDetails?.[0]?.modality).toBe('IMAGE');
+ expect(result.promptTokensDetails?.[0]?.tokenCount).toBe(1806);
+ expect(makeRequestStub).toHaveBeenCalledWith(
+ 'model',
+ Task.COUNT_TOKENS,
+ fakeApiSettings,
+ false,
+ expect.stringContaining('contents'),
+ undefined,
+ );
+ });
+
+ it('total tokens no billable characters', async () => {
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-no-billable-characters.json',
+ );
+ const makeRequestStub = jest
+ .spyOn(request, 'makeRequest')
+ .mockResolvedValue(mockResponse as Response);
+ const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams);
+ expect(result.totalTokens).toBe(258);
+ expect(result).not.toHaveProperty('totalBillableCharacters');
+ expect(makeRequestStub).toHaveBeenCalledWith(
+ 'model',
+ Task.COUNT_TOKENS,
+ fakeApiSettings,
+ false,
+ expect.stringContaining('contents'),
+ undefined,
+ );
+ });
+
+ it('model not found', async () => {
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-failure-model-not-found.json',
+ );
+ const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({
+ ok: false,
+ status: 404,
+ json: mockResponse.json,
+ } as Response);
+ await expect(countTokens(fakeApiSettings, 'model', fakeRequestParams)).rejects.toThrow(
+ /404.*not found/,
+ );
+ expect(mockFetch).toHaveBeenCalled();
+ });
+
+ describe('googleAI', () => {
+ let makeRequestStub: SpiedFunction<
+ (
+ model: string,
+ task: Task,
+ apiSettings: ApiSettings,
+ stream: boolean,
+ body: string,
+ requestOptions?: RequestOptions,
+ ) => Promise
+ >;
+
+ beforeEach(() => {
+ makeRequestStub = jest.spyOn(request, 'makeRequest');
+ });
+
+ afterEach(() => {
+ jest.restoreAllMocks();
+ });
+
+ it('maps request to GoogleAI format', async () => {
+ makeRequestStub.mockResolvedValue({ ok: true, json: () => {} } as Response);
+
+ await countTokens(fakeGoogleAIApiSettings, 'model', fakeRequestParams);
+
+ expect(makeRequestStub).toHaveBeenCalledWith(
+ 'model',
+ Task.COUNT_TOKENS,
+ fakeGoogleAIApiSettings,
+ false,
+ JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')),
+ undefined,
+ );
+ });
+ });
+});
diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts
new file mode 100644
index 0000000000..f93f9d16bc
--- /dev/null
+++ b/packages/ai/__tests__/exported-types.test.ts
@@ -0,0 +1,455 @@
+import { describe, expect, it } from '@jest/globals';
+import {
+ // Runtime values (classes, functions, constants)
+ BackendType,
+ POSSIBLE_ROLES,
+ AIError,
+ GenerativeModel,
+ AIModel,
+ getAI,
+ getGenerativeModel,
+ ChatSession,
+ GoogleAIBackend,
+ VertexAIBackend,
+ // Types that exist - imported for type checking
+ Part,
+ ResponseModality,
+ Role,
+ Tool,
+ TypedSchema,
+ AI,
+ AIOptions,
+ BaseParams,
+ Citation,
+ CitationMetadata,
+ Content,
+ CountTokensRequest,
+ CountTokensResponse,
+ CustomErrorData,
+ EnhancedGenerateContentResponse,
+ ErrorDetails,
+ FileData,
+ FileDataPart,
+ FunctionCall,
+ FunctionCallingConfig,
+ FunctionCallPart,
+ FunctionDeclaration,
+ FunctionDeclarationsTool,
+ FunctionResponse,
+ FunctionResponsePart,
+ GenerateContentCandidate,
+ GenerateContentRequest,
+ GenerateContentResponse,
+ GenerateContentResult,
+ GenerateContentStreamResult,
+ GenerationConfig,
+ GenerativeContentBlob,
+ GroundingAttribution,
+ GroundingMetadata,
+ InlineDataPart,
+ ModalityTokenCount,
+ ModelParams,
+ ObjectSchemaInterface,
+ PromptFeedback,
+ RequestOptions,
+ RetrievedContextAttribution,
+ SafetyRating,
+ SafetySetting,
+ SchemaInterface,
+ SchemaParams,
+ SchemaRequest,
+ SchemaShared,
+ Segment,
+ StartChatParams,
+ TextPart,
+ ToolConfig,
+ UsageMetadata,
+ VideoMetadata,
+ WebAttribution,
+ // Enums
+ AIErrorCode,
+ BlockReason,
+ FinishReason,
+ FunctionCallingMode,
+ HarmBlockMethod,
+ HarmBlockThreshold,
+ HarmCategory,
+ HarmProbability,
+ HarmSeverity,
+ Modality,
+ SchemaType,
+} from '../lib';
+
+describe('AI', function () {
+ describe('modular', function () {
+ // Runtime value exports (constants, classes, functions)
+ it('`BackendType` constant is properly exposed to end user', function () {
+ expect(BackendType).toBeDefined();
+ expect(BackendType.VERTEX_AI).toBeDefined();
+ expect(BackendType.GOOGLE_AI).toBeDefined();
+ });
+
+ it('`POSSIBLE_ROLES` constant is properly exposed to end user', function () {
+ expect(POSSIBLE_ROLES).toBeDefined();
+ });
+
+ it('`AIError` class is properly exposed to end user', function () {
+ expect(AIError).toBeDefined();
+ });
+
+ it('`GenerativeModel` class is properly exposed to end user', function () {
+ expect(GenerativeModel).toBeDefined();
+ });
+
+ it('`AIModel` class is properly exposed to end user', function () {
+ expect(AIModel).toBeDefined();
+ });
+
+ it('`getAI` function is properly exposed to end user', function () {
+ expect(getAI).toBeDefined();
+ });
+
+ it('`getGenerativeModel` function is properly exposed to end user', function () {
+ expect(getGenerativeModel).toBeDefined();
+ });
+
+ it('`ChatSession` class is properly exposed to end user', function () {
+ expect(ChatSession).toBeDefined();
+ });
+
+ it('`GoogleAIBackend` class is properly exposed to end user', function () {
+ expect(GoogleAIBackend).toBeDefined();
+ });
+
+ it('`VertexAIBackend` class is properly exposed to end user', function () {
+ expect(VertexAIBackend).toBeDefined();
+ });
+
+ // Type exports - test that they can be used as types
+ it('`Part` type is properly exposed to end user', function () {
+ const _typeCheck: Part = { text: 'test' };
+ expect(typeof _typeCheck).toBe('object');
+ });
+
+ it('`ResponseModality` type is properly exposed to end user', function () {
+ const _typeCheck: ResponseModality = {} as ResponseModality;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`Role` type is properly exposed to end user', function () {
+ const _typeCheck: Role = 'user';
+ expect(typeof _typeCheck).toBe('string');
+ });
+
+ it('`Tool` type is properly exposed to end user', function () {
+ const _typeCheck: Tool = {} as Tool;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`TypedSchema` type is properly exposed to end user', function () {
+ const _typeCheck: TypedSchema = {} as TypedSchema;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`AI` type is properly exposed to end user', function () {
+ const _typeCheck: AI = {} as AI;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`AIOptions` type is properly exposed to end user', function () {
+ const _typeCheck: AIOptions = {} as AIOptions;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`BaseParams` type is properly exposed to end user', function () {
+ const _typeCheck: BaseParams = {} as BaseParams;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`Citation` type is properly exposed to end user', function () {
+ const _typeCheck: Citation = {} as Citation;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`CitationMetadata` type is properly exposed to end user', function () {
+ const _typeCheck: CitationMetadata = {} as CitationMetadata;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`Content` type is properly exposed to end user', function () {
+ const _typeCheck: Content = {} as Content;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`CountTokensRequest` type is properly exposed to end user', function () {
+ const _typeCheck: CountTokensRequest = {} as CountTokensRequest;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`CountTokensResponse` type is properly exposed to end user', function () {
+ const _typeCheck: CountTokensResponse = {} as CountTokensResponse;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`CustomErrorData` type is properly exposed to end user', function () {
+ const _typeCheck: CustomErrorData = {} as CustomErrorData;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`EnhancedGenerateContentResponse` type is properly exposed to end user', function () {
+ const _typeCheck: EnhancedGenerateContentResponse = {} as EnhancedGenerateContentResponse;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`ErrorDetails` type is properly exposed to end user', function () {
+ const _typeCheck: ErrorDetails = {} as ErrorDetails;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`FileData` type is properly exposed to end user', function () {
+ const _typeCheck: FileData = {} as FileData;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`FileDataPart` type is properly exposed to end user', function () {
+ const _typeCheck: FileDataPart = {} as FileDataPart;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`FunctionCall` type is properly exposed to end user', function () {
+ const _typeCheck: FunctionCall = {} as FunctionCall;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`FunctionCallingConfig` type is properly exposed to end user', function () {
+ const _typeCheck: FunctionCallingConfig = {} as FunctionCallingConfig;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`FunctionCallPart` type is properly exposed to end user', function () {
+ const _typeCheck: FunctionCallPart = {} as FunctionCallPart;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`FunctionDeclaration` type is properly exposed to end user', function () {
+ const _typeCheck: FunctionDeclaration = {} as FunctionDeclaration;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`FunctionDeclarationsTool` type is properly exposed to end user', function () {
+ const _typeCheck: FunctionDeclarationsTool = {} as FunctionDeclarationsTool;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`FunctionResponse` type is properly exposed to end user', function () {
+ const _typeCheck: FunctionResponse = {} as FunctionResponse;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`FunctionResponsePart` type is properly exposed to end user', function () {
+ const _typeCheck: FunctionResponsePart = {} as FunctionResponsePart;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`GenerateContentCandidate` type is properly exposed to end user', function () {
+ const _typeCheck: GenerateContentCandidate = {} as GenerateContentCandidate;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`GenerateContentRequest` type is properly exposed to end user', function () {
+ const _typeCheck: GenerateContentRequest = {} as GenerateContentRequest;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`GenerateContentResponse` type is properly exposed to end user', function () {
+ const _typeCheck: GenerateContentResponse = {} as GenerateContentResponse;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`GenerateContentResult` type is properly exposed to end user', function () {
+ const _typeCheck: GenerateContentResult = {} as GenerateContentResult;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`GenerateContentStreamResult` type is properly exposed to end user', function () {
+ const _typeCheck: GenerateContentStreamResult = {} as GenerateContentStreamResult;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`GenerationConfig` type is properly exposed to end user', function () {
+ const _typeCheck: GenerationConfig = {} as GenerationConfig;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`GenerativeContentBlob` type is properly exposed to end user', function () {
+ const _typeCheck: GenerativeContentBlob = {} as GenerativeContentBlob;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`GroundingAttribution` type is properly exposed to end user', function () {
+ const _typeCheck: GroundingAttribution = {} as GroundingAttribution;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`GroundingMetadata` type is properly exposed to end user', function () {
+ const _typeCheck: GroundingMetadata = {} as GroundingMetadata;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`InlineDataPart` type is properly exposed to end user', function () {
+ const _typeCheck: InlineDataPart = {} as InlineDataPart;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`ModalityTokenCount` type is properly exposed to end user', function () {
+ const _typeCheck: ModalityTokenCount = {} as ModalityTokenCount;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`ModelParams` type is properly exposed to end user', function () {
+ const _typeCheck: ModelParams = {} as ModelParams;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`ObjectSchemaInterface` type is properly exposed to end user', function () {
+ const _typeCheck: ObjectSchemaInterface = {} as ObjectSchemaInterface;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`PromptFeedback` type is properly exposed to end user', function () {
+ const _typeCheck: PromptFeedback = {} as PromptFeedback;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`RequestOptions` type is properly exposed to end user', function () {
+ const _typeCheck: RequestOptions = {} as RequestOptions;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`RetrievedContextAttribution` type is properly exposed to end user', function () {
+ const _typeCheck: RetrievedContextAttribution = {} as RetrievedContextAttribution;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`SafetyRating` type is properly exposed to end user', function () {
+ const _typeCheck: SafetyRating = {} as SafetyRating;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`SafetySetting` type is properly exposed to end user', function () {
+ const _typeCheck: SafetySetting = {} as SafetySetting;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`SchemaInterface` type is properly exposed to end user', function () {
+ const _typeCheck: SchemaInterface = {} as SchemaInterface;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`SchemaParams` type is properly exposed to end user', function () {
+ const _typeCheck: SchemaParams = {} as SchemaParams;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`SchemaRequest` type is properly exposed to end user', function () {
+ const _typeCheck: SchemaRequest = {} as SchemaRequest;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`SchemaShared` type is properly exposed to end user', function () {
+ const _typeCheck: SchemaShared = {} as SchemaShared;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`Segment` type is properly exposed to end user', function () {
+ const _typeCheck: Segment = {} as Segment;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`StartChatParams` type is properly exposed to end user', function () {
+ const _typeCheck: StartChatParams = {} as StartChatParams;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`TextPart` type is properly exposed to end user', function () {
+ const _typeCheck: TextPart = {} as TextPart;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`ToolConfig` type is properly exposed to end user', function () {
+ const _typeCheck: ToolConfig = {} as ToolConfig;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`UsageMetadata` type is properly exposed to end user', function () {
+ const _typeCheck: UsageMetadata = {} as UsageMetadata;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`VideoMetadata` type is properly exposed to end user', function () {
+ const _typeCheck: VideoMetadata = {} as VideoMetadata;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ it('`WebAttribution` type is properly exposed to end user', function () {
+ const _typeCheck: WebAttribution = {} as WebAttribution;
+ expect(typeof _typeCheck).toBeDefined();
+ });
+
+ // Enum exports - test as values since enums are runtime objects
+ it('`AIErrorCode` enum is properly exposed to end user', function () {
+ // Const enum - test by accessing a property
+ expect(AIErrorCode.NO_MODEL).toBeDefined();
+ });
+
+ it('`BlockReason` enum is properly exposed to end user', function () {
+ expect(typeof BlockReason).toBe('object');
+ expect(Object.keys(BlockReason).length).toBeGreaterThan(0);
+ });
+
+ it('`FinishReason` enum is properly exposed to end user', function () {
+ expect(typeof FinishReason).toBe('object');
+ expect(Object.keys(FinishReason).length).toBeGreaterThan(0);
+ });
+
+ it('`FunctionCallingMode` enum is properly exposed to end user', function () {
+ expect(typeof FunctionCallingMode).toBe('object');
+ expect(Object.keys(FunctionCallingMode).length).toBeGreaterThan(0);
+ });
+
+ it('`HarmBlockMethod` enum is properly exposed to end user', function () {
+ expect(typeof HarmBlockMethod).toBe('object');
+ expect(Object.keys(HarmBlockMethod).length).toBeGreaterThan(0);
+ });
+
+ it('`HarmBlockThreshold` enum is properly exposed to end user', function () {
+ expect(typeof HarmBlockThreshold).toBe('object');
+ expect(Object.keys(HarmBlockThreshold).length).toBeGreaterThan(0);
+ });
+
+ it('`HarmCategory` enum is properly exposed to end user', function () {
+ expect(typeof HarmCategory).toBe('object');
+ expect(Object.keys(HarmCategory).length).toBeGreaterThan(0);
+ });
+
+ it('`HarmProbability` enum is properly exposed to end user', function () {
+ expect(typeof HarmProbability).toBe('object');
+ expect(Object.keys(HarmProbability).length).toBeGreaterThan(0);
+ });
+
+ it('`HarmSeverity` enum is properly exposed to end user', function () {
+ expect(typeof HarmSeverity).toBe('object');
+ expect(Object.keys(HarmSeverity).length).toBeGreaterThan(0);
+ });
+
+ it('`Modality` enum is properly exposed to end user', function () {
+ expect(typeof Modality).toBe('object');
+ expect(Object.keys(Modality).length).toBeGreaterThan(0);
+ });
+
+ it('`SchemaType` enum is properly exposed to end user', function () {
+ // Const enum - test by accessing a property
+ expect(SchemaType.STRING).toBeDefined();
+ });
+ });
+});
diff --git a/packages/vertexai/__tests__/generate-content.test.ts b/packages/ai/__tests__/generate-content.test.ts
similarity index 55%
rename from packages/vertexai/__tests__/generate-content.test.ts
rename to packages/ai/__tests__/generate-content.test.ts
index 3bc733e370..48c2b19970 100644
--- a/packages/vertexai/__tests__/generate-content.test.ts
+++ b/packages/ai/__tests__/generate-content.test.ts
@@ -14,8 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-import { describe, expect, it, afterEach, jest } from '@jest/globals';
-import { getMockResponse } from './test-utils/mock-response';
+import { describe, expect, it, afterEach, jest, beforeEach } from '@jest/globals';
+import { BackendName, getMockResponse } from './test-utils/mock-response';
import * as request from '../lib/requests/request';
import { generateContent } from '../lib/methods/generate-content';
import {
@@ -27,11 +27,25 @@ import {
} from '../lib/types';
import { ApiSettings } from '../lib/types/internal';
import { Task } from '../lib/requests/request';
+import { GoogleAIBackend, VertexAIBackend } from '../lib/backend';
+import { SpiedFunction } from 'jest-mock';
+import { AIError } from '../lib/errors';
+import { mapGenerateContentRequest } from '../lib/googleai-mappers';
const fakeApiSettings: ApiSettings = {
apiKey: 'key',
project: 'my-project',
+ appId: 'my-appid',
location: 'us-central1',
+ backend: new VertexAIBackend(),
+};
+
+const fakeGoogleAIApiSettings: ApiSettings = {
+ apiKey: 'key',
+ project: 'my-project',
+ appId: 'my-appid',
+ location: 'us-central1',
+ backend: new GoogleAIBackend(),
};
const fakeRequestParams: GenerateContentRequest = {
@@ -48,13 +62,29 @@ const fakeRequestParams: GenerateContentRequest = {
],
};
+const fakeGoogleAIRequestParams: GenerateContentRequest = {
+ contents: [{ parts: [{ text: 'hello' }], role: 'user' }],
+ generationConfig: {
+ topK: 16,
+ },
+ safetySettings: [
+ {
+ category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
+ threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
+ },
+ ],
+};
+
describe('generateContent()', () => {
afterEach(() => {
jest.restoreAllMocks();
});
it('short response', async () => {
- const mockResponse = getMockResponse('unary-success-basic-reply-short.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-basic-reply-short.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -71,7 +101,10 @@ describe('generateContent()', () => {
});
it('long response', async () => {
- const mockResponse = getMockResponse('unary-success-basic-reply-long.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-basic-reply-long.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -88,8 +121,33 @@ describe('generateContent()', () => {
);
});
+ it('long response with token details', async () => {
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-basic-response-long-usage-metadata.json',
+ );
+ const makeRequestStub = jest
+ .spyOn(request, 'makeRequest')
+ .mockResolvedValue(mockResponse as Response);
+ const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams);
+ expect(result.response.usageMetadata?.totalTokenCount).toEqual(1913);
+ expect(result.response.usageMetadata?.candidatesTokenCount).toEqual(76);
+ expect(result.response.usageMetadata?.promptTokensDetails?.[0]?.modality).toEqual('IMAGE');
+ expect(result.response.usageMetadata?.promptTokensDetails?.[0]?.tokenCount).toEqual(1806);
+ expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.modality).toEqual('TEXT');
+ expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.tokenCount).toEqual(76);
+ expect(makeRequestStub).toHaveBeenCalledWith(
+ 'model',
+ Task.GENERATE_CONTENT,
+ fakeApiSettings,
+ false,
+ expect.anything(),
+ undefined,
+ );
+ });
+
it('citations', async () => {
- const mockResponse = getMockResponse('unary-success-citations.json');
+ const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-success-citations.json');
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -107,7 +165,10 @@ describe('generateContent()', () => {
});
it('blocked prompt', async () => {
- const mockResponse = getMockResponse('unary-failure-prompt-blocked-safety.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-failure-prompt-blocked-safety.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -126,7 +187,10 @@ describe('generateContent()', () => {
});
it('finishReason safety', async () => {
- const mockResponse = getMockResponse('unary-failure-finish-reason-safety.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-failure-finish-reason-safety.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -143,7 +207,7 @@ describe('generateContent()', () => {
});
it('empty content', async () => {
- const mockResponse = getMockResponse('unary-failure-empty-content.json');
+ const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-failure-empty-content.json');
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -160,7 +224,10 @@ describe('generateContent()', () => {
});
it('unknown enum - should ignore', async () => {
- const mockResponse = getMockResponse('unary-success-unknown-enum-safety-ratings.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-unknown-enum-safety-ratings.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -177,7 +244,7 @@ describe('generateContent()', () => {
});
it('image rejected (400)', async () => {
- const mockResponse = getMockResponse('unary-failure-image-rejected.json');
+ const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-failure-image-rejected.json');
const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({
ok: false,
status: 400,
@@ -190,7 +257,10 @@ describe('generateContent()', () => {
});
it('api not enabled (403)', async () => {
- const mockResponse = getMockResponse('unary-failure-firebasevertexai-api-not-enabled.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-failure-firebasevertexai-api-not-enabled.json',
+ );
const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({
ok: false,
status: 403,
@@ -201,4 +271,60 @@ describe('generateContent()', () => {
);
expect(mockFetch).toHaveBeenCalled();
});
+
+ describe('googleAI', () => {
+ let makeRequestStub: SpiedFunction;
+
+ beforeEach(() => {
+ makeRequestStub = jest.spyOn(request, 'makeRequest');
+ });
+
+ afterEach(() => {
+ jest.restoreAllMocks();
+ });
+
+ it('throws error when method is defined', async () => {
+ const mockResponse = getMockResponse(
+ BackendName.GoogleAI,
+ 'unary-success-basic-reply-short.txt',
+ );
+ makeRequestStub.mockResolvedValue(mockResponse as Response);
+
+ const requestParamsWithMethod: GenerateContentRequest = {
+ contents: [{ parts: [{ text: 'hello' }], role: 'user' }],
+ safetySettings: [
+ {
+ category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
+ threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
+ method: HarmBlockMethod.SEVERITY, // Unsupported in Google AI.
+ },
+ ],
+ };
+
+ // Expect generateContent to throw a AIError that method is not supported.
+ await expect(
+ generateContent(fakeGoogleAIApiSettings, 'model', requestParamsWithMethod),
+ ).rejects.toThrow(AIError);
+ expect(makeRequestStub).not.toHaveBeenCalled();
+ });
+
+ it('maps request to GoogleAI format', async () => {
+ const mockResponse = getMockResponse(
+ BackendName.GoogleAI,
+ 'unary-success-basic-reply-short.txt',
+ );
+ makeRequestStub.mockResolvedValue(mockResponse as Response);
+
+ await generateContent(fakeGoogleAIApiSettings, 'model', fakeGoogleAIRequestParams);
+
+ expect(makeRequestStub).toHaveBeenCalledWith(
+ 'model',
+ Task.GENERATE_CONTENT,
+ fakeGoogleAIApiSettings,
+ false,
+ JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)),
+ undefined,
+ );
+ });
+ });
});
diff --git a/packages/vertexai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts
similarity index 77%
rename from packages/vertexai/__tests__/generative-model.test.ts
rename to packages/ai/__tests__/generative-model.test.ts
index e62862b6aa..7d29c501c8 100644
--- a/packages/vertexai/__tests__/generative-model.test.ts
+++ b/packages/ai/__tests__/generative-model.test.ts
@@ -17,50 +17,28 @@
import { describe, expect, it, jest } from '@jest/globals';
import { type ReactNativeFirebase } from '@react-native-firebase/app';
import { GenerativeModel } from '../lib/models/generative-model';
-import { FunctionCallingMode, VertexAI } from '../lib/public-types';
+import { AI, FunctionCallingMode } from '../lib/public-types';
import * as request from '../lib/requests/request';
-import { getMockResponse } from './test-utils/mock-response';
+import { BackendName, getMockResponse } from './test-utils/mock-response';
+import { VertexAIBackend } from '../lib/backend';
-const fakeVertexAI: VertexAI = {
+const fakeAI: AI = {
app: {
name: 'DEFAULT',
+ automaticDataCollectionEnabled: true,
options: {
apiKey: 'key',
projectId: 'my-project',
+ appId: 'my-appid',
},
} as ReactNativeFirebase.FirebaseApp,
+ backend: new VertexAIBackend('us-central1'),
location: 'us-central1',
};
describe('GenerativeModel', () => {
- it('handles plain model name', () => {
- const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' });
- expect(genModel.model).toBe('publishers/google/models/my-model');
- });
-
- it('handles models/ prefixed model name', () => {
- const genModel = new GenerativeModel(fakeVertexAI, {
- model: 'models/my-model',
- });
- expect(genModel.model).toBe('publishers/google/models/my-model');
- });
-
- it('handles full model name', () => {
- const genModel = new GenerativeModel(fakeVertexAI, {
- model: 'publishers/google/models/my-model',
- });
- expect(genModel.model).toBe('publishers/google/models/my-model');
- });
-
- it('handles prefixed tuned model name', () => {
- const genModel = new GenerativeModel(fakeVertexAI, {
- model: 'tunedModels/my-model',
- });
- expect(genModel.model).toBe('tunedModels/my-model');
- });
-
it('passes params through to generateContent', async () => {
- const genModel = new GenerativeModel(fakeVertexAI, {
+ const genModel = new GenerativeModel(fakeAI, {
model: 'my-model',
tools: [
{
@@ -78,7 +56,10 @@ describe('GenerativeModel', () => {
expect(genModel.tools?.length).toBe(1);
expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE);
expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly');
- const mockResponse = getMockResponse('unary-success-basic-reply-short.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-basic-reply-short.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -95,12 +76,15 @@ describe('GenerativeModel', () => {
});
it('passes text-only systemInstruction through to generateContent', async () => {
- const genModel = new GenerativeModel(fakeVertexAI, {
+ const genModel = new GenerativeModel(fakeAI, {
model: 'my-model',
systemInstruction: 'be friendly',
});
expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly');
- const mockResponse = getMockResponse('unary-success-basic-reply-short.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-basic-reply-short.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -117,7 +101,7 @@ describe('GenerativeModel', () => {
});
it('generateContent overrides model values', async () => {
- const genModel = new GenerativeModel(fakeVertexAI, {
+ const genModel = new GenerativeModel(fakeAI, {
model: 'my-model',
tools: [
{
@@ -135,7 +119,10 @@ describe('GenerativeModel', () => {
expect(genModel.tools?.length).toBe(1);
expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE);
expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly');
- const mockResponse = getMockResponse('unary-success-basic-reply-short.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-basic-reply-short.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -160,8 +147,38 @@ describe('GenerativeModel', () => {
makeRequestStub.mockRestore();
});
+ it('passes base model params through to ChatSession when there are no startChatParams', async () => {
+ const genModel = new GenerativeModel(fakeAI, {
+ model: 'my-model',
+ generationConfig: {
+ topK: 1,
+ },
+ });
+ const chatSession = genModel.startChat();
+ expect(chatSession.params?.generationConfig).toEqual({
+ topK: 1,
+ });
+ });
+
+ it('overrides base model params with startChatParams', () => {
+ const genModel = new GenerativeModel(fakeAI, {
+ model: 'my-model',
+ generationConfig: {
+ topK: 1,
+ },
+ });
+ const chatSession = genModel.startChat({
+ generationConfig: {
+ topK: 2,
+ },
+ });
+ expect(chatSession.params?.generationConfig).toEqual({
+ topK: 2,
+ });
+ });
+
it('passes params through to chat.sendMessage', async () => {
- const genModel = new GenerativeModel(fakeVertexAI, {
+ const genModel = new GenerativeModel(fakeAI, {
model: 'my-model',
tools: [{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
@@ -170,7 +187,10 @@ describe('GenerativeModel', () => {
expect(genModel.tools?.length).toBe(1);
expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE);
expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly');
- const mockResponse = getMockResponse('unary-success-basic-reply-short.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-basic-reply-short.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -187,12 +207,15 @@ describe('GenerativeModel', () => {
});
it('passes text-only systemInstruction through to chat.sendMessage', async () => {
- const genModel = new GenerativeModel(fakeVertexAI, {
+ const genModel = new GenerativeModel(fakeAI, {
model: 'my-model',
systemInstruction: 'be friendly',
});
expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly');
- const mockResponse = getMockResponse('unary-success-basic-reply-short.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-basic-reply-short.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -209,7 +232,7 @@ describe('GenerativeModel', () => {
});
it('startChat overrides model values', async () => {
- const genModel = new GenerativeModel(fakeVertexAI, {
+ const genModel = new GenerativeModel(fakeAI, {
model: 'my-model',
tools: [{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
@@ -218,7 +241,10 @@ describe('GenerativeModel', () => {
expect(genModel.tools?.length).toBe(1);
expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE);
expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly');
- const mockResponse = getMockResponse('unary-success-basic-reply-short.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-success-basic-reply-short.json',
+ );
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
@@ -247,8 +273,8 @@ describe('GenerativeModel', () => {
});
it('calls countTokens', async () => {
- const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' });
- const mockResponse = getMockResponse('unary-success-total-tokens.json');
+ const genModel = new GenerativeModel(fakeAI, { model: 'my-model' });
+ const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-success-total-tokens.json');
const makeRequestStub = jest
.spyOn(request, 'makeRequest')
.mockResolvedValue(mockResponse as Response);
diff --git a/packages/ai/__tests__/googleai-mappers.test.ts b/packages/ai/__tests__/googleai-mappers.test.ts
new file mode 100644
index 0000000000..00ecd39635
--- /dev/null
+++ b/packages/ai/__tests__/googleai-mappers.test.ts
@@ -0,0 +1,365 @@
+/**
+ * @license
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import { describe, it, expect, beforeEach, afterEach, jest } from '@jest/globals';
+import {
+ mapCountTokensRequest,
+ mapGenerateContentCandidates,
+ mapGenerateContentRequest,
+ mapGenerateContentResponse,
+ mapPromptFeedback,
+} from '../lib/googleai-mappers';
+import {
+ AIErrorCode,
+ BlockReason,
+ CountTokensRequest,
+ Content,
+ FinishReason,
+ GenerateContentRequest,
+ GoogleAICountTokensRequest,
+ GoogleAIGenerateContentCandidate,
+ GoogleAIGenerateContentResponse,
+ HarmBlockMethod,
+ HarmBlockThreshold,
+ HarmCategory,
+ HarmProbability,
+ HarmSeverity,
+ PromptFeedback,
+ SafetyRating,
+} from '../lib/public-types';
+import { BackendName, getMockResponse } from './test-utils/mock-response';
+import { SpiedFunction } from 'jest-mock';
+
+const fakeModel = 'models/gemini-pro';
+
+const fakeContents: Content[] = [{ role: 'user', parts: [{ text: 'hello' }] }];
+
+describe('Google AI Mappers', () => {
+ let loggerWarnSpy: SpiedFunction<{
+ (message?: any, ...optionalParams: any[]): void;
+ (message?: any, ...optionalParams: any[]): void;
+ }>;
+
+ beforeEach(() => {
+ loggerWarnSpy = jest.spyOn(console, 'warn').mockImplementation(() => {});
+ });
+
+ afterEach(() => {
+ jest.restoreAllMocks();
+ });
+
+ describe('mapGenerateContentRequest', () => {
+ it('should throw if safetySettings contain method', () => {
+ const request: GenerateContentRequest = {
+ contents: fakeContents,
+ safetySettings: [
+ {
+ category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
+ threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
+ method: HarmBlockMethod.SEVERITY,
+ },
+ ],
+ };
+
+ expect(() => mapGenerateContentRequest(request)).toThrowError(
+ expect.objectContaining({
+ code: AIErrorCode.UNSUPPORTED,
+ message:
+ 'AI: SafetySetting.method is not supported in the the Gemini Developer API. Please remove this property. (AI/unsupported)',
+ }),
+ );
+ });
+
+ it('should warn and round topK if present', () => {
+ const request: GenerateContentRequest = {
+ contents: fakeContents,
+ generationConfig: {
+ topK: 15.7,
+ },
+ };
+ const mappedRequest = mapGenerateContentRequest(request);
+ expect(loggerWarnSpy).toHaveBeenCalledWith(
+ expect.any(String), // First argument (timestamp)
+ expect.stringContaining('topK in GenerationConfig has been rounded to the nearest integer'),
+ );
+ expect(mappedRequest.generationConfig?.topK).toBe(16);
+ });
+
+ it('should not modify topK if it is already an integer', () => {
+ const request: GenerateContentRequest = {
+ contents: fakeContents,
+ generationConfig: {
+ topK: 16,
+ },
+ };
+ const mappedRequest = mapGenerateContentRequest(request);
+ expect(loggerWarnSpy).not.toHaveBeenCalled();
+ expect(mappedRequest.generationConfig?.topK).toBe(16);
+ });
+
+ it('should return the request mostly unchanged if valid', () => {
+ const request: GenerateContentRequest = {
+ contents: fakeContents,
+ safetySettings: [
+ {
+ category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
+ threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
+ },
+ ],
+ generationConfig: {
+ temperature: 0.5,
+ },
+ };
+ const mappedRequest = mapGenerateContentRequest({ ...request });
+ expect(mappedRequest).toEqual(request);
+ expect(loggerWarnSpy).not.toHaveBeenCalled();
+ });
+ });
+
+ describe('mapGenerateContentResponse', () => {
+ it('should map a full Google AI response', async () => {
+ const googleAIMockResponse: GoogleAIGenerateContentResponse = await (
+ getMockResponse(BackendName.GoogleAI, 'unary-success-citations.json') as Response
+ ).json();
+ const mappedResponse = mapGenerateContentResponse(googleAIMockResponse);
+
+ expect(mappedResponse.candidates).toBeDefined();
+ expect(mappedResponse.candidates?.[0]?.content.parts[0]?.text).toContain('quantum mechanics');
+
+ // Mapped citations
+ expect(mappedResponse.candidates?.[0]?.citationMetadata?.citations[0]?.startIndex).toBe(
+ googleAIMockResponse.candidates?.[0]?.citationMetadata?.citationSources[0]?.startIndex,
+ );
+ expect(mappedResponse.candidates?.[0]?.citationMetadata?.citations[0]?.endIndex).toBe(
+ googleAIMockResponse.candidates?.[0]?.citationMetadata?.citationSources[0]?.endIndex,
+ );
+
+ // Mapped safety ratings
+ expect(mappedResponse.candidates?.[0]?.safetyRatings?.[0]?.probabilityScore).toBe(0);
+ expect(mappedResponse.candidates?.[0]?.safetyRatings?.[0]?.severityScore).toBe(0);
+ expect(mappedResponse.candidates?.[0]?.safetyRatings?.[0]?.severity).toBe(
+ HarmSeverity.HARM_SEVERITY_UNSUPPORTED,
+ );
+
+ expect(mappedResponse.candidates?.[0]?.finishReason).toBe(FinishReason.STOP);
+
+ // Check usage metadata passthrough
+ expect(mappedResponse.usageMetadata).toEqual(googleAIMockResponse.usageMetadata);
+ });
+
+ it('should handle missing candidates and promptFeedback', () => {
+ const googleAIResponse: GoogleAIGenerateContentResponse = {
+ // No candidates
+ // No promptFeedback
+ usageMetadata: {
+ promptTokenCount: 5,
+ candidatesTokenCount: 0,
+ totalTokenCount: 5,
+ },
+ };
+ const mappedResponse = mapGenerateContentResponse(googleAIResponse);
+ expect(mappedResponse.candidates).toBeUndefined();
+ expect(mappedResponse.promptFeedback).toBeUndefined(); // Mapped to undefined
+ expect(mappedResponse.usageMetadata).toEqual(googleAIResponse.usageMetadata);
+ });
+
+ it('should handle empty candidates array', () => {
+ const googleAIResponse: GoogleAIGenerateContentResponse = {
+ candidates: [],
+ usageMetadata: {
+ promptTokenCount: 5,
+ candidatesTokenCount: 0,
+ totalTokenCount: 5,
+ },
+ };
+ const mappedResponse = mapGenerateContentResponse(googleAIResponse);
+ expect(mappedResponse.candidates).toEqual([]);
+ expect(mappedResponse.promptFeedback).toBeUndefined();
+ expect(mappedResponse.usageMetadata).toEqual(googleAIResponse.usageMetadata);
+ });
+ });
+
+ describe('mapCountTokensRequest', () => {
+ it('should map a Vertex AI CountTokensRequest to Google AI format', () => {
+ const vertexRequest: CountTokensRequest = {
+ contents: fakeContents,
+ systemInstruction: { role: 'system', parts: [{ text: 'Be nice' }] },
+ tools: [{ functionDeclarations: [{ name: 'foo', description: 'bar' }] }],
+ generationConfig: { temperature: 0.8 },
+ };
+
+ const expectedGoogleAIRequest: GoogleAICountTokensRequest = {
+ generateContentRequest: {
+ model: fakeModel,
+ contents: vertexRequest.contents,
+ systemInstruction: vertexRequest.systemInstruction,
+ tools: vertexRequest.tools,
+ generationConfig: vertexRequest.generationConfig,
+ },
+ };
+
+ const mappedRequest = mapCountTokensRequest(vertexRequest, fakeModel);
+ expect(mappedRequest).toEqual(expectedGoogleAIRequest);
+ });
+
+ it('should map a minimal Vertex AI CountTokensRequest', () => {
+ const vertexRequest: CountTokensRequest = {
+ contents: fakeContents,
+ systemInstruction: { role: 'system', parts: [{ text: 'Be nice' }] },
+ generationConfig: { temperature: 0.8 },
+ };
+
+ const expectedGoogleAIRequest: GoogleAICountTokensRequest = {
+ generateContentRequest: {
+ model: fakeModel,
+ contents: vertexRequest.contents,
+ systemInstruction: { role: 'system', parts: [{ text: 'Be nice' }] },
+ generationConfig: { temperature: 0.8 },
+ },
+ };
+
+ const mappedRequest = mapCountTokensRequest(vertexRequest, fakeModel);
+ expect(mappedRequest).toEqual(expectedGoogleAIRequest);
+ });
+ });
+
+ describe('mapGenerateContentCandidates', () => {
+ it('should map citationSources to citationMetadata.citations', () => {
+ const candidates: GoogleAIGenerateContentCandidate[] = [
+ {
+ index: 0,
+ content: { role: 'model', parts: [{ text: 'Cited text' }] },
+ citationMetadata: {
+ citationSources: [
+ { startIndex: 0, endIndex: 5, uri: 'uri1', license: 'MIT' },
+ { startIndex: 6, endIndex: 10, uri: 'uri2' },
+ ],
+ },
+ },
+ ];
+ const mapped = mapGenerateContentCandidates(candidates);
+ expect(mapped[0]?.citationMetadata).toBeDefined();
+ expect(mapped[0]?.citationMetadata?.citations).toEqual(
+ candidates[0]?.citationMetadata?.citationSources,
+ );
+ expect(mapped[0]?.citationMetadata?.citations[0]?.title).toBeUndefined(); // Not in Google AI
+ expect(mapped[0]?.citationMetadata?.citations[0]?.publicationDate).toBeUndefined(); // Not in Google AI
+ });
+
+ it('should add default safety rating properties', () => {
+ const candidates: GoogleAIGenerateContentCandidate[] = [
+ {
+ index: 0,
+ content: { role: 'model', parts: [{ text: 'Maybe unsafe' }] },
+ safetyRatings: [
+ {
+ category: HarmCategory.HARM_CATEGORY_HARASSMENT,
+ probability: HarmProbability.MEDIUM,
+ blocked: false,
+ // Missing severity, probabilityScore, severityScore
+ } as any,
+ ],
+ },
+ ];
+ const mapped = mapGenerateContentCandidates(candidates);
+ expect(mapped[0]?.safetyRatings).toBeDefined();
+ const safetyRating = mapped[0]?.safetyRatings?.[0] as SafetyRating; // Type assertion
+ expect(safetyRating.severity).toBe(HarmSeverity.HARM_SEVERITY_UNSUPPORTED);
+ expect(safetyRating.probabilityScore).toBe(0);
+ expect(safetyRating.severityScore).toBe(0);
+ // Existing properties should be preserved
+ expect(safetyRating.category).toBe(HarmCategory.HARM_CATEGORY_HARASSMENT);
+ expect(safetyRating.probability).toBe(HarmProbability.MEDIUM);
+ expect(safetyRating.blocked).toBe(false);
+ });
+
+ it('should throw if videoMetadata is present in parts', () => {
+ const candidates: GoogleAIGenerateContentCandidate[] = [
+ {
+ index: 0,
+ content: {
+ role: 'model',
+ parts: [
+ {
+ inlineData: { mimeType: 'video/mp4', data: 'base64==' },
+ videoMetadata: { startOffset: '0s', endOffset: '5s' }, // Unsupported
+ },
+ ],
+ },
+ },
+ ];
+ expect(() => mapGenerateContentCandidates(candidates)).toThrowError(
+ expect.objectContaining({
+ code: AIErrorCode.UNSUPPORTED,
+ message:
+ 'AI: Part.videoMetadata is not supported in the Gemini Developer API. Please remove this property. (AI/unsupported)',
+ }),
+ );
+ });
+
+ it('should handle candidates without citation or safety ratings', () => {
+ const candidates: GoogleAIGenerateContentCandidate[] = [
+ {
+ index: 0,
+ content: { role: 'model', parts: [{ text: 'Simple text' }] },
+ finishReason: FinishReason.STOP,
+ },
+ ];
+ const mapped = mapGenerateContentCandidates(candidates);
+ expect(mapped[0]?.citationMetadata).toBeUndefined();
+ expect(mapped[0]?.safetyRatings).toBeUndefined();
+ expect(mapped[0]?.content?.parts[0]?.text).toBe('Simple text');
+ expect(loggerWarnSpy).not.toHaveBeenCalled();
+ });
+
+ it('should handle empty candidate array', () => {
+ const candidates: GoogleAIGenerateContentCandidate[] = [];
+ const mapped = mapGenerateContentCandidates(candidates);
+ expect(mapped).toEqual([]);
+ expect(loggerWarnSpy).not.toHaveBeenCalled();
+ });
+ });
+
+ describe('mapPromptFeedback', () => {
+ it('should add default safety rating properties', () => {
+ const feedback: PromptFeedback = {
+ blockReason: BlockReason.OTHER,
+ safetyRatings: [
+ {
+ category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
+ probability: HarmProbability.HIGH,
+ blocked: true,
+ // Missing severity, probabilityScore, severityScore
+ } as any,
+ ],
+ // Missing blockReasonMessage
+ };
+ const mapped = mapPromptFeedback(feedback);
+ expect(mapped.safetyRatings).toBeDefined();
+ const safetyRating = mapped.safetyRatings[0] as SafetyRating; // Type assertion
+ expect(safetyRating.severity).toBe(HarmSeverity.HARM_SEVERITY_UNSUPPORTED);
+ expect(safetyRating.probabilityScore).toBe(0);
+ expect(safetyRating.severityScore).toBe(0);
+ // Existing properties should be preserved
+ expect(safetyRating.category).toBe(HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT);
+ expect(safetyRating.probability).toBe(HarmProbability.HIGH);
+ expect(safetyRating.blocked).toBe(true);
+ // Other properties
+ expect(mapped.blockReason).toBe(BlockReason.OTHER);
+ expect(mapped.blockReasonMessage).toBeUndefined(); // Not present in input
+ });
+ });
+});
diff --git a/packages/vertexai/__tests__/request-helpers.test.ts b/packages/ai/__tests__/request-helpers.test.ts
similarity index 100%
rename from packages/vertexai/__tests__/request-helpers.test.ts
rename to packages/ai/__tests__/request-helpers.test.ts
diff --git a/packages/vertexai/__tests__/request.test.ts b/packages/ai/__tests__/request.test.ts
similarity index 81%
rename from packages/vertexai/__tests__/request.test.ts
rename to packages/ai/__tests__/request.test.ts
index c992b062e9..3e5e58e415 100644
--- a/packages/vertexai/__tests__/request.test.ts
+++ b/packages/ai/__tests__/request.test.ts
@@ -18,14 +18,17 @@ import { describe, expect, it, jest, afterEach } from '@jest/globals';
import { RequestUrl, Task, getHeaders, makeRequest } from '../lib/requests/request';
import { ApiSettings } from '../lib/types/internal';
import { DEFAULT_API_VERSION } from '../lib/constants';
-import { VertexAIErrorCode } from '../lib/types';
-import { VertexAIError } from '../lib/errors';
-import { getMockResponse } from './test-utils/mock-response';
+import { AIErrorCode } from '../lib/types';
+import { AIError } from '../lib/errors';
+import { BackendName, getMockResponse } from './test-utils/mock-response';
+import { VertexAIBackend } from '../lib/backend';
const fakeApiSettings: ApiSettings = {
apiKey: 'key',
project: 'my-project',
+ appId: 'my-appid',
location: 'us-central1',
+ backend: new VertexAIBackend(),
};
describe('request methods', () => {
@@ -106,7 +109,9 @@ describe('request methods', () => {
const fakeApiSettings: ApiSettings = {
apiKey: 'key',
project: 'myproject',
+ appId: 'my-appid',
location: 'moon',
+ backend: new VertexAIBackend(),
getAuthToken: () => Promise.resolve('authtoken'),
getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }),
};
@@ -140,7 +145,9 @@ describe('request methods', () => {
{
apiKey: 'key',
project: 'myproject',
+ appId: 'my-appid',
location: 'moon',
+ backend: new VertexAIBackend(),
},
true,
{},
@@ -176,6 +183,8 @@ describe('request methods', () => {
project: 'myproject',
location: 'moon',
getAppCheckToken: () => Promise.reject(new Error('oops')),
+ backend: new VertexAIBackend(),
+ appId: 'my-appid',
},
true,
{},
@@ -187,7 +196,7 @@ describe('request methods', () => {
// See: https://github.com/firebase/firebase-js-sdk/blob/main/packages/vertexai/src/requests/request.test.ts#L172
// expect(headers.get('X-Firebase-AppCheck')).toBe('dummytoken');
expect(warnSpy).toHaveBeenCalledWith(
- expect.stringMatching(/vertexai/),
+ expect.stringMatching(/firebase\/ai/),
expect.stringMatching(/App Check.*oops/),
);
});
@@ -204,7 +213,9 @@ describe('request methods', () => {
{
apiKey: 'key',
project: 'myproject',
+ appId: 'my-appid',
location: 'moon',
+ backend: new VertexAIBackend(),
},
true,
{},
@@ -260,10 +271,10 @@ describe('request methods', () => {
timeout: 180000,
});
} catch (e) {
- expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR);
- expect((e as VertexAIError).customErrorData?.status).toBe(500);
- expect((e as VertexAIError).customErrorData?.statusText).toBe('AbortError');
- expect((e as VertexAIError).message).toContain('500 AbortError');
+ expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR);
+ expect((e as AIError).customErrorData?.status).toBe(500);
+ expect((e as AIError).customErrorData?.statusText).toBe('AbortError');
+ expect((e as AIError).message).toContain('500 AbortError');
}
expect(fetchMock).toHaveBeenCalledTimes(1);
@@ -278,10 +289,10 @@ describe('request methods', () => {
try {
await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, '');
} catch (e) {
- expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR);
- expect((e as VertexAIError).customErrorData?.status).toBe(500);
- expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error');
- expect((e as VertexAIError).message).toContain('500 Server Error');
+ expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR);
+ expect((e as AIError).customErrorData?.status).toBe(500);
+ expect((e as AIError).customErrorData?.statusText).toBe('Server Error');
+ expect((e as AIError).message).toContain('500 Server Error');
}
expect(fetchMock).toHaveBeenCalledTimes(1);
});
@@ -296,11 +307,11 @@ describe('request methods', () => {
try {
await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, '');
} catch (e) {
- expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR);
- expect((e as VertexAIError).customErrorData?.status).toBe(500);
- expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error');
- expect((e as VertexAIError).message).toContain('500 Server Error');
- expect((e as VertexAIError).message).toContain('extra info');
+ expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR);
+ expect((e as AIError).customErrorData?.status).toBe(500);
+ expect((e as AIError).customErrorData?.statusText).toBe('Server Error');
+ expect((e as AIError).message).toContain('500 Server Error');
+ expect((e as AIError).message).toContain('extra info');
}
expect(fetchMock).toHaveBeenCalledTimes(1);
});
@@ -327,26 +338,29 @@ describe('request methods', () => {
try {
await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, '');
} catch (e) {
- expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR);
- expect((e as VertexAIError).customErrorData?.status).toBe(500);
- expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error');
- expect((e as VertexAIError).message).toContain('500 Server Error');
- expect((e as VertexAIError).message).toContain('extra info');
- expect((e as VertexAIError).message).toContain('generic::invalid_argument');
+ expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR);
+ expect((e as AIError).customErrorData?.status).toBe(500);
+ expect((e as AIError).customErrorData?.statusText).toBe('Server Error');
+ expect((e as AIError).message).toContain('500 Server Error');
+ expect((e as AIError).message).toContain('extra info');
+ expect((e as AIError).message).toContain('generic::invalid_argument');
}
expect(fetchMock).toHaveBeenCalledTimes(1);
});
});
it('Network error, API not enabled', async () => {
- const mockResponse = getMockResponse('unary-failure-firebasevertexai-api-not-enabled.json');
+ const mockResponse = getMockResponse(
+ BackendName.VertexAI,
+ 'unary-failure-firebasevertexai-api-not-enabled.json',
+ );
const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue(mockResponse as Response);
try {
await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, '');
} catch (e) {
- expect((e as VertexAIError).code).toBe(VertexAIErrorCode.API_NOT_ENABLED);
- expect((e as VertexAIError).message).toContain('my-project');
- expect((e as VertexAIError).message).toContain('googleapis.com');
+ expect((e as AIError).code).toBe(AIErrorCode.API_NOT_ENABLED);
+ expect((e as AIError).message).toContain('my-project');
+ expect((e as AIError).message).toContain('googleapis.com');
}
expect(fetchMock).toHaveBeenCalledTimes(1);
});
diff --git a/packages/vertexai/__tests__/response-helpers.test.ts b/packages/ai/__tests__/response-helpers.test.ts
similarity index 100%
rename from packages/vertexai/__tests__/response-helpers.test.ts
rename to packages/ai/__tests__/response-helpers.test.ts
diff --git a/packages/vertexai/__tests__/schema-builder.test.ts b/packages/ai/__tests__/schema-builder.test.ts
similarity index 98%
rename from packages/vertexai/__tests__/schema-builder.test.ts
rename to packages/ai/__tests__/schema-builder.test.ts
index bec1f6a8d2..738bd17a21 100644
--- a/packages/vertexai/__tests__/schema-builder.test.ts
+++ b/packages/ai/__tests__/schema-builder.test.ts
@@ -16,7 +16,7 @@
*/
import { describe, expect, it } from '@jest/globals';
import { Schema } from '../lib/requests/schema-builder';
-import { VertexAIErrorCode } from '../lib/types';
+import { AIErrorCode } from '../lib/types';
describe('Schema builder', () => {
it('builds integer schema', () => {
@@ -252,7 +252,7 @@ describe('Schema builder', () => {
},
optionalProperties: ['cat'],
});
- expect(() => schema.toJSON()).toThrow(VertexAIErrorCode.INVALID_SCHEMA);
+ expect(() => schema.toJSON()).toThrow(AIErrorCode.INVALID_SCHEMA);
});
});
diff --git a/packages/vertexai/__tests__/service.test.ts b/packages/ai/__tests__/service.test.ts
similarity index 81%
rename from packages/vertexai/__tests__/service.test.ts
rename to packages/ai/__tests__/service.test.ts
index 9f9503f2c9..1de537df17 100644
--- a/packages/vertexai/__tests__/service.test.ts
+++ b/packages/ai/__tests__/service.test.ts
@@ -17,7 +17,8 @@
import { describe, expect, it } from '@jest/globals';
import { type ReactNativeFirebase } from '@react-native-firebase/app';
import { DEFAULT_LOCATION } from '../lib/constants';
-import { VertexAIService } from '../lib/service';
+import { AIService } from '../lib/service';
+import { VertexAIBackend } from '../lib/backend';
const fakeApp = {
name: 'DEFAULT',
@@ -27,18 +28,17 @@ const fakeApp = {
},
} as ReactNativeFirebase.FirebaseApp;
-describe('VertexAIService', () => {
+describe('AIService', () => {
it('uses default location if not specified', () => {
- const vertexAI = new VertexAIService(fakeApp);
+ const vertexAI = new AIService(fakeApp, new VertexAIBackend());
expect(vertexAI.location).toBe(DEFAULT_LOCATION);
});
it('uses custom location if specified', () => {
- const vertexAI = new VertexAIService(
+ const vertexAI = new AIService(
fakeApp,
- /* authProvider */ undefined,
+ new VertexAIBackend('somewhere'),
/* appCheckProvider */ undefined,
- { location: 'somewhere' },
);
expect(vertexAI.location).toBe('somewhere');
});
diff --git a/packages/vertexai/__tests__/stream-reader.test.ts b/packages/ai/__tests__/stream-reader.test.ts
similarity index 88%
rename from packages/vertexai/__tests__/stream-reader.test.ts
rename to packages/ai/__tests__/stream-reader.test.ts
index 4a5ae8aef5..345e32cc6e 100644
--- a/packages/vertexai/__tests__/stream-reader.test.ts
+++ b/packages/ai/__tests__/stream-reader.test.ts
@@ -22,7 +22,11 @@ import {
processStream,
} from '../lib/requests/stream-reader';
-import { getChunkedStream, getMockResponseStreaming } from './test-utils/mock-response';
+import {
+ BackendName,
+ getChunkedStream,
+ getMockResponseStreaming,
+} from './test-utils/mock-response';
import {
BlockReason,
FinishReason,
@@ -30,9 +34,19 @@ import {
HarmCategory,
HarmProbability,
SafetyRating,
- VertexAIErrorCode,
+ AIErrorCode,
} from '../lib/types';
-import { VertexAIError } from '../lib/errors';
+import { AIError } from '../lib/errors';
+import { ApiSettings } from '../lib/types/internal';
+import { VertexAIBackend } from '../lib/backend';
+
+const fakeApiSettings: ApiSettings = {
+ apiKey: 'key',
+ project: 'my-project',
+ appId: 'my-appid',
+ location: 'us-central1',
+ backend: new VertexAIBackend(),
+};
describe('stream-reader', () => {
describe('getResponseStream', () => {
@@ -85,8 +99,11 @@ describe('stream-reader', () => {
});
it('streaming response - short', async () => {
- const fakeResponse = getMockResponseStreaming('streaming-success-basic-reply-short.txt');
- const result = processStream(fakeResponse as Response);
+ const fakeResponse = getMockResponseStreaming(
+ BackendName.VertexAI,
+ 'streaming-success-basic-reply-short.txt',
+ );
+ const result = processStream(fakeResponse as Response, fakeApiSettings);
for await (const response of result.stream) {
expect(response.text()).not.toBe('');
}
@@ -95,8 +112,11 @@ describe('stream-reader', () => {
});
it('streaming response - functioncall', async () => {
- const fakeResponse = getMockResponseStreaming('streaming-success-function-call-short.txt');
- const result = processStream(fakeResponse as Response);
+ const fakeResponse = getMockResponseStreaming(
+ BackendName.VertexAI,
+ 'streaming-success-function-call-short.txt',
+ );
+ const result = processStream(fakeResponse as Response, fakeApiSettings);
for await (const response of result.stream) {
expect(response.text()).toBe('');
expect(response.functionCalls()).toEqual([
@@ -117,8 +137,11 @@ describe('stream-reader', () => {
});
it('handles citations', async () => {
- const fakeResponse = getMockResponseStreaming('streaming-success-citations.txt');
- const result = processStream(fakeResponse as Response);
+ const fakeResponse = getMockResponseStreaming(
+ BackendName.VertexAI,
+ 'streaming-success-citations.txt',
+ );
+ const result = processStream(fakeResponse as Response, fakeApiSettings);
const aggregatedResponse = await result.response;
expect(aggregatedResponse.text()).toContain('Quantum mechanics is');
expect(aggregatedResponse.candidates?.[0]!.citationMetadata?.citations.length).toBe(3);
@@ -133,8 +156,11 @@ describe('stream-reader', () => {
});
it('removes empty text parts', async () => {
- const fakeResponse = getMockResponseStreaming('streaming-success-empty-text-part.txt');
- const result = processStream(fakeResponse as Response);
+ const fakeResponse = getMockResponseStreaming(
+ BackendName.VertexAI,
+ 'streaming-success-empty-text-part.txt',
+ );
+ const result = processStream(fakeResponse as Response, fakeApiSettings);
const aggregatedResponse = await result.response;
expect(aggregatedResponse.text()).toBe('1');
expect(aggregatedResponse.candidates?.length).toBe(1);
@@ -358,8 +384,8 @@ describe('stream-reader', () => {
try {
aggregateResponses(responsesToAggregate);
} catch (e) {
- expect((e as VertexAIError).code).toBe(VertexAIErrorCode.INVALID_CONTENT);
- expect((e as VertexAIError).message).toContain(
+ expect((e as AIError).code).toBe(AIErrorCode.INVALID_CONTENT);
+ expect((e as AIError).message).toContain(
'Part should have at least one property, but there are none. This is likely caused ' +
'by a malformed response from the backend.',
);
diff --git a/packages/vertexai/__tests__/test-utils/convert-mocks.ts b/packages/ai/__tests__/test-utils/convert-mocks.ts
similarity index 69%
rename from packages/vertexai/__tests__/test-utils/convert-mocks.ts
rename to packages/ai/__tests__/test-utils/convert-mocks.ts
index 97a5ed75df..87e18a478e 100644
--- a/packages/vertexai/__tests__/test-utils/convert-mocks.ts
+++ b/packages/ai/__tests__/test-utils/convert-mocks.ts
@@ -20,7 +20,7 @@ const fs = require('fs');
// eslint-disable-next-line @typescript-eslint/no-require-imports
const { join } = require('path');
-function findMockResponseDir(): string {
+function findMockResponseDir(backend: string): string {
const directories = fs
.readdirSync(__dirname, { withFileTypes: true })
.filter(
@@ -36,22 +36,27 @@ function findMockResponseDir(): string {
throw new Error('Multiple directories starting with "vertexai-sdk-test-data*" found');
}
- return join(__dirname, directories[0], 'mock-responses', 'vertexai');
+ return join(__dirname, directories[0], 'mock-responses', backend);
}
async function main(): Promise {
- const mockResponseDir = findMockResponseDir();
- const list = fs.readdirSync(mockResponseDir);
- const lookup: Record = {};
- // eslint-disable-next-line guard-for-in
- for (const fileName of list) {
- console.log(`attempting to read ${mockResponseDir}/${fileName}`)
- const fullText = fs.readFileSync(join(mockResponseDir, fileName), 'utf-8');
- lookup[fileName] = fullText;
+ const backendNames = ['googleai', 'vertexai'];
+ const lookup: Record> = {};
+
+ for (const backend of backendNames) {
+ const mockResponseDir = findMockResponseDir(backend);
+ const list = fs.readdirSync(mockResponseDir);
+ lookup[backend] = {};
+ const backendLookup = lookup[backend];
+ for (const fileName of list) {
+ const fullText = fs.readFileSync(join(mockResponseDir, fileName), 'utf-8');
+ backendLookup[fileName] = fullText;
+ }
}
- let fileText = `// Generated from mocks text files.`;
+ let fileText = `// Generated from mocks text files. Do not edit.`;
fileText += '\n\n';
+ fileText += `// @ts-nocheck\n`;
fileText += `export const mocksLookup: Record = ${JSON.stringify(
lookup,
null,
diff --git a/packages/vertexai/__tests__/test-utils/mock-response.ts b/packages/ai/__tests__/test-utils/mock-response.ts
similarity index 71%
rename from packages/vertexai/__tests__/test-utils/mock-response.ts
rename to packages/ai/__tests__/test-utils/mock-response.ts
index 52eb0eb04e..6e021428d5 100644
--- a/packages/vertexai/__tests__/test-utils/mock-response.ts
+++ b/packages/ai/__tests__/test-utils/mock-response.ts
@@ -17,6 +17,11 @@
import { ReadableStream } from 'web-streams-polyfill';
import { mocksLookup } from './mocks-lookup';
+export enum BackendName {
+ VertexAI = 'vertexai',
+ GoogleAI = 'googleai',
+}
+
/**
* Mock native Response.body
* Streams contents of json file in 20 character chunks
@@ -40,13 +45,18 @@ export function getChunkedStream(input: string, chunkLength = 20): ReadableStrea
return stream;
}
export function getMockResponseStreaming(
+ backendName: BackendName,
filename: string,
chunkLength: number = 20,
): Partial {
- const fullText = mocksLookup[filename];
+ // @ts-ignore
+ const backendMocksLookup: Record = mocksLookup[backendName];
+ if (!backendMocksLookup[filename]) {
+ throw Error(`${backendName} mock response file '${filename}' not found.`);
+ }
+ const fullText = backendMocksLookup[filename] as string;
return {
-
// Really tangled typescript error here from our transitive dependencies.
// Ignoring it now, but uncomment and run `yarn lerna:prepare` in top-level
// of the repo to see if you get it or if it has gone away.
@@ -60,10 +70,16 @@ export function getMockResponseStreaming(
};
}
-export function getMockResponse(filename: string): Partial {
- const fullText = mocksLookup[filename];
+export function getMockResponse(backendName: BackendName, filename: string): Partial {
+ // @ts-ignore
+ const backendMocksLookup: Record = mocksLookup[backendName];
+ if (!backendMocksLookup[filename]) {
+ throw Error(`${backendName} mock response file '${filename}' not found.`);
+ }
+ const fullText = backendMocksLookup[filename] as string;
+
return {
ok: true,
- json: () => Promise.resolve(JSON.parse(fullText!)),
+ json: () => Promise.resolve(JSON.parse(fullText)),
};
}
diff --git a/packages/vertexai/e2e/fetch.e2e.js b/packages/ai/e2e/fetch.e2e.js
similarity index 100%
rename from packages/vertexai/e2e/fetch.e2e.js
rename to packages/ai/e2e/fetch.e2e.js
diff --git a/packages/ai/lib/backend.ts b/packages/ai/lib/backend.ts
new file mode 100644
index 0000000000..7209828122
--- /dev/null
+++ b/packages/ai/lib/backend.ts
@@ -0,0 +1,92 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import { DEFAULT_LOCATION } from './constants';
+import { BackendType } from './public-types';
+
+/**
+ * Abstract base class representing the configuration for an AI service backend.
+ * This class should not be instantiated directly. Use its subclasses; {@link GoogleAIBackend} for
+ * the Gemini Developer API (via {@link https://ai.google/ | Google AI}), and
+ * {@link VertexAIBackend} for the Vertex AI Gemini API.
+ *
+ * @public
+ */
+export abstract class Backend {
+ /**
+ * Specifies the backend type.
+ */
+ readonly backendType: BackendType;
+
+ /**
+ * Protected constructor for use by subclasses.
+ * @param type - The backend type.
+ */
+ protected constructor(type: BackendType) {
+ this.backendType = type;
+ }
+}
+
+/**
+ * Configuration class for the Gemini Developer API.
+ *
+ * Use this with {@link AIOptions} when initializing the AI service via
+ * {@link getAI | getAI()} to specify the Gemini Developer API as the backend.
+ *
+ * @public
+ */
+export class GoogleAIBackend extends Backend {
+ /**
+ * Creates a configuration object for the Gemini Developer API backend.
+ */
+ constructor() {
+ super(BackendType.GOOGLE_AI);
+ }
+}
+
+/**
+ * Configuration class for the Vertex AI Gemini API.
+ *
+ * Use this with {@link AIOptions} when initializing the AI service via
+ * {@link getAI | getAI()} to specify the Vertex AI Gemini API as the backend.
+ *
+ * @public
+ */
+export class VertexAIBackend extends Backend {
+ /**
+ * The region identifier.
+ * See {@link https://firebase.google.com/docs/vertex-ai/locations#available-locations | Vertex AI locations}
+ * for a list of supported locations.
+ */
+ readonly location: string;
+
+ /**
+ * Creates a configuration object for the Vertex AI backend.
+ *
+ * @param location - The region identifier, defaulting to `us-central1`;
+ * see {@link https://firebase.google.com/docs/vertex-ai/locations#available-locations | Vertex AI locations}
+ * for a list of supported locations.
+ */
+ constructor(location: string = DEFAULT_LOCATION) {
+ super(BackendType.VERTEX_AI);
+ if (!location) {
+ this.location = DEFAULT_LOCATION;
+ } else {
+ this.location = location;
+ }
+ }
+}
diff --git a/packages/vertexai/lib/constants.ts b/packages/ai/lib/constants.ts
similarity index 96%
rename from packages/vertexai/lib/constants.ts
rename to packages/ai/lib/constants.ts
index 816f5194a2..a0cffa49ad 100644
--- a/packages/vertexai/lib/constants.ts
+++ b/packages/ai/lib/constants.ts
@@ -17,7 +17,7 @@
import { version } from './version';
-export const VERTEX_TYPE = 'vertexAI';
+export const AI_TYPE = 'AI';
export const DEFAULT_LOCATION = 'us-central1';
diff --git a/packages/vertexai/lib/errors.ts b/packages/ai/lib/errors.ts
similarity index 69%
rename from packages/vertexai/lib/errors.ts
rename to packages/ai/lib/errors.ts
index 370c19aeb0..3a7e18ec3a 100644
--- a/packages/vertexai/lib/errors.ts
+++ b/packages/ai/lib/errors.ts
@@ -16,35 +16,34 @@
*/
import { FirebaseError } from '@firebase/util';
-import { VertexAIErrorCode, CustomErrorData } from './types';
-import { VERTEX_TYPE } from './constants';
+import { AIErrorCode, CustomErrorData } from './types';
+import { AI_TYPE } from './constants';
/**
* Error class for the Vertex AI in Firebase SDK.
*
* @public
*/
-export class VertexAIError extends FirebaseError {
+export class AIError extends FirebaseError {
/**
- * Constructs a new instance of the `VertexAIError` class.
+ * Constructs a new instance of the `AIError` class.
*
- * @param code - The error code from {@link VertexAIErrorCode}
.
+ * @param code - The error code from {@link AIErrorCode}
.
* @param message - A human-readable message describing the error.
* @param customErrorData - Optional error data.
*/
constructor(
- readonly code: VertexAIErrorCode,
+ readonly code: AIErrorCode,
message: string,
readonly customErrorData?: CustomErrorData,
) {
// Match error format used by FirebaseError from ErrorFactory
- const service = VERTEX_TYPE;
- const serviceName = 'VertexAI';
+ const service = AI_TYPE;
const fullCode = `${service}/${code}`;
- const fullMessage = `${serviceName}: ${message} (${fullCode})`;
+ const fullMessage = `${service}: ${message} (${fullCode})`;
super(code, fullMessage);
- Object.setPrototypeOf(this, VertexAIError.prototype);
+ Object.setPrototypeOf(this, AIError.prototype);
// Since Error is an interface, we don't inherit toString and so we define it ourselves.
this.toString = () => fullMessage;
diff --git a/packages/ai/lib/googleai-mappers.ts b/packages/ai/lib/googleai-mappers.ts
new file mode 100644
index 0000000000..2f6724b8d8
--- /dev/null
+++ b/packages/ai/lib/googleai-mappers.ts
@@ -0,0 +1,218 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import { AIError } from './errors';
+import { logger } from './logger';
+import {
+ CitationMetadata,
+ CountTokensRequest,
+ GenerateContentCandidate,
+ GenerateContentRequest,
+ GenerateContentResponse,
+ HarmSeverity,
+ InlineDataPart,
+ PromptFeedback,
+ SafetyRating,
+ AIErrorCode,
+} from './types';
+import {
+ GoogleAIGenerateContentResponse,
+ GoogleAIGenerateContentCandidate,
+ GoogleAICountTokensRequest,
+} from './types/googleai';
+
+/**
+ * This SDK supports both the Vertex AI Gemini API and the Gemini Developer API (using Google AI).
+ * The public API prioritizes the format used by the Vertex AI Gemini API.
+ * We avoid having two sets of types by translating requests and responses between the two API formats.
+ * This translation allows developers to switch between the Vertex AI Gemini API and the Gemini Developer API
+ * with minimal code changes.
+ *
+ * In here are functions that map requests and responses between the two API formats.
+ * Requests in the Vertex AI format are mapped to the Google AI format before being sent.
+ * Responses from the Google AI backend are mapped back to the Vertex AI format before being returned to the user.
+ */
+
+/**
+ * Maps a Vertex AI {@link GenerateContentRequest} to a format that can be sent to Google AI.
+ *
+ * @param generateContentRequest The {@link GenerateContentRequest} to map.
+ * @returns A {@link GenerateContentResponse} that conforms to the Google AI format.
+ *
+ * @throws If the request contains properties that are unsupported by Google AI.
+ *
+ * @internal
+ */
+export function mapGenerateContentRequest(
+ generateContentRequest: GenerateContentRequest,
+): GenerateContentRequest {
+ generateContentRequest.safetySettings?.forEach(safetySetting => {
+ if (safetySetting.method) {
+ throw new AIError(
+ AIErrorCode.UNSUPPORTED,
+ 'SafetySetting.method is not supported in the the Gemini Developer API. Please remove this property.',
+ );
+ }
+ });
+
+ if (generateContentRequest.generationConfig?.topK) {
+ const roundedTopK = Math.round(generateContentRequest.generationConfig.topK);
+
+ if (roundedTopK !== generateContentRequest.generationConfig.topK) {
+ logger.warn(
+ 'topK in GenerationConfig has been rounded to the nearest integer to match the format for requests to the Gemini Developer API.',
+ );
+ generateContentRequest.generationConfig.topK = roundedTopK;
+ }
+ }
+
+ return generateContentRequest;
+}
+
+/**
+ * Maps a {@link GenerateContentResponse} from Google AI to the format of the
+ * {@link GenerateContentResponse} that we get from VertexAI that is exposed in the public API.
+ *
+ * @param googleAIResponse The {@link GenerateContentResponse} from Google AI.
+ * @returns A {@link GenerateContentResponse} that conforms to the public API's format.
+ *
+ * @internal
+ */
+export function mapGenerateContentResponse(
+ googleAIResponse: GoogleAIGenerateContentResponse,
+): GenerateContentResponse {
+ const generateContentResponse = {
+ candidates: googleAIResponse.candidates
+ ? mapGenerateContentCandidates(googleAIResponse.candidates)
+ : undefined,
+ prompt: googleAIResponse.promptFeedback
+ ? mapPromptFeedback(googleAIResponse.promptFeedback)
+ : undefined,
+ usageMetadata: googleAIResponse.usageMetadata,
+ };
+
+ return generateContentResponse;
+}
+
+/**
+ * Maps a Vertex AI {@link CountTokensRequest} to a format that can be sent to Google AI.
+ *
+ * @param countTokensRequest The {@link CountTokensRequest} to map.
+ * @param model The model to count tokens with.
+ * @returns A {@link CountTokensRequest} that conforms to the Google AI format.
+ *
+ * @internal
+ */
+export function mapCountTokensRequest(
+ countTokensRequest: CountTokensRequest,
+ model: string,
+): GoogleAICountTokensRequest {
+ const mappedCountTokensRequest: GoogleAICountTokensRequest = {
+ generateContentRequest: {
+ model,
+ ...countTokensRequest,
+ },
+ };
+
+ return mappedCountTokensRequest;
+}
+
+/**
+ * Maps a Google AI {@link GoogleAIGenerateContentCandidate} to a format that conforms
+ * to the Vertex AI API format.
+ *
+ * @param candidates The {@link GoogleAIGenerateContentCandidate} to map.
+ * @returns A {@link GenerateContentCandidate} that conforms to the Vertex AI format.
+ *
+ * @throws If any {@link Part} in the candidates has a `videoMetadata` property.
+ *
+ * @internal
+ */
+export function mapGenerateContentCandidates(
+ candidates: GoogleAIGenerateContentCandidate[],
+): GenerateContentCandidate[] {
+ const mappedCandidates: GenerateContentCandidate[] = [];
+ let mappedSafetyRatings: SafetyRating[];
+ if (mappedCandidates) {
+ candidates.forEach(candidate => {
+ // Map citationSources to citations.
+ let citationMetadata: CitationMetadata | undefined;
+ if (candidate.citationMetadata) {
+ citationMetadata = {
+ citations: candidate.citationMetadata.citationSources,
+ };
+ }
+
+ // Assign missing candidate SafetyRatings properties to their defaults if undefined.
+ if (candidate.safetyRatings) {
+ mappedSafetyRatings = candidate.safetyRatings.map(safetyRating => {
+ return {
+ ...safetyRating,
+ severity: safetyRating.severity ?? HarmSeverity.HARM_SEVERITY_UNSUPPORTED,
+ probabilityScore: safetyRating.probabilityScore ?? 0,
+ severityScore: safetyRating.severityScore ?? 0,
+ };
+ });
+ }
+
+ // videoMetadata is not supported.
+ // Throw early since developers may send a long video as input and only expect to pay
+ // for inference on a small portion of the video.
+ if (candidate.content?.parts.some(part => (part as InlineDataPart)?.videoMetadata)) {
+ throw new AIError(
+ AIErrorCode.UNSUPPORTED,
+ 'Part.videoMetadata is not supported in the Gemini Developer API. Please remove this property.',
+ );
+ }
+
+ const mappedCandidate = {
+ index: candidate.index,
+ content: candidate.content,
+ finishReason: candidate.finishReason,
+ finishMessage: candidate.finishMessage,
+ safetyRatings: mappedSafetyRatings,
+ citationMetadata,
+ groundingMetadata: candidate.groundingMetadata,
+ };
+ mappedCandidates.push(mappedCandidate);
+ });
+ }
+
+ return mappedCandidates;
+}
+
+export function mapPromptFeedback(promptFeedback: PromptFeedback): PromptFeedback {
+ // Assign missing SafetyRating properties to their defaults if undefined.
+ const mappedSafetyRatings: SafetyRating[] = [];
+ promptFeedback.safetyRatings.forEach(safetyRating => {
+ mappedSafetyRatings.push({
+ category: safetyRating.category,
+ probability: safetyRating.probability,
+ severity: safetyRating.severity ?? HarmSeverity.HARM_SEVERITY_UNSUPPORTED,
+ probabilityScore: safetyRating.probabilityScore ?? 0,
+ severityScore: safetyRating.severityScore ?? 0,
+ blocked: safetyRating.blocked,
+ });
+ });
+
+ const mappedPromptFeedback: PromptFeedback = {
+ blockReason: promptFeedback.blockReason,
+ safetyRatings: mappedSafetyRatings,
+ blockReasonMessage: promptFeedback.blockReasonMessage,
+ };
+ return mappedPromptFeedback;
+}
diff --git a/packages/ai/lib/index.ts b/packages/ai/lib/index.ts
new file mode 100644
index 0000000000..1905b6ba84
--- /dev/null
+++ b/packages/ai/lib/index.ts
@@ -0,0 +1,92 @@
+/**
+ * @license
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import './polyfills';
+import { getApp, ReactNativeFirebase } from '@react-native-firebase/app';
+import { GoogleAIBackend, VertexAIBackend } from './backend';
+import { AIErrorCode, ModelParams, RequestOptions } from './types';
+import { AI, AIOptions } from './public-types';
+import { AIError } from './errors';
+import { GenerativeModel } from './models/generative-model';
+import { AIModel } from './models/ai-model';
+
+export * from './public-types';
+export { ChatSession } from './methods/chat-session';
+export * from './requests/schema-builder';
+export { GoogleAIBackend, VertexAIBackend } from './backend';
+export { GenerativeModel, AIError, AIModel };
+
+/**
+ * Returns the default {@link AI} instance that is associated with the provided
+ * {@link @firebase/app#FirebaseApp}. If no instance exists, initializes a new instance with the
+ * default settings.
+ *
+ * @example
+ * ```javascript
+ * const ai = getAI(app);
+ * ```
+ *
+ * @example
+ * ```javascript
+ * // Get an AI instance configured to use the Gemini Developer API (via Google AI).
+ * const ai = getAI(app, { backend: new GoogleAIBackend() });
+ * ```
+ *
+ * @example
+ * ```javascript
+ * // Get an AI instance configured to use the Vertex AI Gemini API.
+ * const ai = getAI(app, { backend: new VertexAIBackend() });
+ * ```
+ *
+ * @param app - The {@link @firebase/app#FirebaseApp} to use.
+ * @param options - {@link AIOptions} that configure the AI instance.
+ * @returns The default {@link AI} instance for the given {@link @firebase/app#FirebaseApp}.
+ *
+ * @public
+ */
+export function getAI(
+ app: ReactNativeFirebase.FirebaseApp = getApp(),
+ options: AIOptions = { backend: new GoogleAIBackend() },
+): AI {
+ return {
+ app,
+ backend: options.backend,
+ location: (options.backend as VertexAIBackend)?.location || '',
+ appCheck: options.appCheck || null,
+ auth: options.auth || null,
+ } as AI;
+}
+
+/**
+ * Returns a {@link GenerativeModel} class with methods for inference
+ * and other functionality.
+ *
+ * @public
+ */
+export function getGenerativeModel(
+ ai: AI,
+ modelParams: ModelParams,
+ requestOptions?: RequestOptions,
+): GenerativeModel {
+ if (!modelParams.model) {
+ throw new AIError(
+ AIErrorCode.NO_MODEL,
+ `Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`,
+ );
+ }
+ return new GenerativeModel(ai, modelParams, requestOptions);
+}
diff --git a/packages/vertexai/lib/logger.ts b/packages/ai/lib/logger.ts
similarity index 92%
rename from packages/vertexai/lib/logger.ts
rename to packages/ai/lib/logger.ts
index dbc3e84059..55c6e4658a 100644
--- a/packages/vertexai/lib/logger.ts
+++ b/packages/ai/lib/logger.ts
@@ -17,4 +17,4 @@
// @ts-ignore
import { Logger } from '@react-native-firebase/app/lib/internal/logger';
-export const logger = new Logger('@firebase/vertexai');
+export const logger = new Logger('@firebase/ai');
diff --git a/packages/vertexai/lib/methods/chat-session-helpers.ts b/packages/ai/lib/methods/chat-session-helpers.ts
similarity index 80%
rename from packages/vertexai/lib/methods/chat-session-helpers.ts
rename to packages/ai/lib/methods/chat-session-helpers.ts
index 4b9bb56db0..ea8cd826b9 100644
--- a/packages/vertexai/lib/methods/chat-session-helpers.ts
+++ b/packages/ai/lib/methods/chat-session-helpers.ts
@@ -15,8 +15,8 @@
* limitations under the License.
*/
-import { Content, POSSIBLE_ROLES, Part, Role, VertexAIErrorCode } from '../types';
-import { VertexAIError } from '../errors';
+import { Content, POSSIBLE_ROLES, Part, Role, AIErrorCode } from '../types';
+import { AIError } from '../errors';
// https://ai.google.dev/api/rest/v1beta/Content#part
@@ -48,14 +48,14 @@ export function validateChatHistory(history: Content[]): void {
for (const currContent of history) {
const { role, parts } = currContent;
if (!prevContent && role !== 'user') {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_CONTENT,
+ throw new AIError(
+ AIErrorCode.INVALID_CONTENT,
`First Content should be with role 'user', got ${role}`,
);
}
if (!POSSIBLE_ROLES.includes(role)) {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_CONTENT,
+ throw new AIError(
+ AIErrorCode.INVALID_CONTENT,
`Each item should include role field. Got ${role} but valid roles are: ${JSON.stringify(
POSSIBLE_ROLES,
)}`,
@@ -63,17 +63,14 @@ export function validateChatHistory(history: Content[]): void {
}
if (!Array.isArray(parts)) {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_CONTENT,
+ throw new AIError(
+ AIErrorCode.INVALID_CONTENT,
`Content should have 'parts' but property with an array of Parts`,
);
}
if (parts.length === 0) {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_CONTENT,
- `Each Content should have at least one part`,
- );
+ throw new AIError(AIErrorCode.INVALID_CONTENT, `Each Content should have at least one part`);
}
const countFields: Record = {
@@ -93,8 +90,8 @@ export function validateChatHistory(history: Content[]): void {
const validParts = VALID_PARTS_PER_ROLE[role];
for (const key of VALID_PART_FIELDS) {
if (!validParts.includes(key) && countFields[key] > 0) {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_CONTENT,
+ throw new AIError(
+ AIErrorCode.INVALID_CONTENT,
`Content with role '${role}' can't contain '${key}' part`,
);
}
@@ -103,9 +100,9 @@ export function validateChatHistory(history: Content[]): void {
if (prevContent) {
const validPreviousContentRoles = VALID_PREVIOUS_CONTENT_ROLES[role];
if (!validPreviousContentRoles.includes(prevContent.role)) {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_CONTENT,
- `Content with role '${role} can't follow '${
+ throw new AIError(
+ AIErrorCode.INVALID_CONTENT,
+ `Content with role '${role}' can't follow '${
prevContent.role
}'. Valid previous roles: ${JSON.stringify(VALID_PREVIOUS_CONTENT_ROLES)}`,
);
diff --git a/packages/vertexai/lib/methods/chat-session.ts b/packages/ai/lib/methods/chat-session.ts
similarity index 97%
rename from packages/vertexai/lib/methods/chat-session.ts
rename to packages/ai/lib/methods/chat-session.ts
index e3e9cf905f..6bbb6f526c 100644
--- a/packages/vertexai/lib/methods/chat-session.ts
+++ b/packages/ai/lib/methods/chat-session.ts
@@ -73,7 +73,7 @@ export class ChatSession {
/**
* Sends a chat message and receives a non-streaming
- * {@link GenerateContentResult}
+ * {@link GenerateContentResult}
*/
async sendMessage(request: string | Array): Promise {
await this._sendPromise;
@@ -117,7 +117,7 @@ export class ChatSession {
/**
* Sends a chat message and receives the response as a
- * {@link GenerateContentStreamResult}
containing an iterable stream
+ * {@link GenerateContentStreamResult} containing an iterable stream
* and a response promise.
*/
async sendMessageStream(
diff --git a/packages/vertexai/lib/methods/count-tokens.ts b/packages/ai/lib/methods/count-tokens.ts
similarity index 71%
rename from packages/vertexai/lib/methods/count-tokens.ts
rename to packages/ai/lib/methods/count-tokens.ts
index 10d41cffa8..7d1e21098a 100644
--- a/packages/vertexai/lib/methods/count-tokens.ts
+++ b/packages/ai/lib/methods/count-tokens.ts
@@ -18,6 +18,8 @@
import { CountTokensRequest, CountTokensResponse, RequestOptions } from '../types';
import { Task, makeRequest } from '../requests/request';
import { ApiSettings } from '../types/internal';
+import { BackendType } from '../public-types';
+import * as GoogleAIMapper from '../googleai-mappers';
export async function countTokens(
apiSettings: ApiSettings,
@@ -25,12 +27,24 @@ export async function countTokens(
params: CountTokensRequest,
requestOptions?: RequestOptions,
): Promise {
+ let body: string = '';
+ switch (apiSettings.backend.backendType) {
+ case BackendType.GOOGLE_AI:
+ const mappedParams = GoogleAIMapper.mapCountTokensRequest(params, model);
+ body = JSON.stringify(mappedParams);
+ break;
+ case BackendType.VERTEX_AI:
+ default:
+ body = JSON.stringify(params);
+ break;
+ }
+
const response = await makeRequest(
model,
Task.COUNT_TOKENS,
apiSettings,
false,
- JSON.stringify(params),
+ body,
requestOptions,
);
return response.json();
diff --git a/packages/vertexai/lib/methods/generate-content.ts b/packages/ai/lib/methods/generate-content.ts
similarity index 65%
rename from packages/vertexai/lib/methods/generate-content.ts
rename to packages/ai/lib/methods/generate-content.ts
index 6d1a6ecb27..72a081cc95 100644
--- a/packages/vertexai/lib/methods/generate-content.ts
+++ b/packages/ai/lib/methods/generate-content.ts
@@ -26,6 +26,8 @@ import { Task, makeRequest } from '../requests/request';
import { createEnhancedContentResponse } from '../requests/response-helpers';
import { processStream } from '../requests/stream-reader';
import { ApiSettings } from '../types/internal';
+import { BackendType } from '../public-types';
+import * as GoogleAIMapper from '../googleai-mappers';
export async function generateContentStream(
apiSettings: ApiSettings,
@@ -33,6 +35,9 @@ export async function generateContentStream(
params: GenerateContentRequest,
requestOptions?: RequestOptions,
): Promise {
+ if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) {
+ params = GoogleAIMapper.mapGenerateContentRequest(params);
+ }
const response = await makeRequest(
model,
Task.STREAM_GENERATE_CONTENT,
@@ -41,7 +46,7 @@ export async function generateContentStream(
JSON.stringify(params),
requestOptions,
);
- return processStream(response);
+ return processStream(response, apiSettings);
}
export async function generateContent(
@@ -50,6 +55,9 @@ export async function generateContent(
params: GenerateContentRequest,
requestOptions?: RequestOptions,
): Promise {
+ if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) {
+ params = GoogleAIMapper.mapGenerateContentRequest(params);
+ }
const response = await makeRequest(
model,
Task.GENERATE_CONTENT,
@@ -58,9 +66,21 @@ export async function generateContent(
JSON.stringify(params),
requestOptions,
);
- const responseJson: GenerateContentResponse = await response.json();
- const enhancedResponse = createEnhancedContentResponse(responseJson);
+ const generateContentResponse = await processGenerateContentResponse(response, apiSettings);
+ const enhancedResponse = createEnhancedContentResponse(generateContentResponse);
return {
response: enhancedResponse,
};
}
+
+async function processGenerateContentResponse(
+ response: Response,
+ apiSettings: ApiSettings,
+): Promise {
+ const responseJson = await response.json();
+ if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) {
+ return GoogleAIMapper.mapGenerateContentResponse(responseJson);
+ } else {
+ return responseJson;
+ }
+}
diff --git a/packages/ai/lib/models/ai-model.ts b/packages/ai/lib/models/ai-model.ts
new file mode 100644
index 0000000000..5cba58bfc8
--- /dev/null
+++ b/packages/ai/lib/models/ai-model.ts
@@ -0,0 +1,125 @@
+import { ApiSettings } from '../types/internal';
+import { AIError } from '../errors';
+import { AIErrorCode } from '../types';
+import { AI, BackendType } from '../public-types';
+import { AIService } from '../service';
+
+/**
+ * Base class for Firebase AI model APIs.
+ *
+ * Instances of this class are associated with a specific Firebase AI {@link Backend}
+ * and provide methods for interacting with the configured generative model.
+ *
+ * @public
+ */
+export abstract class AIModel {
+ /**
+ * The fully qualified model resource name to use for generating images
+ * (for example, `publishers/google/models/imagen-3.0-generate-002`).
+ */
+ readonly model: string;
+
+ /**
+ * @internal
+ */
+ protected _apiSettings: ApiSettings;
+
+ /**
+ * Constructs a new instance of the {@link AIModel} class.
+ *
+ * This constructor should only be called from subclasses that provide
+ * a model API.
+ *
+ * @param ai - an {@link AI} instance.
+ * @param modelName - The name of the model being used. It can be in one of the following formats:
+ * - `my-model` (short name, will resolve to `publishers/google/models/my-model`)
+ * - `models/my-model` (will resolve to `publishers/google/models/my-model`)
+ * - `publishers/my-publisher/models/my-model` (fully qualified model name)
+ *
+ * @throws If the `apiKey` or `projectId` fields are missing in your
+ * Firebase config.
+ *
+ * @internal
+ */
+ protected constructor(ai: AI, modelName: string) {
+ if (!ai.app?.options?.apiKey) {
+ throw new AIError(
+ AIErrorCode.NO_API_KEY,
+ `The "apiKey" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid API key.`,
+ );
+ } else if (!ai.app?.options?.projectId) {
+ throw new AIError(
+ AIErrorCode.NO_PROJECT_ID,
+ `The "projectId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid project ID.`,
+ );
+ } else if (!ai.app?.options?.appId) {
+ throw new AIError(
+ AIErrorCode.NO_APP_ID,
+ `The "appId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid app ID.`,
+ );
+ } else {
+ this._apiSettings = {
+ apiKey: ai.app.options.apiKey,
+ project: ai.app.options.projectId,
+ appId: ai.app.options.appId,
+ automaticDataCollectionEnabled: ai.app.automaticDataCollectionEnabled,
+ location: ai.location,
+ backend: ai.backend,
+ };
+
+ if ((ai as AIService).appCheck) {
+ this._apiSettings.getAppCheckToken = () => (ai as AIService).appCheck!.getToken();
+ }
+
+ if ((ai as AIService).auth?.currentUser) {
+ this._apiSettings.getAuthToken = () => (ai as AIService).auth!.currentUser!.getIdToken();
+ }
+
+ this.model = AIModel.normalizeModelName(modelName, this._apiSettings.backend.backendType);
+ }
+ }
+
+ /**
+ * Normalizes the given model name to a fully qualified model resource name.
+ *
+ * @param modelName - The model name to normalize.
+ * @returns The fully qualified model resource name.
+ *
+ * @internal
+ */
+ static normalizeModelName(modelName: string, backendType: BackendType): string {
+ if (backendType === BackendType.GOOGLE_AI) {
+ return AIModel.normalizeGoogleAIModelName(modelName);
+ } else {
+ return AIModel.normalizeVertexAIModelName(modelName);
+ }
+ }
+
+ /**
+ * @internal
+ */
+ private static normalizeGoogleAIModelName(modelName: string): string {
+ return `models/${modelName}`;
+ }
+
+ /**
+ * @internal
+ */
+ private static normalizeVertexAIModelName(modelName: string): string {
+ let model: string;
+ if (modelName.includes('/')) {
+ if (modelName.startsWith('models/')) {
+ // Add 'publishers/google' if the user is only passing in 'models/model-name'.
+ model = `publishers/google/${modelName}`;
+ } else {
+ // Any other custom format (e.g. tuned models) must be passed in correctly.
+ model = modelName;
+ }
+ } else {
+ // If path is not included, assume it's a non-tuned model.
+ model = `publishers/google/models/${modelName}`;
+ }
+
+ return model;
+ }
+}
diff --git a/packages/vertexai/lib/models/generative-model.ts b/packages/ai/lib/models/generative-model.ts
similarity index 64%
rename from packages/vertexai/lib/models/generative-model.ts
rename to packages/ai/lib/models/generative-model.ts
index 111cefa427..c3bba041e0 100644
--- a/packages/vertexai/lib/models/generative-model.ts
+++ b/packages/ai/lib/models/generative-model.ts
@@ -31,23 +31,18 @@ import {
StartChatParams,
Tool,
ToolConfig,
- VertexAIErrorCode,
} from '../types';
-import { VertexAIError } from '../errors';
import { ChatSession } from '../methods/chat-session';
import { countTokens } from '../methods/count-tokens';
import { formatGenerateContentInput, formatSystemInstruction } from '../requests/request-helpers';
-import { VertexAI } from '../public-types';
-import { ApiSettings } from '../types/internal';
-import { VertexAIService } from '../service';
+import { AIModel } from './ai-model';
+import { AI } from '../public-types';
/**
* Class for generative model APIs.
* @public
*/
-export class GenerativeModel {
- private _apiSettings: ApiSettings;
- model: string;
+export class GenerativeModel extends AIModel {
generationConfig: GenerationConfig;
safetySettings: SafetySetting[];
requestOptions?: RequestOptions;
@@ -55,45 +50,8 @@ export class GenerativeModel {
toolConfig?: ToolConfig;
systemInstruction?: Content;
- constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions) {
- if (!vertexAI.app?.options?.apiKey) {
- throw new VertexAIError(
- VertexAIErrorCode.NO_API_KEY,
- `The "apiKey" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid API key.`,
- );
- } else if (!vertexAI.app?.options?.projectId) {
- throw new VertexAIError(
- VertexAIErrorCode.NO_PROJECT_ID,
- `The "projectId" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid project ID.`,
- );
- } else {
- this._apiSettings = {
- apiKey: vertexAI.app.options.apiKey,
- project: vertexAI.app.options.projectId,
- location: vertexAI.location,
- };
- if ((vertexAI as VertexAIService).appCheck) {
- this._apiSettings.getAppCheckToken = () =>
- (vertexAI as VertexAIService).appCheck!.getToken();
- }
-
- if ((vertexAI as VertexAIService).auth?.currentUser) {
- this._apiSettings.getAuthToken = () =>
- (vertexAI as VertexAIService).auth!.currentUser!.getIdToken();
- }
- }
- if (modelParams.model.includes('/')) {
- if (modelParams.model.startsWith('models/')) {
- // Add "publishers/google" if the user is only passing in 'models/model-name'.
- this.model = `publishers/google/${modelParams.model}`;
- } else {
- // Any other custom format (e.g. tuned models) must be passed in correctly.
- this.model = modelParams.model;
- }
- } else {
- // If path is not included, assume it's a non-tuned model.
- this.model = `publishers/google/models/${modelParams.model}`;
- }
+ constructor(ai: AI, modelParams: ModelParams, requestOptions?: RequestOptions) {
+ super(ai, modelParams.model);
this.generationConfig = modelParams.generationConfig || {};
this.safetySettings = modelParams.safetySettings || [];
this.tools = modelParams.tools;
@@ -104,7 +62,7 @@ export class GenerativeModel {
/**
* Makes a single non-streaming call to the model
- * and returns an object containing a single {@link GenerateContentResponse}
.
+ * and returns an object containing a single {@link GenerateContentResponse}.
*/
async generateContent(
request: GenerateContentRequest | string | Array,
@@ -151,7 +109,7 @@ export class GenerativeModel {
}
/**
- * Gets a new {@link ChatSession}
instance which can be used for
+ * Gets a new {@link ChatSession} instance which can be used for
* multi-turn chats.
*/
startChat(startChatParams?: StartChatParams): ChatSession {
@@ -162,6 +120,13 @@ export class GenerativeModel {
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
+ generationConfig: this.generationConfig,
+ safetySettings: this.safetySettings,
+ /**
+ * Overrides params inherited from GenerativeModel with those explicitly set in the
+ * StartChatParams. For example, if startChatParams.generationConfig is set, it'll override
+ * this.generationConfig.
+ */
...startChatParams,
},
this.requestOptions,
diff --git a/packages/ai/lib/models/index.ts b/packages/ai/lib/models/index.ts
new file mode 100644
index 0000000000..fcfba15507
--- /dev/null
+++ b/packages/ai/lib/models/index.ts
@@ -0,0 +1,2 @@
+export * from './ai-model';
+export * from './generative-model';
diff --git a/packages/vertexai/lib/polyfills.ts b/packages/ai/lib/polyfills.ts
similarity index 100%
rename from packages/vertexai/lib/polyfills.ts
rename to packages/ai/lib/polyfills.ts
diff --git a/packages/ai/lib/public-types.ts b/packages/ai/lib/public-types.ts
new file mode 100644
index 0000000000..f82f899511
--- /dev/null
+++ b/packages/ai/lib/public-types.ts
@@ -0,0 +1,137 @@
+/**
+ * @license
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import { ReactNativeFirebase } from '@react-native-firebase/app';
+import { FirebaseAuthTypes } from '@react-native-firebase/auth';
+import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check';
+
+export * from './types';
+
+/**
+ * Options for initializing the AI service using {@link getAI | getAI()}.
+ * This allows specifying which backend to use (Vertex AI Gemini API or Gemini Developer API)
+ * and configuring its specific options (like location for Vertex AI).
+ *
+ * @public
+ */
+export interface AIOptions {
+ /**
+ * The backend configuration to use for the AI service instance.
+ */
+ backend: Backend;
+ appCheck?: FirebaseAppCheckTypes.Module | null;
+ auth?: FirebaseAuthTypes.Module | null;
+}
+
+/**
+ * Abstract base class representing the configuration for an AI service backend.
+ * This class should not be instantiated directly. Use its subclasses; {@link GoogleAIBackend} for
+ * the Gemini Developer API (via {@link https://ai.google/ | Google AI}), and
+ * {@link VertexAIBackend} for the Vertex AI Gemini API.
+ *
+ * @public
+ */
+export abstract class Backend {
+ /**
+ * Specifies the backend type.
+ */
+ readonly backendType: BackendType;
+
+ /**
+ * Protected constructor for use by subclasses.
+ * @param type - The backend type.
+ */
+ protected constructor(type: BackendType) {
+ this.backendType = type;
+ }
+}
+
+/**
+ * An enum-like object containing constants that represent the supported backends
+ * for the Firebase AI SDK.
+ * This determines which backend service (Vertex AI Gemini API or Gemini Developer API)
+ * the SDK will communicate with.
+ *
+ * These values are assigned to the `backendType` property within the specific backend
+ * configuration objects ({@link GoogleAIBackend} or {@link VertexAIBackend}) to identify
+ * which service to target.
+ *
+ * @public
+ */
+export const BackendType = {
+ /**
+ * Identifies the backend service for the Vertex AI Gemini API provided through Google Cloud.
+ * Use this constant when creating a {@link VertexAIBackend} configuration.
+ */
+ VERTEX_AI: 'VERTEX_AI',
+
+ /**
+ * Identifies the backend service for the Gemini Developer API ({@link https://ai.google/ | Google AI}).
+ * Use this constant when creating a {@link GoogleAIBackend} configuration.
+ */
+ GOOGLE_AI: 'GOOGLE_AI',
+} as const; // Using 'as const' makes the string values literal types
+
+/**
+ * Type alias representing valid backend types.
+ * It can be either `'VERTEX_AI'` or `'GOOGLE_AI'`.
+ *
+ * @public
+ */
+export type BackendType = (typeof BackendType)[keyof typeof BackendType];
+
+/**
+ * Options for initializing the AI service using {@link getAI | getAI()}.
+ * This allows specifying which backend to use (Vertex AI Gemini API or Gemini Developer API)
+ * and configuring its specific options (like location for Vertex AI).
+ *
+ * @public
+ */
+export interface AIOptions {
+ /**
+ * The backend configuration to use for the AI service instance.
+ */
+ backend: Backend;
+}
+
+/**
+ * An instance of the Firebase AI SDK.
+ *
+ * Do not create this instance directly. Instead, use {@link getAI | getAI()}.
+ *
+ * @public
+ */
+export interface AI {
+ /**
+ * The {@link @firebase/app#FirebaseApp} this {@link AI} instance is associated with.
+ */
+ app: ReactNativeFirebase.FirebaseApp;
+ appCheck?: FirebaseAppCheckTypes.Module | null;
+ auth?: FirebaseAuthTypes.Module | null;
+ /**
+ * A {@link Backend} instance that specifies the configuration for the target backend,
+ * either the Gemini Developer API (using {@link GoogleAIBackend}) or the
+ * Vertex AI Gemini API (using {@link VertexAIBackend}).
+ */
+ backend: Backend;
+ /**
+ * @deprecated use `AI.backend.location` instead.
+ *
+ * The location configured for this AI service instance, relevant for Vertex AI backends.
+ */
+ location: string;
+}
diff --git a/packages/vertexai/lib/requests/request-helpers.ts b/packages/ai/lib/requests/request-helpers.ts
similarity index 92%
rename from packages/vertexai/lib/requests/request-helpers.ts
rename to packages/ai/lib/requests/request-helpers.ts
index 9de045a4ee..6d468f4023 100644
--- a/packages/vertexai/lib/requests/request-helpers.ts
+++ b/packages/ai/lib/requests/request-helpers.ts
@@ -15,8 +15,8 @@
* limitations under the License.
*/
-import { Content, GenerateContentRequest, Part, VertexAIErrorCode } from '../types';
-import { VertexAIError } from '../errors';
+import { Content, GenerateContentRequest, Part, AIErrorCode } from '../types';
+import { AIError } from '../errors';
export function formatSystemInstruction(input?: string | Part | Content): Content | undefined {
if (input == null) {
@@ -32,7 +32,6 @@ export function formatSystemInstruction(input?: string | Part | Content): Conten
return input as Content;
}
}
-
return undefined;
}
@@ -76,15 +75,15 @@ function assignRoleToPartsAndValidateSendMessageRequest(parts: Part[]): Content
}
if (hasUserContent && hasFunctionContent) {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_CONTENT,
+ throw new AIError(
+ AIErrorCode.INVALID_CONTENT,
'Within a single message, FunctionResponse cannot be mixed with other type of Part in the request for sending chat message.',
);
}
if (!hasUserContent && !hasFunctionContent) {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_CONTENT,
+ throw new AIError(
+ AIErrorCode.INVALID_CONTENT,
'No Content is provided for sending chat message.',
);
}
diff --git a/packages/vertexai/lib/requests/request.ts b/packages/ai/lib/requests/request.ts
similarity index 76%
rename from packages/vertexai/lib/requests/request.ts
rename to packages/ai/lib/requests/request.ts
index e055094f90..325e2cf444 100644
--- a/packages/vertexai/lib/requests/request.ts
+++ b/packages/ai/lib/requests/request.ts
@@ -15,8 +15,8 @@
* limitations under the License.
*/
import { Platform } from 'react-native';
-import { ErrorDetails, RequestOptions, VertexAIErrorCode } from '../types';
-import { VertexAIError } from '../errors';
+import { AIErrorCode, ErrorDetails, RequestOptions } from '../types';
+import { AIError } from '../errors';
import { ApiSettings } from '../types/internal';
import {
DEFAULT_API_VERSION,
@@ -26,11 +26,13 @@ import {
PACKAGE_VERSION,
} from '../constants';
import { logger } from '../logger';
+import { GoogleAIBackend, VertexAIBackend } from '../backend';
export enum Task {
GENERATE_CONTENT = 'generateContent',
STREAM_GENERATE_CONTENT = 'streamGenerateContent',
COUNT_TOKENS = 'countTokens',
+ PREDICT = 'predict',
}
export class RequestUrl {
@@ -59,28 +61,47 @@ export class RequestUrl {
return emulatorUrl;
}
- const apiVersion = DEFAULT_API_VERSION;
- const baseUrl = this.requestOptions?.baseUrl || DEFAULT_BASE_URL;
- let url = `${baseUrl}/${apiVersion}`;
- url += `/projects/${this.apiSettings.project}`;
- url += `/locations/${this.apiSettings.location}`;
- url += `/${this.model}`;
- url += `:${this.task}`;
- if (this.stream) {
- url += '?alt=sse';
+ // Manually construct URL to avoid React Native URL API issues
+ let baseUrl = this.baseUrl;
+ // Remove trailing slash if present
+ if (baseUrl.endsWith('/')) {
+ baseUrl = baseUrl.slice(0, -1);
}
- return url;
+
+ const pathname = `/${this.apiVersion}/${this.modelPath}:${this.task}`;
+ const queryString = this.queryParams;
+
+ return `${baseUrl}${pathname}${queryString ? `?${queryString}` : ''}`;
+ }
+
+ private get baseUrl(): string {
+ return this.requestOptions?.baseUrl || DEFAULT_BASE_URL;
}
- /**
- * If the model needs to be passed to the backend, it needs to
- * include project and location path.
- */
- get fullModelString(): string {
- let modelString = `projects/${this.apiSettings.project}`;
- modelString += `/locations/${this.apiSettings.location}`;
- modelString += `/${this.model}`;
- return modelString;
+ private get apiVersion(): string {
+ return DEFAULT_API_VERSION;
+ }
+
+ private get modelPath(): string {
+ if (this.apiSettings.backend instanceof GoogleAIBackend) {
+ return `projects/${this.apiSettings.project}/${this.model}`;
+ } else if (this.apiSettings.backend instanceof VertexAIBackend) {
+ return `projects/${this.apiSettings.project}/locations/${this.apiSettings.backend.location}/${this.model}`;
+ } else {
+ throw new AIError(
+ AIErrorCode.ERROR,
+ `Invalid backend: ${JSON.stringify(this.apiSettings.backend)}`,
+ );
+ }
+ }
+
+ private get queryParams(): string {
+ let params = '';
+ if (this.stream) {
+ params += 'alt=sse';
+ }
+
+ return params;
}
}
@@ -99,6 +120,9 @@ export async function getHeaders(url: RequestUrl): Promise {
headers.append('Content-Type', 'application/json');
headers.append('x-goog-api-client', getClientHeaders());
headers.append('x-goog-api-key', url.apiSettings.apiKey);
+ if (url.apiSettings.automaticDataCollectionEnabled) {
+ headers.append('X-Firebase-Appid', url.apiSettings.appId);
+ }
if (url.apiSettings.getAppCheckToken) {
let appCheckToken;
@@ -154,6 +178,7 @@ export async function makeRequest(
let fetchTimeoutId: string | number | NodeJS.Timeout | undefined;
try {
const request = await constructRequest(model, task, apiSettings, stream, body, requestOptions);
+
const timeoutMillis =
requestOptions?.timeout != null && requestOptions.timeout >= 0
? requestOptions.timeout
@@ -192,9 +217,9 @@ export async function makeRequest(
),
)
) {
- throw new VertexAIError(
- VertexAIErrorCode.API_NOT_ENABLED,
- `The Vertex AI in Firebase SDK requires the Vertex AI in Firebase ` +
+ throw new AIError(
+ AIErrorCode.API_NOT_ENABLED,
+ `The Firebase AI SDK requires the Firebase AI ` +
`API ('firebasevertexai.googleapis.com') to be enabled in your ` +
`Firebase project. Enable this API by visiting the Firebase Console ` +
`at https://console.firebase.google.com/project/${url.apiSettings.project}/genai/ ` +
@@ -208,8 +233,8 @@ export async function makeRequest(
},
);
}
- throw new VertexAIError(
- VertexAIErrorCode.FETCH_ERROR,
+ throw new AIError(
+ AIErrorCode.FETCH_ERROR,
`Error fetching from ${url}: [${response.status} ${response.statusText}] ${message}`,
{
status: response.status,
@@ -221,14 +246,11 @@ export async function makeRequest(
} catch (e) {
let err = e as Error;
if (
- (e as VertexAIError).code !== VertexAIErrorCode.FETCH_ERROR &&
- (e as VertexAIError).code !== VertexAIErrorCode.API_NOT_ENABLED &&
+ (e as AIError).code !== AIErrorCode.FETCH_ERROR &&
+ (e as AIError).code !== AIErrorCode.API_NOT_ENABLED &&
e instanceof Error
) {
- err = new VertexAIError(
- VertexAIErrorCode.ERROR,
- `Error fetching from ${url.toString()}: ${e.message}`,
- );
+ err = new AIError(AIErrorCode.ERROR, `Error fetching from ${url.toString()}: ${e.message}`);
err.stack = e.stack;
}
diff --git a/packages/vertexai/lib/requests/response-helpers.ts b/packages/ai/lib/requests/response-helpers.ts
similarity index 73%
rename from packages/vertexai/lib/requests/response-helpers.ts
rename to packages/ai/lib/requests/response-helpers.ts
index c7abc9d923..4fdb2362bd 100644
--- a/packages/vertexai/lib/requests/response-helpers.ts
+++ b/packages/ai/lib/requests/response-helpers.ts
@@ -21,9 +21,10 @@ import {
FunctionCall,
GenerateContentCandidate,
GenerateContentResponse,
- VertexAIErrorCode,
+ AIErrorCode,
+ InlineDataPart,
} from '../types';
-import { VertexAIError } from '../errors';
+import { AIError } from '../errors';
import { logger } from '../logger';
/**
@@ -62,8 +63,8 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC
);
}
if (hadBadFinishReason(response.candidates[0]!)) {
- throw new VertexAIError(
- VertexAIErrorCode.RESPONSE_ERROR,
+ throw new AIError(
+ AIErrorCode.RESPONSE_ERROR,
`Response error: ${formatBlockErrorMessage(
response,
)}. Response body stored in error.response`,
@@ -74,8 +75,8 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC
}
return getText(response);
} else if (response.promptFeedback) {
- throw new VertexAIError(
- VertexAIErrorCode.RESPONSE_ERROR,
+ throw new AIError(
+ AIErrorCode.RESPONSE_ERROR,
`Text not available. ${formatBlockErrorMessage(response)}`,
{
response,
@@ -84,6 +85,40 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC
}
return '';
};
+ (response as EnhancedGenerateContentResponse).inlineDataParts = ():
+ | InlineDataPart[]
+ | undefined => {
+ if (response.candidates && response.candidates.length > 0) {
+ if (response.candidates.length > 1) {
+ logger.warn(
+ `This response had ${response.candidates.length} ` +
+ `candidates. Returning data from the first candidate only. ` +
+ `Access response.candidates directly to use the other candidates.`,
+ );
+ }
+ if (hadBadFinishReason(response.candidates[0]!)) {
+ throw new AIError(
+ AIErrorCode.RESPONSE_ERROR,
+ `Response error: ${formatBlockErrorMessage(
+ response,
+ )}. Response body stored in error.response`,
+ {
+ response,
+ },
+ );
+ }
+ return getInlineDataParts(response);
+ } else if (response.promptFeedback) {
+ throw new AIError(
+ AIErrorCode.RESPONSE_ERROR,
+ `Data not available. ${formatBlockErrorMessage(response)}`,
+ {
+ response,
+ },
+ );
+ }
+ return undefined;
+ };
(response as EnhancedGenerateContentResponse).functionCalls = () => {
if (response.candidates && response.candidates.length > 0) {
if (response.candidates.length > 1) {
@@ -94,8 +129,8 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC
);
}
if (hadBadFinishReason(response.candidates[0]!)) {
- throw new VertexAIError(
- VertexAIErrorCode.RESPONSE_ERROR,
+ throw new AIError(
+ AIErrorCode.RESPONSE_ERROR,
`Response error: ${formatBlockErrorMessage(
response,
)}. Response body stored in error.response`,
@@ -106,8 +141,8 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC
}
return getFunctionCalls(response);
} else if (response.promptFeedback) {
- throw new VertexAIError(
- VertexAIErrorCode.RESPONSE_ERROR,
+ throw new AIError(
+ AIErrorCode.RESPONSE_ERROR,
`Function call not available. ${formatBlockErrorMessage(response)}`,
{
response,
@@ -125,7 +160,7 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC
export function getText(response: GenerateContentResponse): string {
const textStrings = [];
if (response.candidates?.[0]?.content?.parts) {
- for (const part of response.candidates?.[0].content?.parts) {
+ for (const part of response.candidates?.[0]?.content?.parts) {
if (part.text) {
textStrings.push(part.text);
}
@@ -139,7 +174,7 @@ export function getText(response: GenerateContentResponse): string {
}
/**
- * Returns {@link FunctionCall}
s associated with first candidate.
+ * Returns {@link FunctionCall}s associated with first candidate.
*/
export function getFunctionCalls(response: GenerateContentResponse): FunctionCall[] | undefined {
const functionCalls: FunctionCall[] = [];
@@ -157,6 +192,31 @@ export function getFunctionCalls(response: GenerateContentResponse): FunctionCal
}
}
+/**
+ * Returns {@link InlineDataPart}s in the first candidate if present.
+ *
+ * @internal
+ */
+export function getInlineDataParts(
+ response: GenerateContentResponse,
+): InlineDataPart[] | undefined {
+ const data: InlineDataPart[] = [];
+
+ if (response.candidates?.[0]?.content?.parts) {
+ for (const part of response.candidates?.[0]?.content?.parts) {
+ if (part.inlineData) {
+ data.push(part);
+ }
+ }
+ }
+
+ if (data.length > 0) {
+ return data;
+ } else {
+ return undefined;
+ }
+}
+
const badFinishReasons = [FinishReason.RECITATION, FinishReason.SAFETY];
function hadBadFinishReason(candidate: GenerateContentCandidate): boolean {
diff --git a/packages/vertexai/lib/requests/schema-builder.ts b/packages/ai/lib/requests/schema-builder.ts
similarity index 94%
rename from packages/vertexai/lib/requests/schema-builder.ts
rename to packages/ai/lib/requests/schema-builder.ts
index 92003a0950..21c5605cb7 100644
--- a/packages/vertexai/lib/requests/schema-builder.ts
+++ b/packages/ai/lib/requests/schema-builder.ts
@@ -15,8 +15,8 @@
* limitations under the License.
*/
-import { VertexAIError } from '../errors';
-import { VertexAIErrorCode } from '../types';
+import { AIError } from '../errors';
+import { AIErrorCode } from '../types';
import {
SchemaInterface,
SchemaType,
@@ -49,6 +49,12 @@ export abstract class Schema implements SchemaInterface {
format?: string;
/** Optional. The description of the property. */
description?: string;
+ /** Optional. The items of the property. */
+ items?: SchemaInterface;
+ /** The minimum number of items (elements) in a schema of type {@link SchemaType.ARRAY}. */
+ minItems?: number;
+ /** The maximum number of items (elements) in a schema of type {@link SchemaType.ARRAY}. */
+ maxItems?: number;
/** Optional. Whether the property is nullable. Defaults to false. */
nullable: boolean;
/** Optional. The example of the property. */
@@ -257,8 +263,8 @@ export class ObjectSchema extends Schema {
if (this.optionalProperties) {
for (const propertyKey of this.optionalProperties) {
if (!this.properties.hasOwnProperty(propertyKey)) {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_SCHEMA,
+ throw new AIError(
+ AIErrorCode.INVALID_SCHEMA,
`Property "${propertyKey}" specified in "optionalProperties" does not exist.`,
);
}
diff --git a/packages/vertexai/lib/requests/stream-reader.ts b/packages/ai/lib/requests/stream-reader.ts
similarity index 80%
rename from packages/vertexai/lib/requests/stream-reader.ts
rename to packages/ai/lib/requests/stream-reader.ts
index d24f6d44bf..6fea165c26 100644
--- a/packages/vertexai/lib/requests/stream-reader.ts
+++ b/packages/ai/lib/requests/stream-reader.ts
@@ -22,10 +22,14 @@ import {
GenerateContentResponse,
GenerateContentStreamResult,
Part,
- VertexAIErrorCode,
+ AIErrorCode,
} from '../types';
-import { VertexAIError } from '../errors';
+import { AIError } from '../errors';
import { createEnhancedContentResponse } from './response-helpers';
+import { ApiSettings } from '../types/internal';
+import { BackendType } from '../public-types';
+import * as GoogleAIMapper from '../googleai-mappers';
+import { GoogleAIGenerateContentResponse } from '../types/googleai';
const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
@@ -37,7 +41,10 @@ const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/;
*
* @param response - Response from a fetch call
*/
-export function processStream(response: Response): GenerateContentStreamResult {
+export function processStream(
+ response: Response,
+ apiSettings: ApiSettings,
+): GenerateContentStreamResult {
const inputStream = new ReadableStream({
async start(controller) {
const reader = response.body!.getReader();
@@ -56,28 +63,36 @@ export function processStream(response: Response): GenerateContentStreamResult {
const responseStream = getResponseStream(inputStream);
const [stream1, stream2] = responseStream.tee();
return {
- stream: generateResponseSequence(stream1),
- response: getResponsePromise(stream2),
+ stream: generateResponseSequence(stream1, apiSettings),
+ response: getResponsePromise(stream2, apiSettings),
};
}
async function getResponsePromise(
stream: ReadableStream,
+ apiSettings: ApiSettings,
): Promise {
const allResponses: GenerateContentResponse[] = [];
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
- const enhancedResponse = createEnhancedContentResponse(aggregateResponses(allResponses));
- return enhancedResponse;
+ let generateContentResponse = aggregateResponses(allResponses);
+ if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) {
+ generateContentResponse = GoogleAIMapper.mapGenerateContentResponse(
+ generateContentResponse as GoogleAIGenerateContentResponse,
+ );
+ }
+ return createEnhancedContentResponse(generateContentResponse);
}
+
allResponses.push(value);
}
}
async function* generateResponseSequence(
stream: ReadableStream,
+ apiSettings: ApiSettings,
): AsyncGenerator {
const reader = stream.getReader();
while (true) {
@@ -86,7 +101,15 @@ async function* generateResponseSequence(
break;
}
- const enhancedResponse = createEnhancedContentResponse(value);
+ let enhancedResponse: EnhancedGenerateContentResponse;
+ if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) {
+ enhancedResponse = createEnhancedContentResponse(
+ GoogleAIMapper.mapGenerateContentResponse(value as GoogleAIGenerateContentResponse),
+ );
+ } else {
+ enhancedResponse = createEnhancedContentResponse(value);
+ }
+
yield enhancedResponse;
}
}
@@ -106,9 +129,7 @@ export function getResponseStream(inputStream: ReadableStream): Reada
return reader.read().then(({ value, done }) => {
if (done) {
if (currentText.trim()) {
- controller.error(
- new VertexAIError(VertexAIErrorCode.PARSE_FAILED, 'Failed to parse stream'),
- );
+ controller.error(new AIError(AIErrorCode.PARSE_FAILED, 'Failed to parse stream'));
return;
}
controller.close();
@@ -123,10 +144,7 @@ export function getResponseStream(inputStream: ReadableStream): Reada
parsedResponse = JSON.parse(match[1]!);
} catch (_) {
controller.error(
- new VertexAIError(
- VertexAIErrorCode.PARSE_FAILED,
- `Error parsing JSON response: "${match[1]}`,
- ),
+ new AIError(AIErrorCode.PARSE_FAILED, `Error parsing JSON response: "${match[1]}`),
);
return;
}
@@ -197,8 +215,8 @@ export function aggregateResponses(responses: GenerateContentResponse[]): Genera
newPart.functionCall = part.functionCall;
}
if (Object.keys(newPart).length === 0) {
- throw new VertexAIError(
- VertexAIErrorCode.INVALID_CONTENT,
+ throw new AIError(
+ AIErrorCode.INVALID_CONTENT,
'Part should have at least one property, but there are none. This is likely caused ' +
'by a malformed response from the backend.',
);
diff --git a/packages/vertexai/lib/service.ts b/packages/ai/lib/service.ts
similarity index 79%
rename from packages/vertexai/lib/service.ts
rename to packages/ai/lib/service.ts
index e90ffa9668..79bf741303 100644
--- a/packages/vertexai/lib/service.ts
+++ b/packages/ai/lib/service.ts
@@ -16,24 +16,28 @@
*/
import { ReactNativeFirebase } from '@react-native-firebase/app';
-import { VertexAI, VertexAIOptions } from './public-types';
-import { DEFAULT_LOCATION } from './constants';
+import { AI, Backend } from './public-types';
import { FirebaseAuthTypes } from '@react-native-firebase/auth';
import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check';
+import { VertexAIBackend } from './backend';
-export class VertexAIService implements VertexAI {
+export class AIService implements AI {
auth: FirebaseAuthTypes.Module | null;
appCheck: FirebaseAppCheckTypes.Module | null;
location: string;
constructor(
public app: ReactNativeFirebase.FirebaseApp,
+ public backend: Backend,
auth?: FirebaseAuthTypes.Module,
appCheck?: FirebaseAppCheckTypes.Module,
- public options?: VertexAIOptions,
) {
this.auth = auth || null;
this.appCheck = appCheck || null;
- this.location = this.options?.location || DEFAULT_LOCATION;
+ if (backend instanceof VertexAIBackend) {
+ this.location = backend.location;
+ } else {
+ this.location = '';
+ }
}
}
diff --git a/packages/vertexai/lib/types/content.ts b/packages/ai/lib/types/content.ts
similarity index 100%
rename from packages/vertexai/lib/types/content.ts
rename to packages/ai/lib/types/content.ts
diff --git a/packages/ai/lib/types/enums.ts b/packages/ai/lib/types/enums.ts
new file mode 100644
index 0000000000..035d26703e
--- /dev/null
+++ b/packages/ai/lib/types/enums.ts
@@ -0,0 +1,281 @@
+/**
+ * @license
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Role is the producer of the content.
+ * @public
+ */
+export type Role = (typeof POSSIBLE_ROLES)[number];
+
+/**
+ * Possible roles.
+ * @public
+ */
+export const POSSIBLE_ROLES = ['user', 'model', 'function', 'system'] as const;
+
+/**
+ * Harm categories that would cause prompts or candidates to be blocked.
+ * @public
+ */
+export enum HarmCategory {
+ HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH',
+ HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
+ HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT',
+ HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT',
+}
+
+/**
+ * Threshold above which a prompt or candidate will be blocked.
+ * @public
+ */
+export enum HarmBlockThreshold {
+ /**
+ * Content with `NEGLIGIBLE` will be allowed.
+ */
+ BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE',
+ /**
+ * Content with `NEGLIGIBLE` and `LOW` will be allowed.
+ */
+ BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE',
+ /**
+ * Content with `NEGLIGIBLE`, `LOW`, and `MEDIUM` will be allowed.
+ */
+ BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH',
+ /**
+ * All content will be allowed.
+ */
+ BLOCK_NONE = 'BLOCK_NONE',
+ /**
+ * All content will be allowed. This is the same as `BLOCK_NONE`, but the metadata corresponding
+ * to the {@link HarmCategory} will not be present in the response.
+ */
+ OFF = 'OFF',
+}
+
+/**
+ * This property is not supported in the Gemini Developer API ({@link GoogleAIBackend}).
+ *
+ * @public
+ */
+export enum HarmBlockMethod {
+ /**
+ * The harm block method uses both probability and severity scores.
+ */
+ SEVERITY = 'SEVERITY',
+ /**
+ * The harm block method uses the probability score.
+ */
+ PROBABILITY = 'PROBABILITY',
+}
+
+/**
+ * Probability that a prompt or candidate matches a harm category.
+ * @public
+ */
+export enum HarmProbability {
+ /**
+ * Content has a negligible chance of being unsafe.
+ */
+ NEGLIGIBLE = 'NEGLIGIBLE',
+ /**
+ * Content has a low chance of being unsafe.
+ */
+ LOW = 'LOW',
+ /**
+ * Content has a medium chance of being unsafe.
+ */
+ MEDIUM = 'MEDIUM',
+ /**
+ * Content has a high chance of being unsafe.
+ */
+ HIGH = 'HIGH',
+}
+
+/**
+ * Harm severity levels.
+ * @public
+ */
+export enum HarmSeverity {
+ /**
+ * Negligible level of harm severity.
+ */
+ HARM_SEVERITY_NEGLIGIBLE = 'HARM_SEVERITY_NEGLIGIBLE',
+ /**
+ * Low level of harm severity.
+ */
+ HARM_SEVERITY_LOW = 'HARM_SEVERITY_LOW',
+ /**
+ * Medium level of harm severity.
+ */
+ HARM_SEVERITY_MEDIUM = 'HARM_SEVERITY_MEDIUM',
+ /**
+ * High level of harm severity.
+ */
+ HARM_SEVERITY_HIGH = 'HARM_SEVERITY_HIGH',
+ /**
+ * Harm severity is not supported.
+ *
+ * @remarks
+ * The GoogleAI backend does not support `HarmSeverity`, so this value is used as a fallback.
+ */
+ HARM_SEVERITY_UNSUPPORTED = 'HARM_SEVERITY_UNSUPPORTED',
+}
+
+/**
+ * Reason that a prompt was blocked.
+ * @public
+ */
+export enum BlockReason {
+ /**
+ * Content was blocked by safety settings.
+ */
+ SAFETY = 'SAFETY',
+ /**
+ * Content was blocked, but the reason is uncategorized.
+ */
+ OTHER = 'OTHER',
+ /**
+ * Content was blocked because it contained terms from the terminology blocklist.
+ */
+ BLOCKLIST = 'BLOCKLIST',
+ /**
+ * Content was blocked due to prohibited content.
+ */
+ PROHIBITED_CONTENT = 'PROHIBITED_CONTENT',
+}
+
+/**
+ * Reason that a candidate finished.
+ * @public
+ */
+export enum FinishReason {
+ /**
+ * Natural stop point of the model or provided stop sequence.
+ */
+ STOP = 'STOP',
+ /**
+ * The maximum number of tokens as specified in the request was reached.
+ */
+ MAX_TOKENS = 'MAX_TOKENS',
+ /**
+ * The candidate content was flagged for safety reasons.
+ */
+ SAFETY = 'SAFETY',
+ /**
+ * The candidate content was flagged for recitation reasons.
+ */
+ RECITATION = 'RECITATION',
+ /**
+ * Unknown reason.
+ */
+ OTHER = 'OTHER',
+ /**
+ * The candidate content contained forbidden terms.
+ */
+ BLOCKLIST = 'BLOCKLIST',
+ /**
+ * The candidate content potentially contained prohibited content.
+ */
+ PROHIBITED_CONTENT = 'PROHIBITED_CONTENT',
+ /**
+ * The candidate content potentially contained Sensitive Personally Identifiable Information (SPII).
+ */
+ SPII = 'SPII',
+ /**
+ * The function call generated by the model was invalid.
+ */
+ MALFORMED_FUNCTION_CALL = 'MALFORMED_FUNCTION_CALL',
+}
+
+/**
+ * @public
+ */
+export enum FunctionCallingMode {
+ /**
+ * Default model behavior; model decides to predict either a function call
+ * or a natural language response.
+ */
+ AUTO = 'AUTO',
+ /**
+ * Model is constrained to always predicting a function call only.
+ * If `allowed_function_names` is set, the predicted function call will be
+ * limited to any one of `allowed_function_names`, else the predicted
+ * function call will be any one of the provided `function_declarations`.
+ */
+ ANY = 'ANY',
+ /**
+ * Model will not predict any function call. Model behavior is same as when
+ * not passing any function declarations.
+ */
+ NONE = 'NONE',
+}
+
+/**
+ * Content part modality.
+ * @public
+ */
+export enum Modality {
+ /**
+ * Unspecified modality.
+ */
+ MODALITY_UNSPECIFIED = 'MODALITY_UNSPECIFIED',
+ /**
+ * Plain text.
+ */
+ TEXT = 'TEXT',
+ /**
+ * Image.
+ */
+ IMAGE = 'IMAGE',
+ /**
+ * Video.
+ */
+ VIDEO = 'VIDEO',
+ /**
+ * Audio.
+ */
+ AUDIO = 'AUDIO',
+ /**
+ * Document (for example, PDF).
+ */
+ DOCUMENT = 'DOCUMENT',
+}
+
+/**
+ * Generation modalities to be returned in generation responses.
+ *
+ * @beta
+ */
+export const ResponseModality = {
+ /**
+ * Text.
+ * @beta
+ */
+ TEXT: 'TEXT',
+ /**
+ * Image.
+ * @beta
+ */
+ IMAGE: 'IMAGE',
+} as const;
+
+/**
+ * Generation modalities to be returned in generation responses.
+ *
+ * @beta
+ */
+export type ResponseModality = (typeof ResponseModality)[keyof typeof ResponseModality];
diff --git a/packages/vertexai/lib/types/error.ts b/packages/ai/lib/types/error.ts
similarity index 87%
rename from packages/vertexai/lib/types/error.ts
rename to packages/ai/lib/types/error.ts
index c65e09c55f..4fcc1ac483 100644
--- a/packages/vertexai/lib/types/error.ts
+++ b/packages/ai/lib/types/error.ts
@@ -50,7 +50,7 @@ export interface CustomErrorData {
/** HTTP status text of the error response. */
statusText?: string;
- /** Response from a {@link GenerateContentRequest}
*/
+ /** Response from a {@link GenerateContentRequest} */
response?: GenerateContentResponse;
/** Optional additional details about the error. */
@@ -58,11 +58,11 @@ export interface CustomErrorData {
}
/**
- * Standardized error codes that {@link VertexAIError}
can have.
+ * Standardized error codes that {@link AIError} can have.
*
* @public
*/
-export const enum VertexAIErrorCode {
+export const enum AIErrorCode {
/** A generic error occurred. */
ERROR = 'error',
@@ -87,6 +87,9 @@ export const enum VertexAIErrorCode {
/** An error occurred due to a missing Firebase API key. */
NO_API_KEY = 'no-api-key',
+ /** An error occurred due to a missing Firebase app ID. */
+ NO_APP_ID = 'no-app-id',
+
/** An error occurred due to a model name not being specified during initialization. */
NO_MODEL = 'no-model',
@@ -95,4 +98,7 @@ export const enum VertexAIErrorCode {
/** An error occurred while parsing. */
PARSE_FAILED = 'parse-failed',
+
+ /** An error occurred due an attempt to use an unsupported feature. */
+ UNSUPPORTED = 'unsupported',
}
diff --git a/packages/ai/lib/types/googleai.ts b/packages/ai/lib/types/googleai.ts
new file mode 100644
index 0000000000..4c7dfe30bb
--- /dev/null
+++ b/packages/ai/lib/types/googleai.ts
@@ -0,0 +1,70 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import {
+ Tool,
+ GenerationConfig,
+ Citation,
+ FinishReason,
+ GroundingMetadata,
+ PromptFeedback,
+ SafetyRating,
+ UsageMetadata,
+} from '../public-types';
+import { Content, Part } from './content';
+
+/**
+ * @internal
+ */
+export interface GoogleAICountTokensRequest {
+ generateContentRequest: {
+ model: string; // 'models/model-name'
+ contents: Content[];
+ systemInstruction?: string | Part | Content;
+ tools?: Tool[];
+ generationConfig?: GenerationConfig;
+ };
+}
+
+/**
+ * @internal
+ */
+export interface GoogleAIGenerateContentResponse {
+ candidates?: GoogleAIGenerateContentCandidate[];
+ promptFeedback?: PromptFeedback;
+ usageMetadata?: UsageMetadata;
+}
+
+/**
+ * @internal
+ */
+export interface GoogleAIGenerateContentCandidate {
+ index: number;
+ content: Content;
+ finishReason?: FinishReason;
+ finishMessage?: string;
+ safetyRatings?: SafetyRating[];
+ citationMetadata?: GoogleAICitationMetadata;
+ groundingMetadata?: GroundingMetadata;
+}
+
+/**
+ * @internal
+ */
+export interface GoogleAICitationMetadata {
+ citationSources: Citation[]; // Maps to `citations`
+}
diff --git a/packages/vertexai/lib/types/index.ts b/packages/ai/lib/types/index.ts
similarity index 96%
rename from packages/vertexai/lib/types/index.ts
rename to packages/ai/lib/types/index.ts
index 85133aa07c..6d77a4a935 100644
--- a/packages/vertexai/lib/types/index.ts
+++ b/packages/ai/lib/types/index.ts
@@ -21,3 +21,4 @@ export * from './requests';
export * from './responses';
export * from './error';
export * from './schema';
+export * from './googleai';
diff --git a/packages/vertexai/lib/types/internal.ts b/packages/ai/lib/types/internal.ts
similarity index 82%
rename from packages/vertexai/lib/types/internal.ts
rename to packages/ai/lib/types/internal.ts
index ee60d476c9..8b51e8c846 100644
--- a/packages/vertexai/lib/types/internal.ts
+++ b/packages/ai/lib/types/internal.ts
@@ -15,11 +15,18 @@
* limitations under the License.
*/
import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check';
+import { Backend } from '../public-types';
export interface ApiSettings {
apiKey: string;
+ appId: string;
project: string;
+ /**
+ * @deprecated Use `backend.location` instead.
+ */
location: string;
+ automaticDataCollectionEnabled?: boolean;
+ backend: Backend;
getAuthToken?: () => Promise;
getAppCheckToken?: () => Promise;
}
diff --git a/packages/vertexai/lib/types/polyfills.d.ts b/packages/ai/lib/types/polyfills.d.ts
similarity index 100%
rename from packages/vertexai/lib/types/polyfills.d.ts
rename to packages/ai/lib/types/polyfills.d.ts
diff --git a/packages/vertexai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts
similarity index 78%
rename from packages/vertexai/lib/types/requests.ts
rename to packages/ai/lib/types/requests.ts
index 708a55a11c..53b35b1196 100644
--- a/packages/vertexai/lib/types/requests.ts
+++ b/packages/ai/lib/types/requests.ts
@@ -17,7 +17,13 @@
import { TypedSchema } from '../requests/schema-builder';
import { Content, Part } from './content';
-import { FunctionCallingMode, HarmBlockMethod, HarmBlockThreshold, HarmCategory } from './enums';
+import {
+ FunctionCallingMode,
+ HarmBlockMethod,
+ HarmBlockThreshold,
+ HarmCategory,
+ ResponseModality,
+} from './enums';
import { ObjectSchemaInterface, SchemaRequest } from './schema';
/**
@@ -30,7 +36,7 @@ export interface BaseParams {
}
/**
- * Params passed to {@link getGenerativeModel}
.
+ * Params passed to {@link getGenerativeModel}.
* @public
*/
export interface ModelParams extends BaseParams {
@@ -58,6 +64,13 @@ export interface GenerateContentRequest extends BaseParams {
export interface SafetySetting {
category: HarmCategory;
threshold: HarmBlockThreshold;
+ /**
+ * The harm block method.
+ *
+ * This property is only supported in the Vertex AI Gemini API ({@link VertexAIBackend}).
+ * When using the Gemini Developer API ({@link GoogleAIBackend}), an {@link AIError} will be
+ * thrown if this property is defined.
+ */
method?: HarmBlockMethod;
}
@@ -83,13 +96,23 @@ export interface GenerationConfig {
responseMimeType?: string;
/**
* Output response schema of the generated candidate text. This
- * value can be a class generated with a {@link Schema}
static method
+ * value can be a class generated with a {@link Schema} static method
* like `Schema.string()` or `Schema.object()` or it can be a plain
- * JS object matching the {@link SchemaRequest}
interface.
+ * JS object matching the {@link SchemaRequest} interface.
* Note: This only applies when the specified `responseMIMEType` supports a schema; currently
* this is limited to `application/json` and `text/x.enum`.
*/
responseSchema?: TypedSchema | SchemaRequest;
+ /**
+ * Generation modalities to be returned in generation responses.
+ *
+ * @remarks
+ * - Multimodal response generation is only supported by some Gemini models and versions; see {@link https://firebase.google.com/docs/vertex-ai/models | model versions}.
+ * - Only image generation (`ResponseModality.IMAGE`) is supported.
+ *
+ * @beta
+ */
+ responseModalities?: ResponseModality[];
}
/**
@@ -109,10 +132,22 @@ export interface StartChatParams extends BaseParams {
*/
export interface CountTokensRequest {
contents: Content[];
+ /**
+ * Instructions that direct the model to behave a certain way.
+ */
+ systemInstruction?: string | Part | Content;
+ /**
+ * {@link Tool} configuration.
+ */
+ tools?: Tool[];
+ /**
+ * Configuration options that control how the model generates a response.
+ */
+ generationConfig?: GenerationConfig;
}
/**
- * Params passed to {@link getGenerativeModel}
.
+ * Params passed to {@link getGenerativeModel}.
* @public
*/
export interface RequestOptions {
@@ -172,8 +207,8 @@ export declare interface FunctionDeclarationsTool {
* Optional. One or more function declarations
* to be passed to the model along with the current user query. Model may
* decide to call a subset of these functions by populating
- * {@link FunctionCall}
in the response. User should
- * provide a {@link FunctionResponse}
for each
+ * {@link FunctionCall} in the response. User should
+ * provide a {@link FunctionResponse} for each
* function call in the next turn. Based on the function responses, the model will
* generate the final response back to the user. Maximum 64 function
* declarations can be provided.
diff --git a/packages/vertexai/lib/types/responses.ts b/packages/ai/lib/types/responses.ts
similarity index 59%
rename from packages/vertexai/lib/types/responses.ts
rename to packages/ai/lib/types/responses.ts
index 013391e98b..450a388992 100644
--- a/packages/vertexai/lib/types/responses.ts
+++ b/packages/ai/lib/types/responses.ts
@@ -15,8 +15,15 @@
* limitations under the License.
*/
-import { Content, FunctionCall } from './content';
-import { BlockReason, FinishReason, HarmCategory, HarmProbability, HarmSeverity } from './enums';
+import { Content, FunctionCall, InlineDataPart } from './content';
+import {
+ BlockReason,
+ FinishReason,
+ HarmCategory,
+ HarmProbability,
+ HarmSeverity,
+ Modality,
+} from './enums';
/**
* Result object returned from {@link GenerativeModel.generateContent} call.
@@ -51,6 +58,15 @@ export interface EnhancedGenerateContentResponse extends GenerateContentResponse
* Throws if the prompt or candidate was blocked.
*/
text: () => string;
+ /**
+ * Aggregates and returns all {@link InlineDataPart}s from the {@link GenerateContentResponse}'s
+ * first candidate.
+ *
+ * @returns An array of {@link InlineDataPart}s containing data from the response, if available.
+ *
+ * @throws If the prompt or candidate was blocked.
+ */
+ inlineDataParts: () => InlineDataPart[] | undefined;
functionCalls: () => FunctionCall[] | undefined;
}
@@ -68,7 +84,7 @@ export interface GenerateContentResponse {
}
/**
- * Usage metadata about a {@link GenerateContentResponse}
.
+ * Usage metadata about a {@link GenerateContentResponse}.
*
* @public
*/
@@ -76,6 +92,20 @@ export interface UsageMetadata {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
+ promptTokensDetails?: ModalityTokenCount[];
+ candidatesTokensDetails?: ModalityTokenCount[];
+}
+
+/**
+ * Represents token counting info for a single modality.
+ *
+ * @public
+ */
+export interface ModalityTokenCount {
+ /** The modality associated with this token count. */
+ modality: Modality;
+ /** The number of tokens counted. */
+ tokenCount: number;
}
/**
@@ -86,11 +116,16 @@ export interface UsageMetadata {
export interface PromptFeedback {
blockReason?: BlockReason;
safetyRatings: SafetyRating[];
+ /**
+ * A human-readable description of the `blockReason`.
+ *
+ * This property is only supported in the Vertex AI Gemini API ({@link VertexAIBackend}).
+ */
blockReasonMessage?: string;
}
/**
- * A candidate returned as part of a {@link GenerateContentResponse}
.
+ * A candidate returned as part of a {@link GenerateContentResponse}.
* @public
*/
export interface GenerateContentCandidate {
@@ -104,7 +139,7 @@ export interface GenerateContentCandidate {
}
/**
- * Citation metadata that may be found on a {@link GenerateContentCandidate}
.
+ * Citation metadata that may be found on a {@link GenerateContentCandidate}.
* @public
*/
export interface CitationMetadata {
@@ -120,7 +155,17 @@ export interface Citation {
endIndex?: number;
uri?: string;
license?: string;
+ /**
+ * The title of the cited source, if available.
+ *
+ * This property is only supported in the Vertex AI Gemini API ({@link VertexAIBackend}).
+ */
title?: string;
+ /**
+ * The publication date of the cited source, if available.
+ *
+ * This property is only supported in the Vertex AI Gemini API ({@link VertexAIBackend}).
+ */
publicationDate?: Date;
}
@@ -131,10 +176,14 @@ export interface Citation {
export interface GroundingMetadata {
webSearchQueries?: string[];
retrievalQueries?: string[];
+ /**
+ * @deprecated
+ */
groundingAttributions: GroundingAttribution[];
}
/**
+ * @deprecated
* @public
*/
export interface GroundingAttribution {
@@ -180,14 +229,32 @@ export interface Date {
}
/**
- * A safety rating associated with a {@link GenerateContentCandidate}
+ * A safety rating associated with a {@link GenerateContentCandidate}
* @public
*/
export interface SafetyRating {
category: HarmCategory;
probability: HarmProbability;
+ /**
+ * The harm severity level.
+ *
+ * This property is only supported when using the Vertex AI Gemini API ({@link VertexAIBackend}).
+ * When using the Gemini Developer API ({@link GoogleAIBackend}), this property is not supported and will default to `HarmSeverity.UNSUPPORTED`.
+ */
severity: HarmSeverity;
+ /**
+ * The probability score of the harm category.
+ *
+ * This property is only supported when using the Vertex AI Gemini API ({@link VertexAIBackend}).
+ * When using the Gemini Developer API ({@link GoogleAIBackend}), this property is not supported and will default to 0.
+ */
probabilityScore: number;
+ /**
+ * The severity score of the harm category.
+ *
+ * This property is only supported when using the Vertex AI Gemini API ({@link VertexAIBackend}).
+ * When using the Gemini Developer API ({@link GoogleAIBackend}), this property is not supported and will default to 0.
+ */
severityScore: number;
blocked: boolean;
}
@@ -204,6 +271,13 @@ export interface CountTokensResponse {
/**
* The total number of billable characters counted across all instances
* from the request.
+ *
+ * This property is only supported when using the Vertex AI Gemini API ({@link VertexAIBackend}).
+ * When using the Gemini Developer API ({@link GoogleAIBackend}), this property is not supported and will default to 0.
*/
totalBillableCharacters?: number;
+ /**
+ * The breakdown, by modality, of how many tokens are consumed by the prompt.
+ */
+ promptTokensDetails?: ModalityTokenCount[];
}
diff --git a/packages/vertexai/lib/types/schema.ts b/packages/ai/lib/types/schema.ts
similarity index 65%
rename from packages/vertexai/lib/types/schema.ts
rename to packages/ai/lib/types/schema.ts
index c1376b9aa1..60a23a2d56 100644
--- a/packages/vertexai/lib/types/schema.ts
+++ b/packages/ai/lib/types/schema.ts
@@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
/**
* Contains the list of OpenAPI data types
* as defined by the
@@ -36,40 +37,59 @@ export enum SchemaType {
}
/**
- * Basic {@link Schema}
properties shared across several Schema-related
+ * Basic {@link Schema} properties shared across several Schema-related
* types.
* @public
*/
export interface SchemaShared {
- /** Optional. The format of the property. */
+ /** Optional. The format of the property.
+ * When using the Gemini Developer API ({@link GoogleAIBackend}), this must be either `'enum'` or
+ * `'date-time'`, otherwise requests will fail.
+ */
format?: string;
/** Optional. The description of the property. */
description?: string;
+ /**
+ * The title of the property. This helps document the schema's purpose but does not typically
+ * constrain the generated value. It can subtly guide the model by clarifying the intent of a
+ * field.
+ */
+ title?: string;
/** Optional. The items of the property. */
items?: T;
+ /** The minimum number of items (elements) in a schema of type {@link SchemaType.ARRAY}. */
+ minItems?: number;
+ /** The maximum number of items (elements) in a schema of type {@link SchemaType.ARRAY}. */
+ maxItems?: number;
/** Optional. Map of `Schema` objects. */
properties?: {
[k: string]: T;
};
+ /** A hint suggesting the order in which the keys should appear in the generated JSON string. */
+ propertyOrdering?: string[];
/** Optional. The enum of the property. */
enum?: string[];
/** Optional. The example of the property. */
example?: unknown;
/** Optional. Whether the property is nullable. */
nullable?: boolean;
+ /** The minimum value of a numeric type. */
+ minimum?: number;
+ /** The maximum value of a numeric type. */
+ maximum?: number;
[key: string]: unknown;
}
/**
- * Params passed to {@link Schema}
static methods to create specific
- * {@link Schema}
classes.
+ * Params passed to {@link Schema} static methods to create specific
+ * {@link Schema} classes.
* @public
*/
// eslint-disable-next-line @typescript-eslint/no-empty-object-type
export interface SchemaParams extends SchemaShared {}
/**
- * Final format for {@link Schema}
params passed to backend requests.
+ * Final format for {@link Schema} params passed to backend requests.
* @public
*/
export interface SchemaRequest extends SchemaShared {
@@ -83,7 +103,7 @@ export interface SchemaRequest extends SchemaShared {
}
/**
- * Interface for {@link Schema}
class.
+ * Interface for {@link Schema} class.
* @public
*/
export interface SchemaInterface extends SchemaShared {
@@ -95,7 +115,7 @@ export interface SchemaInterface extends SchemaShared {
}
/**
- * Interface for {@link ObjectSchema}
class.
+ * Interface for {@link ObjectSchema} class.
* @public
*/
export interface ObjectSchemaInterface extends SchemaInterface {
diff --git a/packages/ai/package.json b/packages/ai/package.json
new file mode 100644
index 0000000000..dedb3e2d0e
--- /dev/null
+++ b/packages/ai/package.json
@@ -0,0 +1,89 @@
+{
+ "name": "@react-native-firebase/ai",
+ "version": "22.2.0",
+ "author": "Invertase (http://invertase.io)",
+ "description": "React Native Firebase - Firebase AI is a fully-managed, unified AI development platform for building and using generative AI",
+ "main": "./dist/commonjs/index.js",
+ "module": "./dist/module/index.js",
+ "types": "./dist/typescript/module/lib/index.d.ts",
+ "scripts": {
+ "build": "genversion --esm --semi lib/version.ts",
+ "build:clean": "rimraf dist",
+ "compile": "bob build",
+ "prepare": "yarn tests:ai:mocks && yarn run build && yarn compile"
+ },
+ "repository": {
+ "type": "git",
+ "url": "https://github.com/invertase/react-native-firebase/tree/main/packages/ai"
+ },
+ "license": "Apache-2.0",
+ "keywords": [
+ "react",
+ "react-native",
+ "firebase",
+ "firebase-ai",
+ "ai",
+ "gemini",
+ "generative-ai"
+ ],
+ "peerDependencies": {
+ "@react-native-firebase/app": "22.2.0"
+ },
+ "publishConfig": {
+ "access": "public",
+ "provenance": true
+ },
+ "devDependencies": {
+ "@types/text-encoding": "^0.0.40",
+ "react-native-builder-bob": "^0.40.6",
+ "typescript": "^5.8.3"
+ },
+ "source": "./lib/index.ts",
+ "exports": {
+ ".": {
+ "import": {
+ "types": "./dist/typescript/module/lib/index.d.ts",
+ "default": "./dist/module/index.js"
+ },
+ "require": {
+ "types": "./dist/typescript/commonjs/lib/index.d.ts",
+ "default": "./dist/commonjs/index.js"
+ }
+ }
+ },
+ "files": [
+ "lib",
+ "dist",
+ "!**/__tests__",
+ "!**/__fixtures__",
+ "!**/__mocks__"
+ ],
+ "react-native-builder-bob": {
+ "source": "lib",
+ "output": "dist",
+ "targets": [
+ [
+ "commonjs",
+ {
+ "esm": true
+ }
+ ],
+ [
+ "module",
+ {
+ "esm": true
+ }
+ ],
+ "typescript"
+ ]
+ },
+ "eslintIgnore": [
+ "node_modules/",
+ "dist/"
+ ],
+ "dependencies": {
+ "react-native-fetch-api": "^3.0.0",
+ "text-encoding": "^0.7.0",
+ "web-streams-polyfill": "^4.1.0"
+ }
+}
diff --git a/packages/ai/tsconfig.json b/packages/ai/tsconfig.json
new file mode 100644
index 0000000000..f1d9865812
--- /dev/null
+++ b/packages/ai/tsconfig.json
@@ -0,0 +1,32 @@
+{
+ "compilerOptions": {
+ "rootDir": ".",
+ "allowUnreachableCode": false,
+ "allowUnusedLabels": false,
+ "esModuleInterop": true,
+ "forceConsistentCasingInFileNames": true,
+ "jsx": "react-jsx",
+ "lib": [
+ "ESNext"
+ ],
+ "module": "ESNext",
+ "target": "ESNext",
+ "moduleResolution": "Bundler",
+ "noFallthroughCasesInSwitch": true,
+ "noImplicitReturns": true,
+ "noImplicitUseStrict": false,
+ "noStrictGenericChecks": false,
+ "noUncheckedIndexedAccess": true,
+ "noUnusedLocals": true,
+ "noUnusedParameters": true,
+ "resolveJsonModule": true,
+ "skipLibCheck": true,
+ "strict": true,
+ "baseUrl": ".",
+ "paths": {
+ "@react-native-firebase/app": ["../app/lib"],
+ "@react-native-firebase/auth": ["../auth/lib"],
+ "@react-native-firebase/app-check": ["../app-check/lib"],
+ }
+ }
+}
diff --git a/packages/app/lib/index.d.ts b/packages/app/lib/index.d.ts
index ec668a2a02..f5f90db6aa 100644
--- a/packages/app/lib/index.d.ts
+++ b/packages/app/lib/index.d.ts
@@ -150,6 +150,11 @@ export namespace ReactNativeFirebase {
*/
readonly options: FirebaseAppOptions;
+ /**
+ * The settable config flag for GDPR opt-in/opt-out
+ */
+ automaticDataCollectionEnabled: boolean;
+
/**
* Make this app unusable and free up resources.
*/
diff --git a/packages/app/lib/modular/index.d.ts b/packages/app/lib/modular/index.d.ts
index ec90f4f3d2..49a8337d82 100644
--- a/packages/app/lib/modular/index.d.ts
+++ b/packages/app/lib/modular/index.d.ts
@@ -3,6 +3,7 @@ import { ReactNativeFirebase } from '..';
import FirebaseApp = ReactNativeFirebase.FirebaseApp;
import FirebaseAppOptions = ReactNativeFirebase.FirebaseAppOptions;
import LogLevelString = ReactNativeFirebase.LogLevelString;
+import FirebaseAppConfig = ReactNativeFirebase.FirebaseAppConfig;
/**
* Renders this app unusable and frees the resources of all associated services.
@@ -57,6 +58,16 @@ export function getApps(): FirebaseApp[];
*/
export function initializeApp(options: FirebaseAppOptions, name?: string): Promise;
+/**
+ * Initializes a Firebase app with the provided options and config.
+ * @param options - Options to configure the services used in the app.
+ * @param config - The optional config for your firebase app.
+ * @returns Promise - The initialized Firebase app.
+ */
+export function initializeApp(
+ options: FirebaseAppOptions,
+ config?: FirebaseAppConfig,
+): Promise;
/**
* Retrieves an instance of a Firebase app.
* @param name - The optional name of the app to return ('[DEFAULT]' if omitted).
diff --git a/packages/app/lib/modular/index.js b/packages/app/lib/modular/index.js
index bc4b0b1951..2cb4a6c4ff 100644
--- a/packages/app/lib/modular/index.js
+++ b/packages/app/lib/modular/index.js
@@ -60,11 +60,11 @@ export function getApps() {
/**
* Initializes a Firebase app with the provided options and name.
* @param {FirebaseAppOptions} options - Options to configure the services used in the app.
- * @param {string} [name] - The optional name of the app to initialize ('[DEFAULT]' if omitted).
+ * @param {string | FirebaseAppConfig} [configOrName] - The optional name of the app to initialize ('[DEFAULT]' if omitted).
* @returns {FirebaseApp} - The initialized Firebase app.
*/
-export function initializeApp(options, name) {
- return initializeAppCompat.call(null, options, name, MODULAR_DEPRECATION_ARG);
+export function initializeApp(options, configOrName) {
+ return initializeAppCompat.call(null, options, configOrName, MODULAR_DEPRECATION_ARG);
}
/**
diff --git a/packages/storage/lib/web/RNFBStorageModule.js b/packages/storage/lib/web/RNFBStorageModule.js
index 033f9436a4..bcc575995f 100644
--- a/packages/storage/lib/web/RNFBStorageModule.js
+++ b/packages/storage/lib/web/RNFBStorageModule.js
@@ -353,9 +353,7 @@ export default {
break;
}
- const encoder = new TextEncoder();
-
- const arrayBuffer = encoder.encode(decodedString).buffer;
+ const arrayBuffer = new Uint8Array([...decodedString].map(c => c.charCodeAt(0)));
const task = uploadBytesResumable(ref, arrayBuffer, {
...makeSettableMetadata(metadata),
diff --git a/packages/vertexai/__tests__/api.test.ts b/packages/vertexai/__tests__/api.test.ts
deleted file mode 100644
index 3199157e76..0000000000
--- a/packages/vertexai/__tests__/api.test.ts
+++ /dev/null
@@ -1,103 +0,0 @@
-/**
- * @license
- * Copyright 2024 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-import { describe, expect, it } from '@jest/globals';
-import { getApp, type ReactNativeFirebase } from '../../app/lib';
-
-import { ModelParams, VertexAIErrorCode } from '../lib/types';
-import { VertexAIError } from '../lib/errors';
-import { getGenerativeModel, getVertexAI } from '../lib/index';
-
-import { VertexAI } from '../lib/public-types';
-import { GenerativeModel } from '../lib/models/generative-model';
-
-import '../../auth/lib';
-import '../../app-check/lib';
-import { getAuth } from '../../auth/lib';
-
-const fakeVertexAI: VertexAI = {
- app: {
- name: 'DEFAULT',
- options: {
- apiKey: 'key',
- appId: 'appId',
- projectId: 'my-project',
- },
- } as ReactNativeFirebase.FirebaseApp,
- location: 'us-central1',
-};
-
-describe('Top level API', () => {
- it('should allow auth and app check instances to be passed in', () => {
- const app = getApp();
- const auth = getAuth();
- const appCheck = app.appCheck();
-
- getVertexAI(app, { appCheck, auth });
- });
-
- it('getGenerativeModel throws if no model is provided', () => {
- try {
- getGenerativeModel(fakeVertexAI, {} as ModelParams);
- } catch (e) {
- expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_MODEL);
- expect((e as VertexAIError).message).toContain(
- `VertexAI: Must provide a model name. Example: ` +
- `getGenerativeModel({ model: 'my-model-name' }) (vertexAI/${VertexAIErrorCode.NO_MODEL})`,
- );
- }
- });
-
- it('getGenerativeModel throws if no apiKey is provided', () => {
- const fakeVertexNoApiKey = {
- ...fakeVertexAI,
- app: { options: { projectId: 'my-project' } },
- } as VertexAI;
- try {
- getGenerativeModel(fakeVertexNoApiKey, { model: 'my-model' });
- } catch (e) {
- expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_API_KEY);
- expect((e as VertexAIError).message).toBe(
- `VertexAI: The "apiKey" field is empty in the local ` +
- `Firebase config. Firebase VertexAI requires this field to` +
- ` contain a valid API key. (vertexAI/${VertexAIErrorCode.NO_API_KEY})`,
- );
- }
- });
-
- it('getGenerativeModel throws if no projectId is provided', () => {
- const fakeVertexNoProject = {
- ...fakeVertexAI,
- app: { options: { apiKey: 'my-key' } },
- } as VertexAI;
- try {
- getGenerativeModel(fakeVertexNoProject, { model: 'my-model' });
- } catch (e) {
- expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_PROJECT_ID);
- expect((e as VertexAIError).message).toBe(
- `VertexAI: The "projectId" field is empty in the local` +
- ` Firebase config. Firebase VertexAI requires this field ` +
- `to contain a valid project ID. (vertexAI/${VertexAIErrorCode.NO_PROJECT_ID})`,
- );
- }
- });
-
- it('getGenerativeModel gets a GenerativeModel', () => {
- const genModel = getGenerativeModel(fakeVertexAI, { model: 'my-model' });
- expect(genModel).toBeInstanceOf(GenerativeModel);
- expect(genModel.model).toBe('publishers/google/models/my-model');
- });
-});
diff --git a/packages/vertexai/__tests__/backwards-compatbility.test.ts b/packages/vertexai/__tests__/backwards-compatbility.test.ts
new file mode 100644
index 0000000000..be4641f8c2
--- /dev/null
+++ b/packages/vertexai/__tests__/backwards-compatbility.test.ts
@@ -0,0 +1,83 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+import { describe, expect, it } from '@jest/globals';
+import {
+ AIError,
+ AIModel,
+ GenerativeModel,
+ VertexAIError,
+ VertexAIErrorCode,
+ VertexAIModel,
+ VertexAI,
+ getGenerativeModel,
+} from '../lib/index';
+import { AI, AIErrorCode } from '@react-native-firebase/ai';
+import { VertexAIBackend } from '@react-native-firebase/ai';
+import { ReactNativeFirebase } from '@react-native-firebase/app';
+
+function assertAssignable(): void {}
+
+const fakeAI: AI = {
+ app: {
+ name: 'DEFAULT',
+ automaticDataCollectionEnabled: true,
+ options: {
+ apiKey: 'key',
+ projectId: 'my-project',
+ appId: 'app-id',
+ },
+ } as ReactNativeFirebase.FirebaseApp,
+ backend: new VertexAIBackend('us-central1'),
+ location: 'us-central1',
+};
+
+const fakeVertexAI: VertexAI = fakeAI;
+
+describe('VertexAI backward compatibility', function () {
+ it('should allow VertexAI to be assignable to AI', function () {
+ assertAssignable();
+ });
+
+ it('should allow VertexAIError to extend AIError', function () {
+ assertAssignable();
+ const err = new VertexAIError(VertexAIErrorCode.ERROR, '');
+ expect(err).toBeInstanceOf(AIError);
+ expect(err).toBeInstanceOf(VertexAIError);
+ });
+
+ it('should allow VertexAIErrorCode to be assignable to AIErrorCode', () => {
+ assertAssignable();
+ const errCode = AIErrorCode.ERROR;
+ expect(errCode).toBe(VertexAIErrorCode.ERROR);
+ });
+
+ it('should allow VertexAIModel to extend AIModel', () => {
+ assertAssignable();
+
+ const model = new GenerativeModel(fakeAI, { model: 'model-name' });
+ expect(model).toBeInstanceOf(AIModel);
+ expect(model).toBeInstanceOf(VertexAIModel);
+ });
+
+ describe('VertexAI functions', () => {
+ it('should return a VertexAIModel assignable to AIModel from getGenerativeModel()l', () => {
+ const model = getGenerativeModel(fakeVertexAI, { model: 'model-name' });
+ expect(model).toBeInstanceOf(AIModel);
+ expect(model).toBeInstanceOf(VertexAIModel);
+ });
+ });
+});
diff --git a/packages/vertexai/__tests__/count-tokens.test.ts b/packages/vertexai/__tests__/count-tokens.test.ts
deleted file mode 100644
index 3cd7b78970..0000000000
--- a/packages/vertexai/__tests__/count-tokens.test.ts
+++ /dev/null
@@ -1,88 +0,0 @@
-/**
- * @license
- * Copyright 2024 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-import { describe, expect, it, afterEach, jest } from '@jest/globals';
-import { getMockResponse } from './test-utils/mock-response';
-import * as request from '../lib/requests/request';
-import { countTokens } from '../lib/methods/count-tokens';
-import { CountTokensRequest } from '../lib/types';
-import { ApiSettings } from '../lib/types/internal';
-import { Task } from '../lib/requests/request';
-
-const fakeApiSettings: ApiSettings = {
- apiKey: 'key',
- project: 'my-project',
- location: 'us-central1',
-};
-
-const fakeRequestParams: CountTokensRequest = {
- contents: [{ parts: [{ text: 'hello' }], role: 'user' }],
-};
-
-describe('countTokens()', () => {
- afterEach(() => {
- jest.restoreAllMocks();
- });
-
- it('total tokens', async () => {
- const mockResponse = getMockResponse('unary-success-total-tokens.json');
- const makeRequestStub = jest
- .spyOn(request, 'makeRequest')
- .mockResolvedValue(mockResponse as Response);
- const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams);
- expect(result.totalTokens).toBe(6);
- expect(result.totalBillableCharacters).toBe(16);
- expect(makeRequestStub).toHaveBeenCalledWith(
- 'model',
- Task.COUNT_TOKENS,
- fakeApiSettings,
- false,
- expect.stringContaining('contents'),
- undefined,
- );
- });
-
- it('total tokens no billable characters', async () => {
- const mockResponse = getMockResponse('unary-success-no-billable-characters.json');
- const makeRequestStub = jest
- .spyOn(request, 'makeRequest')
- .mockResolvedValue(mockResponse as Response);
- const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams);
- expect(result.totalTokens).toBe(258);
- expect(result).not.toHaveProperty('totalBillableCharacters');
- expect(makeRequestStub).toHaveBeenCalledWith(
- 'model',
- Task.COUNT_TOKENS,
- fakeApiSettings,
- false,
- expect.stringContaining('contents'),
- undefined,
- );
- });
-
- it('model not found', async () => {
- const mockResponse = getMockResponse('unary-failure-model-not-found.json');
- const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({
- ok: false,
- status: 404,
- json: mockResponse.json,
- } as Response);
- await expect(countTokens(fakeApiSettings, 'model', fakeRequestParams)).rejects.toThrow(
- /404.*not found/,
- );
- expect(mockFetch).toHaveBeenCalled();
- });
-});
diff --git a/packages/vertexai/lib/index.ts b/packages/vertexai/lib/index.ts
index 580f1bc86b..f5de5f0836 100644
--- a/packages/vertexai/lib/index.ts
+++ b/packages/vertexai/lib/index.ts
@@ -15,22 +15,21 @@
* limitations under the License.
*/
-import './polyfills';
import { getApp, ReactNativeFirebase } from '@react-native-firebase/app';
-import { ModelParams, RequestOptions, VertexAIErrorCode } from './types';
-import { DEFAULT_LOCATION } from './constants';
-import { VertexAI, VertexAIOptions } from './public-types';
-import { VertexAIError } from './errors';
-import { GenerativeModel } from './models/generative-model';
-import { VertexAIService } from './service';
-export { ChatSession } from './methods/chat-session';
-export * from './requests/schema-builder';
+import { VertexAIBackend, AIModel, AIError, AIErrorCode } from '@react-native-firebase/ai';
+import { VertexAIOptions, VertexAI } from './public-types';
-export { GenerativeModel };
-
-export { VertexAIError };
+const DEFAULT_LOCATION = 'us-central1';
/**
+ * @deprecated Use the new {@link getAI | getAI()} instead. The Vertex AI in Firebase SDK has been
+ * replaced with the Firebase AI SDK to accommodate the evolving set of supported features and
+ * services. For migration details, see the {@link https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk | migration guide}.
+ *
+ * Returns a {@link VertexAI} instance for the given app, configured to use the
+ * Vertex AI Gemini API. This instance will be
+ * configured to use the Vertex AI Gemini API.
+ *
* Returns a {@link VertexAI}
instance for the given app.
*
* @public
@@ -49,25 +48,34 @@ export function getVertexAI(
location: options?.location || DEFAULT_LOCATION,
appCheck: options?.appCheck || null,
auth: options?.auth || null,
- } as VertexAIService;
+ backend: new VertexAIBackend(options?.location || DEFAULT_LOCATION),
+ };
}
/**
- * Returns a {@link GenerativeModel}
class with methods for inference
- * and other functionality.
+ * @deprecated Use the new {@link AIModel} instead. The Vertex AI in Firebase SDK has been
+ * replaced with the Firebase AI SDK to accommodate the evolving set of supported features and
+ * services. For migration details, see the {@link https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk | migration guide}.
+ *
+ * Base class for Firebase AI model APIs.
*
* @public
*/
-export function getGenerativeModel(
- vertexAI: VertexAI,
- modelParams: ModelParams,
- requestOptions?: RequestOptions,
-): GenerativeModel {
- if (!modelParams.model) {
- throw new VertexAIError(
- VertexAIErrorCode.NO_MODEL,
- `Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`,
- );
- }
- return new GenerativeModel(vertexAI, modelParams, requestOptions);
-}
+export const VertexAIModel = AIModel;
+
+/**
+ * @deprecated Use the new {@link AIError} instead. The Vertex AI in Firebase SDK has been
+ * replaced with the Firebase AI SDK to accommodate the evolving set of supported features and
+ * services. For migration details, see the {@link https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk | migration guide}.
+ *
+ * Error class for the Firebase AI SDK.
+ *
+ * @public
+ */
+export const VertexAIError = AIError;
+
+export { AIErrorCode as VertexAIErrorCode };
+export { VertexAIBackend, AIModel, AIError };
+
+export * from './public-types';
+export * from '@react-native-firebase/ai';
diff --git a/packages/vertexai/lib/public-types.ts b/packages/vertexai/lib/public-types.ts
index 24c6be6efa..1ce05549ae 100644
--- a/packages/vertexai/lib/public-types.ts
+++ b/packages/vertexai/lib/public-types.ts
@@ -14,26 +14,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
-import { ReactNativeFirebase } from '@react-native-firebase/app';
+import { AI, AIErrorCode } from '@react-native-firebase/ai';
import { FirebaseAuthTypes } from '@react-native-firebase/auth';
import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check';
-export * from './types';
-
/**
+ * @deprecated Use the new {@link AI | AI} instead. The Vertex AI in Firebase SDK has been
+ * replaced with the Firebase AI SDK to accommodate the evolving set of supported features and
+ * services. For migration details, see the {@link https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk | migration guide}.
* An instance of the Vertex AI in Firebase SDK.
* @public
*/
-export interface VertexAI {
- /**
- * The {@link @firebase/app#FirebaseApp} this {@link VertexAI}
instance is associated with.
- */
- app: ReactNativeFirebase.FirebaseApp;
- location: string;
- appCheck?: FirebaseAppCheckTypes.Module | null;
- auth?: FirebaseAuthTypes.Module | null;
-}
+export type VertexAI = AI;
/**
* Options when initializing the Vertex AI in Firebase SDK.
@@ -44,3 +36,7 @@ export interface VertexAIOptions {
appCheck?: FirebaseAppCheckTypes.Module | null;
auth?: FirebaseAuthTypes.Module | null;
}
+
+export type VertexAIErrorCode = AIErrorCode;
+
+export const VERTEX_TYPE = 'vertexAI';
diff --git a/packages/vertexai/lib/types/enums.ts b/packages/vertexai/lib/types/enums.ts
deleted file mode 100644
index 010aff903a..0000000000
--- a/packages/vertexai/lib/types/enums.ts
+++ /dev/null
@@ -1,149 +0,0 @@
-/**
- * @license
- * Copyright 2024 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * Role is the producer of the content.
- * @public
- */
-export type Role = (typeof POSSIBLE_ROLES)[number];
-
-/**
- * Possible roles.
- * @public
- */
-export const POSSIBLE_ROLES = ['user', 'model', 'function', 'system'] as const;
-
-/**
- * Harm categories that would cause prompts or candidates to be blocked.
- * @public
- */
-export enum HarmCategory {
- HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH',
- HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
- HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT',
- HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT',
-}
-
-/**
- * Threshold above which a prompt or candidate will be blocked.
- * @public
- */
-export enum HarmBlockThreshold {
- // Content with NEGLIGIBLE will be allowed.
- BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE',
- // Content with NEGLIGIBLE and LOW will be allowed.
- BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE',
- // Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
- BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH',
- // All content will be allowed.
- BLOCK_NONE = 'BLOCK_NONE',
-}
-
-/**
- * @public
- */
-export enum HarmBlockMethod {
- // The harm block method uses both probability and severity scores.
- SEVERITY = 'SEVERITY',
- // The harm block method uses the probability score.
- PROBABILITY = 'PROBABILITY',
-}
-
-/**
- * Probability that a prompt or candidate matches a harm category.
- * @public
- */
-export enum HarmProbability {
- // Content has a negligible chance of being unsafe.
- NEGLIGIBLE = 'NEGLIGIBLE',
- // Content has a low chance of being unsafe.
- LOW = 'LOW',
- // Content has a medium chance of being unsafe.
- MEDIUM = 'MEDIUM',
- // Content has a high chance of being unsafe.
- HIGH = 'HIGH',
-}
-
-/**
- * Harm severity levels.
- * @public
- */
-export enum HarmSeverity {
- // Negligible level of harm severity.
- HARM_SEVERITY_NEGLIGIBLE = 'HARM_SEVERITY_NEGLIGIBLE',
- // Low level of harm severity.
- HARM_SEVERITY_LOW = 'HARM_SEVERITY_LOW',
- // Medium level of harm severity.
- HARM_SEVERITY_MEDIUM = 'HARM_SEVERITY_MEDIUM',
- // High level of harm severity.
- HARM_SEVERITY_HIGH = 'HARM_SEVERITY_HIGH',
-}
-
-/**
- * Reason that a prompt was blocked.
- * @public
- */
-export enum BlockReason {
- // The prompt was blocked because it contained terms from the terminology blocklist.
- BLOCKLIST = 'BLOCKLIST',
- // The prompt was blocked due to prohibited content.
- PROHIBITED_CONTENT = 'PROHIBITED_CONTENT',
- // Content was blocked by safety settings.
- SAFETY = 'SAFETY',
- // Content was blocked, but the reason is uncategorized.
- OTHER = 'OTHER',
-}
-
-/**
- * Reason that a candidate finished.
- * @public
- */
-export enum FinishReason {
- // Token generation was stopped because the response contained forbidden terms.
- BLOCKLIST = 'BLOCKLIST',
- // Token generation was stopped because the response contained potentially prohibited content.
- PROHIBITED_CONTENT = 'PROHIBITED_CONTENT',
- // Token generation was stopped because of Sensitive Personally Identifiable Information (SPII).
- SPII = 'SPII',
- // Natural stop point of the model or provided stop sequence.
- STOP = 'STOP',
- // The maximum number of tokens as specified in the request was reached.
- MAX_TOKENS = 'MAX_TOKENS',
- // The candidate content was flagged for safety reasons.
- SAFETY = 'SAFETY',
- // The candidate content was flagged for recitation reasons.
- RECITATION = 'RECITATION',
- // Unknown reason.
- OTHER = 'OTHER',
-}
-
-/**
- * @public
- */
-export enum FunctionCallingMode {
- // Default model behavior, model decides to predict either a function call
- // or a natural language response.
- AUTO = 'AUTO',
- // Model is constrained to always predicting a function call only.
- // If "allowed_function_names" is set, the predicted function call will be
- // limited to any one of "allowed_function_names", else the predicted
- // function call will be any one of the provided "function_declarations".
- ANY = 'ANY',
- // Model will not predict any function call. Model behavior is same as when
- // not passing any function declarations.
- NONE = 'NONE',
-}
diff --git a/packages/vertexai/package.json b/packages/vertexai/package.json
index 9b53b05cad..64fa61d307 100644
--- a/packages/vertexai/package.json
+++ b/packages/vertexai/package.json
@@ -10,7 +10,7 @@
"build": "genversion --esm --semi lib/version.ts",
"build:clean": "rimraf dist",
"compile": "bob build",
- "prepare": "yarn tests:vertex:mocks && yarn run build && yarn compile"
+ "prepare": "yarn run build && yarn compile"
},
"repository": {
"type": "git",
@@ -25,6 +25,12 @@
"gemini",
"generative-ai"
],
+ "dependencies": {
+ "@react-native-firebase/ai": "22.2.0",
+ "react-native-fetch-api": "^3.0.0",
+ "text-encoding": "^0.7.0",
+ "web-streams-polyfill": "^4.1.0"
+ },
"peerDependencies": {
"@react-native-firebase/app": "22.4.0"
},
@@ -79,10 +85,5 @@
"eslintIgnore": [
"node_modules/",
"dist/"
- ],
- "dependencies": {
- "react-native-fetch-api": "^3.0.0",
- "text-encoding": "^0.7.0",
- "web-streams-polyfill": "^4.1.0"
- }
+ ]
}
diff --git a/packages/vertexai/tsconfig.json b/packages/vertexai/tsconfig.json
index f1d9865812..f371c6edc9 100644
--- a/packages/vertexai/tsconfig.json
+++ b/packages/vertexai/tsconfig.json
@@ -26,7 +26,7 @@
"paths": {
"@react-native-firebase/app": ["../app/lib"],
"@react-native-firebase/auth": ["../auth/lib"],
- "@react-native-firebase/app-check": ["../app-check/lib"],
+ "@react-native-firebase/app-check": ["../app-check/lib"]
}
}
}
diff --git a/scripts/vertex_mock_responses.sh b/scripts/ai_mock_responses.sh
similarity index 95%
rename from scripts/vertex_mock_responses.sh
rename to scripts/ai_mock_responses.sh
index 98930546d0..c32b324af6 100755
--- a/scripts/vertex_mock_responses.sh
+++ b/scripts/ai_mock_responses.sh
@@ -29,7 +29,7 @@ tail -n1)
# Define the directory name using REPO_NAME and LATEST_TAG.
CLONE_DIR="${REPO_NAME}_${LATEST_TAG//./_}"
-cd "$(dirname "$0")/../packages/vertexai/__tests__/test-utils" || exit
+cd "$(dirname "$0")/../packages/ai/__tests__/test-utils" || exit
# Remove any directories that start with REPO_NAME but are not CLONE_DIR
for dir in "${REPO_NAME}"*; do
diff --git a/tests/test-app/examples/ai/ai.js b/tests/test-app/examples/ai/ai.js
new file mode 100644
index 0000000000..6a9d9a428e
--- /dev/null
+++ b/tests/test-app/examples/ai/ai.js
@@ -0,0 +1,328 @@
+import React, { useState } from 'react';
+import { AppRegistry, Button, View, Text, Pressable } from 'react-native';
+
+import { getApp } from '@react-native-firebase/app';
+import { getAI, getGenerativeModel, Schema } from '@react-native-firebase/ai';
+import {
+ PDF_BASE_64,
+ POEM_BASE_64,
+ VIDEO_BASE_64,
+ IMAGE_BASE_64,
+ EMOJI_BASE_64,
+} from '../vertexai/base-64-media';
+
+// eslint-disable-next-line react/prop-types
+function OptionSelector({ selectedOption, setSelectedOption }) {
+ const options = ['image', 'pdf', 'video', 'audio', 'emoji'];
+
+ return (
+
+ {options.map(option => {
+ const isSelected = selectedOption === option;
+ return (
+ setSelectedOption(option)}
+ style={{
+ paddingVertical: 10,
+ paddingHorizontal: 15,
+ margin: 5,
+ borderRadius: 8,
+ borderWidth: 1,
+ borderColor: isSelected ? '#007bff' : '#ccc',
+ backgroundColor: isSelected ? '#007bff' : '#fff',
+ }}
+ >
+
+ {option.toUpperCase()}
+
+
+ );
+ })}
+
+ );
+}
+
+function App() {
+ const [selectedOption, setSelectedOption] = useState('image');
+ const getMediaDetails = option => {
+ switch (option) {
+ case 'image':
+ return { data: IMAGE_BASE_64.trim(), mimeType: 'image/jpeg', prompt: 'What can you see?' };
+ case 'pdf':
+ return {
+ data: PDF_BASE_64.trim(),
+ mimeType: 'application/pdf',
+ prompt: 'What can you see?',
+ };
+ case 'video':
+ return { data: VIDEO_BASE_64.trim(), mimeType: 'video/mp4', prompt: 'What can you see?' };
+ case 'audio':
+ return { data: POEM_BASE_64.trim(), mimeType: 'audio/mp3', prompt: 'What can you hear?' };
+ case 'emoji':
+ return { data: EMOJI_BASE_64.trim(), mimeType: 'image/png', prompt: 'What can you see?' };
+ default:
+ console.error('Invalid option selected');
+ return null;
+ }
+ };
+
+ return (
+
+
+ {
+ try {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+
+ const result = await model.generateContent('What is 2 + 2?');
+
+ console.log('result', result.response.text());
+ } catch (e) {
+ console.error(e);
+ }
+ }}
+ />
+ {
+ try {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+
+ const result = await model.generateContentStream('Write me a short, funny rap');
+
+ let text = '';
+ for await (const chunk of result.stream) {
+ const chunkText = chunk.text();
+ console.log(chunkText);
+
+ text += chunkText;
+ }
+
+ console.log('result', text);
+ } catch (e) {
+ console.error(e);
+ }
+ }}
+ />
+ Select a File Type for multi-modal input:
+
+ {
+ try {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+ const mediaDetails = getMediaDetails(selectedOption);
+ if (!mediaDetails) return;
+
+ const { data, mimeType, prompt } = mediaDetails;
+
+ // Call generateContentStream with the text and images
+ const response = await model.generateContentStream([
+ prompt,
+ { inlineData: { mimeType, data } },
+ ]);
+
+ let text = '';
+ for await (const chunk of response.stream) {
+ text += chunk.text();
+ }
+
+ console.log('Generated text:', text);
+ } catch (e) {
+ console.error(e);
+ }
+ }}
+ />
+ {
+ try {
+ const app = getApp();
+ const ai = getAI(app);
+ const jsonSchema = Schema.object({
+ properties: {
+ characters: Schema.array({
+ items: Schema.object({
+ properties: {
+ name: Schema.string(),
+ accessory: Schema.string(),
+ age: Schema.number(),
+ species: Schema.string(),
+ },
+ optionalProperties: ['accessory'],
+ }),
+ }),
+ },
+ });
+ const model = getGenerativeModel(ai, {
+ model: 'gemini-1.5-flash',
+ generationConfig: {
+ responseMimeType: 'application/json',
+ responseSchema: jsonSchema,
+ },
+ });
+
+ let prompt = "For use in a children's card game, generate 10 animal-based characters.";
+
+ let result = await model.generateContent(prompt);
+ console.log(result.response.text());
+ } catch (e) {
+ console.error(e);
+ }
+ }}
+ />
+ {
+ try {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+
+ const chat = model.startChat({
+ history: [
+ {
+ role: 'user',
+ parts: [{ text: 'Hello, I have 2 dogs in my house.' }],
+ },
+ {
+ role: 'model',
+ parts: [{ text: 'Great to meet you. What would you like to know?' }],
+ },
+ ],
+ generationConfig: {
+ maxOutputTokens: 100,
+ },
+ });
+
+ const msg = 'How many paws are in my house?';
+ const result = await chat.sendMessageStream(msg);
+
+ let text = '';
+ for await (const chunk of result.stream) {
+ const chunkText = chunk.text();
+ text += chunkText;
+ }
+ console.log(text);
+ chat.getHistory();
+ } catch (e) {
+ console.error(e);
+ }
+ }}
+ />
+ {
+ try {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, { model: 'gemini-1.5-flash' });
+
+ const result = await model.countTokens('What is 2 + 2?');
+
+ console.log('totalBillableCharacters', result.totalBillableCharacters);
+ console.log('totalTokens', result.totalTokens);
+ } catch (e) {
+ console.error(e);
+ }
+ }}
+ />
+
+ {
+ // This function calls a hypothetical external API that returns
+ // a collection of weather information for a given location on a given date.
+ // `location` is an object of the form { city: string, state: string }
+ async function fetchWeather({ location, date }) {
+ // For demo purposes, this hypothetical response is hardcoded here in the expected format.
+ return {
+ temperature: 38,
+ chancePrecipitation: '56%',
+ cloudConditions: 'partlyCloudy',
+ };
+ }
+ const fetchWeatherTool = {
+ functionDeclarations: [
+ {
+ name: 'fetchWeather',
+ description: 'Get the weather conditions for a specific city on a specific date',
+ parameters: Schema.object({
+ properties: {
+ location: Schema.object({
+ description:
+ 'The name of the city and its state for which to get ' +
+ 'the weather. Only cities in the USA are supported.',
+ properties: {
+ city: Schema.string({
+ description: 'The city of the location.',
+ }),
+ state: Schema.string({
+ description: 'The US state of the location.',
+ }),
+ },
+ }),
+ date: Schema.string({
+ description:
+ 'The date for which to get the weather. Date must be in the' +
+ ' format: YYYY-MM-DD.',
+ }),
+ },
+ }),
+ },
+ ],
+ };
+ try {
+ const app = getApp();
+ const ai = getAI(app);
+ const model = getGenerativeModel(ai, {
+ model: 'gemini-1.5-flash',
+ tools: fetchWeatherTool,
+ });
+
+ const chat = model.startChat();
+ const prompt = 'What was the weather in Boston on October 17, 2024?';
+
+ // Send the user's question (the prompt) to the model using multi-turn chat.
+ let result = await chat.sendMessage(prompt);
+ const functionCalls = result.response.functionCalls();
+ let functionCall;
+ let functionResult;
+ // When the model responds with one or more function calls, invoke the function(s).
+ if (functionCalls.length > 0) {
+ for (const call of functionCalls) {
+ if (call.name === 'fetchWeather') {
+ // Forward the structured input data prepared by the model
+ // to the hypothetical external API.
+ functionResult = await fetchWeather(call.args);
+ functionCall = call;
+ }
+ }
+ }
+ result = await chat.sendMessage([
+ {
+ functionResponse: {
+ name: functionCall.name, // "fetchWeather"
+ response: functionResult,
+ },
+ },
+ ]);
+ console.log(result.response.text());
+ } catch (e) {
+ console.error(e);
+ }
+ }}
+ />
+
+ );
+}
+
+AppRegistry.registerComponent('testing', () => App);
diff --git a/tsconfig-jest.json b/tsconfig-jest.json
index 0149111b06..8a42f66917 100644
--- a/tsconfig-jest.json
+++ b/tsconfig-jest.json
@@ -6,6 +6,7 @@
"@react-native-firebase/app": ["packages/app/lib"],
"@react-native-firebase/auth": ["packages/auth/lib"],
"@react-native-firebase/app-check": ["packages/app-check/lib"],
+ "@react-native-firebase/ai": ["packages/ai/lib"],
}
}
}
diff --git a/yarn.lock b/yarn.lock
index b04e329a5b..bdb35bcb00 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -5332,6 +5332,21 @@ __metadata:
languageName: node
linkType: hard
+"@react-native-firebase/ai@npm:22.2.0, @react-native-firebase/ai@workspace:packages/ai":
+ version: 0.0.0-use.local
+ resolution: "@react-native-firebase/ai@workspace:packages/ai"
+ dependencies:
+ "@types/text-encoding": "npm:^0.0.40"
+ react-native-builder-bob: "npm:^0.40.6"
+ react-native-fetch-api: "npm:^3.0.0"
+ text-encoding: "npm:^0.7.0"
+ typescript: "npm:^5.8.3"
+ web-streams-polyfill: "npm:^4.1.0"
+ peerDependencies:
+ "@react-native-firebase/app": 22.2.0
+ languageName: unknown
+ linkType: soft
+
"@react-native-firebase/analytics@npm:22.4.0, @react-native-firebase/analytics@workspace:packages/analytics":
version: 0.0.0-use.local
resolution: "@react-native-firebase/analytics@workspace:packages/analytics"
@@ -5548,6 +5563,7 @@ __metadata:
version: 0.0.0-use.local
resolution: "@react-native-firebase/vertexai@workspace:packages/vertexai"
dependencies:
+ "@react-native-firebase/ai": "npm:22.2.0"
"@types/text-encoding": "npm:^0.0.40"
react-native-builder-bob: "npm:^0.40.6"
react-native-fetch-api: "npm:^3.0.0"