From c17a2c910459a73a3f0b5f113ca20b058e10a0ca Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Mon, 2 Jun 2025 13:42:31 +0100 Subject: [PATCH 01/85] chore: initial cp over to ai directory --- packages/ai/CHANGELOG.md | 46 +++ packages/ai/LICENSE | 32 ++ packages/ai/README.md | 54 +++ packages/ai/__tests__/api.test.ts | 103 +++++ .../ai/__tests__/chat-session-helpers.test.ts | 154 +++++++ packages/ai/__tests__/chat-session.test.ts | 100 +++++ packages/ai/__tests__/count-tokens.test.ts | 88 ++++ .../ai/__tests__/generate-content.test.ts | 204 +++++++++ .../ai/__tests__/generative-model.test.ts | 266 ++++++++++++ packages/ai/__tests__/request-helpers.test.ts | 204 +++++++++ packages/ai/__tests__/request.test.ts | 353 ++++++++++++++++ .../ai/__tests__/response-helpers.test.ts | 236 +++++++++++ packages/ai/__tests__/schema-builder.test.ts | 389 ++++++++++++++++++ packages/ai/__tests__/service.test.ts | 45 ++ packages/ai/__tests__/stream-reader.test.ts | 370 +++++++++++++++++ .../ai/__tests__/test-utils/convert-mocks.ts | 67 +++ .../ai/__tests__/test-utils/mock-response.ts | 69 ++++ packages/ai/e2e/fetch.e2e.js | 67 +++ packages/ai/lib/constants.ts | 35 ++ packages/ai/lib/errors.ts | 52 +++ packages/ai/lib/index.ts | 73 ++++ packages/ai/lib/logger.ts | 20 + .../ai/lib/methods/chat-session-helpers.ts | 116 ++++++ packages/ai/lib/methods/chat-session.ts | 182 ++++++++ packages/ai/lib/methods/count-tokens.ts | 37 ++ packages/ai/lib/methods/generate-content.ts | 66 +++ packages/ai/lib/models/generative-model.ts | 180 ++++++++ packages/ai/lib/polyfills.ts | 35 ++ packages/ai/lib/public-types.ts | 46 +++ packages/ai/lib/requests/request-helpers.ts | 116 ++++++ packages/ai/lib/requests/request.ts | 242 +++++++++++ packages/ai/lib/requests/response-helpers.ts | 186 +++++++++ packages/ai/lib/requests/schema-builder.ts | 281 +++++++++++++ packages/ai/lib/requests/stream-reader.ts | 213 ++++++++++ packages/ai/lib/service.ts | 39 ++ packages/ai/lib/types/content.ts | 162 ++++++++ packages/ai/lib/types/enums.ts | 149 +++++++ packages/ai/lib/types/error.ts | 98 +++++ packages/ai/lib/types/index.ts | 23 ++ packages/ai/lib/types/internal.ts | 25 ++ packages/ai/lib/types/polyfills.d.ts | 15 + packages/ai/lib/types/requests.ts | 198 +++++++++ packages/ai/lib/types/responses.ts | 209 ++++++++++ packages/ai/lib/types/schema.ts | 104 +++++ packages/ai/package.json | 88 ++++ packages/ai/tsconfig.json | 32 ++ 46 files changed, 5869 insertions(+) create mode 100644 packages/ai/CHANGELOG.md create mode 100644 packages/ai/LICENSE create mode 100644 packages/ai/README.md create mode 100644 packages/ai/__tests__/api.test.ts create mode 100644 packages/ai/__tests__/chat-session-helpers.test.ts create mode 100644 packages/ai/__tests__/chat-session.test.ts create mode 100644 packages/ai/__tests__/count-tokens.test.ts create mode 100644 packages/ai/__tests__/generate-content.test.ts create mode 100644 packages/ai/__tests__/generative-model.test.ts create mode 100644 packages/ai/__tests__/request-helpers.test.ts create mode 100644 packages/ai/__tests__/request.test.ts create mode 100644 packages/ai/__tests__/response-helpers.test.ts create mode 100644 packages/ai/__tests__/schema-builder.test.ts create mode 100644 packages/ai/__tests__/service.test.ts create mode 100644 packages/ai/__tests__/stream-reader.test.ts create mode 100644 packages/ai/__tests__/test-utils/convert-mocks.ts create mode 100644 packages/ai/__tests__/test-utils/mock-response.ts create mode 100644 packages/ai/e2e/fetch.e2e.js create mode 100644 packages/ai/lib/constants.ts create mode 100644 packages/ai/lib/errors.ts create mode 100644 packages/ai/lib/index.ts create mode 100644 packages/ai/lib/logger.ts create mode 100644 packages/ai/lib/methods/chat-session-helpers.ts create mode 100644 packages/ai/lib/methods/chat-session.ts create mode 100644 packages/ai/lib/methods/count-tokens.ts create mode 100644 packages/ai/lib/methods/generate-content.ts create mode 100644 packages/ai/lib/models/generative-model.ts create mode 100644 packages/ai/lib/polyfills.ts create mode 100644 packages/ai/lib/public-types.ts create mode 100644 packages/ai/lib/requests/request-helpers.ts create mode 100644 packages/ai/lib/requests/request.ts create mode 100644 packages/ai/lib/requests/response-helpers.ts create mode 100644 packages/ai/lib/requests/schema-builder.ts create mode 100644 packages/ai/lib/requests/stream-reader.ts create mode 100644 packages/ai/lib/service.ts create mode 100644 packages/ai/lib/types/content.ts create mode 100644 packages/ai/lib/types/enums.ts create mode 100644 packages/ai/lib/types/error.ts create mode 100644 packages/ai/lib/types/index.ts create mode 100644 packages/ai/lib/types/internal.ts create mode 100644 packages/ai/lib/types/polyfills.d.ts create mode 100644 packages/ai/lib/types/requests.ts create mode 100644 packages/ai/lib/types/responses.ts create mode 100644 packages/ai/lib/types/schema.ts create mode 100644 packages/ai/package.json create mode 100644 packages/ai/tsconfig.json diff --git a/packages/ai/CHANGELOG.md b/packages/ai/CHANGELOG.md new file mode 100644 index 0000000000..1ac68b1cef --- /dev/null +++ b/packages/ai/CHANGELOG.md @@ -0,0 +1,46 @@ +# Change Log + +All notable changes to this project will be documented in this file. +See [Conventional Commits](https://conventionalcommits.org) for commit guidelines. + +## [22.2.0](https://github.com/invertase/react-native-firebase/compare/v22.1.0...v22.2.0) (2025-05-12) + +**Note:** Version bump only for package @react-native-firebase/vertexai + +## [22.1.0](https://github.com/invertase/react-native-firebase/compare/v22.0.0...v22.1.0) (2025-04-30) + +### Bug Fixes + +- **vertexai:** package.json main needs updating to commonjs path ([abe6c19](https://github.com/invertase/react-native-firebase/commit/abe6c190e6a22676fc58a4c5c7740ddeba2efd93)) + +## [22.0.0](https://github.com/invertase/react-native-firebase/compare/v21.14.0...v22.0.0) (2025-04-25) + +### Bug Fixes + +- enable provenance signing during publish ([4535f0d](https://github.com/invertase/react-native-firebase/commit/4535f0d5756c89aeb8f8e772348c71d8176348be)) + +## [21.14.0](https://github.com/invertase/react-native-firebase/compare/v21.13.0...v21.14.0) (2025-04-14) + +**Note:** Version bump only for package @react-native-firebase/vertexai + +## [21.13.0](https://github.com/invertase/react-native-firebase/compare/v21.12.3...v21.13.0) (2025-03-31) + +**Note:** Version bump only for package @react-native-firebase/vertexai + +## [21.12.3](https://github.com/invertase/react-native-firebase/compare/v21.12.2...v21.12.3) (2025-03-26) + +**Note:** Version bump only for package @react-native-firebase/vertexai + +## [21.12.2](https://github.com/invertase/react-native-firebase/compare/v21.12.1...v21.12.2) (2025-03-23) + +**Note:** Version bump only for package @react-native-firebase/vertexai + +## [21.12.1](https://github.com/invertase/react-native-firebase/compare/v21.12.0...v21.12.1) (2025-03-22) + +**Note:** Version bump only for package @react-native-firebase/vertexai + +## [21.12.0](https://github.com/invertase/react-native-firebase/compare/v21.11.0...v21.12.0) (2025-03-03) + +### Features + +- vertexAI package support ([#8236](https://github.com/invertase/react-native-firebase/issues/8236)) ([a1d1361](https://github.com/invertase/react-native-firebase/commit/a1d13610f443a96a7195b3f769f77d9676c0e577)) diff --git a/packages/ai/LICENSE b/packages/ai/LICENSE new file mode 100644 index 0000000000..ef3ed44f06 --- /dev/null +++ b/packages/ai/LICENSE @@ -0,0 +1,32 @@ +Apache-2.0 License +------------------ + +Copyright (c) 2016-present Invertase Limited & Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this library except in compliance with the License. + +You may obtain a copy of the Apache-2.0 License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + +Creative Commons Attribution 3.0 License +---------------------------------------- + +Copyright (c) 2016-present Invertase Limited & Contributors + +Documentation and other instructional materials provided for this project +(including on a separate documentation repository or it's documentation website) are +licensed under the Creative Commons Attribution 3.0 License. Code samples/blocks +contained therein are licensed under the Apache License, Version 2.0 (the "License"), as above. + +You may obtain a copy of the Creative Commons Attribution 3.0 License at + + https://creativecommons.org/licenses/by/3.0/ diff --git a/packages/ai/README.md b/packages/ai/README.md new file mode 100644 index 0000000000..3ba225ba15 --- /dev/null +++ b/packages/ai/README.md @@ -0,0 +1,54 @@ +

+ +
+
+

React Native Firebase - Vertex AI

+

+ +

+ Coverage + NPM downloads + NPM version + License + Maintained with Lerna +

+ +

+ Chat on Discord + Follow on Twitter + Follow on Facebook +

+ +--- + +Vertex AI is a fully-managed, unified AI development platform for building and using generative AI. Access and utilize Vertex AI Studio, Agent Builder, and 150+ foundation models including Gemini 1.5 Pro and Gemini 1.5 Flash. + +[> Learn More](https://firebase.google.com/docs/vertex-ai/) + +## Installation + +Requires `@react-native-firebase/app` to be installed. + +```bash +yarn add @react-native-firebase/vertexai +``` + +## Documentation + +- [Quick Start](https://rnfirebase.io/vertexai/usage) +- [Reference](https://rnfirebase.io/reference/vertexai) + +## License + +- See [LICENSE](/LICENSE) + +--- + +

+ +

+ Built and maintained with 💛 by Invertase. +

+

+ +--- diff --git a/packages/ai/__tests__/api.test.ts b/packages/ai/__tests__/api.test.ts new file mode 100644 index 0000000000..3199157e76 --- /dev/null +++ b/packages/ai/__tests__/api.test.ts @@ -0,0 +1,103 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it } from '@jest/globals'; +import { getApp, type ReactNativeFirebase } from '../../app/lib'; + +import { ModelParams, VertexAIErrorCode } from '../lib/types'; +import { VertexAIError } from '../lib/errors'; +import { getGenerativeModel, getVertexAI } from '../lib/index'; + +import { VertexAI } from '../lib/public-types'; +import { GenerativeModel } from '../lib/models/generative-model'; + +import '../../auth/lib'; +import '../../app-check/lib'; +import { getAuth } from '../../auth/lib'; + +const fakeVertexAI: VertexAI = { + app: { + name: 'DEFAULT', + options: { + apiKey: 'key', + appId: 'appId', + projectId: 'my-project', + }, + } as ReactNativeFirebase.FirebaseApp, + location: 'us-central1', +}; + +describe('Top level API', () => { + it('should allow auth and app check instances to be passed in', () => { + const app = getApp(); + const auth = getAuth(); + const appCheck = app.appCheck(); + + getVertexAI(app, { appCheck, auth }); + }); + + it('getGenerativeModel throws if no model is provided', () => { + try { + getGenerativeModel(fakeVertexAI, {} as ModelParams); + } catch (e) { + expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_MODEL); + expect((e as VertexAIError).message).toContain( + `VertexAI: Must provide a model name. Example: ` + + `getGenerativeModel({ model: 'my-model-name' }) (vertexAI/${VertexAIErrorCode.NO_MODEL})`, + ); + } + }); + + it('getGenerativeModel throws if no apiKey is provided', () => { + const fakeVertexNoApiKey = { + ...fakeVertexAI, + app: { options: { projectId: 'my-project' } }, + } as VertexAI; + try { + getGenerativeModel(fakeVertexNoApiKey, { model: 'my-model' }); + } catch (e) { + expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_API_KEY); + expect((e as VertexAIError).message).toBe( + `VertexAI: The "apiKey" field is empty in the local ` + + `Firebase config. Firebase VertexAI requires this field to` + + ` contain a valid API key. (vertexAI/${VertexAIErrorCode.NO_API_KEY})`, + ); + } + }); + + it('getGenerativeModel throws if no projectId is provided', () => { + const fakeVertexNoProject = { + ...fakeVertexAI, + app: { options: { apiKey: 'my-key' } }, + } as VertexAI; + try { + getGenerativeModel(fakeVertexNoProject, { model: 'my-model' }); + } catch (e) { + expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_PROJECT_ID); + expect((e as VertexAIError).message).toBe( + `VertexAI: The "projectId" field is empty in the local` + + ` Firebase config. Firebase VertexAI requires this field ` + + `to contain a valid project ID. (vertexAI/${VertexAIErrorCode.NO_PROJECT_ID})`, + ); + } + }); + + it('getGenerativeModel gets a GenerativeModel', () => { + const genModel = getGenerativeModel(fakeVertexAI, { model: 'my-model' }); + expect(genModel).toBeInstanceOf(GenerativeModel); + expect(genModel.model).toBe('publishers/google/models/my-model'); + }); +}); diff --git a/packages/ai/__tests__/chat-session-helpers.test.ts b/packages/ai/__tests__/chat-session-helpers.test.ts new file mode 100644 index 0000000000..8bc81f4eab --- /dev/null +++ b/packages/ai/__tests__/chat-session-helpers.test.ts @@ -0,0 +1,154 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it } from '@jest/globals'; +import { validateChatHistory } from '../lib/methods/chat-session-helpers'; +import { Content } from '../lib/types'; +import { FirebaseError } from '@firebase/util'; + +describe('chat-session-helpers', () => { + describe('validateChatHistory', () => { + it('check chat history', () => { + const TCS: Array<{ history: Content[]; isValid: boolean }> = [ + { + history: [{ role: 'user', parts: [{ text: 'hi' }] }], + isValid: true, + }, + { + history: [ + { + role: 'user', + parts: [{ text: 'hi' }, { inlineData: { mimeType: 'image/jpeg', data: 'base64==' } }], + }, + ], + isValid: true, + }, + { + history: [ + { role: 'user', parts: [{ text: 'hi' }] }, + { role: 'model', parts: [{ text: 'hi' }, { text: 'hi' }] }, + ], + isValid: true, + }, + { + history: [ + { role: 'user', parts: [{ text: 'hi' }] }, + { + role: 'model', + parts: [{ functionCall: { name: 'greet', args: { name: 'user' } } }], + }, + ], + isValid: true, + }, + { + history: [ + { role: 'user', parts: [{ text: 'hi' }] }, + { + role: 'model', + parts: [{ functionCall: { name: 'greet', args: { name: 'user' } } }], + }, + { + role: 'function', + parts: [ + { + functionResponse: { name: 'greet', response: { name: 'user' } }, + }, + ], + }, + ], + isValid: true, + }, + { + history: [ + { role: 'user', parts: [{ text: 'hi' }] }, + { + role: 'model', + parts: [{ functionCall: { name: 'greet', args: { name: 'user' } } }], + }, + { + role: 'function', + parts: [ + { + functionResponse: { name: 'greet', response: { name: 'user' } }, + }, + ], + }, + { + role: 'model', + parts: [{ text: 'hi name' }], + }, + ], + isValid: true, + }, + { + //@ts-expect-error + history: [{ role: 'user', parts: '' }], + isValid: false, + }, + { + //@ts-expect-error + history: [{ role: 'user' }], + isValid: false, + }, + { + history: [{ role: 'user', parts: [] }], + isValid: false, + }, + { + history: [{ role: 'model', parts: [{ text: 'hi' }] }], + isValid: false, + }, + { + history: [ + { + role: 'function', + parts: [ + { + functionResponse: { name: 'greet', response: { name: 'user' } }, + }, + ], + }, + ], + isValid: false, + }, + { + history: [ + { role: 'user', parts: [{ text: 'hi' }] }, + { role: 'user', parts: [{ text: 'hi' }] }, + ], + isValid: false, + }, + { + history: [ + { role: 'user', parts: [{ text: 'hi' }] }, + { role: 'model', parts: [{ text: 'hi' }] }, + { role: 'model', parts: [{ text: 'hi' }] }, + ], + isValid: false, + }, + ]; + + TCS.forEach(tc => { + const fn = (): void => validateChatHistory(tc.history); + if (tc.isValid) { + expect(fn).not.toThrow(); + } else { + expect(fn).toThrow(FirebaseError); + } + }); + }); + }); +}); diff --git a/packages/ai/__tests__/chat-session.test.ts b/packages/ai/__tests__/chat-session.test.ts new file mode 100644 index 0000000000..cd96aa32e6 --- /dev/null +++ b/packages/ai/__tests__/chat-session.test.ts @@ -0,0 +1,100 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it, afterEach, jest } from '@jest/globals'; + +import * as generateContentMethods from '../lib/methods/generate-content'; +import { GenerateContentStreamResult } from '../lib/types'; +import { ChatSession } from '../lib/methods/chat-session'; +import { ApiSettings } from '../lib/types/internal'; +import { RequestOptions } from '../lib/types/requests'; + +const fakeApiSettings: ApiSettings = { + apiKey: 'key', + project: 'my-project', + location: 'us-central1', +}; + +const requestOptions: RequestOptions = { + timeout: 1000, +}; + +describe('ChatSession', () => { + afterEach(() => { + jest.restoreAllMocks(); + }); + + describe('sendMessage()', () => { + it('generateContent errors should be catchable', async () => { + const generateContentStub = jest + .spyOn(generateContentMethods, 'generateContent') + .mockRejectedValue('generateContent failed'); + + const chatSession = new ChatSession(fakeApiSettings, 'a-model', {}, requestOptions); + + await expect(chatSession.sendMessage('hello')).rejects.toMatch(/generateContent failed/); + + expect(generateContentStub).toHaveBeenCalledWith( + fakeApiSettings, + 'a-model', + expect.anything(), + requestOptions, + ); + }); + }); + + describe('sendMessageStream()', () => { + it('generateContentStream errors should be catchable', async () => { + jest.useFakeTimers(); + const consoleStub = jest.spyOn(console, 'error').mockImplementation(() => {}); + const generateContentStreamStub = jest + .spyOn(generateContentMethods, 'generateContentStream') + .mockRejectedValue('generateContentStream failed'); + const chatSession = new ChatSession(fakeApiSettings, 'a-model', {}, requestOptions); + await expect(chatSession.sendMessageStream('hello')).rejects.toMatch( + /generateContentStream failed/, + ); + expect(generateContentStreamStub).toHaveBeenCalledWith( + fakeApiSettings, + 'a-model', + expect.anything(), + requestOptions, + ); + jest.runAllTimers(); + expect(consoleStub).not.toHaveBeenCalled(); + jest.useRealTimers(); + }); + + it('downstream sendPromise errors should log but not throw', async () => { + const consoleStub = jest.spyOn(console, 'error').mockImplementation(() => {}); + // make response undefined so that response.candidates errors + const generateContentStreamStub = jest + .spyOn(generateContentMethods, 'generateContentStream') + .mockResolvedValue({} as unknown as GenerateContentStreamResult); + const chatSession = new ChatSession(fakeApiSettings, 'a-model', {}, requestOptions); + await chatSession.sendMessageStream('hello'); + expect(generateContentStreamStub).toHaveBeenCalledWith( + fakeApiSettings, + 'a-model', + expect.anything(), + requestOptions, + ); + // wait for the console.error to be called, due to number of promises in the chain + await new Promise(resolve => setTimeout(resolve, 100)); + expect(consoleStub).toHaveBeenCalledTimes(1); + }); + }); +}); diff --git a/packages/ai/__tests__/count-tokens.test.ts b/packages/ai/__tests__/count-tokens.test.ts new file mode 100644 index 0000000000..3cd7b78970 --- /dev/null +++ b/packages/ai/__tests__/count-tokens.test.ts @@ -0,0 +1,88 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it, afterEach, jest } from '@jest/globals'; +import { getMockResponse } from './test-utils/mock-response'; +import * as request from '../lib/requests/request'; +import { countTokens } from '../lib/methods/count-tokens'; +import { CountTokensRequest } from '../lib/types'; +import { ApiSettings } from '../lib/types/internal'; +import { Task } from '../lib/requests/request'; + +const fakeApiSettings: ApiSettings = { + apiKey: 'key', + project: 'my-project', + location: 'us-central1', +}; + +const fakeRequestParams: CountTokensRequest = { + contents: [{ parts: [{ text: 'hello' }], role: 'user' }], +}; + +describe('countTokens()', () => { + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('total tokens', async () => { + const mockResponse = getMockResponse('unary-success-total-tokens.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams); + expect(result.totalTokens).toBe(6); + expect(result.totalBillableCharacters).toBe(16); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.COUNT_TOKENS, + fakeApiSettings, + false, + expect.stringContaining('contents'), + undefined, + ); + }); + + it('total tokens no billable characters', async () => { + const mockResponse = getMockResponse('unary-success-no-billable-characters.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams); + expect(result.totalTokens).toBe(258); + expect(result).not.toHaveProperty('totalBillableCharacters'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.COUNT_TOKENS, + fakeApiSettings, + false, + expect.stringContaining('contents'), + undefined, + ); + }); + + it('model not found', async () => { + const mockResponse = getMockResponse('unary-failure-model-not-found.json'); + const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: false, + status: 404, + json: mockResponse.json, + } as Response); + await expect(countTokens(fakeApiSettings, 'model', fakeRequestParams)).rejects.toThrow( + /404.*not found/, + ); + expect(mockFetch).toHaveBeenCalled(); + }); +}); diff --git a/packages/ai/__tests__/generate-content.test.ts b/packages/ai/__tests__/generate-content.test.ts new file mode 100644 index 0000000000..3bc733e370 --- /dev/null +++ b/packages/ai/__tests__/generate-content.test.ts @@ -0,0 +1,204 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it, afterEach, jest } from '@jest/globals'; +import { getMockResponse } from './test-utils/mock-response'; +import * as request from '../lib/requests/request'; +import { generateContent } from '../lib/methods/generate-content'; +import { + GenerateContentRequest, + HarmBlockMethod, + HarmBlockThreshold, + HarmCategory, + // RequestOptions, +} from '../lib/types'; +import { ApiSettings } from '../lib/types/internal'; +import { Task } from '../lib/requests/request'; + +const fakeApiSettings: ApiSettings = { + apiKey: 'key', + project: 'my-project', + location: 'us-central1', +}; + +const fakeRequestParams: GenerateContentRequest = { + contents: [{ parts: [{ text: 'hello' }], role: 'user' }], + generationConfig: { + topK: 16, + }, + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + method: HarmBlockMethod.SEVERITY, + }, + ], +}; + +describe('generateContent()', () => { + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('short response', async () => { + const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); + expect(await result.response.text()).toContain('Mountain View, California'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + expect.stringContaining('contents'), + undefined, + ); + }); + + it('long response', async () => { + const mockResponse = getMockResponse('unary-success-basic-reply-long.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); + expect(result.response.text()).toContain('Use Freshly Ground Coffee'); + expect(result.response.text()).toContain('30 minutes of brewing'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + expect.anything(), + undefined, + ); + }); + + it('citations', async () => { + const mockResponse = getMockResponse('unary-success-citations.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); + expect(result.response.text()).toContain('Some information cited from an external source'); + expect(result.response.candidates?.[0]!.citationMetadata?.citations.length).toBe(3); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + expect.anything(), + undefined, + ); + }); + + it('blocked prompt', async () => { + const mockResponse = getMockResponse('unary-failure-prompt-blocked-safety.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + + const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); + + expect(() => result.response.text()).toThrowError('SAFETY'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + expect.anything(), + undefined, + ); + }); + + it('finishReason safety', async () => { + const mockResponse = getMockResponse('unary-failure-finish-reason-safety.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); + expect(() => result.response.text()).toThrow('SAFETY'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + expect.anything(), + undefined, + ); + }); + + it('empty content', async () => { + const mockResponse = getMockResponse('unary-failure-empty-content.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); + expect(result.response.text()).toBe(''); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + expect.anything(), + undefined, + ); + }); + + it('unknown enum - should ignore', async () => { + const mockResponse = getMockResponse('unary-success-unknown-enum-safety-ratings.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); + expect(result.response.text()).toContain('Some text'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + expect.anything(), + undefined, + ); + }); + + it('image rejected (400)', async () => { + const mockResponse = getMockResponse('unary-failure-image-rejected.json'); + const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: false, + status: 400, + json: mockResponse.json, + } as Response); + await expect(generateContent(fakeApiSettings, 'model', fakeRequestParams)).rejects.toThrow( + /400.*invalid argument/, + ); + expect(mockFetch).toHaveBeenCalled(); + }); + + it('api not enabled (403)', async () => { + const mockResponse = getMockResponse('unary-failure-firebasevertexai-api-not-enabled.json'); + const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: false, + status: 403, + json: mockResponse.json, + } as Response); + await expect(generateContent(fakeApiSettings, 'model', fakeRequestParams)).rejects.toThrow( + /firebasevertexai\.googleapis[\s\S]*my-project[\s\S]*api-not-enabled/, + ); + expect(mockFetch).toHaveBeenCalled(); + }); +}); diff --git a/packages/ai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts new file mode 100644 index 0000000000..e62862b6aa --- /dev/null +++ b/packages/ai/__tests__/generative-model.test.ts @@ -0,0 +1,266 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it, jest } from '@jest/globals'; +import { type ReactNativeFirebase } from '@react-native-firebase/app'; +import { GenerativeModel } from '../lib/models/generative-model'; +import { FunctionCallingMode, VertexAI } from '../lib/public-types'; +import * as request from '../lib/requests/request'; +import { getMockResponse } from './test-utils/mock-response'; + +const fakeVertexAI: VertexAI = { + app: { + name: 'DEFAULT', + options: { + apiKey: 'key', + projectId: 'my-project', + }, + } as ReactNativeFirebase.FirebaseApp, + location: 'us-central1', +}; + +describe('GenerativeModel', () => { + it('handles plain model name', () => { + const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' }); + expect(genModel.model).toBe('publishers/google/models/my-model'); + }); + + it('handles models/ prefixed model name', () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'models/my-model', + }); + expect(genModel.model).toBe('publishers/google/models/my-model'); + }); + + it('handles full model name', () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'publishers/google/models/my-model', + }); + expect(genModel.model).toBe('publishers/google/models/my-model'); + }); + + it('handles prefixed tuned model name', () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'tunedModels/my-model', + }); + expect(genModel.model).toBe('tunedModels/my-model'); + }); + + it('passes params through to generateContent', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'my-model', + tools: [ + { + functionDeclarations: [ + { + name: 'myfunc', + description: 'mydesc', + }, + ], + }, + ], + toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, + }); + expect(genModel.tools?.length).toBe(1); + expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); + expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); + const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + await genModel.generateContent('hello'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + expect.anything(), + false, + expect.stringMatching(new RegExp(`myfunc|be friendly|${FunctionCallingMode.NONE}`)), + {}, + ); + makeRequestStub.mockRestore(); + }); + + it('passes text-only systemInstruction through to generateContent', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'my-model', + systemInstruction: 'be friendly', + }); + expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); + const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + await genModel.generateContent('hello'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + expect.anything(), + false, + expect.stringContaining('be friendly'), + {}, + ); + makeRequestStub.mockRestore(); + }); + + it('generateContent overrides model values', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'my-model', + tools: [ + { + functionDeclarations: [ + { + name: 'myfunc', + description: 'mydesc', + }, + ], + }, + ], + toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, + }); + expect(genModel.tools?.length).toBe(1); + expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); + expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); + const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + await genModel.generateContent({ + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + tools: [ + { + functionDeclarations: [{ name: 'otherfunc', description: 'otherdesc' }], + }, + ], + toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } }, + systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }, + }); + expect(makeRequestStub).toHaveBeenCalledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + expect.anything(), + false, + expect.stringMatching(new RegExp(`be formal|otherfunc|${FunctionCallingMode.AUTO}`)), + {}, + ); + makeRequestStub.mockRestore(); + }); + + it('passes params through to chat.sendMessage', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'my-model', + tools: [{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }], + toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, + }); + expect(genModel.tools?.length).toBe(1); + expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); + expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); + const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + await genModel.startChat().sendMessage('hello'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + expect.anything(), + false, + expect.stringMatching(new RegExp(`myfunc|be friendly|${FunctionCallingMode.NONE}`)), + {}, + ); + makeRequestStub.mockRestore(); + }); + + it('passes text-only systemInstruction through to chat.sendMessage', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'my-model', + systemInstruction: 'be friendly', + }); + expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); + const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + await genModel.startChat().sendMessage('hello'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + expect.anything(), + false, + expect.stringContaining('be friendly'), + {}, + ); + makeRequestStub.mockRestore(); + }); + + it('startChat overrides model values', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'my-model', + tools: [{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }], + toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, + }); + expect(genModel.tools?.length).toBe(1); + expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); + expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); + const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + await genModel + .startChat({ + tools: [ + { + functionDeclarations: [{ name: 'otherfunc', description: 'otherdesc' }], + }, + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.AUTO }, + }, + systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }, + }) + .sendMessage('hello'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + expect.anything(), + false, + expect.stringMatching(new RegExp(`otherfunc|be formal|${FunctionCallingMode.AUTO}`)), + {}, + ); + makeRequestStub.mockRestore(); + }); + + it('calls countTokens', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' }); + const mockResponse = getMockResponse('unary-success-total-tokens.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + await genModel.countTokens('hello'); + expect(makeRequestStub).toHaveBeenCalledWith( + 'publishers/google/models/my-model', + request.Task.COUNT_TOKENS, + expect.anything(), + false, + expect.stringContaining('hello'), + undefined, + ); + makeRequestStub.mockRestore(); + }); +}); diff --git a/packages/ai/__tests__/request-helpers.test.ts b/packages/ai/__tests__/request-helpers.test.ts new file mode 100644 index 0000000000..05433f5ba2 --- /dev/null +++ b/packages/ai/__tests__/request-helpers.test.ts @@ -0,0 +1,204 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it } from '@jest/globals'; +import { Content } from '../lib/types'; +import { formatGenerateContentInput } from '../lib/requests/request-helpers'; + +describe('request formatting methods', () => { + describe('formatGenerateContentInput', () => { + it('formats a text string into a request', () => { + const result = formatGenerateContentInput('some text content'); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'some text content' }], + }, + ], + }); + }); + + it('formats an array of strings into a request', () => { + const result = formatGenerateContentInput(['txt1', 'txt2']); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'txt1' }, { text: 'txt2' }], + }, + ], + }); + }); + + it('formats an array of Parts into a request', () => { + const result = formatGenerateContentInput([{ text: 'txt1' }, { text: 'txtB' }]); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'txt1' }, { text: 'txtB' }], + }, + ], + }); + }); + + it('formats a mixed array into a request', () => { + const result = formatGenerateContentInput(['txtA', { text: 'txtB' }]); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }, { text: 'txtB' }], + }, + ], + }); + }); + + it('preserves other properties of request', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + generationConfig: { topK: 100 }, + }); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + generationConfig: { topK: 100 }, + }); + }); + + it('formats systemInstructions if provided as text', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + systemInstruction: 'be excited', + }); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, + }); + }); + + it('formats systemInstructions if provided as Part', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + systemInstruction: { text: 'be excited' }, + }); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, + }); + }); + + it('formats systemInstructions if provided as Content (no role)', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + systemInstruction: { parts: [{ text: 'be excited' }] } as Content, + }); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, + }); + }); + + it('passes thru systemInstructions if provided as Content', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, + }); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }], + }, + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, + }); + }); + + it('formats fileData as part if provided as part', () => { + const result = formatGenerateContentInput([ + 'What is this?', + { + fileData: { + mimeType: 'image/jpeg', + fileUri: 'gs://sample.appspot.com/image.jpeg', + }, + }, + ]); + expect(result).toEqual({ + contents: [ + { + role: 'user', + parts: [ + { text: 'What is this?' }, + { + fileData: { + mimeType: 'image/jpeg', + fileUri: 'gs://sample.appspot.com/image.jpeg', + }, + }, + ], + }, + ], + }); + }); + }); +}); diff --git a/packages/ai/__tests__/request.test.ts b/packages/ai/__tests__/request.test.ts new file mode 100644 index 0000000000..c992b062e9 --- /dev/null +++ b/packages/ai/__tests__/request.test.ts @@ -0,0 +1,353 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it, jest, afterEach } from '@jest/globals'; +import { RequestUrl, Task, getHeaders, makeRequest } from '../lib/requests/request'; +import { ApiSettings } from '../lib/types/internal'; +import { DEFAULT_API_VERSION } from '../lib/constants'; +import { VertexAIErrorCode } from '../lib/types'; +import { VertexAIError } from '../lib/errors'; +import { getMockResponse } from './test-utils/mock-response'; + +const fakeApiSettings: ApiSettings = { + apiKey: 'key', + project: 'my-project', + location: 'us-central1', +}; + +describe('request methods', () => { + afterEach(() => { + jest.restoreAllMocks(); // Use Jest's restoreAllMocks + }); + + describe('RequestUrl', () => { + it('stream', async () => { + const url = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + true, + {}, + ); + const urlStr = url.toString(); + expect(urlStr).toContain('models/model-name:generateContent'); + expect(urlStr).toContain(fakeApiSettings.project); + expect(urlStr).toContain(fakeApiSettings.location); + expect(urlStr).toContain('alt=sse'); + }); + + it('non-stream', async () => { + const url = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + {}, + ); + const urlStr = url.toString(); + expect(urlStr).toContain('models/model-name:generateContent'); + expect(urlStr).toContain(fakeApiSettings.project); + expect(urlStr).toContain(fakeApiSettings.location); + expect(urlStr).not.toContain('alt=sse'); + }); + + it('default apiVersion', async () => { + const url = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + {}, + ); + expect(url.toString()).toContain(DEFAULT_API_VERSION); + }); + + it('custom baseUrl', async () => { + const url = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + { baseUrl: 'https://my.special.endpoint' }, + ); + expect(url.toString()).toContain('https://my.special.endpoint'); + }); + + it('non-stream - tunedModels/', async () => { + const url = new RequestUrl( + 'tunedModels/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + {}, + ); + const urlStr = url.toString(); + expect(urlStr).toContain('tunedModels/model-name:generateContent'); + expect(urlStr).toContain(fakeApiSettings.location); + expect(urlStr).toContain(fakeApiSettings.project); + expect(urlStr).not.toContain('alt=sse'); + }); + }); + + describe('getHeaders', () => { + const fakeApiSettings: ApiSettings = { + apiKey: 'key', + project: 'myproject', + location: 'moon', + getAuthToken: () => Promise.resolve('authtoken'), + getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }), + }; + const fakeUrl = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + true, + {}, + ); + + it('adds client headers', async () => { + const headers = await getHeaders(fakeUrl); + expect(headers.get('x-goog-api-client')).toMatch(/gl-rn\/[0-9\.]+ fire\/[0-9\.]+/); + }); + + it('adds api key', async () => { + const headers = await getHeaders(fakeUrl); + expect(headers.get('x-goog-api-key')).toBe('key'); + }); + + it('adds app check token if it exists', async () => { + const headers = await getHeaders(fakeUrl); + expect(headers.get('X-Firebase-AppCheck')).toBe('appchecktoken'); + }); + + it('ignores app check token header if no appcheck service', async () => { + const fakeUrl = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + { + apiKey: 'key', + project: 'myproject', + location: 'moon', + }, + true, + {}, + ); + const headers = await getHeaders(fakeUrl); + expect(headers.has('X-Firebase-AppCheck')).toBe(false); + }); + + it('ignores app check token header if returned token was undefined', async () => { + const fakeUrl = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + { + apiKey: 'key', + project: 'myproject', + location: 'moon', + //@ts-ignore + getAppCheckToken: () => Promise.resolve(), + }, + true, + {}, + ); + const headers = await getHeaders(fakeUrl); + expect(headers.has('X-Firebase-AppCheck')).toBe(false); + }); + + it('ignores app check token header if returned token had error', async () => { + const fakeUrl = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + { + apiKey: 'key', + project: 'myproject', + location: 'moon', + getAppCheckToken: () => Promise.reject(new Error('oops')), + }, + true, + {}, + ); + + const warnSpy = jest.spyOn(console, 'warn').mockImplementation(() => {}); + await getHeaders(fakeUrl); + // NOTE - no app check header if there is no token, this is different to firebase-js-sdk + // See: https://github.com/firebase/firebase-js-sdk/blob/main/packages/vertexai/src/requests/request.test.ts#L172 + // expect(headers.get('X-Firebase-AppCheck')).toBe('dummytoken'); + expect(warnSpy).toHaveBeenCalledWith( + expect.stringMatching(/vertexai/), + expect.stringMatching(/App Check.*oops/), + ); + }); + + it('adds auth token if it exists', async () => { + const headers = await getHeaders(fakeUrl); + expect(headers.get('Authorization')).toBe('Firebase authtoken'); + }); + + it('ignores auth token header if no auth service', async () => { + const fakeUrl = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + { + apiKey: 'key', + project: 'myproject', + location: 'moon', + }, + true, + {}, + ); + const headers = await getHeaders(fakeUrl); + expect(headers.has('Authorization')).toBe(false); + }); + + it('ignores auth token header if returned token was undefined', async () => { + const fakeUrl = new RequestUrl( + 'models/model-name', + Task.GENERATE_CONTENT, + { + apiKey: 'key', + project: 'myproject', + location: 'moon', + //@ts-ignore + getAppCheckToken: () => Promise.resolve(), + }, + true, + {}, + ); + const headers = await getHeaders(fakeUrl); + expect(headers.has('Authorization')).toBe(false); + }); + }); + + describe('makeRequest', () => { + it('no error', async () => { + const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: true, + } as Response); + const response = await makeRequest( + 'models/model-name', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + '', + ); + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(response.ok).toBe(true); + }); + + it('error with timeout', async () => { + const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: false, + status: 500, + statusText: 'AbortError', + } as Response); + + try { + await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, '', { + timeout: 180000, + }); + } catch (e) { + expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); + expect((e as VertexAIError).customErrorData?.status).toBe(500); + expect((e as VertexAIError).customErrorData?.statusText).toBe('AbortError'); + expect((e as VertexAIError).message).toContain('500 AbortError'); + } + + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + + it('Network error, no response.json()', async () => { + const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: false, + status: 500, + statusText: 'Server Error', + } as Response); + try { + await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); + } catch (e) { + expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); + expect((e as VertexAIError).customErrorData?.status).toBe(500); + expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error'); + expect((e as VertexAIError).message).toContain('500 Server Error'); + } + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + + it('Network error, includes response.json()', async () => { + const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: false, + status: 500, + statusText: 'Server Error', + json: () => Promise.resolve({ error: { message: 'extra info' } }), + } as Response); + try { + await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); + } catch (e) { + expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); + expect((e as VertexAIError).customErrorData?.status).toBe(500); + expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error'); + expect((e as VertexAIError).message).toContain('500 Server Error'); + expect((e as VertexAIError).message).toContain('extra info'); + } + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + + it('Network error, includes response.json() and details', async () => { + const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: false, + status: 500, + statusText: 'Server Error', + json: () => + Promise.resolve({ + error: { + message: 'extra info', + details: [ + { + '@type': 'type.googleapis.com/google.rpc.DebugInfo', + detail: + '[ORIGINAL ERROR] generic::invalid_argument: invalid status photos.thumbnailer.Status.Code::5: Source image 0 too short', + }, + ], + }, + }), + } as Response); + try { + await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); + } catch (e) { + expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); + expect((e as VertexAIError).customErrorData?.status).toBe(500); + expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error'); + expect((e as VertexAIError).message).toContain('500 Server Error'); + expect((e as VertexAIError).message).toContain('extra info'); + expect((e as VertexAIError).message).toContain('generic::invalid_argument'); + } + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + }); + + it('Network error, API not enabled', async () => { + const mockResponse = getMockResponse('unary-failure-firebasevertexai-api-not-enabled.json'); + const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue(mockResponse as Response); + try { + await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); + } catch (e) { + expect((e as VertexAIError).code).toBe(VertexAIErrorCode.API_NOT_ENABLED); + expect((e as VertexAIError).message).toContain('my-project'); + expect((e as VertexAIError).message).toContain('googleapis.com'); + } + expect(fetchMock).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/ai/__tests__/response-helpers.test.ts b/packages/ai/__tests__/response-helpers.test.ts new file mode 100644 index 0000000000..cc0fddc658 --- /dev/null +++ b/packages/ai/__tests__/response-helpers.test.ts @@ -0,0 +1,236 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it, jest, afterEach } from '@jest/globals'; +import { addHelpers, formatBlockErrorMessage } from '../lib/requests/response-helpers'; + +import { BlockReason, Content, FinishReason, GenerateContentResponse } from '../lib/types'; + +const fakeResponseText: GenerateContentResponse = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'Some text' }, { text: ' and some more text' }], + }, + }, + ], +}; + +const functionCallPart1 = { + functionCall: { + name: 'find_theaters', + args: { + location: 'Mountain View, CA', + movie: 'Barbie', + }, + }, +}; + +const functionCallPart2 = { + functionCall: { + name: 'find_times', + args: { + location: 'Mountain View, CA', + movie: 'Barbie', + time: '20:00', + }, + }, +}; + +const fakeResponseFunctionCall: GenerateContentResponse = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [functionCallPart1], + }, + }, + ], +}; + +const fakeResponseFunctionCalls: GenerateContentResponse = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [functionCallPart1, functionCallPart2], + }, + }, + ], +}; + +const fakeResponseMixed1: GenerateContentResponse = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'some text' }, functionCallPart2], + }, + }, + ], +}; + +const fakeResponseMixed2: GenerateContentResponse = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [functionCallPart1, { text: 'some text' }], + }, + }, + ], +}; + +const fakeResponseMixed3: GenerateContentResponse = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'some text' }, functionCallPart1, { text: ' and more text' }], + }, + }, + ], +}; + +const badFakeResponse: GenerateContentResponse = { + promptFeedback: { + blockReason: BlockReason.SAFETY, + safetyRatings: [], + }, +}; + +describe('response-helpers methods', () => { + afterEach(() => { + jest.restoreAllMocks(); // Use Jest's restore function + }); + + describe('addHelpers', () => { + it('good response text', () => { + const enhancedResponse = addHelpers(fakeResponseText); + expect(enhancedResponse.text()).toBe('Some text and some more text'); + expect(enhancedResponse.functionCalls()).toBeUndefined(); + }); + + it('good response functionCall', () => { + const enhancedResponse = addHelpers(fakeResponseFunctionCall); + expect(enhancedResponse.text()).toBe(''); + expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]); + }); + + it('good response functionCalls', () => { + const enhancedResponse = addHelpers(fakeResponseFunctionCalls); + expect(enhancedResponse.text()).toBe(''); + expect(enhancedResponse.functionCalls()).toEqual([ + functionCallPart1.functionCall, + functionCallPart2.functionCall, + ]); + }); + + it('good response text/functionCall', () => { + const enhancedResponse = addHelpers(fakeResponseMixed1); + expect(enhancedResponse.functionCalls()).toEqual([functionCallPart2.functionCall]); + expect(enhancedResponse.text()).toBe('some text'); + }); + + it('good response functionCall/text', () => { + const enhancedResponse = addHelpers(fakeResponseMixed2); + expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]); + expect(enhancedResponse.text()).toBe('some text'); + }); + + it('good response text/functionCall/text', () => { + const enhancedResponse = addHelpers(fakeResponseMixed3); + expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]); + expect(enhancedResponse.text()).toBe('some text and more text'); + }); + + it('bad response safety', () => { + const enhancedResponse = addHelpers(badFakeResponse); + expect(() => enhancedResponse.text()).toThrow('SAFETY'); + }); + }); + + describe('getBlockString', () => { + it('has no promptFeedback or bad finishReason', () => { + const message = formatBlockErrorMessage({ + candidates: [ + { + index: 0, + finishReason: FinishReason.STOP, + finishMessage: 'this was fine', + content: {} as Content, + }, + ], + }); + expect(message).toBe(''); + }); + + it('has promptFeedback and blockReason only', () => { + const message = formatBlockErrorMessage({ + promptFeedback: { + blockReason: BlockReason.SAFETY, + safetyRatings: [], + }, + }); + expect(message).toContain('Response was blocked due to SAFETY'); + }); + + it('has promptFeedback with blockReason and blockMessage', () => { + const message = formatBlockErrorMessage({ + promptFeedback: { + blockReason: BlockReason.SAFETY, + blockReasonMessage: 'safety reasons', + safetyRatings: [], + }, + }); + expect(message).toContain('Response was blocked due to SAFETY: safety reasons'); + }); + + it('has bad finishReason only', () => { + const message = formatBlockErrorMessage({ + candidates: [ + { + index: 0, + finishReason: FinishReason.SAFETY, + content: {} as Content, + }, + ], + }); + expect(message).toContain('Candidate was blocked due to SAFETY'); + }); + + it('has finishReason and finishMessage', () => { + const message = formatBlockErrorMessage({ + candidates: [ + { + index: 0, + finishReason: FinishReason.SAFETY, + finishMessage: 'unsafe candidate', + content: {} as Content, + }, + ], + }); + expect(message).toContain('Candidate was blocked due to SAFETY: unsafe candidate'); + }); + }); +}); diff --git a/packages/ai/__tests__/schema-builder.test.ts b/packages/ai/__tests__/schema-builder.test.ts new file mode 100644 index 0000000000..bec1f6a8d2 --- /dev/null +++ b/packages/ai/__tests__/schema-builder.test.ts @@ -0,0 +1,389 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it } from '@jest/globals'; +import { Schema } from '../lib/requests/schema-builder'; +import { VertexAIErrorCode } from '../lib/types'; + +describe('Schema builder', () => { + it('builds integer schema', () => { + const schema = Schema.integer(); + expect(schema.toJSON()).toEqual({ + type: 'integer', + nullable: false, + }); + }); + + it('builds integer schema with options and overrides', () => { + const schema = Schema.integer({ nullable: true, format: 'int32' }); + expect(schema.toJSON()).toEqual({ + type: 'integer', + format: 'int32', + nullable: true, + }); + }); + + it('builds number schema', () => { + const schema = Schema.number(); + expect(schema.toJSON()).toEqual({ + type: 'number', + nullable: false, + }); + }); + + it('builds number schema with options and unknown options', () => { + const schema = Schema.number({ format: 'float', futureOption: 'test' }); + expect(schema.toJSON()).toEqual({ + type: 'number', + format: 'float', + futureOption: 'test', + nullable: false, + }); + }); + + it('builds boolean schema', () => { + const schema = Schema.boolean(); + expect(schema.toJSON()).toEqual({ + type: 'boolean', + nullable: false, + }); + }); + + it('builds string schema', () => { + const schema = Schema.string({ description: 'hey' }); + expect(schema.toJSON()).toEqual({ + type: 'string', + description: 'hey', + nullable: false, + }); + }); + + it('builds enumString schema', () => { + const schema = Schema.enumString({ + example: 'east', + enum: ['east', 'west'], + }); + expect(schema.toJSON()).toEqual({ + type: 'string', + example: 'east', + enum: ['east', 'west'], + nullable: false, + }); + }); + + it('builds an object schema', () => { + const schema = Schema.object({ + properties: { + someInput: Schema.string(), + }, + }); + + expect(schema.toJSON()).toEqual({ + type: 'object', + nullable: false, + properties: { + someInput: { + type: 'string', + nullable: false, + }, + }, + required: ['someInput'], + }); + }); + + it('builds an object schema with optional properties', () => { + const schema = Schema.object({ + properties: { + someInput: Schema.string(), + someBool: Schema.boolean(), + }, + optionalProperties: ['someBool'], + }); + expect(schema.toJSON()).toEqual({ + type: 'object', + nullable: false, + properties: { + someInput: { + type: 'string', + nullable: false, + }, + someBool: { + type: 'boolean', + nullable: false, + }, + }, + required: ['someInput'], + }); + }); + + it('builds layered schema - partially filled out', () => { + const schema = Schema.array({ + items: Schema.object({ + properties: { + country: Schema.string({ + description: 'A country name', + }), + population: Schema.integer(), + coordinates: Schema.object({ + properties: { + latitude: Schema.number({ format: 'float' }), + longitude: Schema.number({ format: 'double' }), + }, + }), + hemisphere: Schema.object({ + properties: { + latitudinal: Schema.enumString({ enum: ['N', 'S'] }), + longitudinal: Schema.enumString({ enum: ['E', 'W'] }), + }, + }), + isCapital: Schema.boolean(), + }, + }), + }); + expect(schema.toJSON()).toEqual(layeredSchemaOutputPartial); + }); + + it('builds layered schema - fully filled out', () => { + const schema = Schema.array({ + items: Schema.object({ + description: 'A country profile', + nullable: false, + properties: { + country: Schema.string({ + nullable: false, + description: 'Country name', + format: undefined, + }), + population: Schema.integer({ + nullable: false, + description: 'Number of people in country', + format: 'int64', + }), + coordinates: Schema.object({ + nullable: false, + description: 'Latitude and longitude', + properties: { + latitude: Schema.number({ + nullable: false, + description: 'Latitude of capital', + format: 'float', + }), + longitude: Schema.number({ + nullable: false, + description: 'Longitude of capital', + format: 'double', + }), + }, + }), + hemisphere: Schema.object({ + nullable: false, + description: 'Hemisphere(s) country is in', + properties: { + latitudinal: Schema.enumString({ enum: ['N', 'S'] }), + longitudinal: Schema.enumString({ enum: ['E', 'W'] }), + }, + }), + isCapital: Schema.boolean({ + nullable: false, + description: "This doesn't make a lot of sense but it's a demo", + }), + elevation: Schema.integer({ + nullable: false, + description: 'Average elevation', + format: 'float', + }), + }, + optionalProperties: [], + }), + }); + + expect(schema.toJSON()).toEqual(layeredSchemaOutput); + }); + + it('can override "nullable" and set optional properties', () => { + const schema = Schema.object({ + properties: { + country: Schema.string(), + elevation: Schema.number(), + population: Schema.integer({ nullable: true }), + }, + optionalProperties: ['elevation'], + }); + expect(schema.toJSON()).toEqual({ + type: 'object', + nullable: false, + properties: { + country: { + type: 'string', + nullable: false, + }, + elevation: { + type: 'number', + nullable: false, + }, + population: { + type: 'integer', + nullable: true, + }, + }, + required: ['country', 'population'], + }); + }); + + it('throws if an optionalProperties item does not exist', () => { + const schema = Schema.object({ + properties: { + country: Schema.string(), + elevation: Schema.number(), + population: Schema.integer({ nullable: true }), + }, + optionalProperties: ['cat'], + }); + expect(() => schema.toJSON()).toThrow(VertexAIErrorCode.INVALID_SCHEMA); + }); +}); + +const layeredSchemaOutputPartial = { + type: 'array', + nullable: false, + items: { + type: 'object', + nullable: false, + properties: { + country: { + type: 'string', + description: 'A country name', + nullable: false, + }, + population: { + type: 'integer', + nullable: false, + }, + coordinates: { + type: 'object', + nullable: false, + properties: { + latitude: { + type: 'number', + format: 'float', + nullable: false, + }, + longitude: { + type: 'number', + format: 'double', + nullable: false, + }, + }, + required: ['latitude', 'longitude'], + }, + hemisphere: { + type: 'object', + nullable: false, + properties: { + latitudinal: { + type: 'string', + nullable: false, + enum: ['N', 'S'], + }, + longitudinal: { + type: 'string', + nullable: false, + enum: ['E', 'W'], + }, + }, + required: ['latitudinal', 'longitudinal'], + }, + isCapital: { + type: 'boolean', + nullable: false, + }, + }, + required: ['country', 'population', 'coordinates', 'hemisphere', 'isCapital'], + }, +}; + +const layeredSchemaOutput = { + type: 'array', + nullable: false, + items: { + type: 'object', + description: 'A country profile', + nullable: false, + required: ['country', 'population', 'coordinates', 'hemisphere', 'isCapital', 'elevation'], + properties: { + country: { + type: 'string', + description: 'Country name', + nullable: false, + }, + population: { + type: 'integer', + format: 'int64', + description: 'Number of people in country', + nullable: false, + }, + coordinates: { + type: 'object', + description: 'Latitude and longitude', + nullable: false, + required: ['latitude', 'longitude'], + properties: { + latitude: { + type: 'number', + format: 'float', + description: 'Latitude of capital', + nullable: false, + }, + longitude: { + type: 'number', + format: 'double', + description: 'Longitude of capital', + nullable: false, + }, + }, + }, + hemisphere: { + type: 'object', + description: 'Hemisphere(s) country is in', + nullable: false, + required: ['latitudinal', 'longitudinal'], + properties: { + latitudinal: { + type: 'string', + nullable: false, + enum: ['N', 'S'], + }, + longitudinal: { + type: 'string', + nullable: false, + enum: ['E', 'W'], + }, + }, + }, + isCapital: { + type: 'boolean', + description: "This doesn't make a lot of sense but it's a demo", + nullable: false, + }, + elevation: { + type: 'integer', + format: 'float', + description: 'Average elevation', + nullable: false, + }, + }, + }, +}; diff --git a/packages/ai/__tests__/service.test.ts b/packages/ai/__tests__/service.test.ts new file mode 100644 index 0000000000..9f9503f2c9 --- /dev/null +++ b/packages/ai/__tests__/service.test.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it } from '@jest/globals'; +import { type ReactNativeFirebase } from '@react-native-firebase/app'; +import { DEFAULT_LOCATION } from '../lib/constants'; +import { VertexAIService } from '../lib/service'; + +const fakeApp = { + name: 'DEFAULT', + options: { + apiKey: 'key', + projectId: 'my-project', + }, +} as ReactNativeFirebase.FirebaseApp; + +describe('VertexAIService', () => { + it('uses default location if not specified', () => { + const vertexAI = new VertexAIService(fakeApp); + expect(vertexAI.location).toBe(DEFAULT_LOCATION); + }); + + it('uses custom location if specified', () => { + const vertexAI = new VertexAIService( + fakeApp, + /* authProvider */ undefined, + /* appCheckProvider */ undefined, + { location: 'somewhere' }, + ); + expect(vertexAI.location).toBe('somewhere'); + }); +}); diff --git a/packages/ai/__tests__/stream-reader.test.ts b/packages/ai/__tests__/stream-reader.test.ts new file mode 100644 index 0000000000..4a5ae8aef5 --- /dev/null +++ b/packages/ai/__tests__/stream-reader.test.ts @@ -0,0 +1,370 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it, jest, afterEach, beforeAll } from '@jest/globals'; +import { ReadableStream } from 'web-streams-polyfill'; +import { + aggregateResponses, + getResponseStream, + processStream, +} from '../lib/requests/stream-reader'; + +import { getChunkedStream, getMockResponseStreaming } from './test-utils/mock-response'; +import { + BlockReason, + FinishReason, + GenerateContentResponse, + HarmCategory, + HarmProbability, + SafetyRating, + VertexAIErrorCode, +} from '../lib/types'; +import { VertexAIError } from '../lib/errors'; + +describe('stream-reader', () => { + describe('getResponseStream', () => { + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('two lines', async () => { + const src = [{ text: 'A' }, { text: 'B' }]; + const inputStream = getChunkedStream( + src + .map(v => JSON.stringify(v)) + .map(v => 'data: ' + v + '\r\n\r\n') + .join(''), + ); + + const decodeStream = new ReadableStream({ + async start(controller) { + const reader = inputStream.getReader(); + const decoder = new TextDecoder('utf-8'); + while (true) { + const { done, value } = await reader.read(); + if (done) { + controller.close(); + break; + } + const decodedValue = decoder.decode(value, { stream: true }); + controller.enqueue(decodedValue); + } + }, + }); + + const responseStream = getResponseStream<{ text: string }>(decodeStream); + const reader = responseStream.getReader(); + const responses: Array<{ text: string }> = []; + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + responses.push(value); + } + expect(responses).toEqual(src); + }); + }); + + describe('processStream', () => { + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('streaming response - short', async () => { + const fakeResponse = getMockResponseStreaming('streaming-success-basic-reply-short.txt'); + const result = processStream(fakeResponse as Response); + for await (const response of result.stream) { + expect(response.text()).not.toBe(''); + } + const aggregatedResponse = await result.response; + expect(aggregatedResponse.text()).toContain('Cheyenne'); + }); + + it('streaming response - functioncall', async () => { + const fakeResponse = getMockResponseStreaming('streaming-success-function-call-short.txt'); + const result = processStream(fakeResponse as Response); + for await (const response of result.stream) { + expect(response.text()).toBe(''); + expect(response.functionCalls()).toEqual([ + { + name: 'getTemperature', + args: { city: 'San Jose' }, + }, + ]); + } + const aggregatedResponse = await result.response; + expect(aggregatedResponse.text()).toBe(''); + expect(aggregatedResponse.functionCalls()).toEqual([ + { + name: 'getTemperature', + args: { city: 'San Jose' }, + }, + ]); + }); + + it('handles citations', async () => { + const fakeResponse = getMockResponseStreaming('streaming-success-citations.txt'); + const result = processStream(fakeResponse as Response); + const aggregatedResponse = await result.response; + expect(aggregatedResponse.text()).toContain('Quantum mechanics is'); + expect(aggregatedResponse.candidates?.[0]!.citationMetadata?.citations.length).toBe(3); + let foundCitationMetadata = false; + for await (const response of result.stream) { + expect(response.text()).not.toBe(''); + if (response.candidates?.[0]?.citationMetadata) { + foundCitationMetadata = true; + } + } + expect(foundCitationMetadata).toBe(true); + }); + + it('removes empty text parts', async () => { + const fakeResponse = getMockResponseStreaming('streaming-success-empty-text-part.txt'); + const result = processStream(fakeResponse as Response); + const aggregatedResponse = await result.response; + expect(aggregatedResponse.text()).toBe('1'); + expect(aggregatedResponse.candidates?.length).toBe(1); + expect(aggregatedResponse.candidates?.[0]?.content.parts.length).toBe(1); + + // The chunk with the empty text part will still go through the stream + let numChunks = 0; + for await (const _ of result.stream) { + numChunks++; + } + expect(numChunks).toBe(2); + }); + }); + + describe('aggregateResponses', () => { + it('handles no candidates, and promptFeedback', () => { + const responsesToAggregate: GenerateContentResponse[] = [ + { + promptFeedback: { + blockReason: BlockReason.SAFETY, + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + probability: HarmProbability.LOW, + } as SafetyRating, + ], + }, + }, + ]; + const response = aggregateResponses(responsesToAggregate); + expect(response.candidates).toBeUndefined(); + expect(response.promptFeedback?.blockReason).toBe(BlockReason.SAFETY); + }); + + describe('multiple responses, has candidates', () => { + let response: GenerateContentResponse; + beforeAll(() => { + const responsesToAggregate: GenerateContentResponse[] = [ + { + candidates: [ + { + index: 0, + content: { + role: 'user', + parts: [{ text: 'hello.' }], + }, + finishReason: FinishReason.STOP, + finishMessage: 'something', + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + probability: HarmProbability.NEGLIGIBLE, + } as SafetyRating, + ], + }, + ], + promptFeedback: { + blockReason: BlockReason.SAFETY, + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + probability: HarmProbability.LOW, + } as SafetyRating, + ], + }, + }, + { + candidates: [ + { + index: 0, + content: { + role: 'user', + parts: [{ text: 'angry stuff' }], + }, + finishReason: FinishReason.STOP, + finishMessage: 'something', + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + probability: HarmProbability.NEGLIGIBLE, + } as SafetyRating, + ], + citationMetadata: { + citations: [ + { + startIndex: 0, + endIndex: 20, + uri: 'sourceurl', + license: '', + }, + ], + }, + }, + ], + promptFeedback: { + blockReason: BlockReason.OTHER, + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + probability: HarmProbability.HIGH, + } as SafetyRating, + ], + }, + }, + { + candidates: [ + { + index: 0, + content: { + role: 'user', + parts: [{ text: '...more stuff' }], + }, + finishReason: FinishReason.MAX_TOKENS, + finishMessage: 'too many tokens', + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + probability: HarmProbability.MEDIUM, + } as SafetyRating, + ], + citationMetadata: { + citations: [ + { + startIndex: 0, + endIndex: 20, + uri: 'sourceurl', + license: '', + }, + { + startIndex: 150, + endIndex: 155, + uri: 'sourceurl', + license: '', + }, + ], + }, + }, + ], + promptFeedback: { + blockReason: BlockReason.OTHER, + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + probability: HarmProbability.HIGH, + } as SafetyRating, + ], + }, + }, + ]; + response = aggregateResponses(responsesToAggregate); + }); + + it('aggregates text across responses', () => { + expect(response.candidates?.length).toBe(1); + expect(response.candidates?.[0]!.content.parts.map(({ text }) => text)).toEqual([ + 'hello.', + 'angry stuff', + '...more stuff', + ]); + }); + + it("takes the last response's promptFeedback", () => { + expect(response.promptFeedback?.blockReason).toBe(BlockReason.OTHER); + }); + + it("takes the last response's finishReason", () => { + expect(response.candidates?.[0]!.finishReason).toBe(FinishReason.MAX_TOKENS); + }); + + it("takes the last response's finishMessage", () => { + expect(response.candidates?.[0]!.finishMessage).toBe('too many tokens'); + }); + + it("takes the last response's candidate safetyRatings", () => { + expect(response.candidates?.[0]!.safetyRatings?.[0]!.category).toBe( + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + ); + expect(response.candidates?.[0]!.safetyRatings?.[0]!.probability).toBe( + HarmProbability.MEDIUM, + ); + }); + + it('collects all citations into one array', () => { + expect(response.candidates?.[0]!.citationMetadata?.citations.length).toBe(2); + expect(response.candidates?.[0]!.citationMetadata?.citations[0]!.startIndex).toBe(0); + expect(response.candidates?.[0]!.citationMetadata?.citations[1]!.startIndex).toBe(150); + }); + + it('throws if a part has no properties', () => { + const responsesToAggregate: GenerateContentResponse[] = [ + { + candidates: [ + { + index: 0, + content: { + role: 'user', + parts: [{} as any], // Empty + }, + finishReason: FinishReason.STOP, + finishMessage: 'something', + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + probability: HarmProbability.NEGLIGIBLE, + } as SafetyRating, + ], + }, + ], + promptFeedback: { + blockReason: BlockReason.SAFETY, + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + probability: HarmProbability.LOW, + } as SafetyRating, + ], + }, + }, + ]; + + try { + aggregateResponses(responsesToAggregate); + } catch (e) { + expect((e as VertexAIError).code).toBe(VertexAIErrorCode.INVALID_CONTENT); + expect((e as VertexAIError).message).toContain( + 'Part should have at least one property, but there are none. This is likely caused ' + + 'by a malformed response from the backend.', + ); + } + }); + }); + }); +}); diff --git a/packages/ai/__tests__/test-utils/convert-mocks.ts b/packages/ai/__tests__/test-utils/convert-mocks.ts new file mode 100644 index 0000000000..97a5ed75df --- /dev/null +++ b/packages/ai/__tests__/test-utils/convert-mocks.ts @@ -0,0 +1,67 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// eslint-disable-next-line @typescript-eslint/no-require-imports +const fs = require('fs'); +// eslint-disable-next-line @typescript-eslint/no-require-imports +const { join } = require('path'); + +function findMockResponseDir(): string { + const directories = fs + .readdirSync(__dirname, { withFileTypes: true }) + .filter( + (dirent: any) => dirent.isDirectory() && dirent.name.startsWith('vertexai-sdk-test-data'), + ) + .map((dirent: any) => dirent.name); + + if (directories.length === 0) { + throw new Error('No directory starting with "vertexai-sdk-test-data*" found.'); + } + + if (directories.length > 1) { + throw new Error('Multiple directories starting with "vertexai-sdk-test-data*" found'); + } + + return join(__dirname, directories[0], 'mock-responses', 'vertexai'); +} + +async function main(): Promise { + const mockResponseDir = findMockResponseDir(); + const list = fs.readdirSync(mockResponseDir); + const lookup: Record = {}; + // eslint-disable-next-line guard-for-in + for (const fileName of list) { + console.log(`attempting to read ${mockResponseDir}/${fileName}`) + const fullText = fs.readFileSync(join(mockResponseDir, fileName), 'utf-8'); + lookup[fileName] = fullText; + } + let fileText = `// Generated from mocks text files.`; + + fileText += '\n\n'; + fileText += `export const mocksLookup: Record = ${JSON.stringify( + lookup, + null, + 2, + )}`; + fileText += ';\n'; + fs.writeFileSync(join(__dirname, 'mocks-lookup.ts'), fileText, 'utf-8'); +} + +main().catch(e => { + console.error(e); + process.exit(1); +}); diff --git a/packages/ai/__tests__/test-utils/mock-response.ts b/packages/ai/__tests__/test-utils/mock-response.ts new file mode 100644 index 0000000000..52eb0eb04e --- /dev/null +++ b/packages/ai/__tests__/test-utils/mock-response.ts @@ -0,0 +1,69 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { ReadableStream } from 'web-streams-polyfill'; +import { mocksLookup } from './mocks-lookup'; + +/** + * Mock native Response.body + * Streams contents of json file in 20 character chunks + */ +export function getChunkedStream(input: string, chunkLength = 20): ReadableStream { + const encoder = new TextEncoder(); + let currentChunkStart = 0; + + const stream = new ReadableStream({ + start(controller) { + while (currentChunkStart < input.length) { + const substring = input.slice(currentChunkStart, currentChunkStart + chunkLength); + currentChunkStart += chunkLength; + const chunk = encoder.encode(substring); + controller.enqueue(chunk); + } + controller.close(); + }, + }); + + return stream; +} +export function getMockResponseStreaming( + filename: string, + chunkLength: number = 20, +): Partial { + const fullText = mocksLookup[filename]; + + return { + + // Really tangled typescript error here from our transitive dependencies. + // Ignoring it now, but uncomment and run `yarn lerna:prepare` in top-level + // of the repo to see if you get it or if it has gone away. + // + // last stack frame of the error is from node_modules/undici-types/fetch.d.ts + // + // > Property 'value' is optional in type 'ReadableStreamReadDoneResult' but required in type '{ done: true; value: T | undefined; }'. + // + // @ts-ignore + body: getChunkedStream(fullText!, chunkLength), + }; +} + +export function getMockResponse(filename: string): Partial { + const fullText = mocksLookup[filename]; + return { + ok: true, + json: () => Promise.resolve(JSON.parse(fullText!)), + }; +} diff --git a/packages/ai/e2e/fetch.e2e.js b/packages/ai/e2e/fetch.e2e.js new file mode 100644 index 0000000000..de8832ac56 --- /dev/null +++ b/packages/ai/e2e/fetch.e2e.js @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2016-present Invertase Limited & Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this library except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +import { getGenerativeModel } from '../lib/index'; + +const fakeVertexAI = { + app: { + name: 'DEFAULT', + options: { + appId: 'appId', + projectId: 'my-project', + apiKey: 'key', + }, + }, + location: 'us-central1', +}; +// See emulator setup: packages/vertexai/lib/requests/request.ts +globalThis.RNFB_VERTEXAI_EMULATOR_URL = true; + +// It calls firebase functions emulator that mimics responses from VertexAI server +describe('fetch requests()', function () { + it('should fetch', async function () { + const model = getGenerativeModel(fakeVertexAI, { model: 'gemini-1.5-flash' }); + const result = await model.generateContent("What is google's mission statement?"); + const text = result.response.text(); + // See vertexAI function emulator for response + text.should.containEql( + 'Google\'s mission is to "organize the world\'s information and make it universally accessible and useful."', + ); + }); + + it('should fetch stream', async function () { + const model = getGenerativeModel(fakeVertexAI, { model: 'gemini-1.5-flash' }); + // See vertexAI function emulator for response + const poem = [ + 'The wind whispers secrets through the trees,', + 'Rustling leaves in a gentle breeze.', + 'Sunlight dances on the grass,', + 'A fleeting moment, sure to pass.', + 'Birdsong fills the air so bright,', + 'A symphony of pure delight.', + 'Time stands still, a peaceful pause,', + "In nature's beauty, no flaws.", + ]; + const result = await model.generateContentStream('Write me a short poem'); + + const text = []; + for await (const chunk of result.stream) { + const chunkText = chunk.text(); + text.push(chunkText); + } + text.should.deepEqual(poem); + }); +}); diff --git a/packages/ai/lib/constants.ts b/packages/ai/lib/constants.ts new file mode 100644 index 0000000000..816f5194a2 --- /dev/null +++ b/packages/ai/lib/constants.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { version } from './version'; + +export const VERTEX_TYPE = 'vertexAI'; + +export const DEFAULT_LOCATION = 'us-central1'; + +export const DEFAULT_BASE_URL = 'https://firebasevertexai.googleapis.com'; + +// This is the default API version for the VertexAI API. At some point, should be able to change when the feature becomes available. +// `v1beta` & `stable` available: https://cloud.google.com/vertex-ai/docs/reference#versions +export const DEFAULT_API_VERSION = 'v1beta'; + +export const PACKAGE_VERSION = version; + +export const LANGUAGE_TAG = 'gl-rn'; + +// Timeout is 180s by default +export const DEFAULT_FETCH_TIMEOUT_MS = 180 * 1000; diff --git a/packages/ai/lib/errors.ts b/packages/ai/lib/errors.ts new file mode 100644 index 0000000000..370c19aeb0 --- /dev/null +++ b/packages/ai/lib/errors.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { FirebaseError } from '@firebase/util'; +import { VertexAIErrorCode, CustomErrorData } from './types'; +import { VERTEX_TYPE } from './constants'; + +/** + * Error class for the Vertex AI in Firebase SDK. + * + * @public + */ +export class VertexAIError extends FirebaseError { + /** + * Constructs a new instance of the `VertexAIError` class. + * + * @param code - The error code from {@link VertexAIErrorCode}. + * @param message - A human-readable message describing the error. + * @param customErrorData - Optional error data. + */ + constructor( + readonly code: VertexAIErrorCode, + message: string, + readonly customErrorData?: CustomErrorData, + ) { + // Match error format used by FirebaseError from ErrorFactory + const service = VERTEX_TYPE; + const serviceName = 'VertexAI'; + const fullCode = `${service}/${code}`; + const fullMessage = `${serviceName}: ${message} (${fullCode})`; + super(code, fullMessage); + + Object.setPrototypeOf(this, VertexAIError.prototype); + + // Since Error is an interface, we don't inherit toString and so we define it ourselves. + this.toString = () => fullMessage; + } +} diff --git a/packages/ai/lib/index.ts b/packages/ai/lib/index.ts new file mode 100644 index 0000000000..580f1bc86b --- /dev/null +++ b/packages/ai/lib/index.ts @@ -0,0 +1,73 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import './polyfills'; +import { getApp, ReactNativeFirebase } from '@react-native-firebase/app'; +import { ModelParams, RequestOptions, VertexAIErrorCode } from './types'; +import { DEFAULT_LOCATION } from './constants'; +import { VertexAI, VertexAIOptions } from './public-types'; +import { VertexAIError } from './errors'; +import { GenerativeModel } from './models/generative-model'; +import { VertexAIService } from './service'; +export { ChatSession } from './methods/chat-session'; +export * from './requests/schema-builder'; + +export { GenerativeModel }; + +export { VertexAIError }; + +/** + * Returns a {@link VertexAI} instance for the given app. + * + * @public + * + * @param app - The {@link @FirebaseApp} to use. + * @param options - The {@link VertexAIOptions} to use. + * @param appCheck - The {@link @AppCheck} to use. + * @param auth - The {@link @Auth} to use. + */ +export function getVertexAI( + app: ReactNativeFirebase.FirebaseApp = getApp(), + options?: VertexAIOptions, +): VertexAI { + return { + app, + location: options?.location || DEFAULT_LOCATION, + appCheck: options?.appCheck || null, + auth: options?.auth || null, + } as VertexAIService; +} + +/** + * Returns a {@link GenerativeModel} class with methods for inference + * and other functionality. + * + * @public + */ +export function getGenerativeModel( + vertexAI: VertexAI, + modelParams: ModelParams, + requestOptions?: RequestOptions, +): GenerativeModel { + if (!modelParams.model) { + throw new VertexAIError( + VertexAIErrorCode.NO_MODEL, + `Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`, + ); + } + return new GenerativeModel(vertexAI, modelParams, requestOptions); +} diff --git a/packages/ai/lib/logger.ts b/packages/ai/lib/logger.ts new file mode 100644 index 0000000000..dbc3e84059 --- /dev/null +++ b/packages/ai/lib/logger.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// @ts-ignore +import { Logger } from '@react-native-firebase/app/lib/internal/logger'; + +export const logger = new Logger('@firebase/vertexai'); diff --git a/packages/ai/lib/methods/chat-session-helpers.ts b/packages/ai/lib/methods/chat-session-helpers.ts new file mode 100644 index 0000000000..4b9bb56db0 --- /dev/null +++ b/packages/ai/lib/methods/chat-session-helpers.ts @@ -0,0 +1,116 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Content, POSSIBLE_ROLES, Part, Role, VertexAIErrorCode } from '../types'; +import { VertexAIError } from '../errors'; + +// https://ai.google.dev/api/rest/v1beta/Content#part + +const VALID_PART_FIELDS: Array = [ + 'text', + 'inlineData', + 'functionCall', + 'functionResponse', +]; + +const VALID_PARTS_PER_ROLE: { [key in Role]: Array } = { + user: ['text', 'inlineData'], + function: ['functionResponse'], + model: ['text', 'functionCall'], + // System instructions shouldn't be in history anyway. + system: ['text'], +}; + +const VALID_PREVIOUS_CONTENT_ROLES: { [key in Role]: Role[] } = { + user: ['model'], + function: ['model'], + model: ['user', 'function'], + // System instructions shouldn't be in history. + system: [], +}; + +export function validateChatHistory(history: Content[]): void { + let prevContent: Content | null = null; + for (const currContent of history) { + const { role, parts } = currContent; + if (!prevContent && role !== 'user') { + throw new VertexAIError( + VertexAIErrorCode.INVALID_CONTENT, + `First Content should be with role 'user', got ${role}`, + ); + } + if (!POSSIBLE_ROLES.includes(role)) { + throw new VertexAIError( + VertexAIErrorCode.INVALID_CONTENT, + `Each item should include role field. Got ${role} but valid roles are: ${JSON.stringify( + POSSIBLE_ROLES, + )}`, + ); + } + + if (!Array.isArray(parts)) { + throw new VertexAIError( + VertexAIErrorCode.INVALID_CONTENT, + `Content should have 'parts' but property with an array of Parts`, + ); + } + + if (parts.length === 0) { + throw new VertexAIError( + VertexAIErrorCode.INVALID_CONTENT, + `Each Content should have at least one part`, + ); + } + + const countFields: Record = { + text: 0, + inlineData: 0, + functionCall: 0, + functionResponse: 0, + }; + + for (const part of parts) { + for (const key of VALID_PART_FIELDS) { + if (key in part) { + countFields[key] += 1; + } + } + } + const validParts = VALID_PARTS_PER_ROLE[role]; + for (const key of VALID_PART_FIELDS) { + if (!validParts.includes(key) && countFields[key] > 0) { + throw new VertexAIError( + VertexAIErrorCode.INVALID_CONTENT, + `Content with role '${role}' can't contain '${key}' part`, + ); + } + } + + if (prevContent) { + const validPreviousContentRoles = VALID_PREVIOUS_CONTENT_ROLES[role]; + if (!validPreviousContentRoles.includes(prevContent.role)) { + throw new VertexAIError( + VertexAIErrorCode.INVALID_CONTENT, + `Content with role '${role} can't follow '${ + prevContent.role + }'. Valid previous roles: ${JSON.stringify(VALID_PREVIOUS_CONTENT_ROLES)}`, + ); + } + } + prevContent = currContent; + } +} diff --git a/packages/ai/lib/methods/chat-session.ts b/packages/ai/lib/methods/chat-session.ts new file mode 100644 index 0000000000..e3e9cf905f --- /dev/null +++ b/packages/ai/lib/methods/chat-session.ts @@ -0,0 +1,182 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + Content, + GenerateContentRequest, + GenerateContentResult, + GenerateContentStreamResult, + Part, + RequestOptions, + StartChatParams, + EnhancedGenerateContentResponse, +} from '../types'; +import { formatNewContent } from '../requests/request-helpers'; +import { formatBlockErrorMessage } from '../requests/response-helpers'; +import { validateChatHistory } from './chat-session-helpers'; +import { generateContent, generateContentStream } from './generate-content'; +import { ApiSettings } from '../types/internal'; +import { logger } from '../logger'; + +/** + * Do not log a message for this error. + */ +const SILENT_ERROR = 'SILENT_ERROR'; + +/** + * ChatSession class that enables sending chat messages and stores + * history of sent and received messages so far. + * + * @public + */ +export class ChatSession { + private _apiSettings: ApiSettings; + private _history: Content[] = []; + private _sendPromise: Promise = Promise.resolve(); + + constructor( + apiSettings: ApiSettings, + public model: string, + public params?: StartChatParams, + public requestOptions?: RequestOptions, + ) { + this._apiSettings = apiSettings; + if (params?.history) { + validateChatHistory(params.history); + this._history = params.history; + } + } + + /** + * Gets the chat history so far. Blocked prompts are not added to history. + * Neither blocked candidates nor the prompts that generated them are added + * to history. + */ + async getHistory(): Promise { + await this._sendPromise; + return this._history; + } + + /** + * Sends a chat message and receives a non-streaming + * {@link GenerateContentResult} + */ + async sendMessage(request: string | Array): Promise { + await this._sendPromise; + const newContent = formatNewContent(request); + const generateContentRequest: GenerateContentRequest = { + safetySettings: this.params?.safetySettings, + generationConfig: this.params?.generationConfig, + tools: this.params?.tools, + toolConfig: this.params?.toolConfig, + systemInstruction: this.params?.systemInstruction, + contents: [...this._history, newContent], + }; + let finalResult = {} as GenerateContentResult; + // Add onto the chain. + this._sendPromise = this._sendPromise + .then(() => + generateContent(this._apiSettings, this.model, generateContentRequest, this.requestOptions), + ) + .then((result: GenerateContentResult) => { + if (result.response.candidates && result.response.candidates.length > 0) { + this._history.push(newContent); + const responseContent: Content = { + parts: result.response.candidates?.[0]?.content.parts || [], + // Response seems to come back without a role set. + role: result.response.candidates?.[0]?.content.role || 'model', + }; + this._history.push(responseContent); + } else { + const blockErrorMessage = formatBlockErrorMessage(result.response); + if (blockErrorMessage) { + logger.warn( + `sendMessage() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.`, + ); + } + } + finalResult = result; + }); + await this._sendPromise; + return finalResult; + } + + /** + * Sends a chat message and receives the response as a + * {@link GenerateContentStreamResult} containing an iterable stream + * and a response promise. + */ + async sendMessageStream( + request: string | Array, + ): Promise { + await this._sendPromise; + const newContent = formatNewContent(request); + const generateContentRequest: GenerateContentRequest = { + safetySettings: this.params?.safetySettings, + generationConfig: this.params?.generationConfig, + tools: this.params?.tools, + toolConfig: this.params?.toolConfig, + systemInstruction: this.params?.systemInstruction, + contents: [...this._history, newContent], + }; + const streamPromise = generateContentStream( + this._apiSettings, + this.model, + generateContentRequest, + this.requestOptions, + ); + + // Add onto the chain. + this._sendPromise = this._sendPromise + .then(() => streamPromise) + // This must be handled to avoid unhandled rejection, but jump + // to the final catch block with a label to not log this error. + .catch(_ignored => { + throw new Error(SILENT_ERROR); + }) + .then(streamResult => streamResult.response) + .then((response: EnhancedGenerateContentResponse) => { + if (response.candidates && response.candidates.length > 0) { + this._history.push(newContent); + const responseContent = { ...response.candidates[0]?.content }; + // Response seems to come back without a role set. + if (!responseContent.role) { + responseContent.role = 'model'; + } + this._history.push(responseContent as Content); + } else { + const blockErrorMessage = formatBlockErrorMessage(response); + if (blockErrorMessage) { + logger.warn( + `sendMessageStream() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.`, + ); + } + } + }) + .catch(e => { + // Errors in streamPromise are already catchable by the user as + // streamPromise is returned. + // Avoid duplicating the error message in logs. + if (e.message !== SILENT_ERROR) { + // Users do not have access to _sendPromise to catch errors + // downstream from streamPromise, so they should not throw. + logger.error(e); + } + }); + return streamPromise; + } +} diff --git a/packages/ai/lib/methods/count-tokens.ts b/packages/ai/lib/methods/count-tokens.ts new file mode 100644 index 0000000000..10d41cffa8 --- /dev/null +++ b/packages/ai/lib/methods/count-tokens.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { CountTokensRequest, CountTokensResponse, RequestOptions } from '../types'; +import { Task, makeRequest } from '../requests/request'; +import { ApiSettings } from '../types/internal'; + +export async function countTokens( + apiSettings: ApiSettings, + model: string, + params: CountTokensRequest, + requestOptions?: RequestOptions, +): Promise { + const response = await makeRequest( + model, + Task.COUNT_TOKENS, + apiSettings, + false, + JSON.stringify(params), + requestOptions, + ); + return response.json(); +} diff --git a/packages/ai/lib/methods/generate-content.ts b/packages/ai/lib/methods/generate-content.ts new file mode 100644 index 0000000000..6d1a6ecb27 --- /dev/null +++ b/packages/ai/lib/methods/generate-content.ts @@ -0,0 +1,66 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + GenerateContentRequest, + GenerateContentResponse, + GenerateContentResult, + GenerateContentStreamResult, + RequestOptions, +} from '../types'; +import { Task, makeRequest } from '../requests/request'; +import { createEnhancedContentResponse } from '../requests/response-helpers'; +import { processStream } from '../requests/stream-reader'; +import { ApiSettings } from '../types/internal'; + +export async function generateContentStream( + apiSettings: ApiSettings, + model: string, + params: GenerateContentRequest, + requestOptions?: RequestOptions, +): Promise { + const response = await makeRequest( + model, + Task.STREAM_GENERATE_CONTENT, + apiSettings, + /* stream */ true, + JSON.stringify(params), + requestOptions, + ); + return processStream(response); +} + +export async function generateContent( + apiSettings: ApiSettings, + model: string, + params: GenerateContentRequest, + requestOptions?: RequestOptions, +): Promise { + const response = await makeRequest( + model, + Task.GENERATE_CONTENT, + apiSettings, + /* stream */ false, + JSON.stringify(params), + requestOptions, + ); + const responseJson: GenerateContentResponse = await response.json(); + const enhancedResponse = createEnhancedContentResponse(responseJson); + return { + response: enhancedResponse, + }; +} diff --git a/packages/ai/lib/models/generative-model.ts b/packages/ai/lib/models/generative-model.ts new file mode 100644 index 0000000000..111cefa427 --- /dev/null +++ b/packages/ai/lib/models/generative-model.ts @@ -0,0 +1,180 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { generateContent, generateContentStream } from '../methods/generate-content'; +import { + Content, + CountTokensRequest, + CountTokensResponse, + GenerateContentRequest, + GenerateContentResult, + GenerateContentStreamResult, + GenerationConfig, + ModelParams, + Part, + RequestOptions, + SafetySetting, + StartChatParams, + Tool, + ToolConfig, + VertexAIErrorCode, +} from '../types'; +import { VertexAIError } from '../errors'; +import { ChatSession } from '../methods/chat-session'; +import { countTokens } from '../methods/count-tokens'; +import { formatGenerateContentInput, formatSystemInstruction } from '../requests/request-helpers'; +import { VertexAI } from '../public-types'; +import { ApiSettings } from '../types/internal'; +import { VertexAIService } from '../service'; + +/** + * Class for generative model APIs. + * @public + */ +export class GenerativeModel { + private _apiSettings: ApiSettings; + model: string; + generationConfig: GenerationConfig; + safetySettings: SafetySetting[]; + requestOptions?: RequestOptions; + tools?: Tool[]; + toolConfig?: ToolConfig; + systemInstruction?: Content; + + constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions) { + if (!vertexAI.app?.options?.apiKey) { + throw new VertexAIError( + VertexAIErrorCode.NO_API_KEY, + `The "apiKey" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid API key.`, + ); + } else if (!vertexAI.app?.options?.projectId) { + throw new VertexAIError( + VertexAIErrorCode.NO_PROJECT_ID, + `The "projectId" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid project ID.`, + ); + } else { + this._apiSettings = { + apiKey: vertexAI.app.options.apiKey, + project: vertexAI.app.options.projectId, + location: vertexAI.location, + }; + if ((vertexAI as VertexAIService).appCheck) { + this._apiSettings.getAppCheckToken = () => + (vertexAI as VertexAIService).appCheck!.getToken(); + } + + if ((vertexAI as VertexAIService).auth?.currentUser) { + this._apiSettings.getAuthToken = () => + (vertexAI as VertexAIService).auth!.currentUser!.getIdToken(); + } + } + if (modelParams.model.includes('/')) { + if (modelParams.model.startsWith('models/')) { + // Add "publishers/google" if the user is only passing in 'models/model-name'. + this.model = `publishers/google/${modelParams.model}`; + } else { + // Any other custom format (e.g. tuned models) must be passed in correctly. + this.model = modelParams.model; + } + } else { + // If path is not included, assume it's a non-tuned model. + this.model = `publishers/google/models/${modelParams.model}`; + } + this.generationConfig = modelParams.generationConfig || {}; + this.safetySettings = modelParams.safetySettings || []; + this.tools = modelParams.tools; + this.toolConfig = modelParams.toolConfig; + this.systemInstruction = formatSystemInstruction(modelParams.systemInstruction); + this.requestOptions = requestOptions || {}; + } + + /** + * Makes a single non-streaming call to the model + * and returns an object containing a single {@link GenerateContentResponse}. + */ + async generateContent( + request: GenerateContentRequest | string | Array, + ): Promise { + const formattedParams = formatGenerateContentInput(request); + return generateContent( + this._apiSettings, + this.model, + { + generationConfig: this.generationConfig, + safetySettings: this.safetySettings, + tools: this.tools, + toolConfig: this.toolConfig, + systemInstruction: this.systemInstruction, + ...formattedParams, + }, + this.requestOptions, + ); + } + + /** + * Makes a single streaming call to the model + * and returns an object containing an iterable stream that iterates + * over all chunks in the streaming response as well as + * a promise that returns the final aggregated response. + */ + async generateContentStream( + request: GenerateContentRequest | string | Array, + ): Promise { + const formattedParams = formatGenerateContentInput(request); + return generateContentStream( + this._apiSettings, + this.model, + { + generationConfig: this.generationConfig, + safetySettings: this.safetySettings, + tools: this.tools, + toolConfig: this.toolConfig, + systemInstruction: this.systemInstruction, + ...formattedParams, + }, + this.requestOptions, + ); + } + + /** + * Gets a new {@link ChatSession} instance which can be used for + * multi-turn chats. + */ + startChat(startChatParams?: StartChatParams): ChatSession { + return new ChatSession( + this._apiSettings, + this.model, + { + tools: this.tools, + toolConfig: this.toolConfig, + systemInstruction: this.systemInstruction, + ...startChatParams, + }, + this.requestOptions, + ); + } + + /** + * Counts the tokens in the provided request. + */ + async countTokens( + request: CountTokensRequest | string | Array, + ): Promise { + const formattedParams = formatGenerateContentInput(request); + return countTokens(this._apiSettings, this.model, formattedParams); + } +} diff --git a/packages/ai/lib/polyfills.ts b/packages/ai/lib/polyfills.ts new file mode 100644 index 0000000000..cbe2cfecb0 --- /dev/null +++ b/packages/ai/lib/polyfills.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// @ts-ignore +import { polyfillGlobal } from 'react-native/Libraries/Utilities/PolyfillFunctions'; +// @ts-ignore +import { ReadableStream } from 'web-streams-polyfill/dist/ponyfill'; +// @ts-ignore +import { fetch, Headers, Request, Response } from 'react-native-fetch-api'; + +polyfillGlobal( + 'fetch', + () => + (...args: any[]) => + fetch(args[0], { ...args[1], reactNative: { textStreaming: true } }), +); +polyfillGlobal('Headers', () => Headers); +polyfillGlobal('Request', () => Request); +polyfillGlobal('Response', () => Response); +polyfillGlobal('ReadableStream', () => ReadableStream); + +import 'text-encoding'; diff --git a/packages/ai/lib/public-types.ts b/packages/ai/lib/public-types.ts new file mode 100644 index 0000000000..24c6be6efa --- /dev/null +++ b/packages/ai/lib/public-types.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { ReactNativeFirebase } from '@react-native-firebase/app'; +import { FirebaseAuthTypes } from '@react-native-firebase/auth'; +import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; + +export * from './types'; + +/** + * An instance of the Vertex AI in Firebase SDK. + * @public + */ +export interface VertexAI { + /** + * The {@link @firebase/app#FirebaseApp} this {@link VertexAI} instance is associated with. + */ + app: ReactNativeFirebase.FirebaseApp; + location: string; + appCheck?: FirebaseAppCheckTypes.Module | null; + auth?: FirebaseAuthTypes.Module | null; +} + +/** + * Options when initializing the Vertex AI in Firebase SDK. + * @public + */ +export interface VertexAIOptions { + location?: string; + appCheck?: FirebaseAppCheckTypes.Module | null; + auth?: FirebaseAuthTypes.Module | null; +} diff --git a/packages/ai/lib/requests/request-helpers.ts b/packages/ai/lib/requests/request-helpers.ts new file mode 100644 index 0000000000..9de045a4ee --- /dev/null +++ b/packages/ai/lib/requests/request-helpers.ts @@ -0,0 +1,116 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Content, GenerateContentRequest, Part, VertexAIErrorCode } from '../types'; +import { VertexAIError } from '../errors'; + +export function formatSystemInstruction(input?: string | Part | Content): Content | undefined { + if (input == null) { + return undefined; + } else if (typeof input === 'string') { + return { role: 'system', parts: [{ text: input }] } as Content; + } else if ((input as Part).text) { + return { role: 'system', parts: [input as Part] }; + } else if ((input as Content).parts) { + if (!(input as Content).role) { + return { role: 'system', parts: (input as Content).parts }; + } else { + return input as Content; + } + } + + return undefined; +} + +export function formatNewContent(request: string | Array): Content { + let newParts: Part[] = []; + if (typeof request === 'string') { + newParts = [{ text: request }]; + } else { + for (const partOrString of request) { + if (typeof partOrString === 'string') { + newParts.push({ text: partOrString }); + } else { + newParts.push(partOrString); + } + } + } + return assignRoleToPartsAndValidateSendMessageRequest(newParts); +} + +/** + * When multiple Part types (i.e. FunctionResponsePart and TextPart) are + * passed in a single Part array, we may need to assign different roles to each + * part. Currently only FunctionResponsePart requires a role other than 'user'. + * @private + * @param parts Array of parts to pass to the model + * @returns Array of content items + */ +function assignRoleToPartsAndValidateSendMessageRequest(parts: Part[]): Content { + const userContent: Content = { role: 'user', parts: [] }; + const functionContent: Content = { role: 'function', parts: [] }; + let hasUserContent = false; + let hasFunctionContent = false; + for (const part of parts) { + if ('functionResponse' in part) { + functionContent.parts.push(part); + hasFunctionContent = true; + } else { + userContent.parts.push(part); + hasUserContent = true; + } + } + + if (hasUserContent && hasFunctionContent) { + throw new VertexAIError( + VertexAIErrorCode.INVALID_CONTENT, + 'Within a single message, FunctionResponse cannot be mixed with other type of Part in the request for sending chat message.', + ); + } + + if (!hasUserContent && !hasFunctionContent) { + throw new VertexAIError( + VertexAIErrorCode.INVALID_CONTENT, + 'No Content is provided for sending chat message.', + ); + } + + if (hasUserContent) { + return userContent; + } + + return functionContent; +} + +export function formatGenerateContentInput( + params: GenerateContentRequest | string | Array, +): GenerateContentRequest { + let formattedRequest: GenerateContentRequest; + if ((params as GenerateContentRequest).contents) { + formattedRequest = params as GenerateContentRequest; + } else { + // Array or string + const content = formatNewContent(params as string | Array); + formattedRequest = { contents: [content] }; + } + if ((params as GenerateContentRequest).systemInstruction) { + formattedRequest.systemInstruction = formatSystemInstruction( + (params as GenerateContentRequest).systemInstruction, + ); + } + return formattedRequest; +} diff --git a/packages/ai/lib/requests/request.ts b/packages/ai/lib/requests/request.ts new file mode 100644 index 0000000000..e055094f90 --- /dev/null +++ b/packages/ai/lib/requests/request.ts @@ -0,0 +1,242 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { Platform } from 'react-native'; +import { ErrorDetails, RequestOptions, VertexAIErrorCode } from '../types'; +import { VertexAIError } from '../errors'; +import { ApiSettings } from '../types/internal'; +import { + DEFAULT_API_VERSION, + DEFAULT_BASE_URL, + DEFAULT_FETCH_TIMEOUT_MS, + LANGUAGE_TAG, + PACKAGE_VERSION, +} from '../constants'; +import { logger } from '../logger'; + +export enum Task { + GENERATE_CONTENT = 'generateContent', + STREAM_GENERATE_CONTENT = 'streamGenerateContent', + COUNT_TOKENS = 'countTokens', +} + +export class RequestUrl { + constructor( + public model: string, + public task: Task, + public apiSettings: ApiSettings, + public stream: boolean, + public requestOptions?: RequestOptions, + ) {} + toString(): string { + // @ts-ignore + const isTestEnvironment = globalThis.RNFB_VERTEXAI_EMULATOR_URL; + if (isTestEnvironment) { + let emulatorUrl; + logger.info( + 'Running VertexAI in test environment, pointing to Firebase Functions emulator URL', + ); + const isAndroid = Platform.OS === 'android'; + + if (this.stream) { + emulatorUrl = `http://${isAndroid ? '10.0.2.2' : '127.0.0.1'}:5001/react-native-firebase-testing/us-central1/testFetchStream`; + } else { + emulatorUrl = `http://${isAndroid ? '10.0.2.2' : '127.0.0.1'}:5001/react-native-firebase-testing/us-central1/testFetch`; + } + return emulatorUrl; + } + + const apiVersion = DEFAULT_API_VERSION; + const baseUrl = this.requestOptions?.baseUrl || DEFAULT_BASE_URL; + let url = `${baseUrl}/${apiVersion}`; + url += `/projects/${this.apiSettings.project}`; + url += `/locations/${this.apiSettings.location}`; + url += `/${this.model}`; + url += `:${this.task}`; + if (this.stream) { + url += '?alt=sse'; + } + return url; + } + + /** + * If the model needs to be passed to the backend, it needs to + * include project and location path. + */ + get fullModelString(): string { + let modelString = `projects/${this.apiSettings.project}`; + modelString += `/locations/${this.apiSettings.location}`; + modelString += `/${this.model}`; + return modelString; + } +} + +/** + * Log language and "fire/version" to x-goog-api-client + */ +function getClientHeaders(): string { + const loggingTags = []; + loggingTags.push(`${LANGUAGE_TAG}/${PACKAGE_VERSION}`); + loggingTags.push(`fire/${PACKAGE_VERSION}`); + return loggingTags.join(' '); +} + +export async function getHeaders(url: RequestUrl): Promise { + const headers = new Headers(); + headers.append('Content-Type', 'application/json'); + headers.append('x-goog-api-client', getClientHeaders()); + headers.append('x-goog-api-key', url.apiSettings.apiKey); + if (url.apiSettings.getAppCheckToken) { + let appCheckToken; + + try { + appCheckToken = await url.apiSettings.getAppCheckToken(); + } catch (e) { + logger.warn(`Unable to obtain a valid App Check token: ${e}`); + } + if (appCheckToken) { + headers.append('X-Firebase-AppCheck', appCheckToken.token); + } + } + + if (url.apiSettings.getAuthToken) { + const authToken = await url.apiSettings.getAuthToken(); + if (authToken) { + headers.append('Authorization', `Firebase ${authToken}`); + } + } + + return headers; +} + +export async function constructRequest( + model: string, + task: Task, + apiSettings: ApiSettings, + stream: boolean, + body: string, + requestOptions?: RequestOptions, +): Promise<{ url: string; fetchOptions: RequestInit }> { + const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); + return { + url: url.toString(), + fetchOptions: { + method: 'POST', + headers: await getHeaders(url), + body, + }, + }; +} + +export async function makeRequest( + model: string, + task: Task, + apiSettings: ApiSettings, + stream: boolean, + body: string, + requestOptions?: RequestOptions, +): Promise { + const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); + let response; + let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; + try { + const request = await constructRequest(model, task, apiSettings, stream, body, requestOptions); + const timeoutMillis = + requestOptions?.timeout != null && requestOptions.timeout >= 0 + ? requestOptions.timeout + : DEFAULT_FETCH_TIMEOUT_MS; + const abortController = new AbortController(); + fetchTimeoutId = setTimeout(() => abortController.abort(), timeoutMillis); + request.fetchOptions.signal = abortController.signal; + const fetchOptions = stream + ? { + ...request.fetchOptions, + reactNative: { + textStreaming: true, + }, + } + : request.fetchOptions; + response = await fetch(request.url, fetchOptions); + if (!response.ok) { + let message = ''; + let errorDetails; + try { + const json = await response.json(); + message = json.error.message; + if (json.error.details) { + message += ` ${JSON.stringify(json.error.details)}`; + errorDetails = json.error.details; + } + } catch (_) { + // ignored + } + if ( + response.status === 403 && + errorDetails.some((detail: ErrorDetails) => detail.reason === 'SERVICE_DISABLED') && + errorDetails.some((detail: ErrorDetails) => + (detail.links as Array>)?.[0]?.description?.includes( + 'Google developers console API activation', + ), + ) + ) { + throw new VertexAIError( + VertexAIErrorCode.API_NOT_ENABLED, + `The Vertex AI in Firebase SDK requires the Vertex AI in Firebase ` + + `API ('firebasevertexai.googleapis.com') to be enabled in your ` + + `Firebase project. Enable this API by visiting the Firebase Console ` + + `at https://console.firebase.google.com/project/${url.apiSettings.project}/genai/ ` + + `and clicking "Get started". If you enabled this API recently, ` + + `wait a few minutes for the action to propagate to our systems and ` + + `then retry.`, + { + status: response.status, + statusText: response.statusText, + errorDetails, + }, + ); + } + throw new VertexAIError( + VertexAIErrorCode.FETCH_ERROR, + `Error fetching from ${url}: [${response.status} ${response.statusText}] ${message}`, + { + status: response.status, + statusText: response.statusText, + errorDetails, + }, + ); + } + } catch (e) { + let err = e as Error; + if ( + (e as VertexAIError).code !== VertexAIErrorCode.FETCH_ERROR && + (e as VertexAIError).code !== VertexAIErrorCode.API_NOT_ENABLED && + e instanceof Error + ) { + err = new VertexAIError( + VertexAIErrorCode.ERROR, + `Error fetching from ${url.toString()}: ${e.message}`, + ); + err.stack = e.stack; + } + + throw err; + } finally { + if (fetchTimeoutId) { + clearTimeout(fetchTimeoutId); + } + } + return response; +} diff --git a/packages/ai/lib/requests/response-helpers.ts b/packages/ai/lib/requests/response-helpers.ts new file mode 100644 index 0000000000..c7abc9d923 --- /dev/null +++ b/packages/ai/lib/requests/response-helpers.ts @@ -0,0 +1,186 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + EnhancedGenerateContentResponse, + FinishReason, + FunctionCall, + GenerateContentCandidate, + GenerateContentResponse, + VertexAIErrorCode, +} from '../types'; +import { VertexAIError } from '../errors'; +import { logger } from '../logger'; + +/** + * Creates an EnhancedGenerateContentResponse object that has helper functions and + * other modifications that improve usability. + */ +export function createEnhancedContentResponse( + response: GenerateContentResponse, +): EnhancedGenerateContentResponse { + /** + * The Vertex AI backend omits default values. + * This causes the `index` property to be omitted from the first candidate in the + * response, since it has index 0, and 0 is a default value. + * See: https://github.com/firebase/firebase-js-sdk/issues/8566 + */ + if (response.candidates && !response.candidates[0]?.hasOwnProperty('index')) { + response.candidates[0]!.index = 0; + } + + const responseWithHelpers = addHelpers(response); + return responseWithHelpers; +} + +/** + * Adds convenience helper methods to a response object, including stream + * chunks (as long as each chunk is a complete GenerateContentResponse JSON). + */ +export function addHelpers(response: GenerateContentResponse): EnhancedGenerateContentResponse { + (response as EnhancedGenerateContentResponse).text = () => { + if (response.candidates && response.candidates.length > 0) { + if (response.candidates.length > 1) { + logger.warn( + `This response had ${response.candidates.length} ` + + `candidates. Returning text from the first candidate only. ` + + `Access response.candidates directly to use the other candidates.`, + ); + } + if (hadBadFinishReason(response.candidates[0]!)) { + throw new VertexAIError( + VertexAIErrorCode.RESPONSE_ERROR, + `Response error: ${formatBlockErrorMessage( + response, + )}. Response body stored in error.response`, + { + response, + }, + ); + } + return getText(response); + } else if (response.promptFeedback) { + throw new VertexAIError( + VertexAIErrorCode.RESPONSE_ERROR, + `Text not available. ${formatBlockErrorMessage(response)}`, + { + response, + }, + ); + } + return ''; + }; + (response as EnhancedGenerateContentResponse).functionCalls = () => { + if (response.candidates && response.candidates.length > 0) { + if (response.candidates.length > 1) { + logger.warn( + `This response had ${response.candidates.length} ` + + `candidates. Returning function calls from the first candidate only. ` + + `Access response.candidates directly to use the other candidates.`, + ); + } + if (hadBadFinishReason(response.candidates[0]!)) { + throw new VertexAIError( + VertexAIErrorCode.RESPONSE_ERROR, + `Response error: ${formatBlockErrorMessage( + response, + )}. Response body stored in error.response`, + { + response, + }, + ); + } + return getFunctionCalls(response); + } else if (response.promptFeedback) { + throw new VertexAIError( + VertexAIErrorCode.RESPONSE_ERROR, + `Function call not available. ${formatBlockErrorMessage(response)}`, + { + response, + }, + ); + } + return undefined; + }; + return response as EnhancedGenerateContentResponse; +} + +/** + * Returns all text found in all parts of first candidate. + */ +export function getText(response: GenerateContentResponse): string { + const textStrings = []; + if (response.candidates?.[0]?.content?.parts) { + for (const part of response.candidates?.[0].content?.parts) { + if (part.text) { + textStrings.push(part.text); + } + } + } + if (textStrings.length > 0) { + return textStrings.join(''); + } else { + return ''; + } +} + +/** + * Returns {@link FunctionCall}s associated with first candidate. + */ +export function getFunctionCalls(response: GenerateContentResponse): FunctionCall[] | undefined { + const functionCalls: FunctionCall[] = []; + if (response.candidates?.[0]?.content?.parts) { + for (const part of response.candidates?.[0].content?.parts) { + if (part.functionCall) { + functionCalls.push(part.functionCall); + } + } + } + if (functionCalls.length > 0) { + return functionCalls; + } else { + return undefined; + } +} + +const badFinishReasons = [FinishReason.RECITATION, FinishReason.SAFETY]; + +function hadBadFinishReason(candidate: GenerateContentCandidate): boolean { + return !!candidate.finishReason && badFinishReasons.includes(candidate.finishReason); +} + +export function formatBlockErrorMessage(response: GenerateContentResponse): string { + let message = ''; + if ((!response.candidates || response.candidates.length === 0) && response.promptFeedback) { + message += 'Response was blocked'; + if (response.promptFeedback?.blockReason) { + message += ` due to ${response.promptFeedback.blockReason}`; + } + if (response.promptFeedback?.blockReasonMessage) { + message += `: ${response.promptFeedback.blockReasonMessage}`; + } + } else if (response.candidates?.[0]) { + const firstCandidate = response.candidates[0]; + if (hadBadFinishReason(firstCandidate)) { + message += `Candidate was blocked due to ${firstCandidate.finishReason}`; + if (firstCandidate.finishMessage) { + message += `: ${firstCandidate.finishMessage}`; + } + } + } + return message; +} diff --git a/packages/ai/lib/requests/schema-builder.ts b/packages/ai/lib/requests/schema-builder.ts new file mode 100644 index 0000000000..92003a0950 --- /dev/null +++ b/packages/ai/lib/requests/schema-builder.ts @@ -0,0 +1,281 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { VertexAIError } from '../errors'; +import { VertexAIErrorCode } from '../types'; +import { + SchemaInterface, + SchemaType, + SchemaParams, + SchemaRequest, + ObjectSchemaInterface, +} from '../types/schema'; + +/** + * Parent class encompassing all Schema types, with static methods that + * allow building specific Schema types. This class can be converted with + * `JSON.stringify()` into a JSON string accepted by Vertex AI REST endpoints. + * (This string conversion is automatically done when calling SDK methods.) + * @public + */ +export abstract class Schema implements SchemaInterface { + /** + * Optional. The type of the property. {@link + * SchemaType}. + */ + type: SchemaType; + /** Optional. The format of the property. + * Supported formats:
+ *
    + *
  • for NUMBER type: "float", "double"
  • + *
  • for INTEGER type: "int32", "int64"
  • + *
  • for STRING type: "email", "byte", etc
  • + *
+ */ + format?: string; + /** Optional. The description of the property. */ + description?: string; + /** Optional. Whether the property is nullable. Defaults to false. */ + nullable: boolean; + /** Optional. The example of the property. */ + example?: unknown; + /** + * Allows user to add other schema properties that have not yet + * been officially added to the SDK. + */ + [key: string]: unknown; + + constructor(schemaParams: SchemaInterface) { + for (const paramKey in schemaParams) { + this[paramKey] = schemaParams[paramKey]; + } + // Ensure these are explicitly set to avoid TS errors. + this.type = schemaParams.type; + this.nullable = schemaParams.hasOwnProperty('nullable') ? !!schemaParams.nullable : false; + } + + /** + * Defines how this Schema should be serialized as JSON. + * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/JSON/stringify#tojson_behavior + * @internal + */ + toJSON(): SchemaRequest { + const obj: { type: SchemaType; [key: string]: unknown } = { + type: this.type, + }; + for (const prop in this) { + if (this.hasOwnProperty(prop) && this[prop] !== undefined) { + if (prop !== 'required' || this.type === SchemaType.OBJECT) { + obj[prop] = this[prop]; + } + } + } + return obj as SchemaRequest; + } + + static array(arrayParams: SchemaParams & { items: Schema }): ArraySchema { + return new ArraySchema(arrayParams, arrayParams.items); + } + + static object( + objectParams: SchemaParams & { + properties: { + [k: string]: Schema; + }; + optionalProperties?: string[]; + }, + ): ObjectSchema { + return new ObjectSchema(objectParams, objectParams.properties, objectParams.optionalProperties); + } + + static string(stringParams?: SchemaParams): StringSchema { + return new StringSchema(stringParams); + } + + static enumString(stringParams: SchemaParams & { enum: string[] }): StringSchema { + return new StringSchema(stringParams, stringParams.enum); + } + + static integer(integerParams?: SchemaParams): IntegerSchema { + return new IntegerSchema(integerParams); + } + + static number(numberParams?: SchemaParams): NumberSchema { + return new NumberSchema(numberParams); + } + + static boolean(booleanParams?: SchemaParams): BooleanSchema { + return new BooleanSchema(booleanParams); + } +} + +/** + * A type that includes all specific Schema types. + * @public + */ +export type TypedSchema = + | IntegerSchema + | NumberSchema + | StringSchema + | BooleanSchema + | ObjectSchema + | ArraySchema; + +/** + * Schema class for "integer" types. + * @public + */ +export class IntegerSchema extends Schema { + constructor(schemaParams?: SchemaParams) { + super({ + type: SchemaType.INTEGER, + ...schemaParams, + }); + } +} + +/** + * Schema class for "number" types. + * @public + */ +export class NumberSchema extends Schema { + constructor(schemaParams?: SchemaParams) { + super({ + type: SchemaType.NUMBER, + ...schemaParams, + }); + } +} + +/** + * Schema class for "boolean" types. + * @public + */ +export class BooleanSchema extends Schema { + constructor(schemaParams?: SchemaParams) { + super({ + type: SchemaType.BOOLEAN, + ...schemaParams, + }); + } +} + +/** + * Schema class for "string" types. Can be used with or without + * enum values. + * @public + */ +export class StringSchema extends Schema { + enum?: string[]; + constructor(schemaParams?: SchemaParams, enumValues?: string[]) { + super({ + type: SchemaType.STRING, + ...schemaParams, + }); + this.enum = enumValues; + } + + /** + * @internal + */ + toJSON(): SchemaRequest { + const obj = super.toJSON(); + if (this.enum) { + obj['enum'] = this.enum; + } + return obj as SchemaRequest; + } +} + +/** + * Schema class for "array" types. + * The `items` param should refer to the type of item that can be a member + * of the array. + * @public + */ +export class ArraySchema extends Schema { + constructor( + schemaParams: SchemaParams, + public items: TypedSchema, + ) { + super({ + type: SchemaType.ARRAY, + ...schemaParams, + }); + } + + /** + * @internal + */ + toJSON(): SchemaRequest { + const obj = super.toJSON(); + obj.items = this.items.toJSON(); + return obj; + } +} + +/** + * Schema class for "object" types. + * The `properties` param must be a map of `Schema` objects. + * @public + */ +export class ObjectSchema extends Schema { + constructor( + schemaParams: SchemaParams, + public properties: { + [k: string]: TypedSchema; + }, + public optionalProperties: string[] = [], + ) { + super({ + type: SchemaType.OBJECT, + ...schemaParams, + }); + } + + /** + * @internal + */ + toJSON(): SchemaRequest { + const obj = super.toJSON(); + obj.properties = { ...this.properties }; + const required = []; + if (this.optionalProperties) { + for (const propertyKey of this.optionalProperties) { + if (!this.properties.hasOwnProperty(propertyKey)) { + throw new VertexAIError( + VertexAIErrorCode.INVALID_SCHEMA, + `Property "${propertyKey}" specified in "optionalProperties" does not exist.`, + ); + } + } + } + for (const propertyKey in this.properties) { + if (this.properties.hasOwnProperty(propertyKey)) { + obj.properties[propertyKey] = this.properties[propertyKey]!.toJSON() as SchemaRequest; + if (!this.optionalProperties.includes(propertyKey)) { + required.push(propertyKey); + } + } + } + if (required.length > 0) { + obj.required = required; + } + delete (obj as ObjectSchemaInterface).optionalProperties; + return obj as SchemaRequest; + } +} diff --git a/packages/ai/lib/requests/stream-reader.ts b/packages/ai/lib/requests/stream-reader.ts new file mode 100644 index 0000000000..d24f6d44bf --- /dev/null +++ b/packages/ai/lib/requests/stream-reader.ts @@ -0,0 +1,213 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { ReadableStream } from 'web-streams-polyfill'; +import { + EnhancedGenerateContentResponse, + GenerateContentCandidate, + GenerateContentResponse, + GenerateContentStreamResult, + Part, + VertexAIErrorCode, +} from '../types'; +import { VertexAIError } from '../errors'; +import { createEnhancedContentResponse } from './response-helpers'; + +const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/; + +/** + * Process a response.body stream from the backend and return an + * iterator that provides one complete GenerateContentResponse at a time + * and a promise that resolves with a single aggregated + * GenerateContentResponse. + * + * @param response - Response from a fetch call + */ +export function processStream(response: Response): GenerateContentStreamResult { + const inputStream = new ReadableStream({ + async start(controller) { + const reader = response.body!.getReader(); + const decoder = new TextDecoder('utf-8'); + while (true) { + const { done, value } = await reader.read(); + if (done) { + controller.close(); + break; + } + const decodedValue = decoder.decode(value, { stream: true }); + controller.enqueue(decodedValue); + } + }, + }); + const responseStream = getResponseStream(inputStream); + const [stream1, stream2] = responseStream.tee(); + return { + stream: generateResponseSequence(stream1), + response: getResponsePromise(stream2), + }; +} + +async function getResponsePromise( + stream: ReadableStream, +): Promise { + const allResponses: GenerateContentResponse[] = []; + const reader = stream.getReader(); + while (true) { + const { done, value } = await reader.read(); + if (done) { + const enhancedResponse = createEnhancedContentResponse(aggregateResponses(allResponses)); + return enhancedResponse; + } + allResponses.push(value); + } +} + +async function* generateResponseSequence( + stream: ReadableStream, +): AsyncGenerator { + const reader = stream.getReader(); + while (true) { + const { value, done } = await reader.read(); + if (done) { + break; + } + + const enhancedResponse = createEnhancedContentResponse(value); + yield enhancedResponse; + } +} + +/** + * Reads a raw stream from the fetch response and join incomplete + * chunks, returning a new stream that provides a single complete + * GenerateContentResponse in each iteration. + */ +export function getResponseStream(inputStream: ReadableStream): ReadableStream { + const reader = inputStream.getReader(); + const stream = new ReadableStream({ + start(controller) { + let currentText = ''; + return pump().then(() => undefined); + function pump(): Promise<(() => Promise) | undefined> { + return reader.read().then(({ value, done }) => { + if (done) { + if (currentText.trim()) { + controller.error( + new VertexAIError(VertexAIErrorCode.PARSE_FAILED, 'Failed to parse stream'), + ); + return; + } + controller.close(); + return; + } + + currentText += value; + let match = currentText.match(responseLineRE); + let parsedResponse: T; + while (match) { + try { + parsedResponse = JSON.parse(match[1]!); + } catch (_) { + controller.error( + new VertexAIError( + VertexAIErrorCode.PARSE_FAILED, + `Error parsing JSON response: "${match[1]}`, + ), + ); + return; + } + controller.enqueue(parsedResponse); + currentText = currentText.substring(match[0].length); + match = currentText.match(responseLineRE); + } + return pump(); + }); + } + }, + }); + return stream; +} + +/** + * Aggregates an array of `GenerateContentResponse`s into a single + * GenerateContentResponse. + */ +export function aggregateResponses(responses: GenerateContentResponse[]): GenerateContentResponse { + const lastResponse = responses[responses.length - 1]; + const aggregatedResponse: GenerateContentResponse = { + promptFeedback: lastResponse?.promptFeedback, + }; + for (const response of responses) { + if (response.candidates) { + for (const candidate of response.candidates) { + // Index will be undefined if it's the first index (0), so we should use 0 if it's undefined. + // See: https://github.com/firebase/firebase-js-sdk/issues/8566 + const i = candidate.index || 0; + if (!aggregatedResponse.candidates) { + aggregatedResponse.candidates = []; + } + if (!aggregatedResponse.candidates[i]) { + aggregatedResponse.candidates[i] = { + index: candidate.index, + } as GenerateContentCandidate; + } + // Keep overwriting, the last one will be final + aggregatedResponse.candidates[i].citationMetadata = candidate.citationMetadata; + aggregatedResponse.candidates[i].finishReason = candidate.finishReason; + aggregatedResponse.candidates[i].finishMessage = candidate.finishMessage; + aggregatedResponse.candidates[i].safetyRatings = candidate.safetyRatings; + + /** + * Candidates should always have content and parts, but this handles + * possible malformed responses. + */ + if (candidate.content && candidate.content.parts) { + if (!aggregatedResponse.candidates[i].content) { + aggregatedResponse.candidates[i].content = { + role: candidate.content.role || 'user', + parts: [], + }; + } + const newPart: Partial = {}; + for (const part of candidate.content.parts) { + if (part.text !== undefined) { + // The backend can send empty text parts. If these are sent back + // (e.g. in chat history), the backend will respond with an error. + // To prevent this, ignore empty text parts. + if (part.text === '') { + continue; + } + newPart.text = part.text; + } + if (part.functionCall) { + newPart.functionCall = part.functionCall; + } + if (Object.keys(newPart).length === 0) { + throw new VertexAIError( + VertexAIErrorCode.INVALID_CONTENT, + 'Part should have at least one property, but there are none. This is likely caused ' + + 'by a malformed response from the backend.', + ); + } + aggregatedResponse.candidates[i].content.parts.push(newPart as Part); + } + } + } + } + } + return aggregatedResponse; +} diff --git a/packages/ai/lib/service.ts b/packages/ai/lib/service.ts new file mode 100644 index 0000000000..e90ffa9668 --- /dev/null +++ b/packages/ai/lib/service.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { ReactNativeFirebase } from '@react-native-firebase/app'; +import { VertexAI, VertexAIOptions } from './public-types'; +import { DEFAULT_LOCATION } from './constants'; +import { FirebaseAuthTypes } from '@react-native-firebase/auth'; +import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; + +export class VertexAIService implements VertexAI { + auth: FirebaseAuthTypes.Module | null; + appCheck: FirebaseAppCheckTypes.Module | null; + location: string; + + constructor( + public app: ReactNativeFirebase.FirebaseApp, + auth?: FirebaseAuthTypes.Module, + appCheck?: FirebaseAppCheckTypes.Module, + public options?: VertexAIOptions, + ) { + this.auth = auth || null; + this.appCheck = appCheck || null; + this.location = this.options?.location || DEFAULT_LOCATION; + } +} diff --git a/packages/ai/lib/types/content.ts b/packages/ai/lib/types/content.ts new file mode 100644 index 0000000000..abf5d29222 --- /dev/null +++ b/packages/ai/lib/types/content.ts @@ -0,0 +1,162 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Role } from './enums'; + +/** + * Content type for both prompts and response candidates. + * @public + */ +export interface Content { + role: Role; + parts: Part[]; +} + +/** + * Content part - includes text, image/video, or function call/response + * part types. + * @public + */ +export type Part = + | TextPart + | InlineDataPart + | FunctionCallPart + | FunctionResponsePart + | FileDataPart; + +/** + * Content part interface if the part represents a text string. + * @public + */ +export interface TextPart { + text: string; + inlineData?: never; + functionCall?: never; + functionResponse?: never; +} + +/** + * Content part interface if the part represents an image. + * @public + */ +export interface InlineDataPart { + text?: never; + inlineData: GenerativeContentBlob; + functionCall?: never; + functionResponse?: never; + /** + * Applicable if `inlineData` is a video. + */ + videoMetadata?: VideoMetadata; +} + +/** + * Describes the input video content. + * @public + */ +export interface VideoMetadata { + /** + * The start offset of the video in + * protobuf {@link https://cloud.google.com/ruby/docs/reference/google-cloud-workflows-v1/latest/Google-Protobuf-Duration#json-mapping | Duration} format. + */ + startOffset: string; + /** + * The end offset of the video in + * protobuf {@link https://cloud.google.com/ruby/docs/reference/google-cloud-workflows-v1/latest/Google-Protobuf-Duration#json-mapping | Duration} format. + */ + endOffset: string; +} + +/** + * Content part interface if the part represents a {@link FunctionCall}. + * @public + */ +export interface FunctionCallPart { + text?: never; + inlineData?: never; + functionCall: FunctionCall; + functionResponse?: never; +} + +/** + * Content part interface if the part represents {@link FunctionResponse}. + * @public + */ +export interface FunctionResponsePart { + text?: never; + inlineData?: never; + functionCall?: never; + functionResponse: FunctionResponse; +} + +/** + * Content part interface if the part represents {@link FileData} + * @public + */ +export interface FileDataPart { + text?: never; + inlineData?: never; + functionCall?: never; + functionResponse?: never; + fileData: FileData; +} + +/** + * A predicted {@link FunctionCall} returned from the model + * that contains a string representing the {@link FunctionDeclaration.name} + * and a structured JSON object containing the parameters and their values. + * @public + */ +export interface FunctionCall { + name: string; + args: object; +} + +/** + * The result output from a {@link FunctionCall} that contains a string + * representing the {@link FunctionDeclaration.name} + * and a structured JSON object containing any output + * from the function is used as context to the model. + * This should contain the result of a {@link FunctionCall} + * made based on model prediction. + * @public + */ +export interface FunctionResponse { + name: string; + response: object; +} + +/** + * Interface for sending an image. + * @public + */ +export interface GenerativeContentBlob { + mimeType: string; + /** + * Image as a base64 string. + */ + data: string; +} + +/** + * Data pointing to a file uploaded on Google Cloud Storage. + * @public + */ +export interface FileData { + mimeType: string; + fileUri: string; +} diff --git a/packages/ai/lib/types/enums.ts b/packages/ai/lib/types/enums.ts new file mode 100644 index 0000000000..010aff903a --- /dev/null +++ b/packages/ai/lib/types/enums.ts @@ -0,0 +1,149 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Role is the producer of the content. + * @public + */ +export type Role = (typeof POSSIBLE_ROLES)[number]; + +/** + * Possible roles. + * @public + */ +export const POSSIBLE_ROLES = ['user', 'model', 'function', 'system'] as const; + +/** + * Harm categories that would cause prompts or candidates to be blocked. + * @public + */ +export enum HarmCategory { + HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH', + HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT', + HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT', +} + +/** + * Threshold above which a prompt or candidate will be blocked. + * @public + */ +export enum HarmBlockThreshold { + // Content with NEGLIGIBLE will be allowed. + BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE', + // Content with NEGLIGIBLE and LOW will be allowed. + BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE', + // Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed. + BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH', + // All content will be allowed. + BLOCK_NONE = 'BLOCK_NONE', +} + +/** + * @public + */ +export enum HarmBlockMethod { + // The harm block method uses both probability and severity scores. + SEVERITY = 'SEVERITY', + // The harm block method uses the probability score. + PROBABILITY = 'PROBABILITY', +} + +/** + * Probability that a prompt or candidate matches a harm category. + * @public + */ +export enum HarmProbability { + // Content has a negligible chance of being unsafe. + NEGLIGIBLE = 'NEGLIGIBLE', + // Content has a low chance of being unsafe. + LOW = 'LOW', + // Content has a medium chance of being unsafe. + MEDIUM = 'MEDIUM', + // Content has a high chance of being unsafe. + HIGH = 'HIGH', +} + +/** + * Harm severity levels. + * @public + */ +export enum HarmSeverity { + // Negligible level of harm severity. + HARM_SEVERITY_NEGLIGIBLE = 'HARM_SEVERITY_NEGLIGIBLE', + // Low level of harm severity. + HARM_SEVERITY_LOW = 'HARM_SEVERITY_LOW', + // Medium level of harm severity. + HARM_SEVERITY_MEDIUM = 'HARM_SEVERITY_MEDIUM', + // High level of harm severity. + HARM_SEVERITY_HIGH = 'HARM_SEVERITY_HIGH', +} + +/** + * Reason that a prompt was blocked. + * @public + */ +export enum BlockReason { + // The prompt was blocked because it contained terms from the terminology blocklist. + BLOCKLIST = 'BLOCKLIST', + // The prompt was blocked due to prohibited content. + PROHIBITED_CONTENT = 'PROHIBITED_CONTENT', + // Content was blocked by safety settings. + SAFETY = 'SAFETY', + // Content was blocked, but the reason is uncategorized. + OTHER = 'OTHER', +} + +/** + * Reason that a candidate finished. + * @public + */ +export enum FinishReason { + // Token generation was stopped because the response contained forbidden terms. + BLOCKLIST = 'BLOCKLIST', + // Token generation was stopped because the response contained potentially prohibited content. + PROHIBITED_CONTENT = 'PROHIBITED_CONTENT', + // Token generation was stopped because of Sensitive Personally Identifiable Information (SPII). + SPII = 'SPII', + // Natural stop point of the model or provided stop sequence. + STOP = 'STOP', + // The maximum number of tokens as specified in the request was reached. + MAX_TOKENS = 'MAX_TOKENS', + // The candidate content was flagged for safety reasons. + SAFETY = 'SAFETY', + // The candidate content was flagged for recitation reasons. + RECITATION = 'RECITATION', + // Unknown reason. + OTHER = 'OTHER', +} + +/** + * @public + */ +export enum FunctionCallingMode { + // Default model behavior, model decides to predict either a function call + // or a natural language response. + AUTO = 'AUTO', + // Model is constrained to always predicting a function call only. + // If "allowed_function_names" is set, the predicted function call will be + // limited to any one of "allowed_function_names", else the predicted + // function call will be any one of the provided "function_declarations". + ANY = 'ANY', + // Model will not predict any function call. Model behavior is same as when + // not passing any function declarations. + NONE = 'NONE', +} diff --git a/packages/ai/lib/types/error.ts b/packages/ai/lib/types/error.ts new file mode 100644 index 0000000000..c65e09c55f --- /dev/null +++ b/packages/ai/lib/types/error.ts @@ -0,0 +1,98 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GenerateContentResponse } from './responses'; + +/** + * Details object that may be included in an error response. + * + * @public + */ +export interface ErrorDetails { + '@type'?: string; + + /** The reason for the error. */ + reason?: string; + + /** The domain where the error occurred. */ + domain?: string; + + /** Additional metadata about the error. */ + metadata?: Record; + + /** Any other relevant information about the error. */ + [key: string]: unknown; +} + +/** + * Details object that contains data originating from a bad HTTP response. + * + * @public + */ +export interface CustomErrorData { + /** HTTP status code of the error response. */ + status?: number; + + /** HTTP status text of the error response. */ + statusText?: string; + + /** Response from a {@link GenerateContentRequest} */ + response?: GenerateContentResponse; + + /** Optional additional details about the error. */ + errorDetails?: ErrorDetails[]; +} + +/** + * Standardized error codes that {@link VertexAIError} can have. + * + * @public + */ +export const enum VertexAIErrorCode { + /** A generic error occurred. */ + ERROR = 'error', + + /** An error occurred in a request. */ + REQUEST_ERROR = 'request-error', + + /** An error occurred in a response. */ + RESPONSE_ERROR = 'response-error', + + /** An error occurred while performing a fetch. */ + FETCH_ERROR = 'fetch-error', + + /** An error associated with a Content object. */ + INVALID_CONTENT = 'invalid-content', + + /** An error due to the Firebase API not being enabled in the Console. */ + API_NOT_ENABLED = 'api-not-enabled', + + /** An error due to invalid Schema input. */ + INVALID_SCHEMA = 'invalid-schema', + + /** An error occurred due to a missing Firebase API key. */ + NO_API_KEY = 'no-api-key', + + /** An error occurred due to a model name not being specified during initialization. */ + NO_MODEL = 'no-model', + + /** An error occurred due to a missing project ID. */ + NO_PROJECT_ID = 'no-project-id', + + /** An error occurred while parsing. */ + PARSE_FAILED = 'parse-failed', +} diff --git a/packages/ai/lib/types/index.ts b/packages/ai/lib/types/index.ts new file mode 100644 index 0000000000..85133aa07c --- /dev/null +++ b/packages/ai/lib/types/index.ts @@ -0,0 +1,23 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from './content'; +export * from './enums'; +export * from './requests'; +export * from './responses'; +export * from './error'; +export * from './schema'; diff --git a/packages/ai/lib/types/internal.ts b/packages/ai/lib/types/internal.ts new file mode 100644 index 0000000000..ee60d476c9 --- /dev/null +++ b/packages/ai/lib/types/internal.ts @@ -0,0 +1,25 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; + +export interface ApiSettings { + apiKey: string; + project: string; + location: string; + getAuthToken?: () => Promise; + getAppCheckToken?: () => Promise; +} diff --git a/packages/ai/lib/types/polyfills.d.ts b/packages/ai/lib/types/polyfills.d.ts new file mode 100644 index 0000000000..06fdf29b09 --- /dev/null +++ b/packages/ai/lib/types/polyfills.d.ts @@ -0,0 +1,15 @@ +declare module 'react-native-fetch-api' { + export function fetch(input: RequestInfo, init?: RequestInit): Promise; +} + +declare global { + interface RequestInit { + /** + * @description Polyfilled to enable text ReadableStream for React Native: + * @link https://github.com/facebook/react-native/issues/27741#issuecomment-2362901032 + */ + reactNative?: { + textStreaming: boolean; + }; + } +} diff --git a/packages/ai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts new file mode 100644 index 0000000000..708a55a11c --- /dev/null +++ b/packages/ai/lib/types/requests.ts @@ -0,0 +1,198 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { TypedSchema } from '../requests/schema-builder'; +import { Content, Part } from './content'; +import { FunctionCallingMode, HarmBlockMethod, HarmBlockThreshold, HarmCategory } from './enums'; +import { ObjectSchemaInterface, SchemaRequest } from './schema'; + +/** + * Base parameters for a number of methods. + * @public + */ +export interface BaseParams { + safetySettings?: SafetySetting[]; + generationConfig?: GenerationConfig; +} + +/** + * Params passed to {@link getGenerativeModel}. + * @public + */ +export interface ModelParams extends BaseParams { + model: string; + tools?: Tool[]; + toolConfig?: ToolConfig; + systemInstruction?: string | Part | Content; +} + +/** + * Request sent through {@link GenerativeModel.generateContent} + * @public + */ +export interface GenerateContentRequest extends BaseParams { + contents: Content[]; + tools?: Tool[]; + toolConfig?: ToolConfig; + systemInstruction?: string | Part | Content; +} + +/** + * Safety setting that can be sent as part of request parameters. + * @public + */ +export interface SafetySetting { + category: HarmCategory; + threshold: HarmBlockThreshold; + method?: HarmBlockMethod; +} + +/** + * Config options for content-related requests + * @public + */ +export interface GenerationConfig { + candidateCount?: number; + stopSequences?: string[]; + maxOutputTokens?: number; + temperature?: number; + topP?: number; + topK?: number; + presencePenalty?: number; + frequencyPenalty?: number; + /** + * Output response MIME type of the generated candidate text. + * Supported MIME types are `text/plain` (default, text output), + * `application/json` (JSON response in the candidates), and + * `text/x.enum`. + */ + responseMimeType?: string; + /** + * Output response schema of the generated candidate text. This + * value can be a class generated with a {@link Schema} static method + * like `Schema.string()` or `Schema.object()` or it can be a plain + * JS object matching the {@link SchemaRequest} interface. + *
Note: This only applies when the specified `responseMIMEType` supports a schema; currently + * this is limited to `application/json` and `text/x.enum`. + */ + responseSchema?: TypedSchema | SchemaRequest; +} + +/** + * Params for {@link GenerativeModel.startChat}. + * @public + */ +export interface StartChatParams extends BaseParams { + history?: Content[]; + tools?: Tool[]; + toolConfig?: ToolConfig; + systemInstruction?: string | Part | Content; +} + +/** + * Params for calling {@link GenerativeModel.countTokens} + * @public + */ +export interface CountTokensRequest { + contents: Content[]; +} + +/** + * Params passed to {@link getGenerativeModel}. + * @public + */ +export interface RequestOptions { + /** + * Request timeout in milliseconds. Defaults to 180 seconds (180000ms). + */ + timeout?: number; + /** + * Base url for endpoint. Defaults to https://firebasevertexai.googleapis.com + */ + baseUrl?: string; +} + +/** + * Defines a tool that model can call to access external knowledge. + * @public + */ +export declare type Tool = FunctionDeclarationsTool; + +/** + * Structured representation of a function declaration as defined by the + * {@link https://spec.openapis.org/oas/v3.0.3 | OpenAPI 3.0 specification}. + * Included + * in this declaration are the function name and parameters. This + * `FunctionDeclaration` is a representation of a block of code that can be used + * as a Tool by the model and executed by the client. + * @public + */ +export declare interface FunctionDeclaration { + /** + * The name of the function to call. Must start with a letter or an + * underscore. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with + * a max length of 64. + */ + name: string; + /** + * Description and purpose of the function. Model uses it to decide + * how and whether to call the function. + */ + description: string; + /** + * Optional. Describes the parameters to this function in JSON Schema Object + * format. Reflects the Open API 3.03 Parameter Object. Parameter names are + * case-sensitive. For a function with no parameters, this can be left unset. + */ + parameters?: ObjectSchemaInterface; +} + +/** + * A `FunctionDeclarationsTool` is a piece of code that enables the system to + * interact with external systems to perform an action, or set of actions, + * outside of knowledge and scope of the model. + * @public + */ +export declare interface FunctionDeclarationsTool { + /** + * Optional. One or more function declarations + * to be passed to the model along with the current user query. Model may + * decide to call a subset of these functions by populating + * {@link FunctionCall} in the response. User should + * provide a {@link FunctionResponse} for each + * function call in the next turn. Based on the function responses, the model will + * generate the final response back to the user. Maximum 64 function + * declarations can be provided. + */ + functionDeclarations?: FunctionDeclaration[]; +} + +/** + * Tool config. This config is shared for all tools provided in the request. + * @public + */ +export interface ToolConfig { + functionCallingConfig?: FunctionCallingConfig; +} + +/** + * @public + */ +export interface FunctionCallingConfig { + mode?: FunctionCallingMode; + allowedFunctionNames?: string[]; +} diff --git a/packages/ai/lib/types/responses.ts b/packages/ai/lib/types/responses.ts new file mode 100644 index 0000000000..013391e98b --- /dev/null +++ b/packages/ai/lib/types/responses.ts @@ -0,0 +1,209 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Content, FunctionCall } from './content'; +import { BlockReason, FinishReason, HarmCategory, HarmProbability, HarmSeverity } from './enums'; + +/** + * Result object returned from {@link GenerativeModel.generateContent} call. + * + * @public + */ +export interface GenerateContentResult { + response: EnhancedGenerateContentResponse; +} + +/** + * Result object returned from {@link GenerativeModel.generateContentStream} call. + * Iterate over `stream` to get chunks as they come in and/or + * use the `response` promise to get the aggregated response when + * the stream is done. + * + * @public + */ +export interface GenerateContentStreamResult { + stream: AsyncGenerator; + response: Promise; +} + +/** + * Response object wrapped with helper methods. + * + * @public + */ +export interface EnhancedGenerateContentResponse extends GenerateContentResponse { + /** + * Returns the text string from the response, if available. + * Throws if the prompt or candidate was blocked. + */ + text: () => string; + functionCalls: () => FunctionCall[] | undefined; +} + +/** + * Individual response from {@link GenerativeModel.generateContent} and + * {@link GenerativeModel.generateContentStream}. + * `generateContentStream()` will return one in each chunk until + * the stream is done. + * @public + */ +export interface GenerateContentResponse { + candidates?: GenerateContentCandidate[]; + promptFeedback?: PromptFeedback; + usageMetadata?: UsageMetadata; +} + +/** + * Usage metadata about a {@link GenerateContentResponse}. + * + * @public + */ +export interface UsageMetadata { + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; +} + +/** + * If the prompt was blocked, this will be populated with `blockReason` and + * the relevant `safetyRatings`. + * @public + */ +export interface PromptFeedback { + blockReason?: BlockReason; + safetyRatings: SafetyRating[]; + blockReasonMessage?: string; +} + +/** + * A candidate returned as part of a {@link GenerateContentResponse}. + * @public + */ +export interface GenerateContentCandidate { + index: number; + content: Content; + finishReason?: FinishReason; + finishMessage?: string; + safetyRatings?: SafetyRating[]; + citationMetadata?: CitationMetadata; + groundingMetadata?: GroundingMetadata; +} + +/** + * Citation metadata that may be found on a {@link GenerateContentCandidate}. + * @public + */ +export interface CitationMetadata { + citations: Citation[]; +} + +/** + * A single citation. + * @public + */ +export interface Citation { + startIndex?: number; + endIndex?: number; + uri?: string; + license?: string; + title?: string; + publicationDate?: Date; +} + +/** + * Metadata returned to client when grounding is enabled. + * @public + */ +export interface GroundingMetadata { + webSearchQueries?: string[]; + retrievalQueries?: string[]; + groundingAttributions: GroundingAttribution[]; +} + +/** + * @public + */ +export interface GroundingAttribution { + segment: Segment; + confidenceScore?: number; + web?: WebAttribution; + retrievedContext?: RetrievedContextAttribution; +} + +/** + * @public + */ +export interface Segment { + partIndex: number; + startIndex: number; + endIndex: number; +} + +/** + * @public + */ +export interface WebAttribution { + uri: string; + title: string; +} + +/** + * @public + */ +export interface RetrievedContextAttribution { + uri: string; + title: string; +} + +/** + * Protobuf google.type.Date + * @public + */ +export interface Date { + year: number; + month: number; + day: number; +} + +/** + * A safety rating associated with a {@link GenerateContentCandidate} + * @public + */ +export interface SafetyRating { + category: HarmCategory; + probability: HarmProbability; + severity: HarmSeverity; + probabilityScore: number; + severityScore: number; + blocked: boolean; +} + +/** + * Response from calling {@link GenerativeModel.countTokens}. + * @public + */ +export interface CountTokensResponse { + /** + * The total number of tokens counted across all instances from the request. + */ + totalTokens: number; + /** + * The total number of billable characters counted across all instances + * from the request. + */ + totalBillableCharacters?: number; +} diff --git a/packages/ai/lib/types/schema.ts b/packages/ai/lib/types/schema.ts new file mode 100644 index 0000000000..c1376b9aa1 --- /dev/null +++ b/packages/ai/lib/types/schema.ts @@ -0,0 +1,104 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Contains the list of OpenAPI data types + * as defined by the + * {@link https://swagger.io/docs/specification/data-models/data-types/ | OpenAPI specification} + * @public + */ +export enum SchemaType { + /** String type. */ + STRING = 'string', + /** Number type. */ + NUMBER = 'number', + /** Integer type. */ + INTEGER = 'integer', + /** Boolean type. */ + BOOLEAN = 'boolean', + /** Array type. */ + ARRAY = 'array', + /** Object type. */ + OBJECT = 'object', +} + +/** + * Basic {@link Schema} properties shared across several Schema-related + * types. + * @public + */ +export interface SchemaShared { + /** Optional. The format of the property. */ + format?: string; + /** Optional. The description of the property. */ + description?: string; + /** Optional. The items of the property. */ + items?: T; + /** Optional. Map of `Schema` objects. */ + properties?: { + [k: string]: T; + }; + /** Optional. The enum of the property. */ + enum?: string[]; + /** Optional. The example of the property. */ + example?: unknown; + /** Optional. Whether the property is nullable. */ + nullable?: boolean; + [key: string]: unknown; +} + +/** + * Params passed to {@link Schema} static methods to create specific + * {@link Schema} classes. + * @public + */ +// eslint-disable-next-line @typescript-eslint/no-empty-object-type +export interface SchemaParams extends SchemaShared {} + +/** + * Final format for {@link Schema} params passed to backend requests. + * @public + */ +export interface SchemaRequest extends SchemaShared { + /** + * The type of the property. {@link + * SchemaType}. + */ + type: SchemaType; + /** Optional. Array of required property. */ + required?: string[]; +} + +/** + * Interface for {@link Schema} class. + * @public + */ +export interface SchemaInterface extends SchemaShared { + /** + * The type of the property. {@link + * SchemaType}. + */ + type: SchemaType; +} + +/** + * Interface for {@link ObjectSchema} class. + * @public + */ +export interface ObjectSchemaInterface extends SchemaInterface { + type: SchemaType.OBJECT; + optionalProperties?: string[]; +} diff --git a/packages/ai/package.json b/packages/ai/package.json new file mode 100644 index 0000000000..1f1c6c9e1d --- /dev/null +++ b/packages/ai/package.json @@ -0,0 +1,88 @@ +{ + "name": "@react-native-firebase/vertexai", + "version": "22.2.0", + "author": "Invertase (http://invertase.io)", + "description": "React Native Firebase - Vertex AI is a fully-managed, unified AI development platform for building and using generative AI", + "main": "./dist/commonjs/index.js", + "module": "./dist/module/index.js", + "types": "./dist/typescript/module/lib/index.d.ts", + "scripts": { + "build": "genversion --esm --semi lib/version.ts", + "build:clean": "rimraf dist", + "compile": "bob build", + "prepare": "yarn tests:vertex:mocks && yarn run build && yarn compile" + }, + "repository": { + "type": "git", + "url": "https://github.com/invertase/react-native-firebase/tree/main/packages/vertexai" + }, + "license": "Apache-2.0", + "keywords": [ + "react", + "react-native", + "firebase", + "vertexai", + "gemini", + "generative-ai" + ], + "peerDependencies": { + "@react-native-firebase/app": "22.2.0" + }, + "publishConfig": { + "access": "public", + "provenance": true + }, + "devDependencies": { + "@types/text-encoding": "^0.0.40", + "react-native-builder-bob": "^0.40.6", + "typescript": "^5.8.3" + }, + "source": "./lib/index.ts", + "exports": { + ".": { + "import": { + "types": "./dist/typescript/module/lib/index.d.ts", + "default": "./dist/module/index.js" + }, + "require": { + "types": "./dist/typescript/commonjs/lib/index.d.ts", + "default": "./dist/commonjs/index.js" + } + } + }, + "files": [ + "lib", + "dist", + "!**/__tests__", + "!**/__fixtures__", + "!**/__mocks__" + ], + "react-native-builder-bob": { + "source": "lib", + "output": "dist", + "targets": [ + [ + "commonjs", + { + "esm": true + } + ], + [ + "module", + { + "esm": true + } + ], + "typescript" + ] + }, + "eslintIgnore": [ + "node_modules/", + "dist/" + ], + "dependencies": { + "react-native-fetch-api": "^3.0.0", + "text-encoding": "^0.7.0", + "web-streams-polyfill": "^4.1.0" + } +} diff --git a/packages/ai/tsconfig.json b/packages/ai/tsconfig.json new file mode 100644 index 0000000000..f1d9865812 --- /dev/null +++ b/packages/ai/tsconfig.json @@ -0,0 +1,32 @@ +{ + "compilerOptions": { + "rootDir": ".", + "allowUnreachableCode": false, + "allowUnusedLabels": false, + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "jsx": "react-jsx", + "lib": [ + "ESNext" + ], + "module": "ESNext", + "target": "ESNext", + "moduleResolution": "Bundler", + "noFallthroughCasesInSwitch": true, + "noImplicitReturns": true, + "noImplicitUseStrict": false, + "noStrictGenericChecks": false, + "noUncheckedIndexedAccess": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "resolveJsonModule": true, + "skipLibCheck": true, + "strict": true, + "baseUrl": ".", + "paths": { + "@react-native-firebase/app": ["../app/lib"], + "@react-native-firebase/auth": ["../auth/lib"], + "@react-native-firebase/app-check": ["../app-check/lib"], + } + } +} From 06a24a703b03c8c8b421fb944a9de6d5ced046f7 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 10:07:23 +0100 Subject: [PATCH 02/85] refactor: move to AI implementation, including AIModel base class --- packages/ai/lib/backend.ts | 92 +++++++++++++++ packages/ai/lib/constants.ts | 2 +- packages/ai/lib/errors.ts | 15 ++- packages/ai/lib/index.ts | 64 +++++++---- packages/ai/lib/models/ai-model.ts | 124 +++++++++++++++++++++ packages/ai/lib/models/generative-model.ts | 63 +++-------- packages/ai/lib/public-types.ts | 111 ++++++++++++++++++ packages/ai/lib/service.ts | 14 ++- packages/ai/lib/types/error.ts | 8 +- packages/ai/lib/types/internal.ts | 7 ++ 10 files changed, 412 insertions(+), 88 deletions(-) create mode 100644 packages/ai/lib/backend.ts create mode 100644 packages/ai/lib/models/ai-model.ts diff --git a/packages/ai/lib/backend.ts b/packages/ai/lib/backend.ts new file mode 100644 index 0000000000..7209828122 --- /dev/null +++ b/packages/ai/lib/backend.ts @@ -0,0 +1,92 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { DEFAULT_LOCATION } from './constants'; +import { BackendType } from './public-types'; + +/** + * Abstract base class representing the configuration for an AI service backend. + * This class should not be instantiated directly. Use its subclasses; {@link GoogleAIBackend} for + * the Gemini Developer API (via {@link https://ai.google/ | Google AI}), and + * {@link VertexAIBackend} for the Vertex AI Gemini API. + * + * @public + */ +export abstract class Backend { + /** + * Specifies the backend type. + */ + readonly backendType: BackendType; + + /** + * Protected constructor for use by subclasses. + * @param type - The backend type. + */ + protected constructor(type: BackendType) { + this.backendType = type; + } +} + +/** + * Configuration class for the Gemini Developer API. + * + * Use this with {@link AIOptions} when initializing the AI service via + * {@link getAI | getAI()} to specify the Gemini Developer API as the backend. + * + * @public + */ +export class GoogleAIBackend extends Backend { + /** + * Creates a configuration object for the Gemini Developer API backend. + */ + constructor() { + super(BackendType.GOOGLE_AI); + } +} + +/** + * Configuration class for the Vertex AI Gemini API. + * + * Use this with {@link AIOptions} when initializing the AI service via + * {@link getAI | getAI()} to specify the Vertex AI Gemini API as the backend. + * + * @public + */ +export class VertexAIBackend extends Backend { + /** + * The region identifier. + * See {@link https://firebase.google.com/docs/vertex-ai/locations#available-locations | Vertex AI locations} + * for a list of supported locations. + */ + readonly location: string; + + /** + * Creates a configuration object for the Vertex AI backend. + * + * @param location - The region identifier, defaulting to `us-central1`; + * see {@link https://firebase.google.com/docs/vertex-ai/locations#available-locations | Vertex AI locations} + * for a list of supported locations. + */ + constructor(location: string = DEFAULT_LOCATION) { + super(BackendType.VERTEX_AI); + if (!location) { + this.location = DEFAULT_LOCATION; + } else { + this.location = location; + } + } +} diff --git a/packages/ai/lib/constants.ts b/packages/ai/lib/constants.ts index 816f5194a2..a0cffa49ad 100644 --- a/packages/ai/lib/constants.ts +++ b/packages/ai/lib/constants.ts @@ -17,7 +17,7 @@ import { version } from './version'; -export const VERTEX_TYPE = 'vertexAI'; +export const AI_TYPE = 'AI'; export const DEFAULT_LOCATION = 'us-central1'; diff --git a/packages/ai/lib/errors.ts b/packages/ai/lib/errors.ts index 370c19aeb0..ea09d2f162 100644 --- a/packages/ai/lib/errors.ts +++ b/packages/ai/lib/errors.ts @@ -16,15 +16,15 @@ */ import { FirebaseError } from '@firebase/util'; -import { VertexAIErrorCode, CustomErrorData } from './types'; -import { VERTEX_TYPE } from './constants'; +import { AIErrorCode, CustomErrorData } from './types'; +import { AI_TYPE } from './constants'; /** * Error class for the Vertex AI in Firebase SDK. * * @public */ -export class VertexAIError extends FirebaseError { +export class AIError extends FirebaseError { /** * Constructs a new instance of the `VertexAIError` class. * @@ -33,18 +33,17 @@ export class VertexAIError extends FirebaseError { * @param customErrorData - Optional error data. */ constructor( - readonly code: VertexAIErrorCode, + readonly code: AIErrorCode, message: string, readonly customErrorData?: CustomErrorData, ) { // Match error format used by FirebaseError from ErrorFactory - const service = VERTEX_TYPE; - const serviceName = 'VertexAI'; + const service = AI_TYPE; const fullCode = `${service}/${code}`; - const fullMessage = `${serviceName}: ${message} (${fullCode})`; + const fullMessage = `${service}: ${message} (${fullCode})`; super(code, fullMessage); - Object.setPrototypeOf(this, VertexAIError.prototype); + Object.setPrototypeOf(this, AIError.prototype); // Since Error is an interface, we don't inherit toString and so we define it ourselves. this.toString = () => fullMessage; diff --git a/packages/ai/lib/index.ts b/packages/ai/lib/index.ts index 580f1bc86b..0d5281fbc4 100644 --- a/packages/ai/lib/index.ts +++ b/packages/ai/lib/index.ts @@ -17,57 +17,73 @@ import './polyfills'; import { getApp, ReactNativeFirebase } from '@react-native-firebase/app'; -import { ModelParams, RequestOptions, VertexAIErrorCode } from './types'; -import { DEFAULT_LOCATION } from './constants'; -import { VertexAI, VertexAIOptions } from './public-types'; -import { VertexAIError } from './errors'; +import { GoogleAIBackend, VertexAIBackend } from './backend'; +import { AIErrorCode, ModelParams, RequestOptions } from './types'; +import { AI, AIOptions } from './public-types'; +import { AIError } from './errors'; import { GenerativeModel } from './models/generative-model'; -import { VertexAIService } from './service'; export { ChatSession } from './methods/chat-session'; export * from './requests/schema-builder'; export { GenerativeModel }; -export { VertexAIError }; +export { AIError }; /** - * Returns a {@link VertexAI} instance for the given app. + * Returns the default {@link AI} instance that is associated with the provided + * {@link @firebase/app#FirebaseApp}. If no instance exists, initializes a new instance with the + * default settings. * - * @public + * @example + * ```javascript + * const ai = getAI(app); + * ``` + * + * @example + * ```javascript + * // Get an AI instance configured to use the Gemini Developer API (via Google AI). + * const ai = getAI(app, { backend: new GoogleAIBackend() }); + * ``` + * + * @example + * ```javascript + * // Get an AI instance configured to use the Vertex AI Gemini API. + * const ai = getAI(app, { backend: new VertexAIBackend() }); + * ``` * - * @param app - The {@link @FirebaseApp} to use. - * @param options - The {@link VertexAIOptions} to use. - * @param appCheck - The {@link @AppCheck} to use. - * @param auth - The {@link @Auth} to use. + * @param app - The {@link @firebase/app#FirebaseApp} to use. + * @param options - {@link AIOptions} that configure the AI instance. + * @returns The default {@link AI} instance for the given {@link @firebase/app#FirebaseApp}. + * + * @public */ -export function getVertexAI( +export function getAI( app: ReactNativeFirebase.FirebaseApp = getApp(), - options?: VertexAIOptions, -): VertexAI { + options: AIOptions = { backend: new GoogleAIBackend() }, +): AI { return { app, - location: options?.location || DEFAULT_LOCATION, - appCheck: options?.appCheck || null, - auth: options?.auth || null, - } as VertexAIService; + backend: options.backend, + location: (options.backend as VertexAIBackend)?.location || '', + } as AI; } /** - * Returns a {@link GenerativeModel} class with methods for inference + * Returns a {@link GenerativeModel} class with methods for inference * and other functionality. * * @public */ export function getGenerativeModel( - vertexAI: VertexAI, + ai: AI, modelParams: ModelParams, requestOptions?: RequestOptions, ): GenerativeModel { if (!modelParams.model) { - throw new VertexAIError( - VertexAIErrorCode.NO_MODEL, + throw new AIError( + AIErrorCode.NO_MODEL, `Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`, ); } - return new GenerativeModel(vertexAI, modelParams, requestOptions); + return new GenerativeModel(ai, modelParams, requestOptions); } diff --git a/packages/ai/lib/models/ai-model.ts b/packages/ai/lib/models/ai-model.ts new file mode 100644 index 0000000000..948daf4e28 --- /dev/null +++ b/packages/ai/lib/models/ai-model.ts @@ -0,0 +1,124 @@ +import { ApiSettings } from '../types/internal'; +import { AIError } from '../errors'; +import { AIErrorCode } from '../types'; +import { AI, BackendType } from '../public-types'; +import { AIService } from '../service'; + +/** + * Base class for Firebase AI model APIs. + * + * Instances of this class are associated with a specific Firebase AI {@link Backend} + * and provide methods for interacting with the configured generative model. + * + * @public + */ +export abstract class AIModel { + /** + * The fully qualified model resource name to use for generating images + * (for example, `publishers/google/models/imagen-3.0-generate-002`). + */ + readonly model: string; + + /** + * @internal + */ + protected _apiSettings: ApiSettings; + + /** + * Constructs a new instance of the {@link AIModel} class. + * + * This constructor should only be called from subclasses that provide + * a model API. + * + * @param ai - an {@link AI} instance. + * @param modelName - The name of the model being used. It can be in one of the following formats: + * - `my-model` (short name, will resolve to `publishers/google/models/my-model`) + * - `models/my-model` (will resolve to `publishers/google/models/my-model`) + * - `publishers/my-publisher/models/my-model` (fully qualified model name) + * + * @throws If the `apiKey` or `projectId` fields are missing in your + * Firebase config. + * + * @internal + */ + protected constructor(ai: AI, modelName: string) { + if (!ai.app?.options?.apiKey) { + throw new AIError( + AIErrorCode.NO_API_KEY, + `The "apiKey" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid API key.`, + ); + } else if (!ai.app?.options?.projectId) { + throw new AIError( + AIErrorCode.NO_PROJECT_ID, + `The "projectId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid project ID.`, + ); + } else if (!ai.app?.options?.appId) { + throw new AIError( + AIErrorCode.NO_APP_ID, + `The "appId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid app ID.`, + ); + } else { + this._apiSettings = { + apiKey: ai.app.options.apiKey, + project: ai.app.options.projectId, + appId: ai.app.options.appId, + automaticDataCollectionEnabled: ai.app.automaticDataCollectionEnabled, + location: ai.location, + backend: ai.backend, + }; + if ((ai as AIService).appCheck) { + this._apiSettings.getAppCheckToken = () => (ai as AIService).appCheck!.getToken(); + } + + if ((ai as AIService).auth?.currentUser) { + this._apiSettings.getAuthToken = () => (ai as AIService).auth!.currentUser!.getIdToken(); + } + + this.model = AIModel.normalizeModelName(modelName, this._apiSettings.backend.backendType); + } + } + + /** + * Normalizes the given model name to a fully qualified model resource name. + * + * @param modelName - The model name to normalize. + * @returns The fully qualified model resource name. + * + * @internal + */ + static normalizeModelName(modelName: string, backendType: BackendType): string { + if (backendType === BackendType.GOOGLE_AI) { + return AIModel.normalizeGoogleAIModelName(modelName); + } else { + return AIModel.normalizeVertexAIModelName(modelName); + } + } + + /** + * @internal + */ + private static normalizeGoogleAIModelName(modelName: string): string { + return `models/${modelName}`; + } + + /** + * @internal + */ + private static normalizeVertexAIModelName(modelName: string): string { + let model: string; + if (modelName.includes('/')) { + if (modelName.startsWith('models/')) { + // Add 'publishers/google' if the user is only passing in 'models/model-name'. + model = `publishers/google/${modelName}`; + } else { + // Any other custom format (e.g. tuned models) must be passed in correctly. + model = modelName; + } + } else { + // If path is not included, assume it's a non-tuned model. + model = `publishers/google/models/${modelName}`; + } + + return model; + } +} diff --git a/packages/ai/lib/models/generative-model.ts b/packages/ai/lib/models/generative-model.ts index 111cefa427..c3bba041e0 100644 --- a/packages/ai/lib/models/generative-model.ts +++ b/packages/ai/lib/models/generative-model.ts @@ -31,23 +31,18 @@ import { StartChatParams, Tool, ToolConfig, - VertexAIErrorCode, } from '../types'; -import { VertexAIError } from '../errors'; import { ChatSession } from '../methods/chat-session'; import { countTokens } from '../methods/count-tokens'; import { formatGenerateContentInput, formatSystemInstruction } from '../requests/request-helpers'; -import { VertexAI } from '../public-types'; -import { ApiSettings } from '../types/internal'; -import { VertexAIService } from '../service'; +import { AIModel } from './ai-model'; +import { AI } from '../public-types'; /** * Class for generative model APIs. * @public */ -export class GenerativeModel { - private _apiSettings: ApiSettings; - model: string; +export class GenerativeModel extends AIModel { generationConfig: GenerationConfig; safetySettings: SafetySetting[]; requestOptions?: RequestOptions; @@ -55,45 +50,8 @@ export class GenerativeModel { toolConfig?: ToolConfig; systemInstruction?: Content; - constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions) { - if (!vertexAI.app?.options?.apiKey) { - throw new VertexAIError( - VertexAIErrorCode.NO_API_KEY, - `The "apiKey" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid API key.`, - ); - } else if (!vertexAI.app?.options?.projectId) { - throw new VertexAIError( - VertexAIErrorCode.NO_PROJECT_ID, - `The "projectId" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid project ID.`, - ); - } else { - this._apiSettings = { - apiKey: vertexAI.app.options.apiKey, - project: vertexAI.app.options.projectId, - location: vertexAI.location, - }; - if ((vertexAI as VertexAIService).appCheck) { - this._apiSettings.getAppCheckToken = () => - (vertexAI as VertexAIService).appCheck!.getToken(); - } - - if ((vertexAI as VertexAIService).auth?.currentUser) { - this._apiSettings.getAuthToken = () => - (vertexAI as VertexAIService).auth!.currentUser!.getIdToken(); - } - } - if (modelParams.model.includes('/')) { - if (modelParams.model.startsWith('models/')) { - // Add "publishers/google" if the user is only passing in 'models/model-name'. - this.model = `publishers/google/${modelParams.model}`; - } else { - // Any other custom format (e.g. tuned models) must be passed in correctly. - this.model = modelParams.model; - } - } else { - // If path is not included, assume it's a non-tuned model. - this.model = `publishers/google/models/${modelParams.model}`; - } + constructor(ai: AI, modelParams: ModelParams, requestOptions?: RequestOptions) { + super(ai, modelParams.model); this.generationConfig = modelParams.generationConfig || {}; this.safetySettings = modelParams.safetySettings || []; this.tools = modelParams.tools; @@ -104,7 +62,7 @@ export class GenerativeModel { /** * Makes a single non-streaming call to the model - * and returns an object containing a single {@link GenerateContentResponse}. + * and returns an object containing a single {@link GenerateContentResponse}. */ async generateContent( request: GenerateContentRequest | string | Array, @@ -151,7 +109,7 @@ export class GenerativeModel { } /** - * Gets a new {@link ChatSession} instance which can be used for + * Gets a new {@link ChatSession} instance which can be used for * multi-turn chats. */ startChat(startChatParams?: StartChatParams): ChatSession { @@ -162,6 +120,13 @@ export class GenerativeModel { tools: this.tools, toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, + generationConfig: this.generationConfig, + safetySettings: this.safetySettings, + /** + * Overrides params inherited from GenerativeModel with those explicitly set in the + * StartChatParams. For example, if startChatParams.generationConfig is set, it'll override + * this.generationConfig. + */ ...startChatParams, }, this.requestOptions, diff --git a/packages/ai/lib/public-types.ts b/packages/ai/lib/public-types.ts index 24c6be6efa..95d9a1dc63 100644 --- a/packages/ai/lib/public-types.ts +++ b/packages/ai/lib/public-types.ts @@ -44,3 +44,114 @@ export interface VertexAIOptions { appCheck?: FirebaseAppCheckTypes.Module | null; auth?: FirebaseAuthTypes.Module | null; } + +/** + * Options for initializing the AI service using {@link getAI | getAI()}. + * This allows specifying which backend to use (Vertex AI Gemini API or Gemini Developer API) + * and configuring its specific options (like location for Vertex AI). + * + * @public + */ +export interface AIOptions { + /** + * The backend configuration to use for the AI service instance. + */ + backend: Backend; +} + +/** + * Abstract base class representing the configuration for an AI service backend. + * This class should not be instantiated directly. Use its subclasses; {@link GoogleAIBackend} for + * the Gemini Developer API (via {@link https://ai.google/ | Google AI}), and + * {@link VertexAIBackend} for the Vertex AI Gemini API. + * + * @public + */ +export abstract class Backend { + /** + * Specifies the backend type. + */ + readonly backendType: BackendType; + + /** + * Protected constructor for use by subclasses. + * @param type - The backend type. + */ + protected constructor(type: BackendType) { + this.backendType = type; + } +} + +/** + * An enum-like object containing constants that represent the supported backends + * for the Firebase AI SDK. + * This determines which backend service (Vertex AI Gemini API or Gemini Developer API) + * the SDK will communicate with. + * + * These values are assigned to the `backendType` property within the specific backend + * configuration objects ({@link GoogleAIBackend} or {@link VertexAIBackend}) to identify + * which service to target. + * + * @public + */ +export const BackendType = { + /** + * Identifies the backend service for the Vertex AI Gemini API provided through Google Cloud. + * Use this constant when creating a {@link VertexAIBackend} configuration. + */ + VERTEX_AI: 'VERTEX_AI', + + /** + * Identifies the backend service for the Gemini Developer API ({@link https://ai.google/ | Google AI}). + * Use this constant when creating a {@link GoogleAIBackend} configuration. + */ + GOOGLE_AI: 'GOOGLE_AI', +} as const; // Using 'as const' makes the string values literal types + +/** + * Type alias representing valid backend types. + * It can be either `'VERTEX_AI'` or `'GOOGLE_AI'`. + * + * @public + */ +export type BackendType = (typeof BackendType)[keyof typeof BackendType]; + +/** + * Options for initializing the AI service using {@link getAI | getAI()}. + * This allows specifying which backend to use (Vertex AI Gemini API or Gemini Developer API) + * and configuring its specific options (like location for Vertex AI). + * + * @public + */ +export interface AIOptions { + /** + * The backend configuration to use for the AI service instance. + */ + backend: Backend; +} + +/** + * An instance of the Firebase AI SDK. + * + * Do not create this instance directly. Instead, use {@link getAI | getAI()}. + * + * @public + */ +export interface AI { + /** + * The {@link @firebase/app#FirebaseApp} this {@link AI} instance is associated with. + */ + app: ReactNativeFirebase.FirebaseApp; + /** + * A {@link Backend} instance that specifies the configuration for the target backend, + * either the Gemini Developer API (using {@link GoogleAIBackend}) or the + * Vertex AI Gemini API (using {@link VertexAIBackend}). + */ + backend: Backend; + /** + * @deprecated use `AI.backend.location` instead. + * + * The location configured for this AI service instance, relevant for Vertex AI backends. + */ + location: string; +} diff --git a/packages/ai/lib/service.ts b/packages/ai/lib/service.ts index e90ffa9668..79bf741303 100644 --- a/packages/ai/lib/service.ts +++ b/packages/ai/lib/service.ts @@ -16,24 +16,28 @@ */ import { ReactNativeFirebase } from '@react-native-firebase/app'; -import { VertexAI, VertexAIOptions } from './public-types'; -import { DEFAULT_LOCATION } from './constants'; +import { AI, Backend } from './public-types'; import { FirebaseAuthTypes } from '@react-native-firebase/auth'; import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; +import { VertexAIBackend } from './backend'; -export class VertexAIService implements VertexAI { +export class AIService implements AI { auth: FirebaseAuthTypes.Module | null; appCheck: FirebaseAppCheckTypes.Module | null; location: string; constructor( public app: ReactNativeFirebase.FirebaseApp, + public backend: Backend, auth?: FirebaseAuthTypes.Module, appCheck?: FirebaseAppCheckTypes.Module, - public options?: VertexAIOptions, ) { this.auth = auth || null; this.appCheck = appCheck || null; - this.location = this.options?.location || DEFAULT_LOCATION; + if (backend instanceof VertexAIBackend) { + this.location = backend.location; + } else { + this.location = ''; + } } } diff --git a/packages/ai/lib/types/error.ts b/packages/ai/lib/types/error.ts index c65e09c55f..38383fddc9 100644 --- a/packages/ai/lib/types/error.ts +++ b/packages/ai/lib/types/error.ts @@ -62,7 +62,7 @@ export interface CustomErrorData { * * @public */ -export const enum VertexAIErrorCode { +export const enum AIErrorCode { /** A generic error occurred. */ ERROR = 'error', @@ -87,6 +87,9 @@ export const enum VertexAIErrorCode { /** An error occurred due to a missing Firebase API key. */ NO_API_KEY = 'no-api-key', + /** An error occurred due to a missing Firebase app ID. */ + NO_APP_ID = 'no-app-id', + /** An error occurred due to a model name not being specified during initialization. */ NO_MODEL = 'no-model', @@ -95,4 +98,7 @@ export const enum VertexAIErrorCode { /** An error occurred while parsing. */ PARSE_FAILED = 'parse-failed', + + /** An error occured due an attempt to use an unsupported feature. */ + UNSUPPORTED = 'unsupported', } diff --git a/packages/ai/lib/types/internal.ts b/packages/ai/lib/types/internal.ts index ee60d476c9..8b51e8c846 100644 --- a/packages/ai/lib/types/internal.ts +++ b/packages/ai/lib/types/internal.ts @@ -15,11 +15,18 @@ * limitations under the License. */ import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; +import { Backend } from '../public-types'; export interface ApiSettings { apiKey: string; + appId: string; project: string; + /** + * @deprecated Use `backend.location` instead. + */ location: string; + automaticDataCollectionEnabled?: boolean; + backend: Backend; getAuthToken?: () => Promise; getAppCheckToken?: () => Promise; } From 44183d687bd84299e070587d49393246ff602169 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 10:08:09 +0100 Subject: [PATCH 03/85] chore: update types for FirebaseApp to match JS SDK --- packages/app/lib/index.d.ts | 5 +++++ packages/app/lib/modular/index.d.ts | 11 +++++++++++ packages/app/lib/modular/index.js | 6 +++--- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/packages/app/lib/index.d.ts b/packages/app/lib/index.d.ts index 02374120a4..a51dce1ad2 100644 --- a/packages/app/lib/index.d.ts +++ b/packages/app/lib/index.d.ts @@ -149,6 +149,11 @@ export namespace ReactNativeFirebase { */ readonly options: FirebaseAppOptions; + /** + * The settable config flag for GDPR opt-in/opt-out + */ + automaticDataCollectionEnabled: boolean; + /** * Make this app unusable and free up resources. */ diff --git a/packages/app/lib/modular/index.d.ts b/packages/app/lib/modular/index.d.ts index ec90f4f3d2..49a8337d82 100644 --- a/packages/app/lib/modular/index.d.ts +++ b/packages/app/lib/modular/index.d.ts @@ -3,6 +3,7 @@ import { ReactNativeFirebase } from '..'; import FirebaseApp = ReactNativeFirebase.FirebaseApp; import FirebaseAppOptions = ReactNativeFirebase.FirebaseAppOptions; import LogLevelString = ReactNativeFirebase.LogLevelString; +import FirebaseAppConfig = ReactNativeFirebase.FirebaseAppConfig; /** * Renders this app unusable and frees the resources of all associated services. @@ -57,6 +58,16 @@ export function getApps(): FirebaseApp[]; */ export function initializeApp(options: FirebaseAppOptions, name?: string): Promise; +/** + * Initializes a Firebase app with the provided options and config. + * @param options - Options to configure the services used in the app. + * @param config - The optional config for your firebase app. + * @returns Promise - The initialized Firebase app. + */ +export function initializeApp( + options: FirebaseAppOptions, + config?: FirebaseAppConfig, +): Promise; /** * Retrieves an instance of a Firebase app. * @param name - The optional name of the app to return ('[DEFAULT]' if omitted). diff --git a/packages/app/lib/modular/index.js b/packages/app/lib/modular/index.js index bc4b0b1951..2cb4a6c4ff 100644 --- a/packages/app/lib/modular/index.js +++ b/packages/app/lib/modular/index.js @@ -60,11 +60,11 @@ export function getApps() { /** * Initializes a Firebase app with the provided options and name. * @param {FirebaseAppOptions} options - Options to configure the services used in the app. - * @param {string} [name] - The optional name of the app to initialize ('[DEFAULT]' if omitted). + * @param {string | FirebaseAppConfig} [configOrName] - The optional name of the app to initialize ('[DEFAULT]' if omitted). * @returns {FirebaseApp} - The initialized Firebase app. */ -export function initializeApp(options, name) { - return initializeAppCompat.call(null, options, name, MODULAR_DEPRECATION_ARG); +export function initializeApp(options, configOrName) { + return initializeAppCompat.call(null, options, configOrName, MODULAR_DEPRECATION_ARG); } /** From fd0df2b63b3927b24208bbec390f45716bcde22c Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 10:16:27 +0100 Subject: [PATCH 04/85] chore: update CHANGELOG & package.json --- packages/ai/CHANGELOG.md | 42 +++++----------------------------------- packages/ai/package.json | 9 +++++---- 2 files changed, 10 insertions(+), 41 deletions(-) diff --git a/packages/ai/CHANGELOG.md b/packages/ai/CHANGELOG.md index 1ac68b1cef..8db65e03d0 100644 --- a/packages/ai/CHANGELOG.md +++ b/packages/ai/CHANGELOG.md @@ -3,44 +3,12 @@ All notable changes to this project will be documented in this file. See [Conventional Commits](https://conventionalcommits.org) for commit guidelines. -## [22.2.0](https://github.com/invertase/react-native-firebase/compare/v22.1.0...v22.2.0) (2025-05-12) +## Feature -**Note:** Version bump only for package @react-native-firebase/vertexai +Initial release of the Firebase AI Logic SDK (`FirebaseAI`). This SDK *replaces* the previous Vertex AI in Firebase SDK (`FirebaseVertexAI`) to accommodate the evolving set of supported features and services. +The new Firebase AI Logic SDK provides **preview** support for the Gemini Developer API, including its free tier offering. +Using the Firebase AI Logic SDK with the Vertex AI Gemini API is still generally available (GA). -## [22.1.0](https://github.com/invertase/react-native-firebase/compare/v22.0.0...v22.1.0) (2025-04-30) +To start using the new SDK, import the `@react-native-firebase/ai` package and use the modular method `getAI()` to initialize. See details in the [migration guide](https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk). -### Bug Fixes -- **vertexai:** package.json main needs updating to commonjs path ([abe6c19](https://github.com/invertase/react-native-firebase/commit/abe6c190e6a22676fc58a4c5c7740ddeba2efd93)) - -## [22.0.0](https://github.com/invertase/react-native-firebase/compare/v21.14.0...v22.0.0) (2025-04-25) - -### Bug Fixes - -- enable provenance signing during publish ([4535f0d](https://github.com/invertase/react-native-firebase/commit/4535f0d5756c89aeb8f8e772348c71d8176348be)) - -## [21.14.0](https://github.com/invertase/react-native-firebase/compare/v21.13.0...v21.14.0) (2025-04-14) - -**Note:** Version bump only for package @react-native-firebase/vertexai - -## [21.13.0](https://github.com/invertase/react-native-firebase/compare/v21.12.3...v21.13.0) (2025-03-31) - -**Note:** Version bump only for package @react-native-firebase/vertexai - -## [21.12.3](https://github.com/invertase/react-native-firebase/compare/v21.12.2...v21.12.3) (2025-03-26) - -**Note:** Version bump only for package @react-native-firebase/vertexai - -## [21.12.2](https://github.com/invertase/react-native-firebase/compare/v21.12.1...v21.12.2) (2025-03-23) - -**Note:** Version bump only for package @react-native-firebase/vertexai - -## [21.12.1](https://github.com/invertase/react-native-firebase/compare/v21.12.0...v21.12.1) (2025-03-22) - -**Note:** Version bump only for package @react-native-firebase/vertexai - -## [21.12.0](https://github.com/invertase/react-native-firebase/compare/v21.11.0...v21.12.0) (2025-03-03) - -### Features - -- vertexAI package support ([#8236](https://github.com/invertase/react-native-firebase/issues/8236)) ([a1d1361](https://github.com/invertase/react-native-firebase/commit/a1d13610f443a96a7195b3f769f77d9676c0e577)) diff --git a/packages/ai/package.json b/packages/ai/package.json index 1f1c6c9e1d..6b230fa141 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -1,8 +1,8 @@ { - "name": "@react-native-firebase/vertexai", + "name": "@react-native-firebase/ai", "version": "22.2.0", "author": "Invertase (http://invertase.io)", - "description": "React Native Firebase - Vertex AI is a fully-managed, unified AI development platform for building and using generative AI", + "description": "React Native Firebase - Firebase AI is a fully-managed, unified AI development platform for building and using generative AI", "main": "./dist/commonjs/index.js", "module": "./dist/module/index.js", "types": "./dist/typescript/module/lib/index.d.ts", @@ -14,14 +14,15 @@ }, "repository": { "type": "git", - "url": "https://github.com/invertase/react-native-firebase/tree/main/packages/vertexai" + "url": "https://github.com/invertase/react-native-firebase/tree/main/packages/ai" }, "license": "Apache-2.0", "keywords": [ "react", "react-native", "firebase", - "vertexai", + "firebase-ai", + "ai", "gemini", "generative-ai" ], From 649eaf560bc6d2226d3f1af842e7be52a7506682 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 10:30:46 +0100 Subject: [PATCH 05/85] format --- packages/ai/CHANGELOG.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/ai/CHANGELOG.md b/packages/ai/CHANGELOG.md index 8db65e03d0..b47a3c0062 100644 --- a/packages/ai/CHANGELOG.md +++ b/packages/ai/CHANGELOG.md @@ -9,6 +9,4 @@ Initial release of the Firebase AI Logic SDK (`FirebaseAI`). This SDK *replaces* The new Firebase AI Logic SDK provides **preview** support for the Gemini Developer API, including its free tier offering. Using the Firebase AI Logic SDK with the Vertex AI Gemini API is still generally available (GA). -To start using the new SDK, import the `@react-native-firebase/ai` package and use the modular method `getAI()` to initialize. See details in the [migration guide](https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk). - - +To start using the new SDK, import the `@react-native-firebase/ai` package and use the modular method `getAI()` to initialize. See details in the [migration guide](https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk). \ No newline at end of file From 547ce8a0e31e43e2efc764ec1ce4c03531c95295 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 10:37:38 +0100 Subject: [PATCH 06/85] request-helpers --- packages/ai/lib/requests/request-helpers.ts | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/packages/ai/lib/requests/request-helpers.ts b/packages/ai/lib/requests/request-helpers.ts index 9de045a4ee..6d468f4023 100644 --- a/packages/ai/lib/requests/request-helpers.ts +++ b/packages/ai/lib/requests/request-helpers.ts @@ -15,8 +15,8 @@ * limitations under the License. */ -import { Content, GenerateContentRequest, Part, VertexAIErrorCode } from '../types'; -import { VertexAIError } from '../errors'; +import { Content, GenerateContentRequest, Part, AIErrorCode } from '../types'; +import { AIError } from '../errors'; export function formatSystemInstruction(input?: string | Part | Content): Content | undefined { if (input == null) { @@ -32,7 +32,6 @@ export function formatSystemInstruction(input?: string | Part | Content): Conten return input as Content; } } - return undefined; } @@ -76,15 +75,15 @@ function assignRoleToPartsAndValidateSendMessageRequest(parts: Part[]): Content } if (hasUserContent && hasFunctionContent) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, + throw new AIError( + AIErrorCode.INVALID_CONTENT, 'Within a single message, FunctionResponse cannot be mixed with other type of Part in the request for sending chat message.', ); } if (!hasUserContent && !hasFunctionContent) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, + throw new AIError( + AIErrorCode.INVALID_CONTENT, 'No Content is provided for sending chat message.', ); } From ea4c130f56f627ad08c5f456517e15711b29854f Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 10:44:14 +0100 Subject: [PATCH 07/85] request.ts --- packages/ai/lib/requests/request.ts | 78 +++++++++++++++++------------ 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/packages/ai/lib/requests/request.ts b/packages/ai/lib/requests/request.ts index e055094f90..77a1dd8691 100644 --- a/packages/ai/lib/requests/request.ts +++ b/packages/ai/lib/requests/request.ts @@ -15,8 +15,8 @@ * limitations under the License. */ import { Platform } from 'react-native'; -import { ErrorDetails, RequestOptions, VertexAIErrorCode } from '../types'; -import { VertexAIError } from '../errors'; +import { AIErrorCode, ErrorDetails, RequestOptions } from '../types'; +import { AIError } from '../errors'; import { ApiSettings } from '../types/internal'; import { DEFAULT_API_VERSION, @@ -26,11 +26,13 @@ import { PACKAGE_VERSION, } from '../constants'; import { logger } from '../logger'; +import { GoogleAIBackend, VertexAIBackend } from '../backend'; export enum Task { GENERATE_CONTENT = 'generateContent', STREAM_GENERATE_CONTENT = 'streamGenerateContent', COUNT_TOKENS = 'countTokens', + PREDICT = 'predict', } export class RequestUrl { @@ -58,29 +60,40 @@ export class RequestUrl { } return emulatorUrl; } + const url = new URL(this.baseUrl); // Throws if the URL is invalid + url.pathname = `/${this.apiVersion}/${this.modelPath}:${this.task}`; + url.search = this.queryParams.toString(); + return url.toString(); + } - const apiVersion = DEFAULT_API_VERSION; - const baseUrl = this.requestOptions?.baseUrl || DEFAULT_BASE_URL; - let url = `${baseUrl}/${apiVersion}`; - url += `/projects/${this.apiSettings.project}`; - url += `/locations/${this.apiSettings.location}`; - url += `/${this.model}`; - url += `:${this.task}`; - if (this.stream) { - url += '?alt=sse'; + private get baseUrl(): string { + return this.requestOptions?.baseUrl || DEFAULT_BASE_URL; + } + + private get apiVersion(): string { + return DEFAULT_API_VERSION; // TODO: allow user-set options if that feature becomes available + } + + private get modelPath(): string { + if (this.apiSettings.backend instanceof GoogleAIBackend) { + return `projects/${this.apiSettings.project}/${this.model}`; + } else if (this.apiSettings.backend instanceof VertexAIBackend) { + return `projects/${this.apiSettings.project}/locations/${this.apiSettings.backend.location}/${this.model}`; + } else { + throw new AIError( + AIErrorCode.ERROR, + `Invalid backend: ${JSON.stringify(this.apiSettings.backend)}`, + ); } - return url; } - /** - * If the model needs to be passed to the backend, it needs to - * include project and location path. - */ - get fullModelString(): string { - let modelString = `projects/${this.apiSettings.project}`; - modelString += `/locations/${this.apiSettings.location}`; - modelString += `/${this.model}`; - return modelString; + private get queryParams(): URLSearchParams { + const params = new URLSearchParams(); + if (this.stream) { + params.set('alt', 'sse'); + } + + return params; } } @@ -99,6 +112,9 @@ export async function getHeaders(url: RequestUrl): Promise { headers.append('Content-Type', 'application/json'); headers.append('x-goog-api-client', getClientHeaders()); headers.append('x-goog-api-key', url.apiSettings.apiKey); + if (url.apiSettings.automaticDataCollectionEnabled) { + headers.append('X-Firebase-Appid', url.apiSettings.appId); + } if (url.apiSettings.getAppCheckToken) { let appCheckToken; @@ -154,6 +170,7 @@ export async function makeRequest( let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; try { const request = await constructRequest(model, task, apiSettings, stream, body, requestOptions); + // Timeout is 180s by default const timeoutMillis = requestOptions?.timeout != null && requestOptions.timeout >= 0 ? requestOptions.timeout @@ -192,9 +209,9 @@ export async function makeRequest( ), ) ) { - throw new VertexAIError( - VertexAIErrorCode.API_NOT_ENABLED, - `The Vertex AI in Firebase SDK requires the Vertex AI in Firebase ` + + throw new AIError( + AIErrorCode.API_NOT_ENABLED, + `The Firebase AI SDK requires the Firebase AI ` + `API ('firebasevertexai.googleapis.com') to be enabled in your ` + `Firebase project. Enable this API by visiting the Firebase Console ` + `at https://console.firebase.google.com/project/${url.apiSettings.project}/genai/ ` + @@ -208,8 +225,8 @@ export async function makeRequest( }, ); } - throw new VertexAIError( - VertexAIErrorCode.FETCH_ERROR, + throw new AIError( + AIErrorCode.FETCH_ERROR, `Error fetching from ${url}: [${response.status} ${response.statusText}] ${message}`, { status: response.status, @@ -221,14 +238,11 @@ export async function makeRequest( } catch (e) { let err = e as Error; if ( - (e as VertexAIError).code !== VertexAIErrorCode.FETCH_ERROR && - (e as VertexAIError).code !== VertexAIErrorCode.API_NOT_ENABLED && + (e as AIError).code !== AIErrorCode.FETCH_ERROR && + (e as AIError).code !== AIErrorCode.API_NOT_ENABLED && e instanceof Error ) { - err = new VertexAIError( - VertexAIErrorCode.ERROR, - `Error fetching from ${url.toString()}: ${e.message}`, - ); + err = new AIError(AIErrorCode.ERROR, `Error fetching from ${url.toString()}: ${e.message}`); err.stack = e.stack; } From bc0ab3b5308ed2bb654da3c70a9739adb56e8a2e Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 10:50:28 +0100 Subject: [PATCH 08/85] response-helpers.ts --- packages/ai/lib/requests/response-helpers.ts | 84 +++++++++++++++++--- packages/ai/lib/types/responses.ts | 11 ++- 2 files changed, 82 insertions(+), 13 deletions(-) diff --git a/packages/ai/lib/requests/response-helpers.ts b/packages/ai/lib/requests/response-helpers.ts index c7abc9d923..4fdb2362bd 100644 --- a/packages/ai/lib/requests/response-helpers.ts +++ b/packages/ai/lib/requests/response-helpers.ts @@ -21,9 +21,10 @@ import { FunctionCall, GenerateContentCandidate, GenerateContentResponse, - VertexAIErrorCode, + AIErrorCode, + InlineDataPart, } from '../types'; -import { VertexAIError } from '../errors'; +import { AIError } from '../errors'; import { logger } from '../logger'; /** @@ -62,8 +63,8 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC ); } if (hadBadFinishReason(response.candidates[0]!)) { - throw new VertexAIError( - VertexAIErrorCode.RESPONSE_ERROR, + throw new AIError( + AIErrorCode.RESPONSE_ERROR, `Response error: ${formatBlockErrorMessage( response, )}. Response body stored in error.response`, @@ -74,8 +75,8 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC } return getText(response); } else if (response.promptFeedback) { - throw new VertexAIError( - VertexAIErrorCode.RESPONSE_ERROR, + throw new AIError( + AIErrorCode.RESPONSE_ERROR, `Text not available. ${formatBlockErrorMessage(response)}`, { response, @@ -84,6 +85,40 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC } return ''; }; + (response as EnhancedGenerateContentResponse).inlineDataParts = (): + | InlineDataPart[] + | undefined => { + if (response.candidates && response.candidates.length > 0) { + if (response.candidates.length > 1) { + logger.warn( + `This response had ${response.candidates.length} ` + + `candidates. Returning data from the first candidate only. ` + + `Access response.candidates directly to use the other candidates.`, + ); + } + if (hadBadFinishReason(response.candidates[0]!)) { + throw new AIError( + AIErrorCode.RESPONSE_ERROR, + `Response error: ${formatBlockErrorMessage( + response, + )}. Response body stored in error.response`, + { + response, + }, + ); + } + return getInlineDataParts(response); + } else if (response.promptFeedback) { + throw new AIError( + AIErrorCode.RESPONSE_ERROR, + `Data not available. ${formatBlockErrorMessage(response)}`, + { + response, + }, + ); + } + return undefined; + }; (response as EnhancedGenerateContentResponse).functionCalls = () => { if (response.candidates && response.candidates.length > 0) { if (response.candidates.length > 1) { @@ -94,8 +129,8 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC ); } if (hadBadFinishReason(response.candidates[0]!)) { - throw new VertexAIError( - VertexAIErrorCode.RESPONSE_ERROR, + throw new AIError( + AIErrorCode.RESPONSE_ERROR, `Response error: ${formatBlockErrorMessage( response, )}. Response body stored in error.response`, @@ -106,8 +141,8 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC } return getFunctionCalls(response); } else if (response.promptFeedback) { - throw new VertexAIError( - VertexAIErrorCode.RESPONSE_ERROR, + throw new AIError( + AIErrorCode.RESPONSE_ERROR, `Function call not available. ${formatBlockErrorMessage(response)}`, { response, @@ -125,7 +160,7 @@ export function addHelpers(response: GenerateContentResponse): EnhancedGenerateC export function getText(response: GenerateContentResponse): string { const textStrings = []; if (response.candidates?.[0]?.content?.parts) { - for (const part of response.candidates?.[0].content?.parts) { + for (const part of response.candidates?.[0]?.content?.parts) { if (part.text) { textStrings.push(part.text); } @@ -139,7 +174,7 @@ export function getText(response: GenerateContentResponse): string { } /** - * Returns {@link FunctionCall}s associated with first candidate. + * Returns {@link FunctionCall}s associated with first candidate. */ export function getFunctionCalls(response: GenerateContentResponse): FunctionCall[] | undefined { const functionCalls: FunctionCall[] = []; @@ -157,6 +192,31 @@ export function getFunctionCalls(response: GenerateContentResponse): FunctionCal } } +/** + * Returns {@link InlineDataPart}s in the first candidate if present. + * + * @internal + */ +export function getInlineDataParts( + response: GenerateContentResponse, +): InlineDataPart[] | undefined { + const data: InlineDataPart[] = []; + + if (response.candidates?.[0]?.content?.parts) { + for (const part of response.candidates?.[0]?.content?.parts) { + if (part.inlineData) { + data.push(part); + } + } + } + + if (data.length > 0) { + return data; + } else { + return undefined; + } +} + const badFinishReasons = [FinishReason.RECITATION, FinishReason.SAFETY]; function hadBadFinishReason(candidate: GenerateContentCandidate): boolean { diff --git a/packages/ai/lib/types/responses.ts b/packages/ai/lib/types/responses.ts index 013391e98b..d162aa8015 100644 --- a/packages/ai/lib/types/responses.ts +++ b/packages/ai/lib/types/responses.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { Content, FunctionCall } from './content'; +import { Content, FunctionCall, InlineDataPart } from './content'; import { BlockReason, FinishReason, HarmCategory, HarmProbability, HarmSeverity } from './enums'; /** @@ -51,6 +51,15 @@ export interface EnhancedGenerateContentResponse extends GenerateContentResponse * Throws if the prompt or candidate was blocked. */ text: () => string; + /** + * Aggregates and returns all {@link InlineDataPart}s from the {@link GenerateContentResponse}'s + * first candidate. + * + * @returns An array of {@link InlineDataPart}s containing data from the response, if available. + * + * @throws If the prompt or candidate was blocked. + */ + inlineDataParts: () => InlineDataPart[] | undefined; functionCalls: () => FunctionCall[] | undefined; } From 86ddd948e8107642c497911c4a4f4bd40f115a6b Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 11:02:29 +0100 Subject: [PATCH 09/85] schema-builder.ts --- packages/ai/lib/requests/schema-builder.ts | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/packages/ai/lib/requests/schema-builder.ts b/packages/ai/lib/requests/schema-builder.ts index 92003a0950..21c5605cb7 100644 --- a/packages/ai/lib/requests/schema-builder.ts +++ b/packages/ai/lib/requests/schema-builder.ts @@ -15,8 +15,8 @@ * limitations under the License. */ -import { VertexAIError } from '../errors'; -import { VertexAIErrorCode } from '../types'; +import { AIError } from '../errors'; +import { AIErrorCode } from '../types'; import { SchemaInterface, SchemaType, @@ -49,6 +49,12 @@ export abstract class Schema implements SchemaInterface { format?: string; /** Optional. The description of the property. */ description?: string; + /** Optional. The items of the property. */ + items?: SchemaInterface; + /** The minimum number of items (elements) in a schema of type {@link SchemaType.ARRAY}. */ + minItems?: number; + /** The maximum number of items (elements) in a schema of type {@link SchemaType.ARRAY}. */ + maxItems?: number; /** Optional. Whether the property is nullable. Defaults to false. */ nullable: boolean; /** Optional. The example of the property. */ @@ -257,8 +263,8 @@ export class ObjectSchema extends Schema { if (this.optionalProperties) { for (const propertyKey of this.optionalProperties) { if (!this.properties.hasOwnProperty(propertyKey)) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_SCHEMA, + throw new AIError( + AIErrorCode.INVALID_SCHEMA, `Property "${propertyKey}" specified in "optionalProperties" does not exist.`, ); } From 68f66695ed598ca2d9d1f5e2f744de1c2f83e7ff Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 11:29:44 +0100 Subject: [PATCH 10/85] stream-reader.ts --- packages/ai/lib/googleai-mappers.ts | 218 ++++++++++++++++++++++ packages/ai/lib/requests/stream-reader.ts | 52 ++++-- packages/ai/lib/types/enums.ts | 7 + packages/ai/lib/types/googleai.ts | 70 +++++++ 4 files changed, 330 insertions(+), 17 deletions(-) create mode 100644 packages/ai/lib/googleai-mappers.ts create mode 100644 packages/ai/lib/types/googleai.ts diff --git a/packages/ai/lib/googleai-mappers.ts b/packages/ai/lib/googleai-mappers.ts new file mode 100644 index 0000000000..2f6724b8d8 --- /dev/null +++ b/packages/ai/lib/googleai-mappers.ts @@ -0,0 +1,218 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { AIError } from './errors'; +import { logger } from './logger'; +import { + CitationMetadata, + CountTokensRequest, + GenerateContentCandidate, + GenerateContentRequest, + GenerateContentResponse, + HarmSeverity, + InlineDataPart, + PromptFeedback, + SafetyRating, + AIErrorCode, +} from './types'; +import { + GoogleAIGenerateContentResponse, + GoogleAIGenerateContentCandidate, + GoogleAICountTokensRequest, +} from './types/googleai'; + +/** + * This SDK supports both the Vertex AI Gemini API and the Gemini Developer API (using Google AI). + * The public API prioritizes the format used by the Vertex AI Gemini API. + * We avoid having two sets of types by translating requests and responses between the two API formats. + * This translation allows developers to switch between the Vertex AI Gemini API and the Gemini Developer API + * with minimal code changes. + * + * In here are functions that map requests and responses between the two API formats. + * Requests in the Vertex AI format are mapped to the Google AI format before being sent. + * Responses from the Google AI backend are mapped back to the Vertex AI format before being returned to the user. + */ + +/** + * Maps a Vertex AI {@link GenerateContentRequest} to a format that can be sent to Google AI. + * + * @param generateContentRequest The {@link GenerateContentRequest} to map. + * @returns A {@link GenerateContentResponse} that conforms to the Google AI format. + * + * @throws If the request contains properties that are unsupported by Google AI. + * + * @internal + */ +export function mapGenerateContentRequest( + generateContentRequest: GenerateContentRequest, +): GenerateContentRequest { + generateContentRequest.safetySettings?.forEach(safetySetting => { + if (safetySetting.method) { + throw new AIError( + AIErrorCode.UNSUPPORTED, + 'SafetySetting.method is not supported in the the Gemini Developer API. Please remove this property.', + ); + } + }); + + if (generateContentRequest.generationConfig?.topK) { + const roundedTopK = Math.round(generateContentRequest.generationConfig.topK); + + if (roundedTopK !== generateContentRequest.generationConfig.topK) { + logger.warn( + 'topK in GenerationConfig has been rounded to the nearest integer to match the format for requests to the Gemini Developer API.', + ); + generateContentRequest.generationConfig.topK = roundedTopK; + } + } + + return generateContentRequest; +} + +/** + * Maps a {@link GenerateContentResponse} from Google AI to the format of the + * {@link GenerateContentResponse} that we get from VertexAI that is exposed in the public API. + * + * @param googleAIResponse The {@link GenerateContentResponse} from Google AI. + * @returns A {@link GenerateContentResponse} that conforms to the public API's format. + * + * @internal + */ +export function mapGenerateContentResponse( + googleAIResponse: GoogleAIGenerateContentResponse, +): GenerateContentResponse { + const generateContentResponse = { + candidates: googleAIResponse.candidates + ? mapGenerateContentCandidates(googleAIResponse.candidates) + : undefined, + prompt: googleAIResponse.promptFeedback + ? mapPromptFeedback(googleAIResponse.promptFeedback) + : undefined, + usageMetadata: googleAIResponse.usageMetadata, + }; + + return generateContentResponse; +} + +/** + * Maps a Vertex AI {@link CountTokensRequest} to a format that can be sent to Google AI. + * + * @param countTokensRequest The {@link CountTokensRequest} to map. + * @param model The model to count tokens with. + * @returns A {@link CountTokensRequest} that conforms to the Google AI format. + * + * @internal + */ +export function mapCountTokensRequest( + countTokensRequest: CountTokensRequest, + model: string, +): GoogleAICountTokensRequest { + const mappedCountTokensRequest: GoogleAICountTokensRequest = { + generateContentRequest: { + model, + ...countTokensRequest, + }, + }; + + return mappedCountTokensRequest; +} + +/** + * Maps a Google AI {@link GoogleAIGenerateContentCandidate} to a format that conforms + * to the Vertex AI API format. + * + * @param candidates The {@link GoogleAIGenerateContentCandidate} to map. + * @returns A {@link GenerateContentCandidate} that conforms to the Vertex AI format. + * + * @throws If any {@link Part} in the candidates has a `videoMetadata` property. + * + * @internal + */ +export function mapGenerateContentCandidates( + candidates: GoogleAIGenerateContentCandidate[], +): GenerateContentCandidate[] { + const mappedCandidates: GenerateContentCandidate[] = []; + let mappedSafetyRatings: SafetyRating[]; + if (mappedCandidates) { + candidates.forEach(candidate => { + // Map citationSources to citations. + let citationMetadata: CitationMetadata | undefined; + if (candidate.citationMetadata) { + citationMetadata = { + citations: candidate.citationMetadata.citationSources, + }; + } + + // Assign missing candidate SafetyRatings properties to their defaults if undefined. + if (candidate.safetyRatings) { + mappedSafetyRatings = candidate.safetyRatings.map(safetyRating => { + return { + ...safetyRating, + severity: safetyRating.severity ?? HarmSeverity.HARM_SEVERITY_UNSUPPORTED, + probabilityScore: safetyRating.probabilityScore ?? 0, + severityScore: safetyRating.severityScore ?? 0, + }; + }); + } + + // videoMetadata is not supported. + // Throw early since developers may send a long video as input and only expect to pay + // for inference on a small portion of the video. + if (candidate.content?.parts.some(part => (part as InlineDataPart)?.videoMetadata)) { + throw new AIError( + AIErrorCode.UNSUPPORTED, + 'Part.videoMetadata is not supported in the Gemini Developer API. Please remove this property.', + ); + } + + const mappedCandidate = { + index: candidate.index, + content: candidate.content, + finishReason: candidate.finishReason, + finishMessage: candidate.finishMessage, + safetyRatings: mappedSafetyRatings, + citationMetadata, + groundingMetadata: candidate.groundingMetadata, + }; + mappedCandidates.push(mappedCandidate); + }); + } + + return mappedCandidates; +} + +export function mapPromptFeedback(promptFeedback: PromptFeedback): PromptFeedback { + // Assign missing SafetyRating properties to their defaults if undefined. + const mappedSafetyRatings: SafetyRating[] = []; + promptFeedback.safetyRatings.forEach(safetyRating => { + mappedSafetyRatings.push({ + category: safetyRating.category, + probability: safetyRating.probability, + severity: safetyRating.severity ?? HarmSeverity.HARM_SEVERITY_UNSUPPORTED, + probabilityScore: safetyRating.probabilityScore ?? 0, + severityScore: safetyRating.severityScore ?? 0, + blocked: safetyRating.blocked, + }); + }); + + const mappedPromptFeedback: PromptFeedback = { + blockReason: promptFeedback.blockReason, + safetyRatings: mappedSafetyRatings, + blockReasonMessage: promptFeedback.blockReasonMessage, + }; + return mappedPromptFeedback; +} diff --git a/packages/ai/lib/requests/stream-reader.ts b/packages/ai/lib/requests/stream-reader.ts index d24f6d44bf..6fea165c26 100644 --- a/packages/ai/lib/requests/stream-reader.ts +++ b/packages/ai/lib/requests/stream-reader.ts @@ -22,10 +22,14 @@ import { GenerateContentResponse, GenerateContentStreamResult, Part, - VertexAIErrorCode, + AIErrorCode, } from '../types'; -import { VertexAIError } from '../errors'; +import { AIError } from '../errors'; import { createEnhancedContentResponse } from './response-helpers'; +import { ApiSettings } from '../types/internal'; +import { BackendType } from '../public-types'; +import * as GoogleAIMapper from '../googleai-mappers'; +import { GoogleAIGenerateContentResponse } from '../types/googleai'; const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/; @@ -37,7 +41,10 @@ const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/; * * @param response - Response from a fetch call */ -export function processStream(response: Response): GenerateContentStreamResult { +export function processStream( + response: Response, + apiSettings: ApiSettings, +): GenerateContentStreamResult { const inputStream = new ReadableStream({ async start(controller) { const reader = response.body!.getReader(); @@ -56,28 +63,36 @@ export function processStream(response: Response): GenerateContentStreamResult { const responseStream = getResponseStream(inputStream); const [stream1, stream2] = responseStream.tee(); return { - stream: generateResponseSequence(stream1), - response: getResponsePromise(stream2), + stream: generateResponseSequence(stream1, apiSettings), + response: getResponsePromise(stream2, apiSettings), }; } async function getResponsePromise( stream: ReadableStream, + apiSettings: ApiSettings, ): Promise { const allResponses: GenerateContentResponse[] = []; const reader = stream.getReader(); while (true) { const { done, value } = await reader.read(); if (done) { - const enhancedResponse = createEnhancedContentResponse(aggregateResponses(allResponses)); - return enhancedResponse; + let generateContentResponse = aggregateResponses(allResponses); + if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { + generateContentResponse = GoogleAIMapper.mapGenerateContentResponse( + generateContentResponse as GoogleAIGenerateContentResponse, + ); + } + return createEnhancedContentResponse(generateContentResponse); } + allResponses.push(value); } } async function* generateResponseSequence( stream: ReadableStream, + apiSettings: ApiSettings, ): AsyncGenerator { const reader = stream.getReader(); while (true) { @@ -86,7 +101,15 @@ async function* generateResponseSequence( break; } - const enhancedResponse = createEnhancedContentResponse(value); + let enhancedResponse: EnhancedGenerateContentResponse; + if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { + enhancedResponse = createEnhancedContentResponse( + GoogleAIMapper.mapGenerateContentResponse(value as GoogleAIGenerateContentResponse), + ); + } else { + enhancedResponse = createEnhancedContentResponse(value); + } + yield enhancedResponse; } } @@ -106,9 +129,7 @@ export function getResponseStream(inputStream: ReadableStream): Reada return reader.read().then(({ value, done }) => { if (done) { if (currentText.trim()) { - controller.error( - new VertexAIError(VertexAIErrorCode.PARSE_FAILED, 'Failed to parse stream'), - ); + controller.error(new AIError(AIErrorCode.PARSE_FAILED, 'Failed to parse stream')); return; } controller.close(); @@ -123,10 +144,7 @@ export function getResponseStream(inputStream: ReadableStream): Reada parsedResponse = JSON.parse(match[1]!); } catch (_) { controller.error( - new VertexAIError( - VertexAIErrorCode.PARSE_FAILED, - `Error parsing JSON response: "${match[1]}`, - ), + new AIError(AIErrorCode.PARSE_FAILED, `Error parsing JSON response: "${match[1]}`), ); return; } @@ -197,8 +215,8 @@ export function aggregateResponses(responses: GenerateContentResponse[]): Genera newPart.functionCall = part.functionCall; } if (Object.keys(newPart).length === 0) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, + throw new AIError( + AIErrorCode.INVALID_CONTENT, 'Part should have at least one property, but there are none. This is likely caused ' + 'by a malformed response from the backend.', ); diff --git a/packages/ai/lib/types/enums.ts b/packages/ai/lib/types/enums.ts index 010aff903a..886faba7d3 100644 --- a/packages/ai/lib/types/enums.ts +++ b/packages/ai/lib/types/enums.ts @@ -91,6 +91,13 @@ export enum HarmSeverity { HARM_SEVERITY_MEDIUM = 'HARM_SEVERITY_MEDIUM', // High level of harm severity. HARM_SEVERITY_HIGH = 'HARM_SEVERITY_HIGH', + /** + * Harm severity is not supported. + * + * @remarks + * The GoogleAI backend does not support `HarmSeverity`, so this value is used as a fallback. + */ + HARM_SEVERITY_UNSUPPORTED = 'HARM_SEVERITY_UNSUPPORTED', } /** diff --git a/packages/ai/lib/types/googleai.ts b/packages/ai/lib/types/googleai.ts new file mode 100644 index 0000000000..4c7dfe30bb --- /dev/null +++ b/packages/ai/lib/types/googleai.ts @@ -0,0 +1,70 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + Tool, + GenerationConfig, + Citation, + FinishReason, + GroundingMetadata, + PromptFeedback, + SafetyRating, + UsageMetadata, +} from '../public-types'; +import { Content, Part } from './content'; + +/** + * @internal + */ +export interface GoogleAICountTokensRequest { + generateContentRequest: { + model: string; // 'models/model-name' + contents: Content[]; + systemInstruction?: string | Part | Content; + tools?: Tool[]; + generationConfig?: GenerationConfig; + }; +} + +/** + * @internal + */ +export interface GoogleAIGenerateContentResponse { + candidates?: GoogleAIGenerateContentCandidate[]; + promptFeedback?: PromptFeedback; + usageMetadata?: UsageMetadata; +} + +/** + * @internal + */ +export interface GoogleAIGenerateContentCandidate { + index: number; + content: Content; + finishReason?: FinishReason; + finishMessage?: string; + safetyRatings?: SafetyRating[]; + citationMetadata?: GoogleAICitationMetadata; + groundingMetadata?: GroundingMetadata; +} + +/** + * @internal + */ +export interface GoogleAICitationMetadata { + citationSources: Citation[]; // Maps to `citations` +} From dc3cacb9fb91dbb942690f0e077798fcca6db1a2 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 11:36:00 +0100 Subject: [PATCH 11/85] chat-session-helpers.ts --- .../ai/lib/methods/chat-session-helpers.ts | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/packages/ai/lib/methods/chat-session-helpers.ts b/packages/ai/lib/methods/chat-session-helpers.ts index 4b9bb56db0..ea8cd826b9 100644 --- a/packages/ai/lib/methods/chat-session-helpers.ts +++ b/packages/ai/lib/methods/chat-session-helpers.ts @@ -15,8 +15,8 @@ * limitations under the License. */ -import { Content, POSSIBLE_ROLES, Part, Role, VertexAIErrorCode } from '../types'; -import { VertexAIError } from '../errors'; +import { Content, POSSIBLE_ROLES, Part, Role, AIErrorCode } from '../types'; +import { AIError } from '../errors'; // https://ai.google.dev/api/rest/v1beta/Content#part @@ -48,14 +48,14 @@ export function validateChatHistory(history: Content[]): void { for (const currContent of history) { const { role, parts } = currContent; if (!prevContent && role !== 'user') { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, + throw new AIError( + AIErrorCode.INVALID_CONTENT, `First Content should be with role 'user', got ${role}`, ); } if (!POSSIBLE_ROLES.includes(role)) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, + throw new AIError( + AIErrorCode.INVALID_CONTENT, `Each item should include role field. Got ${role} but valid roles are: ${JSON.stringify( POSSIBLE_ROLES, )}`, @@ -63,17 +63,14 @@ export function validateChatHistory(history: Content[]): void { } if (!Array.isArray(parts)) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, + throw new AIError( + AIErrorCode.INVALID_CONTENT, `Content should have 'parts' but property with an array of Parts`, ); } if (parts.length === 0) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - `Each Content should have at least one part`, - ); + throw new AIError(AIErrorCode.INVALID_CONTENT, `Each Content should have at least one part`); } const countFields: Record = { @@ -93,8 +90,8 @@ export function validateChatHistory(history: Content[]): void { const validParts = VALID_PARTS_PER_ROLE[role]; for (const key of VALID_PART_FIELDS) { if (!validParts.includes(key) && countFields[key] > 0) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, + throw new AIError( + AIErrorCode.INVALID_CONTENT, `Content with role '${role}' can't contain '${key}' part`, ); } @@ -103,9 +100,9 @@ export function validateChatHistory(history: Content[]): void { if (prevContent) { const validPreviousContentRoles = VALID_PREVIOUS_CONTENT_ROLES[role]; if (!validPreviousContentRoles.includes(prevContent.role)) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - `Content with role '${role} can't follow '${ + throw new AIError( + AIErrorCode.INVALID_CONTENT, + `Content with role '${role}' can't follow '${ prevContent.role }'. Valid previous roles: ${JSON.stringify(VALID_PREVIOUS_CONTENT_ROLES)}`, ); From 0582d8fd43f8d111322bc5f1f3e51c3a1792466d Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 11:38:19 +0100 Subject: [PATCH 12/85] chat-session.ts --- packages/ai/lib/methods/chat-session.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/ai/lib/methods/chat-session.ts b/packages/ai/lib/methods/chat-session.ts index e3e9cf905f..6bbb6f526c 100644 --- a/packages/ai/lib/methods/chat-session.ts +++ b/packages/ai/lib/methods/chat-session.ts @@ -73,7 +73,7 @@ export class ChatSession { /** * Sends a chat message and receives a non-streaming - * {@link GenerateContentResult} + * {@link GenerateContentResult} */ async sendMessage(request: string | Array): Promise { await this._sendPromise; @@ -117,7 +117,7 @@ export class ChatSession { /** * Sends a chat message and receives the response as a - * {@link GenerateContentStreamResult} containing an iterable stream + * {@link GenerateContentStreamResult} containing an iterable stream * and a response promise. */ async sendMessageStream( From d65fe3d05e71b57a709c040235eccb24149d85ed Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 11:49:20 +0100 Subject: [PATCH 13/85] count-tokens.ts --- packages/ai/lib/methods/count-tokens.ts | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/packages/ai/lib/methods/count-tokens.ts b/packages/ai/lib/methods/count-tokens.ts index 10d41cffa8..e5711642fc 100644 --- a/packages/ai/lib/methods/count-tokens.ts +++ b/packages/ai/lib/methods/count-tokens.ts @@ -18,6 +18,8 @@ import { CountTokensRequest, CountTokensResponse, RequestOptions } from '../types'; import { Task, makeRequest } from '../requests/request'; import { ApiSettings } from '../types/internal'; +import { BackendType } from '../public-types'; +import * as GoogleAIMapper from '../googleai-mappers'; export async function countTokens( apiSettings: ApiSettings, @@ -25,12 +27,19 @@ export async function countTokens( params: CountTokensRequest, requestOptions?: RequestOptions, ): Promise { + let body: string = ''; + if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { + const mappedParams = GoogleAIMapper.mapCountTokensRequest(params, model); + body = JSON.stringify(mappedParams); + } else { + body = JSON.stringify(params); + } const response = await makeRequest( model, Task.COUNT_TOKENS, apiSettings, false, - JSON.stringify(params), + body, requestOptions, ); return response.json(); From 70bf812c43b4fab062fa82749ee73108991c3bad Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 11:51:10 +0100 Subject: [PATCH 14/85] generate-content.ts --- packages/ai/lib/methods/generate-content.ts | 26 ++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/packages/ai/lib/methods/generate-content.ts b/packages/ai/lib/methods/generate-content.ts index 6d1a6ecb27..72a081cc95 100644 --- a/packages/ai/lib/methods/generate-content.ts +++ b/packages/ai/lib/methods/generate-content.ts @@ -26,6 +26,8 @@ import { Task, makeRequest } from '../requests/request'; import { createEnhancedContentResponse } from '../requests/response-helpers'; import { processStream } from '../requests/stream-reader'; import { ApiSettings } from '../types/internal'; +import { BackendType } from '../public-types'; +import * as GoogleAIMapper from '../googleai-mappers'; export async function generateContentStream( apiSettings: ApiSettings, @@ -33,6 +35,9 @@ export async function generateContentStream( params: GenerateContentRequest, requestOptions?: RequestOptions, ): Promise { + if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { + params = GoogleAIMapper.mapGenerateContentRequest(params); + } const response = await makeRequest( model, Task.STREAM_GENERATE_CONTENT, @@ -41,7 +46,7 @@ export async function generateContentStream( JSON.stringify(params), requestOptions, ); - return processStream(response); + return processStream(response, apiSettings); } export async function generateContent( @@ -50,6 +55,9 @@ export async function generateContent( params: GenerateContentRequest, requestOptions?: RequestOptions, ): Promise { + if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { + params = GoogleAIMapper.mapGenerateContentRequest(params); + } const response = await makeRequest( model, Task.GENERATE_CONTENT, @@ -58,9 +66,21 @@ export async function generateContent( JSON.stringify(params), requestOptions, ); - const responseJson: GenerateContentResponse = await response.json(); - const enhancedResponse = createEnhancedContentResponse(responseJson); + const generateContentResponse = await processGenerateContentResponse(response, apiSettings); + const enhancedResponse = createEnhancedContentResponse(generateContentResponse); return { response: enhancedResponse, }; } + +async function processGenerateContentResponse( + response: Response, + apiSettings: ApiSettings, +): Promise { + const responseJson = await response.json(); + if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { + return GoogleAIMapper.mapGenerateContentResponse(responseJson); + } else { + return responseJson; + } +} From fc7c77e4cf2a4b0d9cb8ab7e46d8743f26cf862d Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 11:53:23 +0100 Subject: [PATCH 15/85] models index --- packages/ai/lib/models/index.ts | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 packages/ai/lib/models/index.ts diff --git a/packages/ai/lib/models/index.ts b/packages/ai/lib/models/index.ts new file mode 100644 index 0000000000..fcfba15507 --- /dev/null +++ b/packages/ai/lib/models/index.ts @@ -0,0 +1,2 @@ +export * from './ai-model'; +export * from './generative-model'; From 64ae76336b3285c2386827c1982ea6fe25586374 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 12:07:57 +0100 Subject: [PATCH 16/85] enums.ts --- packages/ai/lib/types/enums.ts | 203 ++++++++++++++++++++++++++------- 1 file changed, 164 insertions(+), 39 deletions(-) diff --git a/packages/ai/lib/types/enums.ts b/packages/ai/lib/types/enums.ts index 886faba7d3..035d26703e 100644 --- a/packages/ai/lib/types/enums.ts +++ b/packages/ai/lib/types/enums.ts @@ -43,23 +43,42 @@ export enum HarmCategory { * @public */ export enum HarmBlockThreshold { - // Content with NEGLIGIBLE will be allowed. + /** + * Content with `NEGLIGIBLE` will be allowed. + */ BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE', - // Content with NEGLIGIBLE and LOW will be allowed. + /** + * Content with `NEGLIGIBLE` and `LOW` will be allowed. + */ BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE', - // Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed. + /** + * Content with `NEGLIGIBLE`, `LOW`, and `MEDIUM` will be allowed. + */ BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH', - // All content will be allowed. + /** + * All content will be allowed. + */ BLOCK_NONE = 'BLOCK_NONE', + /** + * All content will be allowed. This is the same as `BLOCK_NONE`, but the metadata corresponding + * to the {@link HarmCategory} will not be present in the response. + */ + OFF = 'OFF', } /** + * This property is not supported in the Gemini Developer API ({@link GoogleAIBackend}). + * * @public */ export enum HarmBlockMethod { - // The harm block method uses both probability and severity scores. + /** + * The harm block method uses both probability and severity scores. + */ SEVERITY = 'SEVERITY', - // The harm block method uses the probability score. + /** + * The harm block method uses the probability score. + */ PROBABILITY = 'PROBABILITY', } @@ -68,13 +87,21 @@ export enum HarmBlockMethod { * @public */ export enum HarmProbability { - // Content has a negligible chance of being unsafe. + /** + * Content has a negligible chance of being unsafe. + */ NEGLIGIBLE = 'NEGLIGIBLE', - // Content has a low chance of being unsafe. + /** + * Content has a low chance of being unsafe. + */ LOW = 'LOW', - // Content has a medium chance of being unsafe. + /** + * Content has a medium chance of being unsafe. + */ MEDIUM = 'MEDIUM', - // Content has a high chance of being unsafe. + /** + * Content has a high chance of being unsafe. + */ HIGH = 'HIGH', } @@ -83,13 +110,21 @@ export enum HarmProbability { * @public */ export enum HarmSeverity { - // Negligible level of harm severity. + /** + * Negligible level of harm severity. + */ HARM_SEVERITY_NEGLIGIBLE = 'HARM_SEVERITY_NEGLIGIBLE', - // Low level of harm severity. + /** + * Low level of harm severity. + */ HARM_SEVERITY_LOW = 'HARM_SEVERITY_LOW', - // Medium level of harm severity. + /** + * Medium level of harm severity. + */ HARM_SEVERITY_MEDIUM = 'HARM_SEVERITY_MEDIUM', - // High level of harm severity. + /** + * High level of harm severity. + */ HARM_SEVERITY_HIGH = 'HARM_SEVERITY_HIGH', /** * Harm severity is not supported. @@ -105,14 +140,22 @@ export enum HarmSeverity { * @public */ export enum BlockReason { - // The prompt was blocked because it contained terms from the terminology blocklist. - BLOCKLIST = 'BLOCKLIST', - // The prompt was blocked due to prohibited content. - PROHIBITED_CONTENT = 'PROHIBITED_CONTENT', - // Content was blocked by safety settings. + /** + * Content was blocked by safety settings. + */ SAFETY = 'SAFETY', - // Content was blocked, but the reason is uncategorized. + /** + * Content was blocked, but the reason is uncategorized. + */ OTHER = 'OTHER', + /** + * Content was blocked because it contained terms from the terminology blocklist. + */ + BLOCKLIST = 'BLOCKLIST', + /** + * Content was blocked due to prohibited content. + */ + PROHIBITED_CONTENT = 'PROHIBITED_CONTENT', } /** @@ -120,37 +163,119 @@ export enum BlockReason { * @public */ export enum FinishReason { - // Token generation was stopped because the response contained forbidden terms. - BLOCKLIST = 'BLOCKLIST', - // Token generation was stopped because the response contained potentially prohibited content. - PROHIBITED_CONTENT = 'PROHIBITED_CONTENT', - // Token generation was stopped because of Sensitive Personally Identifiable Information (SPII). - SPII = 'SPII', - // Natural stop point of the model or provided stop sequence. + /** + * Natural stop point of the model or provided stop sequence. + */ STOP = 'STOP', - // The maximum number of tokens as specified in the request was reached. + /** + * The maximum number of tokens as specified in the request was reached. + */ MAX_TOKENS = 'MAX_TOKENS', - // The candidate content was flagged for safety reasons. + /** + * The candidate content was flagged for safety reasons. + */ SAFETY = 'SAFETY', - // The candidate content was flagged for recitation reasons. + /** + * The candidate content was flagged for recitation reasons. + */ RECITATION = 'RECITATION', - // Unknown reason. + /** + * Unknown reason. + */ OTHER = 'OTHER', + /** + * The candidate content contained forbidden terms. + */ + BLOCKLIST = 'BLOCKLIST', + /** + * The candidate content potentially contained prohibited content. + */ + PROHIBITED_CONTENT = 'PROHIBITED_CONTENT', + /** + * The candidate content potentially contained Sensitive Personally Identifiable Information (SPII). + */ + SPII = 'SPII', + /** + * The function call generated by the model was invalid. + */ + MALFORMED_FUNCTION_CALL = 'MALFORMED_FUNCTION_CALL', } /** * @public */ export enum FunctionCallingMode { - // Default model behavior, model decides to predict either a function call - // or a natural language response. + /** + * Default model behavior; model decides to predict either a function call + * or a natural language response. + */ AUTO = 'AUTO', - // Model is constrained to always predicting a function call only. - // If "allowed_function_names" is set, the predicted function call will be - // limited to any one of "allowed_function_names", else the predicted - // function call will be any one of the provided "function_declarations". + /** + * Model is constrained to always predicting a function call only. + * If `allowed_function_names` is set, the predicted function call will be + * limited to any one of `allowed_function_names`, else the predicted + * function call will be any one of the provided `function_declarations`. + */ ANY = 'ANY', - // Model will not predict any function call. Model behavior is same as when - // not passing any function declarations. + /** + * Model will not predict any function call. Model behavior is same as when + * not passing any function declarations. + */ NONE = 'NONE', } + +/** + * Content part modality. + * @public + */ +export enum Modality { + /** + * Unspecified modality. + */ + MODALITY_UNSPECIFIED = 'MODALITY_UNSPECIFIED', + /** + * Plain text. + */ + TEXT = 'TEXT', + /** + * Image. + */ + IMAGE = 'IMAGE', + /** + * Video. + */ + VIDEO = 'VIDEO', + /** + * Audio. + */ + AUDIO = 'AUDIO', + /** + * Document (for example, PDF). + */ + DOCUMENT = 'DOCUMENT', +} + +/** + * Generation modalities to be returned in generation responses. + * + * @beta + */ +export const ResponseModality = { + /** + * Text. + * @beta + */ + TEXT: 'TEXT', + /** + * Image. + * @beta + */ + IMAGE: 'IMAGE', +} as const; + +/** + * Generation modalities to be returned in generation responses. + * + * @beta + */ +export type ResponseModality = (typeof ResponseModality)[keyof typeof ResponseModality]; From 3047eb19ebf541ddb530cab44af2974a21d50175 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 12:08:41 +0100 Subject: [PATCH 17/85] error.ts --- packages/ai/lib/types/error.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/ai/lib/types/error.ts b/packages/ai/lib/types/error.ts index 38383fddc9..4fcc1ac483 100644 --- a/packages/ai/lib/types/error.ts +++ b/packages/ai/lib/types/error.ts @@ -50,7 +50,7 @@ export interface CustomErrorData { /** HTTP status text of the error response. */ statusText?: string; - /** Response from a {@link GenerateContentRequest} */ + /** Response from a {@link GenerateContentRequest} */ response?: GenerateContentResponse; /** Optional additional details about the error. */ @@ -58,7 +58,7 @@ export interface CustomErrorData { } /** - * Standardized error codes that {@link VertexAIError} can have. + * Standardized error codes that {@link AIError} can have. * * @public */ @@ -99,6 +99,6 @@ export const enum AIErrorCode { /** An error occurred while parsing. */ PARSE_FAILED = 'parse-failed', - /** An error occured due an attempt to use an unsupported feature. */ + /** An error occurred due an attempt to use an unsupported feature. */ UNSUPPORTED = 'unsupported', } From 2febfdc9458404365b65f88519c9d4bc03011aca Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 12:24:08 +0100 Subject: [PATCH 18/85] chore: add googleai to exports --- packages/ai/lib/types/index.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/ai/lib/types/index.ts b/packages/ai/lib/types/index.ts index 85133aa07c..6d77a4a935 100644 --- a/packages/ai/lib/types/index.ts +++ b/packages/ai/lib/types/index.ts @@ -21,3 +21,4 @@ export * from './requests'; export * from './responses'; export * from './error'; export * from './schema'; +export * from './googleai'; From 68421efa1acf7604d5879f73b7cf838ea4bc8157 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 12:25:36 +0100 Subject: [PATCH 19/85] request.ts types --- packages/ai/lib/types/requests.ts | 49 ++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/packages/ai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts index 708a55a11c..53b35b1196 100644 --- a/packages/ai/lib/types/requests.ts +++ b/packages/ai/lib/types/requests.ts @@ -17,7 +17,13 @@ import { TypedSchema } from '../requests/schema-builder'; import { Content, Part } from './content'; -import { FunctionCallingMode, HarmBlockMethod, HarmBlockThreshold, HarmCategory } from './enums'; +import { + FunctionCallingMode, + HarmBlockMethod, + HarmBlockThreshold, + HarmCategory, + ResponseModality, +} from './enums'; import { ObjectSchemaInterface, SchemaRequest } from './schema'; /** @@ -30,7 +36,7 @@ export interface BaseParams { } /** - * Params passed to {@link getGenerativeModel}. + * Params passed to {@link getGenerativeModel}. * @public */ export interface ModelParams extends BaseParams { @@ -58,6 +64,13 @@ export interface GenerateContentRequest extends BaseParams { export interface SafetySetting { category: HarmCategory; threshold: HarmBlockThreshold; + /** + * The harm block method. + * + * This property is only supported in the Vertex AI Gemini API ({@link VertexAIBackend}). + * When using the Gemini Developer API ({@link GoogleAIBackend}), an {@link AIError} will be + * thrown if this property is defined. + */ method?: HarmBlockMethod; } @@ -83,13 +96,23 @@ export interface GenerationConfig { responseMimeType?: string; /** * Output response schema of the generated candidate text. This - * value can be a class generated with a {@link Schema} static method + * value can be a class generated with a {@link Schema} static method * like `Schema.string()` or `Schema.object()` or it can be a plain - * JS object matching the {@link SchemaRequest} interface. + * JS object matching the {@link SchemaRequest} interface. *
Note: This only applies when the specified `responseMIMEType` supports a schema; currently * this is limited to `application/json` and `text/x.enum`. */ responseSchema?: TypedSchema | SchemaRequest; + /** + * Generation modalities to be returned in generation responses. + * + * @remarks + * - Multimodal response generation is only supported by some Gemini models and versions; see {@link https://firebase.google.com/docs/vertex-ai/models | model versions}. + * - Only image generation (`ResponseModality.IMAGE`) is supported. + * + * @beta + */ + responseModalities?: ResponseModality[]; } /** @@ -109,10 +132,22 @@ export interface StartChatParams extends BaseParams { */ export interface CountTokensRequest { contents: Content[]; + /** + * Instructions that direct the model to behave a certain way. + */ + systemInstruction?: string | Part | Content; + /** + * {@link Tool} configuration. + */ + tools?: Tool[]; + /** + * Configuration options that control how the model generates a response. + */ + generationConfig?: GenerationConfig; } /** - * Params passed to {@link getGenerativeModel}. + * Params passed to {@link getGenerativeModel}. * @public */ export interface RequestOptions { @@ -172,8 +207,8 @@ export declare interface FunctionDeclarationsTool { * Optional. One or more function declarations * to be passed to the model along with the current user query. Model may * decide to call a subset of these functions by populating - * {@link FunctionCall} in the response. User should - * provide a {@link FunctionResponse} for each + * {@link FunctionCall} in the response. User should + * provide a {@link FunctionResponse} for each * function call in the next turn. Based on the function responses, the model will * generate the final response back to the user. Maximum 64 function * declarations can be provided. From e8af3fc8915ddae92670e193f9e059f5a8d2720a Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 12:28:21 +0100 Subject: [PATCH 20/85] types/responses.ts --- packages/ai/lib/types/responses.ts | 75 ++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/packages/ai/lib/types/responses.ts b/packages/ai/lib/types/responses.ts index d162aa8015..450a388992 100644 --- a/packages/ai/lib/types/responses.ts +++ b/packages/ai/lib/types/responses.ts @@ -16,7 +16,14 @@ */ import { Content, FunctionCall, InlineDataPart } from './content'; -import { BlockReason, FinishReason, HarmCategory, HarmProbability, HarmSeverity } from './enums'; +import { + BlockReason, + FinishReason, + HarmCategory, + HarmProbability, + HarmSeverity, + Modality, +} from './enums'; /** * Result object returned from {@link GenerativeModel.generateContent} call. @@ -77,7 +84,7 @@ export interface GenerateContentResponse { } /** - * Usage metadata about a {@link GenerateContentResponse}. + * Usage metadata about a {@link GenerateContentResponse}. * * @public */ @@ -85,6 +92,20 @@ export interface UsageMetadata { promptTokenCount: number; candidatesTokenCount: number; totalTokenCount: number; + promptTokensDetails?: ModalityTokenCount[]; + candidatesTokensDetails?: ModalityTokenCount[]; +} + +/** + * Represents token counting info for a single modality. + * + * @public + */ +export interface ModalityTokenCount { + /** The modality associated with this token count. */ + modality: Modality; + /** The number of tokens counted. */ + tokenCount: number; } /** @@ -95,11 +116,16 @@ export interface UsageMetadata { export interface PromptFeedback { blockReason?: BlockReason; safetyRatings: SafetyRating[]; + /** + * A human-readable description of the `blockReason`. + * + * This property is only supported in the Vertex AI Gemini API ({@link VertexAIBackend}). + */ blockReasonMessage?: string; } /** - * A candidate returned as part of a {@link GenerateContentResponse}. + * A candidate returned as part of a {@link GenerateContentResponse}. * @public */ export interface GenerateContentCandidate { @@ -113,7 +139,7 @@ export interface GenerateContentCandidate { } /** - * Citation metadata that may be found on a {@link GenerateContentCandidate}. + * Citation metadata that may be found on a {@link GenerateContentCandidate}. * @public */ export interface CitationMetadata { @@ -129,7 +155,17 @@ export interface Citation { endIndex?: number; uri?: string; license?: string; + /** + * The title of the cited source, if available. + * + * This property is only supported in the Vertex AI Gemini API ({@link VertexAIBackend}). + */ title?: string; + /** + * The publication date of the cited source, if available. + * + * This property is only supported in the Vertex AI Gemini API ({@link VertexAIBackend}). + */ publicationDate?: Date; } @@ -140,10 +176,14 @@ export interface Citation { export interface GroundingMetadata { webSearchQueries?: string[]; retrievalQueries?: string[]; + /** + * @deprecated + */ groundingAttributions: GroundingAttribution[]; } /** + * @deprecated * @public */ export interface GroundingAttribution { @@ -189,14 +229,32 @@ export interface Date { } /** - * A safety rating associated with a {@link GenerateContentCandidate} + * A safety rating associated with a {@link GenerateContentCandidate} * @public */ export interface SafetyRating { category: HarmCategory; probability: HarmProbability; + /** + * The harm severity level. + * + * This property is only supported when using the Vertex AI Gemini API ({@link VertexAIBackend}). + * When using the Gemini Developer API ({@link GoogleAIBackend}), this property is not supported and will default to `HarmSeverity.UNSUPPORTED`. + */ severity: HarmSeverity; + /** + * The probability score of the harm category. + * + * This property is only supported when using the Vertex AI Gemini API ({@link VertexAIBackend}). + * When using the Gemini Developer API ({@link GoogleAIBackend}), this property is not supported and will default to 0. + */ probabilityScore: number; + /** + * The severity score of the harm category. + * + * This property is only supported when using the Vertex AI Gemini API ({@link VertexAIBackend}). + * When using the Gemini Developer API ({@link GoogleAIBackend}), this property is not supported and will default to 0. + */ severityScore: number; blocked: boolean; } @@ -213,6 +271,13 @@ export interface CountTokensResponse { /** * The total number of billable characters counted across all instances * from the request. + * + * This property is only supported when using the Vertex AI Gemini API ({@link VertexAIBackend}). + * When using the Gemini Developer API ({@link GoogleAIBackend}), this property is not supported and will default to 0. */ totalBillableCharacters?: number; + /** + * The breakdown, by modality, of how many tokens are consumed by the prompt. + */ + promptTokensDetails?: ModalityTokenCount[]; } From 1c24dd9277430256e16162939fc3511cd286f6db Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 12:29:41 +0100 Subject: [PATCH 21/85] types/schema.ts --- packages/ai/lib/types/schema.ts | 34 ++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/packages/ai/lib/types/schema.ts b/packages/ai/lib/types/schema.ts index c1376b9aa1..60a23a2d56 100644 --- a/packages/ai/lib/types/schema.ts +++ b/packages/ai/lib/types/schema.ts @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + /** * Contains the list of OpenAPI data types * as defined by the @@ -36,40 +37,59 @@ export enum SchemaType { } /** - * Basic {@link Schema} properties shared across several Schema-related + * Basic {@link Schema} properties shared across several Schema-related * types. * @public */ export interface SchemaShared { - /** Optional. The format of the property. */ + /** Optional. The format of the property. + * When using the Gemini Developer API ({@link GoogleAIBackend}), this must be either `'enum'` or + * `'date-time'`, otherwise requests will fail. + */ format?: string; /** Optional. The description of the property. */ description?: string; + /** + * The title of the property. This helps document the schema's purpose but does not typically + * constrain the generated value. It can subtly guide the model by clarifying the intent of a + * field. + */ + title?: string; /** Optional. The items of the property. */ items?: T; + /** The minimum number of items (elements) in a schema of type {@link SchemaType.ARRAY}. */ + minItems?: number; + /** The maximum number of items (elements) in a schema of type {@link SchemaType.ARRAY}. */ + maxItems?: number; /** Optional. Map of `Schema` objects. */ properties?: { [k: string]: T; }; + /** A hint suggesting the order in which the keys should appear in the generated JSON string. */ + propertyOrdering?: string[]; /** Optional. The enum of the property. */ enum?: string[]; /** Optional. The example of the property. */ example?: unknown; /** Optional. Whether the property is nullable. */ nullable?: boolean; + /** The minimum value of a numeric type. */ + minimum?: number; + /** The maximum value of a numeric type. */ + maximum?: number; [key: string]: unknown; } /** - * Params passed to {@link Schema} static methods to create specific - * {@link Schema} classes. + * Params passed to {@link Schema} static methods to create specific + * {@link Schema} classes. * @public */ // eslint-disable-next-line @typescript-eslint/no-empty-object-type export interface SchemaParams extends SchemaShared {} /** - * Final format for {@link Schema} params passed to backend requests. + * Final format for {@link Schema} params passed to backend requests. * @public */ export interface SchemaRequest extends SchemaShared { @@ -83,7 +103,7 @@ export interface SchemaRequest extends SchemaShared { } /** - * Interface for {@link Schema} class. + * Interface for {@link Schema} class. * @public */ export interface SchemaInterface extends SchemaShared { @@ -95,7 +115,7 @@ export interface SchemaInterface extends SchemaShared { } /** - * Interface for {@link ObjectSchema} class. + * Interface for {@link ObjectSchema} class. * @public */ export interface ObjectSchemaInterface extends SchemaInterface { From e28877ecb9c37f00415a066e8d6106d155d8b699 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 14:45:03 +0100 Subject: [PATCH 22/85] test: backend --- packages/ai/__tests__/backend.test.ts | 38 +++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 packages/ai/__tests__/backend.test.ts diff --git a/packages/ai/__tests__/backend.test.ts b/packages/ai/__tests__/backend.test.ts new file mode 100644 index 0000000000..6e5637a83b --- /dev/null +++ b/packages/ai/__tests__/backend.test.ts @@ -0,0 +1,38 @@ +import { describe, it, expect } from '@jest/globals'; +import { GoogleAIBackend, VertexAIBackend } from '../lib/backend'; +import { BackendType } from 'lib/public-types'; +import { DEFAULT_LOCATION } from 'lib/constants'; +describe('Backend', () => { + describe('GoogleAIBackend', () => { + it('sets backendType to GOOGLE_AI', () => { + const backend = new GoogleAIBackend(); + expect(backend.backendType).toBe(BackendType.GOOGLE_AI); // Use toBe instead of to.equal + }); + }); + + describe('VertexAIBackend', () => { + it('set backendType to VERTEX_AI', () => { + const backend = new VertexAIBackend(); + expect(backend.backendType).toBe(BackendType.VERTEX_AI); // Use toBe instead of to.equal + expect(backend.location).toBe(DEFAULT_LOCATION); // Use toBe instead of to.equal + }); + + it('sets custom location', () => { + const backend = new VertexAIBackend('test-location'); + expect(backend.backendType).toBe(BackendType.VERTEX_AI); // Use toBe instead of to.equal + expect(backend.location).toBe('test-location'); // Use toBe instead of to.equal + }); + + it('uses default location if location is empty string', () => { + const backend = new VertexAIBackend(''); + expect(backend.backendType).toBe(BackendType.VERTEX_AI); // Use toBe instead of to.equal + expect(backend.location).toBe(DEFAULT_LOCATION); // Use toBe instead of to.equal + }); + + it('uses default location if location is null', () => { + const backend = new VertexAIBackend(null as any); + expect(backend.backendType).toBe(BackendType.VERTEX_AI); // Use toBe instead of to.equal + expect(backend.location).toBe(DEFAULT_LOCATION); // Use toBe instead of to.equal + }); + }); +}); From ba719a5ad877e9608239a668f218792b3d537377 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Tue, 3 Jun 2025 14:56:38 +0100 Subject: [PATCH 23/85] add license header to test file --- packages/ai/__tests__/backend.test.ts | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/packages/ai/__tests__/backend.test.ts b/packages/ai/__tests__/backend.test.ts index 6e5637a83b..bdab4be957 100644 --- a/packages/ai/__tests__/backend.test.ts +++ b/packages/ai/__tests__/backend.test.ts @@ -1,3 +1,19 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ import { describe, it, expect } from '@jest/globals'; import { GoogleAIBackend, VertexAIBackend } from '../lib/backend'; import { BackendType } from 'lib/public-types'; From cfc0bc8986cb0a3c49aa2c272a7d6a7e66cbe523 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 09:30:21 +0100 Subject: [PATCH 24/85] test: googleai-mapper --- .../ai/__tests__/googleai-mappers.test.ts | 358 ++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 packages/ai/__tests__/googleai-mappers.test.ts diff --git a/packages/ai/__tests__/googleai-mappers.test.ts b/packages/ai/__tests__/googleai-mappers.test.ts new file mode 100644 index 0000000000..42e8773c82 --- /dev/null +++ b/packages/ai/__tests__/googleai-mappers.test.ts @@ -0,0 +1,358 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, it, expect, beforeEach, afterEach, jest } from '@jest/globals'; +import { AIError } from 'lib'; +import { + mapCountTokensRequest, + mapGenerateContentCandidates, + mapGenerateContentRequest, + mapGenerateContentResponse, + mapPromptFeedback, +} from 'lib/googleai-mappers'; +import { + AIErrorCode, + BlockReason, + CountTokensRequest, + Content, + FinishReason, + GenerateContentRequest, + GoogleAICountTokensRequest, + GoogleAIGenerateContentCandidate, + GoogleAIGenerateContentResponse, + HarmBlockMethod, + HarmBlockThreshold, + HarmCategory, + HarmProbability, + HarmSeverity, + PromptFeedback, + SafetyRating, +} from 'lib/public-types'; +import { getMockResponse } from './test-utils/mock-response'; +import { SpiedFunction } from 'jest-mock'; + +const fakeModel = 'models/gemini-pro'; + +const fakeContents: Content[] = [{ role: 'user', parts: [{ text: 'hello' }] }]; + +describe('Google AI Mappers', () => { + let loggerWarnSpy: SpiedFunction<{ + (message?: any, ...optionalParams: any[]): void; + (message?: any, ...optionalParams: any[]): void; + }>; + + beforeEach(() => { + loggerWarnSpy = jest.spyOn(console, 'warn').mockImplementation(() => {}); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + describe('mapGenerateContentRequest', () => { + it('should throw if safetySettings contain method', () => { + const request: GenerateContentRequest = { + contents: fakeContents, + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + method: HarmBlockMethod.SEVERITY, + }, + ], + }; + const error = new AIError( + AIErrorCode.UNSUPPORTED, + 'SafetySettings.method is not supported in requests to the Gemini Developer API', + ); + expect(() => mapGenerateContentRequest(request)).toThrowError(error); + }); + + it('should warn and round topK if present', () => { + const request: GenerateContentRequest = { + contents: fakeContents, + generationConfig: { + topK: 15.7, + }, + }; + const mappedRequest = mapGenerateContentRequest(request); + expect(loggerWarnSpy).toHaveBeenCalledWith( + 'topK in GenerationConfig has been rounded to the nearest integer to match the format for requests to the Gemini Developer API.', + ); + expect(mappedRequest.generationConfig?.topK).toBe(16); + }); + + it('should not modify topK if it is already an integer', () => { + const request: GenerateContentRequest = { + contents: fakeContents, + generationConfig: { + topK: 16, + }, + }; + const mappedRequest = mapGenerateContentRequest(request); + expect(loggerWarnSpy).not.toHaveBeenCalled(); + expect(mappedRequest.generationConfig?.topK).toBe(16); + }); + + it('should return the request mostly unchanged if valid', () => { + const request: GenerateContentRequest = { + contents: fakeContents, + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + }, + ], + generationConfig: { + temperature: 0.5, + }, + }; + const mappedRequest = mapGenerateContentRequest({ ...request }); + expect(mappedRequest).toEqual(request); + expect(loggerWarnSpy).not.toHaveBeenCalled(); + }); + }); + + describe('mapGenerateContentResponse', () => { + it('should map a full Google AI response', async () => { + const googleAIMockResponse: GoogleAIGenerateContentResponse = await ( + getMockResponse('unary-success-citations.json') as Response + ).json(); + const mappedResponse = mapGenerateContentResponse(googleAIMockResponse); + + expect(mappedResponse.candidates).toBeDefined(); + expect(mappedResponse.candidates?.[0]?.content.parts[0]?.text).toContain('quantum mechanics'); + + // Mapped citations + expect(mappedResponse.candidates?.[0]?.citationMetadata?.citations[0]?.startIndex).toBe( + googleAIMockResponse.candidates?.[0]?.citationMetadata?.citationSources[0]?.startIndex, + ); + expect(mappedResponse.candidates?.[0]?.citationMetadata?.citations[0]?.endIndex).toBe( + googleAIMockResponse.candidates?.[0]?.citationMetadata?.citationSources[0]?.endIndex, + ); + + // Mapped safety ratings + expect(mappedResponse.candidates?.[0]?.safetyRatings?.[0]?.probabilityScore).toBe(0); + expect(mappedResponse.candidates?.[0]?.safetyRatings?.[0]?.severityScore).toBe(0); + expect(mappedResponse.candidates?.[0]?.safetyRatings?.[0]?.severity).toBe( + HarmSeverity.HARM_SEVERITY_UNSUPPORTED, + ); + + expect(mappedResponse.candidates?.[0]?.finishReason).toBe(FinishReason.STOP); + + // Check usage metadata passthrough + expect(mappedResponse.usageMetadata).toEqual(googleAIMockResponse.usageMetadata); + }); + + it('should handle missing candidates and promptFeedback', () => { + const googleAIResponse: GoogleAIGenerateContentResponse = { + // No candidates + // No promptFeedback + usageMetadata: { + promptTokenCount: 5, + candidatesTokenCount: 0, + totalTokenCount: 5, + }, + }; + const mappedResponse = mapGenerateContentResponse(googleAIResponse); + expect(mappedResponse.candidates).toBeUndefined(); + expect(mappedResponse.promptFeedback).toBeUndefined(); // Mapped to undefined + expect(mappedResponse.usageMetadata).toEqual(googleAIResponse.usageMetadata); + }); + + it('should handle empty candidates array', () => { + const googleAIResponse: GoogleAIGenerateContentResponse = { + candidates: [], + usageMetadata: { + promptTokenCount: 5, + candidatesTokenCount: 0, + totalTokenCount: 5, + }, + }; + const mappedResponse = mapGenerateContentResponse(googleAIResponse); + expect(mappedResponse.candidates).toEqual([]); + expect(mappedResponse.promptFeedback).toBeUndefined(); + expect(mappedResponse.usageMetadata).toEqual(googleAIResponse.usageMetadata); + }); + }); + + describe('mapCountTokensRequest', () => { + it('should map a Vertex AI CountTokensRequest to Google AI format', () => { + const vertexRequest: CountTokensRequest = { + contents: fakeContents, + systemInstruction: { role: 'system', parts: [{ text: 'Be nice' }] }, + tools: [{ functionDeclarations: [{ name: 'foo', description: 'bar' }] }], + generationConfig: { temperature: 0.8 }, + }; + + const expectedGoogleAIRequest: GoogleAICountTokensRequest = { + generateContentRequest: { + model: fakeModel, + contents: vertexRequest.contents, + systemInstruction: vertexRequest.systemInstruction, + tools: vertexRequest.tools, + generationConfig: vertexRequest.generationConfig, + }, + }; + + const mappedRequest = mapCountTokensRequest(vertexRequest, fakeModel); + expect(mappedRequest).toEqual(expectedGoogleAIRequest); + }); + + it('should map a minimal Vertex AI CountTokensRequest', () => { + const vertexRequest: CountTokensRequest = { + contents: fakeContents, + systemInstruction: { role: 'system', parts: [{ text: 'Be nice' }] }, + generationConfig: { temperature: 0.8 }, + }; + + const expectedGoogleAIRequest: GoogleAICountTokensRequest = { + generateContentRequest: { + model: fakeModel, + contents: vertexRequest.contents, + systemInstruction: { role: 'system', parts: [{ text: 'Be nice' }] }, + generationConfig: { temperature: 0.8 }, + }, + }; + + const mappedRequest = mapCountTokensRequest(vertexRequest, fakeModel); + expect(mappedRequest).toEqual(expectedGoogleAIRequest); + }); + }); + + describe('mapGenerateContentCandidates', () => { + it('should map citationSources to citationMetadata.citations', () => { + const candidates: GoogleAIGenerateContentCandidate[] = [ + { + index: 0, + content: { role: 'model', parts: [{ text: 'Cited text' }] }, + citationMetadata: { + citationSources: [ + { startIndex: 0, endIndex: 5, uri: 'uri1', license: 'MIT' }, + { startIndex: 6, endIndex: 10, uri: 'uri2' }, + ], + }, + }, + ]; + const mapped = mapGenerateContentCandidates(candidates); + expect(mapped[0]?.citationMetadata).toBeDefined(); + expect(mapped[0]?.citationMetadata?.citations).toEqual( + candidates[0]?.citationMetadata?.citationSources, + ); + expect(mapped[0]?.citationMetadata?.citations[0]?.title).toBeUndefined(); // Not in Google AI + expect(mapped[0]?.citationMetadata?.citations[0]?.publicationDate).toBeUndefined(); // Not in Google AI + }); + + it('should add default safety rating properties', () => { + const candidates: GoogleAIGenerateContentCandidate[] = [ + { + index: 0, + content: { role: 'model', parts: [{ text: 'Maybe unsafe' }] }, + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + probability: HarmProbability.MEDIUM, + blocked: false, + // Missing severity, probabilityScore, severityScore + } as any, + ], + }, + ]; + const mapped = mapGenerateContentCandidates(candidates); + expect(mapped[0]?.safetyRatings).toBeDefined(); + const safetyRating = mapped[0]?.safetyRatings?.[0] as SafetyRating; // Type assertion + expect(safetyRating.severity).toBe(HarmSeverity.HARM_SEVERITY_UNSUPPORTED); + expect(safetyRating.probabilityScore).toBe(0); + expect(safetyRating.severityScore).toBe(0); + // Existing properties should be preserved + expect(safetyRating.category).toBe(HarmCategory.HARM_CATEGORY_HARASSMENT); + expect(safetyRating.probability).toBe(HarmProbability.MEDIUM); + expect(safetyRating.blocked).toBe(false); + }); + + it('should throw if videoMetadata is present in parts', () => { + const candidates: GoogleAIGenerateContentCandidate[] = [ + { + index: 0, + content: { + role: 'model', + parts: [ + { + inlineData: { mimeType: 'video/mp4', data: 'base64==' }, + videoMetadata: { startOffset: '0s', endOffset: '5s' }, // Unsupported + }, + ], + }, + }, + ]; + expect(() => mapGenerateContentCandidates(candidates)).toThrowError( + new AIError(AIErrorCode.UNSUPPORTED, 'Part.videoMetadata is not supported'), + ); + }); + + it('should handle candidates without citation or safety ratings', () => { + const candidates: GoogleAIGenerateContentCandidate[] = [ + { + index: 0, + content: { role: 'model', parts: [{ text: 'Simple text' }] }, + finishReason: FinishReason.STOP, + }, + ]; + const mapped = mapGenerateContentCandidates(candidates); + expect(mapped[0]?.citationMetadata).toBeUndefined(); + expect(mapped[0]?.safetyRatings).toBeUndefined(); + expect(mapped[0]?.content?.parts[0]?.text).toBe('Simple text'); + expect(loggerWarnSpy).not.toHaveBeenCalled(); + }); + + it('should handle empty candidate array', () => { + const candidates: GoogleAIGenerateContentCandidate[] = []; + const mapped = mapGenerateContentCandidates(candidates); + expect(mapped).toEqual([]); + expect(loggerWarnSpy).not.toHaveBeenCalled(); + }); + }); + + describe('mapPromptFeedback', () => { + it('should add default safety rating properties', () => { + const feedback: PromptFeedback = { + blockReason: BlockReason.OTHER, + safetyRatings: [ + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + probability: HarmProbability.HIGH, + blocked: true, + // Missing severity, probabilityScore, severityScore + } as any, + ], + // Missing blockReasonMessage + }; + const mapped = mapPromptFeedback(feedback); + expect(mapped.safetyRatings).toBeDefined(); + const safetyRating = mapped.safetyRatings[0] as SafetyRating; // Type assertion + expect(safetyRating.severity).toBe(HarmSeverity.HARM_SEVERITY_UNSUPPORTED); + expect(safetyRating.probabilityScore).toBe(0); + expect(safetyRating.severityScore).toBe(0); + // Existing properties should be preserved + expect(safetyRating.category).toBe(HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT); + expect(safetyRating.probability).toBe(HarmProbability.HIGH); + expect(safetyRating.blocked).toBe(true); + // Other properties + expect(mapped.blockReason).toBe(BlockReason.OTHER); + expect(mapped.blockReasonMessage).toBeUndefined(); // Not present in input + }); + }); +}); From d275f19209e57601be918b96350e563554638bdb Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 12:24:46 +0100 Subject: [PATCH 25/85] chore: firebase_ai yarn.lock --- yarn.lock | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/yarn.lock b/yarn.lock index eb71c99b65..4f0de566e6 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5348,6 +5348,21 @@ __metadata: languageName: node linkType: hard +"@react-native-firebase/ai@workspace:packages/ai": + version: 0.0.0-use.local + resolution: "@react-native-firebase/ai@workspace:packages/ai" + dependencies: + "@types/text-encoding": "npm:^0.0.40" + react-native-builder-bob: "npm:^0.40.6" + react-native-fetch-api: "npm:^3.0.0" + text-encoding: "npm:^0.7.0" + typescript: "npm:^5.8.3" + web-streams-polyfill: "npm:^4.1.0" + peerDependencies: + "@react-native-firebase/app": 22.2.0 + languageName: unknown + linkType: soft + "@react-native-firebase/analytics@npm:23.0.0, @react-native-firebase/analytics@workspace:packages/analytics": version: 0.0.0-use.local resolution: "@react-native-firebase/analytics@workspace:packages/analytics" From 38c179947f609514a9ff7afc8163b1216468f573 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 12:25:17 +0100 Subject: [PATCH 26/85] test(ai): update unit test for firebase ai --- packages/ai/__tests__/api.test.ts | 75 +++++++------ packages/ai/__tests__/backend.test.ts | 5 +- packages/ai/__tests__/chat-session.test.ts | 3 + packages/ai/__tests__/count-tokens.test.ts | 3 + .../ai/__tests__/generate-content.test.ts | 101 +++++++++++++++++- .../ai/__tests__/generative-model.test.ts | 78 ++++++++------ .../ai/__tests__/googleai-mappers.test.ts | 6 +- packages/ai/__tests__/request.test.ts | 59 +++++----- packages/ai/__tests__/schema-builder.test.ts | 4 +- packages/ai/__tests__/service.test.ts | 12 +-- packages/ai/__tests__/stream-reader.test.ts | 26 +++-- 11 files changed, 258 insertions(+), 114 deletions(-) diff --git a/packages/ai/__tests__/api.test.ts b/packages/ai/__tests__/api.test.ts index 3199157e76..79d9c24c92 100644 --- a/packages/ai/__tests__/api.test.ts +++ b/packages/ai/__tests__/api.test.ts @@ -15,20 +15,19 @@ * limitations under the License. */ import { describe, expect, it } from '@jest/globals'; -import { getApp, type ReactNativeFirebase } from '../../app/lib'; +import { type ReactNativeFirebase } from '../../app/lib'; -import { ModelParams, VertexAIErrorCode } from '../lib/types'; -import { VertexAIError } from '../lib/errors'; -import { getGenerativeModel, getVertexAI } from '../lib/index'; +import { ModelParams, AIErrorCode } from '../lib/types'; +import { AIError } from '../lib/errors'; +import { getGenerativeModel } from '../lib/index'; -import { VertexAI } from '../lib/public-types'; +import { AI } from '../lib/public-types'; import { GenerativeModel } from '../lib/models/generative-model'; -import '../../auth/lib'; -import '../../app-check/lib'; -import { getAuth } from '../../auth/lib'; +import { AI_TYPE } from '../lib/constants'; +import { VertexAIBackend } from '../lib/backend'; -const fakeVertexAI: VertexAI = { +const fakeAI: AI = { app: { name: 'DEFAULT', options: { @@ -37,66 +36,76 @@ const fakeVertexAI: VertexAI = { projectId: 'my-project', }, } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), location: 'us-central1', }; describe('Top level API', () => { - it('should allow auth and app check instances to be passed in', () => { - const app = getApp(); - const auth = getAuth(); - const appCheck = app.appCheck(); - - getVertexAI(app, { appCheck, auth }); - }); - it('getGenerativeModel throws if no model is provided', () => { try { - getGenerativeModel(fakeVertexAI, {} as ModelParams); + getGenerativeModel(fakeAI, {} as ModelParams); } catch (e) { - expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_MODEL); - expect((e as VertexAIError).message).toContain( + expect((e as AIError).code).toContain(AIErrorCode.NO_MODEL); + expect((e as AIError).message).toContain( `VertexAI: Must provide a model name. Example: ` + - `getGenerativeModel({ model: 'my-model-name' }) (vertexAI/${VertexAIErrorCode.NO_MODEL})`, + `getGenerativeModel({ model: 'my-model-name' }) (vertexAI/${AIErrorCode.NO_MODEL})`, ); } }); it('getGenerativeModel throws if no apiKey is provided', () => { const fakeVertexNoApiKey = { - ...fakeVertexAI, - app: { options: { projectId: 'my-project' } }, - } as VertexAI; + ...fakeAI, + app: { options: { projectId: 'my-project', appId: 'my-appid' } }, + } as AI; try { getGenerativeModel(fakeVertexNoApiKey, { model: 'my-model' }); } catch (e) { - expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_API_KEY); - expect((e as VertexAIError).message).toBe( + expect((e as AIError).code).toContain(AIErrorCode.NO_API_KEY); + expect((e as AIError).message).toBe( `VertexAI: The "apiKey" field is empty in the local ` + `Firebase config. Firebase VertexAI requires this field to` + - ` contain a valid API key. (vertexAI/${VertexAIErrorCode.NO_API_KEY})`, + ` contain a valid API key. (vertexAI/${AIErrorCode.NO_API_KEY})`, ); } }); it('getGenerativeModel throws if no projectId is provided', () => { const fakeVertexNoProject = { - ...fakeVertexAI, + ...fakeAI, app: { options: { apiKey: 'my-key' } }, - } as VertexAI; + } as AI; try { getGenerativeModel(fakeVertexNoProject, { model: 'my-model' }); } catch (e) { - expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_PROJECT_ID); - expect((e as VertexAIError).message).toBe( + expect((e as AIError).code).toContain(AIErrorCode.NO_PROJECT_ID); + expect((e as AIError).message).toBe( `VertexAI: The "projectId" field is empty in the local` + ` Firebase config. Firebase VertexAI requires this field ` + - `to contain a valid project ID. (vertexAI/${VertexAIErrorCode.NO_PROJECT_ID})`, + `to contain a valid project ID. (vertexAI/${AIErrorCode.NO_PROJECT_ID})`, + ); + } + }); + + it('getGenerativeModel throws if no appId is provided', () => { + const fakeVertexNoProject = { + ...fakeAI, + app: { options: { apiKey: 'my-key', projectId: 'my-projectid' } }, + } as AI; + try { + getGenerativeModel(fakeVertexNoProject, { model: 'my-model' }); + } catch (e) { + expect((e as AIError).code).toContain(AIErrorCode.NO_APP_ID); + expect((e as AIError).message).toBe( + `AI: The "appId" field is empty in the local` + + ` Firebase config. Firebase AI requires this field ` + + `to contain a valid app ID. (${AI_TYPE}/${AIErrorCode.NO_APP_ID})`, ); } }); it('getGenerativeModel gets a GenerativeModel', () => { - const genModel = getGenerativeModel(fakeVertexAI, { model: 'my-model' }); + const genModel = getGenerativeModel(fakeAI, { model: 'my-model' }); expect(genModel).toBeInstanceOf(GenerativeModel); expect(genModel.model).toBe('publishers/google/models/my-model'); }); diff --git a/packages/ai/__tests__/backend.test.ts b/packages/ai/__tests__/backend.test.ts index bdab4be957..ad42ac3e7c 100644 --- a/packages/ai/__tests__/backend.test.ts +++ b/packages/ai/__tests__/backend.test.ts @@ -16,8 +16,9 @@ */ import { describe, it, expect } from '@jest/globals'; import { GoogleAIBackend, VertexAIBackend } from '../lib/backend'; -import { BackendType } from 'lib/public-types'; -import { DEFAULT_LOCATION } from 'lib/constants'; +import { BackendType } from '../lib/public-types'; +import { DEFAULT_LOCATION } from '../lib/constants'; + describe('Backend', () => { describe('GoogleAIBackend', () => { it('sets backendType to GOOGLE_AI', () => { diff --git a/packages/ai/__tests__/chat-session.test.ts b/packages/ai/__tests__/chat-session.test.ts index cd96aa32e6..10b025c62a 100644 --- a/packages/ai/__tests__/chat-session.test.ts +++ b/packages/ai/__tests__/chat-session.test.ts @@ -21,11 +21,14 @@ import { GenerateContentStreamResult } from '../lib/types'; import { ChatSession } from '../lib/methods/chat-session'; import { ApiSettings } from '../lib/types/internal'; import { RequestOptions } from '../lib/types/requests'; +import { VertexAIBackend } from '../lib/backend'; const fakeApiSettings: ApiSettings = { apiKey: 'key', project: 'my-project', + appId: 'my-appid', location: 'us-central1', + backend: new VertexAIBackend(), }; const requestOptions: RequestOptions = { diff --git a/packages/ai/__tests__/count-tokens.test.ts b/packages/ai/__tests__/count-tokens.test.ts index 3cd7b78970..d0c68dd61b 100644 --- a/packages/ai/__tests__/count-tokens.test.ts +++ b/packages/ai/__tests__/count-tokens.test.ts @@ -21,11 +21,14 @@ import { countTokens } from '../lib/methods/count-tokens'; import { CountTokensRequest } from '../lib/types'; import { ApiSettings } from '../lib/types/internal'; import { Task } from '../lib/requests/request'; +import { GoogleAIBackend } from '../lib/backend'; const fakeApiSettings: ApiSettings = { apiKey: 'key', project: 'my-project', location: 'us-central1', + appId: '', + backend: new GoogleAIBackend(), }; const fakeRequestParams: CountTokensRequest = { diff --git a/packages/ai/__tests__/generate-content.test.ts b/packages/ai/__tests__/generate-content.test.ts index 3bc733e370..afa9249c12 100644 --- a/packages/ai/__tests__/generate-content.test.ts +++ b/packages/ai/__tests__/generate-content.test.ts @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { describe, expect, it, afterEach, jest } from '@jest/globals'; +import { describe, expect, it, afterEach, jest, beforeEach } from '@jest/globals'; import { getMockResponse } from './test-utils/mock-response'; import * as request from '../lib/requests/request'; import { generateContent } from '../lib/methods/generate-content'; @@ -27,11 +27,25 @@ import { } from '../lib/types'; import { ApiSettings } from '../lib/types/internal'; import { Task } from '../lib/requests/request'; +import { GoogleAIBackend, VertexAIBackend } from '../lib/backend'; +import { SpiedFunction } from 'jest-mock'; +import { AIError } from '../lib/errors'; +import { mapGenerateContentRequest } from '../lib/googleai-mappers'; const fakeApiSettings: ApiSettings = { apiKey: 'key', project: 'my-project', + appId: 'my-appid', location: 'us-central1', + backend: new VertexAIBackend(), +}; + +const fakeGoogleAIApiSettings: ApiSettings = { + apiKey: 'key', + project: 'my-project', + appId: 'my-appid', + location: 'us-central1', + backend: new GoogleAIBackend(), }; const fakeRequestParams: GenerateContentRequest = { @@ -48,6 +62,19 @@ const fakeRequestParams: GenerateContentRequest = { ], }; +const fakeGoogleAIRequestParams: GenerateContentRequest = { + contents: [{ parts: [{ text: 'hello' }], role: 'user' }], + generationConfig: { + topK: 16, + }, + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + }, + ], +}; + describe('generateContent()', () => { afterEach(() => { jest.restoreAllMocks(); @@ -88,6 +115,28 @@ describe('generateContent()', () => { ); }); + it('long response with token details', async () => { + const mockResponse = getMockResponse('unary-success-basic-response-long-usage-metadata.json'); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); + expect(result.response.usageMetadata?.totalTokenCount).toEqual(1913); + expect(result.response.usageMetadata?.candidatesTokenCount).toEqual(76); + expect(result.response.usageMetadata?.promptTokensDetails?.[0]?.modality).toEqual('IMAGE'); + expect(result.response.usageMetadata?.promptTokensDetails?.[0]?.tokenCount).toEqual(1806); + expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.modality).toEqual('TEXT'); + expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.tokenCount).toEqual(76); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.GENERATE_CONTENT, + fakeApiSettings, + false, + expect.anything(), + undefined, + ); + }); + it('citations', async () => { const mockResponse = getMockResponse('unary-success-citations.json'); const makeRequestStub = jest @@ -201,4 +250,54 @@ describe('generateContent()', () => { ); expect(mockFetch).toHaveBeenCalled(); }); + + describe('googleAI', () => { + let makeRequestStub: SpiedFunction; + + beforeEach(() => { + makeRequestStub = jest.spyOn(request, 'makeRequest'); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('throws error when method is defined', async () => { + const mockResponse = getMockResponse('unary-success-basic-reply-short.txt'); + makeRequestStub.mockResolvedValue(mockResponse as Response); + + const requestParamsWithMethod: GenerateContentRequest = { + contents: [{ parts: [{ text: 'hello' }], role: 'user' }], + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + method: HarmBlockMethod.SEVERITY, // Unsupported in Google AI. + }, + ], + }; + + // Expect generateContent to throw a AIError that method is not supported. + await expect( + generateContent(fakeGoogleAIApiSettings, 'model', requestParamsWithMethod), + ).rejects.toThrow(AIError); + expect(makeRequestStub).not.toHaveBeenCalled(); + }); + + it('maps request to GoogleAI format', async () => { + const mockResponse = getMockResponse('unary-success-basic-reply-short.txt'); + makeRequestStub.mockResolvedValue(mockResponse as Response); + + await generateContent(fakeGoogleAIApiSettings, 'model', fakeGoogleAIRequestParams); + + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.GENERATE_CONTENT, + fakeGoogleAIApiSettings, + false, + JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)), + undefined, + ); + }); + }); }); diff --git a/packages/ai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts index e62862b6aa..cab6321586 100644 --- a/packages/ai/__tests__/generative-model.test.ts +++ b/packages/ai/__tests__/generative-model.test.ts @@ -17,50 +17,28 @@ import { describe, expect, it, jest } from '@jest/globals'; import { type ReactNativeFirebase } from '@react-native-firebase/app'; import { GenerativeModel } from '../lib/models/generative-model'; -import { FunctionCallingMode, VertexAI } from '../lib/public-types'; +import { AI, FunctionCallingMode } from '../lib/public-types'; import * as request from '../lib/requests/request'; import { getMockResponse } from './test-utils/mock-response'; +import { VertexAIBackend } from '../lib/backend'; -const fakeVertexAI: VertexAI = { +const fakeAI: AI = { app: { name: 'DEFAULT', + automaticDataCollectionEnabled: true, options: { apiKey: 'key', projectId: 'my-project', + appId: 'my-appid', }, } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), location: 'us-central1', }; describe('GenerativeModel', () => { - it('handles plain model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' }); - expect(genModel.model).toBe('publishers/google/models/my-model'); - }); - - it('handles models/ prefixed model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'models/my-model', - }); - expect(genModel.model).toBe('publishers/google/models/my-model'); - }); - - it('handles full model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'publishers/google/models/my-model', - }); - expect(genModel.model).toBe('publishers/google/models/my-model'); - }); - - it('handles prefixed tuned model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'tunedModels/my-model', - }); - expect(genModel.model).toBe('tunedModels/my-model'); - }); - it('passes params through to generateContent', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { + const genModel = new GenerativeModel(fakeAI, { model: 'my-model', tools: [ { @@ -95,7 +73,7 @@ describe('GenerativeModel', () => { }); it('passes text-only systemInstruction through to generateContent', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { + const genModel = new GenerativeModel(fakeAI, { model: 'my-model', systemInstruction: 'be friendly', }); @@ -117,7 +95,7 @@ describe('GenerativeModel', () => { }); it('generateContent overrides model values', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { + const genModel = new GenerativeModel(fakeAI, { model: 'my-model', tools: [ { @@ -160,8 +138,38 @@ describe('GenerativeModel', () => { makeRequestStub.mockRestore(); }); + it('passes base model params through to ChatSession when there are no startChatParams', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + topK: 1, + }, + }); + const chatSession = genModel.startChat(); + expect(chatSession.params?.generationConfig).toEqual({ + topK: 1, + }); + }); + + it('overrides base model params with startChatParams', () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + topK: 1, + }, + }); + const chatSession = genModel.startChat({ + generationConfig: { + topK: 2, + }, + }); + expect(chatSession.params?.generationConfig).toEqual({ + topK: 2, + }); + }); + it('passes params through to chat.sendMessage', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { + const genModel = new GenerativeModel(fakeAI, { model: 'my-model', tools: [{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, @@ -187,7 +195,7 @@ describe('GenerativeModel', () => { }); it('passes text-only systemInstruction through to chat.sendMessage', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { + const genModel = new GenerativeModel(fakeAI, { model: 'my-model', systemInstruction: 'be friendly', }); @@ -209,7 +217,7 @@ describe('GenerativeModel', () => { }); it('startChat overrides model values', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { + const genModel = new GenerativeModel(fakeAI, { model: 'my-model', tools: [{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, @@ -247,7 +255,7 @@ describe('GenerativeModel', () => { }); it('calls countTokens', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' }); + const genModel = new GenerativeModel(fakeAI, { model: 'my-model' }); const mockResponse = getMockResponse('unary-success-total-tokens.json'); const makeRequestStub = jest .spyOn(request, 'makeRequest') diff --git a/packages/ai/__tests__/googleai-mappers.test.ts b/packages/ai/__tests__/googleai-mappers.test.ts index 42e8773c82..50efdf4346 100644 --- a/packages/ai/__tests__/googleai-mappers.test.ts +++ b/packages/ai/__tests__/googleai-mappers.test.ts @@ -15,14 +15,14 @@ * limitations under the License. */ import { describe, it, expect, beforeEach, afterEach, jest } from '@jest/globals'; -import { AIError } from 'lib'; +import { AIError } from '../lib/errors'; import { mapCountTokensRequest, mapGenerateContentCandidates, mapGenerateContentRequest, mapGenerateContentResponse, mapPromptFeedback, -} from 'lib/googleai-mappers'; +} from '../lib/googleai-mappers'; import { AIErrorCode, BlockReason, @@ -40,7 +40,7 @@ import { HarmSeverity, PromptFeedback, SafetyRating, -} from 'lib/public-types'; +} from '../lib/public-types'; import { getMockResponse } from './test-utils/mock-response'; import { SpiedFunction } from 'jest-mock'; diff --git a/packages/ai/__tests__/request.test.ts b/packages/ai/__tests__/request.test.ts index c992b062e9..b52cc1da27 100644 --- a/packages/ai/__tests__/request.test.ts +++ b/packages/ai/__tests__/request.test.ts @@ -18,14 +18,17 @@ import { describe, expect, it, jest, afterEach } from '@jest/globals'; import { RequestUrl, Task, getHeaders, makeRequest } from '../lib/requests/request'; import { ApiSettings } from '../lib/types/internal'; import { DEFAULT_API_VERSION } from '../lib/constants'; -import { VertexAIErrorCode } from '../lib/types'; -import { VertexAIError } from '../lib/errors'; +import { AIErrorCode } from '../lib/types'; +import { AIError } from '../lib/errors'; import { getMockResponse } from './test-utils/mock-response'; +import { VertexAIBackend } from '../lib/backend'; const fakeApiSettings: ApiSettings = { apiKey: 'key', project: 'my-project', + appId: 'my-appid', location: 'us-central1', + backend: new VertexAIBackend(), }; describe('request methods', () => { @@ -106,7 +109,9 @@ describe('request methods', () => { const fakeApiSettings: ApiSettings = { apiKey: 'key', project: 'myproject', + appId: 'my-appid', location: 'moon', + backend: new VertexAIBackend(), getAuthToken: () => Promise.resolve('authtoken'), getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }), }; @@ -140,7 +145,9 @@ describe('request methods', () => { { apiKey: 'key', project: 'myproject', + appId: 'my-appid', location: 'moon', + backend: new VertexAIBackend(), }, true, {}, @@ -176,6 +183,8 @@ describe('request methods', () => { project: 'myproject', location: 'moon', getAppCheckToken: () => Promise.reject(new Error('oops')), + backend: new VertexAIBackend(), + appId: 'my-appid', }, true, {}, @@ -204,7 +213,9 @@ describe('request methods', () => { { apiKey: 'key', project: 'myproject', + appId: 'my-appid', location: 'moon', + backend: new VertexAIBackend(), }, true, {}, @@ -260,10 +271,10 @@ describe('request methods', () => { timeout: 180000, }); } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); - expect((e as VertexAIError).customErrorData?.status).toBe(500); - expect((e as VertexAIError).customErrorData?.statusText).toBe('AbortError'); - expect((e as VertexAIError).message).toContain('500 AbortError'); + expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR); + expect((e as AIError).customErrorData?.status).toBe(500); + expect((e as AIError).customErrorData?.statusText).toBe('AbortError'); + expect((e as AIError).message).toContain('500 AbortError'); } expect(fetchMock).toHaveBeenCalledTimes(1); @@ -278,10 +289,10 @@ describe('request methods', () => { try { await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); - expect((e as VertexAIError).customErrorData?.status).toBe(500); - expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error'); - expect((e as VertexAIError).message).toContain('500 Server Error'); + expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR); + expect((e as AIError).customErrorData?.status).toBe(500); + expect((e as AIError).customErrorData?.statusText).toBe('Server Error'); + expect((e as AIError).message).toContain('500 Server Error'); } expect(fetchMock).toHaveBeenCalledTimes(1); }); @@ -296,11 +307,11 @@ describe('request methods', () => { try { await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); - expect((e as VertexAIError).customErrorData?.status).toBe(500); - expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error'); - expect((e as VertexAIError).message).toContain('500 Server Error'); - expect((e as VertexAIError).message).toContain('extra info'); + expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR); + expect((e as AIError).customErrorData?.status).toBe(500); + expect((e as AIError).customErrorData?.statusText).toBe('Server Error'); + expect((e as AIError).message).toContain('500 Server Error'); + expect((e as AIError).message).toContain('extra info'); } expect(fetchMock).toHaveBeenCalledTimes(1); }); @@ -327,12 +338,12 @@ describe('request methods', () => { try { await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); - expect((e as VertexAIError).customErrorData?.status).toBe(500); - expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error'); - expect((e as VertexAIError).message).toContain('500 Server Error'); - expect((e as VertexAIError).message).toContain('extra info'); - expect((e as VertexAIError).message).toContain('generic::invalid_argument'); + expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR); + expect((e as AIError).customErrorData?.status).toBe(500); + expect((e as AIError).customErrorData?.statusText).toBe('Server Error'); + expect((e as AIError).message).toContain('500 Server Error'); + expect((e as AIError).message).toContain('extra info'); + expect((e as AIError).message).toContain('generic::invalid_argument'); } expect(fetchMock).toHaveBeenCalledTimes(1); }); @@ -344,9 +355,9 @@ describe('request methods', () => { try { await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.API_NOT_ENABLED); - expect((e as VertexAIError).message).toContain('my-project'); - expect((e as VertexAIError).message).toContain('googleapis.com'); + expect((e as AIError).code).toBe(AIErrorCode.API_NOT_ENABLED); + expect((e as AIError).message).toContain('my-project'); + expect((e as AIError).message).toContain('googleapis.com'); } expect(fetchMock).toHaveBeenCalledTimes(1); }); diff --git a/packages/ai/__tests__/schema-builder.test.ts b/packages/ai/__tests__/schema-builder.test.ts index bec1f6a8d2..738bd17a21 100644 --- a/packages/ai/__tests__/schema-builder.test.ts +++ b/packages/ai/__tests__/schema-builder.test.ts @@ -16,7 +16,7 @@ */ import { describe, expect, it } from '@jest/globals'; import { Schema } from '../lib/requests/schema-builder'; -import { VertexAIErrorCode } from '../lib/types'; +import { AIErrorCode } from '../lib/types'; describe('Schema builder', () => { it('builds integer schema', () => { @@ -252,7 +252,7 @@ describe('Schema builder', () => { }, optionalProperties: ['cat'], }); - expect(() => schema.toJSON()).toThrow(VertexAIErrorCode.INVALID_SCHEMA); + expect(() => schema.toJSON()).toThrow(AIErrorCode.INVALID_SCHEMA); }); }); diff --git a/packages/ai/__tests__/service.test.ts b/packages/ai/__tests__/service.test.ts index 9f9503f2c9..1de537df17 100644 --- a/packages/ai/__tests__/service.test.ts +++ b/packages/ai/__tests__/service.test.ts @@ -17,7 +17,8 @@ import { describe, expect, it } from '@jest/globals'; import { type ReactNativeFirebase } from '@react-native-firebase/app'; import { DEFAULT_LOCATION } from '../lib/constants'; -import { VertexAIService } from '../lib/service'; +import { AIService } from '../lib/service'; +import { VertexAIBackend } from '../lib/backend'; const fakeApp = { name: 'DEFAULT', @@ -27,18 +28,17 @@ const fakeApp = { }, } as ReactNativeFirebase.FirebaseApp; -describe('VertexAIService', () => { +describe('AIService', () => { it('uses default location if not specified', () => { - const vertexAI = new VertexAIService(fakeApp); + const vertexAI = new AIService(fakeApp, new VertexAIBackend()); expect(vertexAI.location).toBe(DEFAULT_LOCATION); }); it('uses custom location if specified', () => { - const vertexAI = new VertexAIService( + const vertexAI = new AIService( fakeApp, - /* authProvider */ undefined, + new VertexAIBackend('somewhere'), /* appCheckProvider */ undefined, - { location: 'somewhere' }, ); expect(vertexAI.location).toBe('somewhere'); }); diff --git a/packages/ai/__tests__/stream-reader.test.ts b/packages/ai/__tests__/stream-reader.test.ts index 4a5ae8aef5..5e695316de 100644 --- a/packages/ai/__tests__/stream-reader.test.ts +++ b/packages/ai/__tests__/stream-reader.test.ts @@ -30,9 +30,19 @@ import { HarmCategory, HarmProbability, SafetyRating, - VertexAIErrorCode, + AIErrorCode, } from '../lib/types'; -import { VertexAIError } from '../lib/errors'; +import { AIError } from '../lib/errors'; +import { ApiSettings } from '../lib/types/internal'; +import { VertexAIBackend } from '../lib/backend'; + +const fakeApiSettings: ApiSettings = { + apiKey: 'key', + project: 'my-project', + appId: 'my-appid', + location: 'us-central1', + backend: new VertexAIBackend(), +}; describe('stream-reader', () => { describe('getResponseStream', () => { @@ -86,7 +96,7 @@ describe('stream-reader', () => { it('streaming response - short', async () => { const fakeResponse = getMockResponseStreaming('streaming-success-basic-reply-short.txt'); - const result = processStream(fakeResponse as Response); + const result = processStream(fakeResponse as Response, fakeApiSettings); for await (const response of result.stream) { expect(response.text()).not.toBe(''); } @@ -96,7 +106,7 @@ describe('stream-reader', () => { it('streaming response - functioncall', async () => { const fakeResponse = getMockResponseStreaming('streaming-success-function-call-short.txt'); - const result = processStream(fakeResponse as Response); + const result = processStream(fakeResponse as Response, fakeApiSettings); for await (const response of result.stream) { expect(response.text()).toBe(''); expect(response.functionCalls()).toEqual([ @@ -118,7 +128,7 @@ describe('stream-reader', () => { it('handles citations', async () => { const fakeResponse = getMockResponseStreaming('streaming-success-citations.txt'); - const result = processStream(fakeResponse as Response); + const result = processStream(fakeResponse as Response, fakeApiSettings); const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).toContain('Quantum mechanics is'); expect(aggregatedResponse.candidates?.[0]!.citationMetadata?.citations.length).toBe(3); @@ -134,7 +144,7 @@ describe('stream-reader', () => { it('removes empty text parts', async () => { const fakeResponse = getMockResponseStreaming('streaming-success-empty-text-part.txt'); - const result = processStream(fakeResponse as Response); + const result = processStream(fakeResponse as Response, fakeApiSettings); const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).toBe('1'); expect(aggregatedResponse.candidates?.length).toBe(1); @@ -358,8 +368,8 @@ describe('stream-reader', () => { try { aggregateResponses(responsesToAggregate); } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.INVALID_CONTENT); - expect((e as VertexAIError).message).toContain( + expect((e as AIError).code).toBe(AIErrorCode.INVALID_CONTENT); + expect((e as AIError).message).toContain( 'Part should have at least one property, but there are none. This is likely caused ' + 'by a malformed response from the backend.', ); From b56cd784893ed5950a73be8bd38cee9f82dd650e Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 13:11:37 +0100 Subject: [PATCH 27/85] test: update convert mocks in line with latest mock response repo --- .../ai/__tests__/test-utils/convert-mocks.ts | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/packages/ai/__tests__/test-utils/convert-mocks.ts b/packages/ai/__tests__/test-utils/convert-mocks.ts index 97a5ed75df..76bfd6a8aa 100644 --- a/packages/ai/__tests__/test-utils/convert-mocks.ts +++ b/packages/ai/__tests__/test-utils/convert-mocks.ts @@ -20,7 +20,7 @@ const fs = require('fs'); // eslint-disable-next-line @typescript-eslint/no-require-imports const { join } = require('path'); -function findMockResponseDir(): string { +function findMockResponseDir(backend: string): string { const directories = fs .readdirSync(__dirname, { withFileTypes: true }) .filter( @@ -36,18 +36,22 @@ function findMockResponseDir(): string { throw new Error('Multiple directories starting with "vertexai-sdk-test-data*" found'); } - return join(__dirname, directories[0], 'mock-responses', 'vertexai'); + return join(__dirname, directories[0], 'mock-responses', backend); } async function main(): Promise { - const mockResponseDir = findMockResponseDir(); - const list = fs.readdirSync(mockResponseDir); - const lookup: Record = {}; - // eslint-disable-next-line guard-for-in - for (const fileName of list) { - console.log(`attempting to read ${mockResponseDir}/${fileName}`) - const fullText = fs.readFileSync(join(mockResponseDir, fileName), 'utf-8'); - lookup[fileName] = fullText; + const backendNames = ['googleai', 'vertexai']; + const lookup: Record> = {}; + + for (const backend of backendNames) { + const mockResponseDir = findMockResponseDir(backend); + const list = fs.readdirSync(mockResponseDir); + lookup[backend] = {}; + const backendLookup = lookup[backend]; + for (const fileName of list) { + const fullText = fs.readFileSync(join(mockResponseDir, fileName), 'utf-8'); + backendLookup[fileName] = fullText; + } } let fileText = `// Generated from mocks text files.`; From 25dbed08438a2c535fbe9847dfb8435654d847d9 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 13:12:14 +0100 Subject: [PATCH 28/85] test(ai): update script name to run --- packages/ai/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/ai/package.json b/packages/ai/package.json index 6b230fa141..dedb3e2d0e 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -10,7 +10,7 @@ "build": "genversion --esm --semi lib/version.ts", "build:clean": "rimraf dist", "compile": "bob build", - "prepare": "yarn tests:vertex:mocks && yarn run build && yarn compile" + "prepare": "yarn tests:ai:mocks && yarn run build && yarn compile" }, "repository": { "type": "git", From cb61e8c90300a3bcaf7d762f348c3c4f0523fb2a Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 13:17:09 +0100 Subject: [PATCH 29/85] test(ai): getMockResponse() --- packages/ai/__tests__/test-utils/mock-response.ts | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/packages/ai/__tests__/test-utils/mock-response.ts b/packages/ai/__tests__/test-utils/mock-response.ts index 52eb0eb04e..9e15fa2d8e 100644 --- a/packages/ai/__tests__/test-utils/mock-response.ts +++ b/packages/ai/__tests__/test-utils/mock-response.ts @@ -46,7 +46,6 @@ export function getMockResponseStreaming( const fullText = mocksLookup[filename]; return { - // Really tangled typescript error here from our transitive dependencies. // Ignoring it now, but uncomment and run `yarn lerna:prepare` in top-level // of the repo to see if you get it or if it has gone away. @@ -60,10 +59,17 @@ export function getMockResponseStreaming( }; } -export function getMockResponse(filename: string): Partial { - const fullText = mocksLookup[filename]; +type BackendName = 'vertexai' | 'googleai'; +export function getMockResponse(backendName: BackendName, filename: string): Partial { + // @ts-ignore + const backendMocksLookup: Record = mocksLookup[backendName]; + if (!(filename in backendMocksLookup)) { + throw Error(`${backendName} mock response file '${filename}' not found.`); + } + const fullText = backendMocksLookup[filename] as string; + return { ok: true, - json: () => Promise.resolve(JSON.parse(fullText!)), + json: () => Promise.resolve(JSON.parse(fullText)), }; } From b5c61dc4f3256148a6684cbfea4d964dc525c5ce Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 13:48:40 +0100 Subject: [PATCH 30/85] test(ai): mock-response update for new mocks --- .../ai/__tests__/test-utils/mock-response.ts | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/packages/ai/__tests__/test-utils/mock-response.ts b/packages/ai/__tests__/test-utils/mock-response.ts index 9e15fa2d8e..6e021428d5 100644 --- a/packages/ai/__tests__/test-utils/mock-response.ts +++ b/packages/ai/__tests__/test-utils/mock-response.ts @@ -17,6 +17,11 @@ import { ReadableStream } from 'web-streams-polyfill'; import { mocksLookup } from './mocks-lookup'; +export enum BackendName { + VertexAI = 'vertexai', + GoogleAI = 'googleai', +} + /** * Mock native Response.body * Streams contents of json file in 20 character chunks @@ -40,10 +45,16 @@ export function getChunkedStream(input: string, chunkLength = 20): ReadableStrea return stream; } export function getMockResponseStreaming( + backendName: BackendName, filename: string, chunkLength: number = 20, ): Partial { - const fullText = mocksLookup[filename]; + // @ts-ignore + const backendMocksLookup: Record = mocksLookup[backendName]; + if (!backendMocksLookup[filename]) { + throw Error(`${backendName} mock response file '${filename}' not found.`); + } + const fullText = backendMocksLookup[filename] as string; return { // Really tangled typescript error here from our transitive dependencies. @@ -59,11 +70,10 @@ export function getMockResponseStreaming( }; } -type BackendName = 'vertexai' | 'googleai'; export function getMockResponse(backendName: BackendName, filename: string): Partial { // @ts-ignore const backendMocksLookup: Record = mocksLookup[backendName]; - if (!(filename in backendMocksLookup)) { + if (!backendMocksLookup[filename]) { throw Error(`${backendName} mock response file '${filename}' not found.`); } const fullText = backendMocksLookup[filename] as string; From 214c23566f6a095d0008d9c47410e6d2159dd3ed Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 13:49:03 +0100 Subject: [PATCH 31/85] test(ai): count tokens unit tests --- packages/ai/__tests__/count-tokens.test.ts | 88 ++++++++++++++++++++-- 1 file changed, 82 insertions(+), 6 deletions(-) diff --git a/packages/ai/__tests__/count-tokens.test.ts b/packages/ai/__tests__/count-tokens.test.ts index d0c68dd61b..96faf96254 100644 --- a/packages/ai/__tests__/count-tokens.test.ts +++ b/packages/ai/__tests__/count-tokens.test.ts @@ -14,14 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { describe, expect, it, afterEach, jest } from '@jest/globals'; -import { getMockResponse } from './test-utils/mock-response'; +import { describe, expect, it, afterEach, jest, beforeEach } from '@jest/globals'; +import { BackendName, getMockResponse } from './test-utils/mock-response'; import * as request from '../lib/requests/request'; import { countTokens } from '../lib/methods/count-tokens'; -import { CountTokensRequest } from '../lib/types'; +import { CountTokensRequest, RequestOptions } from '../lib/types'; import { ApiSettings } from '../lib/types/internal'; import { Task } from '../lib/requests/request'; import { GoogleAIBackend } from '../lib/backend'; +import { SpiedFunction } from 'jest-mock'; +import { mapCountTokensRequest } from '../lib/googleai-mappers'; const fakeApiSettings: ApiSettings = { apiKey: 'key', @@ -31,6 +33,14 @@ const fakeApiSettings: ApiSettings = { backend: new GoogleAIBackend(), }; +const fakeGoogleAIApiSettings: ApiSettings = { + apiKey: 'key', + project: 'my-project', + appId: 'my-appid', + location: '', + backend: new GoogleAIBackend(), +}; + const fakeRequestParams: CountTokensRequest = { contents: [{ parts: [{ text: 'hello' }], role: 'user' }], }; @@ -41,7 +51,7 @@ describe('countTokens()', () => { }); it('total tokens', async () => { - const mockResponse = getMockResponse('unary-success-total-tokens.json'); + const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-success-total-tokens.json'); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -58,8 +68,35 @@ describe('countTokens()', () => { ); }); + it('total tokens with modality details', async () => { + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-detailed-token-response.json', + ); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams); + + expect(result.totalTokens).toBe(1837); + expect(result.totalBillableCharacters).toBe(117); + expect(result.promptTokensDetails?.[0]?.modality).toBe('IMAGE'); + expect(result.promptTokensDetails?.[0]?.tokenCount).toBe(1806); + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.COUNT_TOKENS, + fakeApiSettings, + false, + expect.stringContaining('contents'), + undefined, + ); + }); + it('total tokens no billable characters', async () => { - const mockResponse = getMockResponse('unary-success-no-billable-characters.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-no-billable-characters.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -77,7 +114,10 @@ describe('countTokens()', () => { }); it('model not found', async () => { - const mockResponse = getMockResponse('unary-failure-model-not-found.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-failure-model-not-found.json', + ); const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ ok: false, status: 404, @@ -88,4 +128,40 @@ describe('countTokens()', () => { ); expect(mockFetch).toHaveBeenCalled(); }); + + describe('googleAI', () => { + let makeRequestStub: SpiedFunction< + ( + model: string, + task: Task, + apiSettings: ApiSettings, + stream: boolean, + body: string, + requestOptions?: RequestOptions, + ) => Promise + >; + + beforeEach(() => { + makeRequestStub = jest.spyOn(request, 'makeRequest'); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('maps request to GoogleAI format', async () => { + makeRequestStub.mockResolvedValue({ ok: true, json: () => {} } as Response); // Unused + + await countTokens(fakeGoogleAIApiSettings, 'model', fakeRequestParams); + + expect(makeRequestStub).toHaveBeenCalledWith( + 'model', + Task.COUNT_TOKENS, + fakeGoogleAIApiSettings, + false, + JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')), + undefined, + ); + }); + }); }); From cf2bbd7a87fce67103799871c8a72c44b81e3776 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 14:04:57 +0100 Subject: [PATCH 32/85] test: generate-content mock response --- .../ai/__tests__/generate-content.test.ts | 53 ++++++++++++++----- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/packages/ai/__tests__/generate-content.test.ts b/packages/ai/__tests__/generate-content.test.ts index afa9249c12..48c2b19970 100644 --- a/packages/ai/__tests__/generate-content.test.ts +++ b/packages/ai/__tests__/generate-content.test.ts @@ -15,7 +15,7 @@ * limitations under the License. */ import { describe, expect, it, afterEach, jest, beforeEach } from '@jest/globals'; -import { getMockResponse } from './test-utils/mock-response'; +import { BackendName, getMockResponse } from './test-utils/mock-response'; import * as request from '../lib/requests/request'; import { generateContent } from '../lib/methods/generate-content'; import { @@ -81,7 +81,10 @@ describe('generateContent()', () => { }); it('short response', async () => { - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -98,7 +101,10 @@ describe('generateContent()', () => { }); it('long response', async () => { - const mockResponse = getMockResponse('unary-success-basic-reply-long.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-long.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -116,7 +122,10 @@ describe('generateContent()', () => { }); it('long response with token details', async () => { - const mockResponse = getMockResponse('unary-success-basic-response-long-usage-metadata.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-response-long-usage-metadata.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -138,7 +147,7 @@ describe('generateContent()', () => { }); it('citations', async () => { - const mockResponse = getMockResponse('unary-success-citations.json'); + const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-success-citations.json'); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -156,7 +165,10 @@ describe('generateContent()', () => { }); it('blocked prompt', async () => { - const mockResponse = getMockResponse('unary-failure-prompt-blocked-safety.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-failure-prompt-blocked-safety.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -175,7 +187,10 @@ describe('generateContent()', () => { }); it('finishReason safety', async () => { - const mockResponse = getMockResponse('unary-failure-finish-reason-safety.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-failure-finish-reason-safety.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -192,7 +207,7 @@ describe('generateContent()', () => { }); it('empty content', async () => { - const mockResponse = getMockResponse('unary-failure-empty-content.json'); + const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-failure-empty-content.json'); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -209,7 +224,10 @@ describe('generateContent()', () => { }); it('unknown enum - should ignore', async () => { - const mockResponse = getMockResponse('unary-success-unknown-enum-safety-ratings.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-unknown-enum-safety-ratings.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -226,7 +244,7 @@ describe('generateContent()', () => { }); it('image rejected (400)', async () => { - const mockResponse = getMockResponse('unary-failure-image-rejected.json'); + const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-failure-image-rejected.json'); const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ ok: false, status: 400, @@ -239,7 +257,10 @@ describe('generateContent()', () => { }); it('api not enabled (403)', async () => { - const mockResponse = getMockResponse('unary-failure-firebasevertexai-api-not-enabled.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-failure-firebasevertexai-api-not-enabled.json', + ); const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ ok: false, status: 403, @@ -263,7 +284,10 @@ describe('generateContent()', () => { }); it('throws error when method is defined', async () => { - const mockResponse = getMockResponse('unary-success-basic-reply-short.txt'); + const mockResponse = getMockResponse( + BackendName.GoogleAI, + 'unary-success-basic-reply-short.txt', + ); makeRequestStub.mockResolvedValue(mockResponse as Response); const requestParamsWithMethod: GenerateContentRequest = { @@ -285,7 +309,10 @@ describe('generateContent()', () => { }); it('maps request to GoogleAI format', async () => { - const mockResponse = getMockResponse('unary-success-basic-reply-short.txt'); + const mockResponse = getMockResponse( + BackendName.GoogleAI, + 'unary-success-basic-reply-short.txt', + ); makeRequestStub.mockResolvedValue(mockResponse as Response); await generateContent(fakeGoogleAIApiSettings, 'model', fakeGoogleAIRequestParams); From 10d081df1d5fdcdfde2f84c3ce6b054453845d64 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 14:44:12 +0100 Subject: [PATCH 33/85] test(ai): update unit tests to use updated mocks --- .../ai/__tests__/generative-model.test.ts | 34 ++++++++++++++----- .../ai/__tests__/googleai-mappers.test.ts | 4 +-- packages/ai/__tests__/request.test.ts | 7 ++-- packages/ai/__tests__/stream-reader.test.ts | 26 +++++++++++--- 4 files changed, 54 insertions(+), 17 deletions(-) diff --git a/packages/ai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts index cab6321586..7d29c501c8 100644 --- a/packages/ai/__tests__/generative-model.test.ts +++ b/packages/ai/__tests__/generative-model.test.ts @@ -19,7 +19,7 @@ import { type ReactNativeFirebase } from '@react-native-firebase/app'; import { GenerativeModel } from '../lib/models/generative-model'; import { AI, FunctionCallingMode } from '../lib/public-types'; import * as request from '../lib/requests/request'; -import { getMockResponse } from './test-utils/mock-response'; +import { BackendName, getMockResponse } from './test-utils/mock-response'; import { VertexAIBackend } from '../lib/backend'; const fakeAI: AI = { @@ -56,7 +56,10 @@ describe('GenerativeModel', () => { expect(genModel.tools?.length).toBe(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -78,7 +81,10 @@ describe('GenerativeModel', () => { systemInstruction: 'be friendly', }); expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -113,7 +119,10 @@ describe('GenerativeModel', () => { expect(genModel.tools?.length).toBe(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -178,7 +187,10 @@ describe('GenerativeModel', () => { expect(genModel.tools?.length).toBe(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -200,7 +212,10 @@ describe('GenerativeModel', () => { systemInstruction: 'be friendly', }); expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -226,7 +241,10 @@ describe('GenerativeModel', () => { expect(genModel.tools?.length).toBe(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); @@ -256,7 +274,7 @@ describe('GenerativeModel', () => { it('calls countTokens', async () => { const genModel = new GenerativeModel(fakeAI, { model: 'my-model' }); - const mockResponse = getMockResponse('unary-success-total-tokens.json'); + const mockResponse = getMockResponse(BackendName.VertexAI, 'unary-success-total-tokens.json'); const makeRequestStub = jest .spyOn(request, 'makeRequest') .mockResolvedValue(mockResponse as Response); diff --git a/packages/ai/__tests__/googleai-mappers.test.ts b/packages/ai/__tests__/googleai-mappers.test.ts index 50efdf4346..c418b4408c 100644 --- a/packages/ai/__tests__/googleai-mappers.test.ts +++ b/packages/ai/__tests__/googleai-mappers.test.ts @@ -41,7 +41,7 @@ import { PromptFeedback, SafetyRating, } from '../lib/public-types'; -import { getMockResponse } from './test-utils/mock-response'; +import { BackendName, getMockResponse } from './test-utils/mock-response'; import { SpiedFunction } from 'jest-mock'; const fakeModel = 'models/gemini-pro'; @@ -129,7 +129,7 @@ describe('Google AI Mappers', () => { describe('mapGenerateContentResponse', () => { it('should map a full Google AI response', async () => { const googleAIMockResponse: GoogleAIGenerateContentResponse = await ( - getMockResponse('unary-success-citations.json') as Response + getMockResponse(BackendName.GoogleAI, 'unary-success-citations.json') as Response ).json(); const mappedResponse = mapGenerateContentResponse(googleAIMockResponse); diff --git a/packages/ai/__tests__/request.test.ts b/packages/ai/__tests__/request.test.ts index b52cc1da27..54b3378492 100644 --- a/packages/ai/__tests__/request.test.ts +++ b/packages/ai/__tests__/request.test.ts @@ -20,7 +20,7 @@ import { ApiSettings } from '../lib/types/internal'; import { DEFAULT_API_VERSION } from '../lib/constants'; import { AIErrorCode } from '../lib/types'; import { AIError } from '../lib/errors'; -import { getMockResponse } from './test-utils/mock-response'; +import { BackendName, getMockResponse } from './test-utils/mock-response'; import { VertexAIBackend } from '../lib/backend'; const fakeApiSettings: ApiSettings = { @@ -350,7 +350,10 @@ describe('request methods', () => { }); it('Network error, API not enabled', async () => { - const mockResponse = getMockResponse('unary-failure-firebasevertexai-api-not-enabled.json'); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-failure-firebasevertexai-api-not-enabled.json', + ); const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue(mockResponse as Response); try { await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); diff --git a/packages/ai/__tests__/stream-reader.test.ts b/packages/ai/__tests__/stream-reader.test.ts index 5e695316de..345e32cc6e 100644 --- a/packages/ai/__tests__/stream-reader.test.ts +++ b/packages/ai/__tests__/stream-reader.test.ts @@ -22,7 +22,11 @@ import { processStream, } from '../lib/requests/stream-reader'; -import { getChunkedStream, getMockResponseStreaming } from './test-utils/mock-response'; +import { + BackendName, + getChunkedStream, + getMockResponseStreaming, +} from './test-utils/mock-response'; import { BlockReason, FinishReason, @@ -95,7 +99,10 @@ describe('stream-reader', () => { }); it('streaming response - short', async () => { - const fakeResponse = getMockResponseStreaming('streaming-success-basic-reply-short.txt'); + const fakeResponse = getMockResponseStreaming( + BackendName.VertexAI, + 'streaming-success-basic-reply-short.txt', + ); const result = processStream(fakeResponse as Response, fakeApiSettings); for await (const response of result.stream) { expect(response.text()).not.toBe(''); @@ -105,7 +112,10 @@ describe('stream-reader', () => { }); it('streaming response - functioncall', async () => { - const fakeResponse = getMockResponseStreaming('streaming-success-function-call-short.txt'); + const fakeResponse = getMockResponseStreaming( + BackendName.VertexAI, + 'streaming-success-function-call-short.txt', + ); const result = processStream(fakeResponse as Response, fakeApiSettings); for await (const response of result.stream) { expect(response.text()).toBe(''); @@ -127,7 +137,10 @@ describe('stream-reader', () => { }); it('handles citations', async () => { - const fakeResponse = getMockResponseStreaming('streaming-success-citations.txt'); + const fakeResponse = getMockResponseStreaming( + BackendName.VertexAI, + 'streaming-success-citations.txt', + ); const result = processStream(fakeResponse as Response, fakeApiSettings); const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).toContain('Quantum mechanics is'); @@ -143,7 +156,10 @@ describe('stream-reader', () => { }); it('removes empty text parts', async () => { - const fakeResponse = getMockResponseStreaming('streaming-success-empty-text-part.txt'); + const fakeResponse = getMockResponseStreaming( + BackendName.VertexAI, + 'streaming-success-empty-text-part.txt', + ); const result = processStream(fakeResponse as Response, fakeApiSettings); const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).toBe('1'); From 6541dee01aeeaf9942b436d12a01c22afb86c89e Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 15:17:46 +0100 Subject: [PATCH 34/85] test: fix unit tests --- packages/ai/__tests__/api.test.ts | 16 ++++++++-------- packages/ai/__tests__/googleai-mappers.test.ts | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/packages/ai/__tests__/api.test.ts b/packages/ai/__tests__/api.test.ts index 79d9c24c92..9ff9163ba1 100644 --- a/packages/ai/__tests__/api.test.ts +++ b/packages/ai/__tests__/api.test.ts @@ -47,8 +47,8 @@ describe('Top level API', () => { } catch (e) { expect((e as AIError).code).toContain(AIErrorCode.NO_MODEL); expect((e as AIError).message).toContain( - `VertexAI: Must provide a model name. Example: ` + - `getGenerativeModel({ model: 'my-model-name' }) (vertexAI/${AIErrorCode.NO_MODEL})`, + `AI: Must provide a model name. Example: ` + + `getGenerativeModel({ model: 'my-model-name' }) (${AI_TYPE}/${AIErrorCode.NO_MODEL})`, ); } }); @@ -63,9 +63,9 @@ describe('Top level API', () => { } catch (e) { expect((e as AIError).code).toContain(AIErrorCode.NO_API_KEY); expect((e as AIError).message).toBe( - `VertexAI: The "apiKey" field is empty in the local ` + - `Firebase config. Firebase VertexAI requires this field to` + - ` contain a valid API key. (vertexAI/${AIErrorCode.NO_API_KEY})`, + `AI: The "apiKey" field is empty in the local ` + + `Firebase config. Firebase AI requires this field to` + + ` contain a valid API key. (${AI_TYPE}/${AIErrorCode.NO_API_KEY})`, ); } }); @@ -80,9 +80,9 @@ describe('Top level API', () => { } catch (e) { expect((e as AIError).code).toContain(AIErrorCode.NO_PROJECT_ID); expect((e as AIError).message).toBe( - `VertexAI: The "projectId" field is empty in the local` + - ` Firebase config. Firebase VertexAI requires this field ` + - `to contain a valid project ID. (vertexAI/${AIErrorCode.NO_PROJECT_ID})`, + `AI: The "projectId" field is empty in the local` + + ` Firebase config. Firebase AI requires this field ` + + `to contain a valid project ID. (${AI_TYPE}/${AIErrorCode.NO_PROJECT_ID})`, ); } }); diff --git a/packages/ai/__tests__/googleai-mappers.test.ts b/packages/ai/__tests__/googleai-mappers.test.ts index c418b4408c..15cb297cd0 100644 --- a/packages/ai/__tests__/googleai-mappers.test.ts +++ b/packages/ai/__tests__/googleai-mappers.test.ts @@ -76,7 +76,7 @@ describe('Google AI Mappers', () => { }; const error = new AIError( AIErrorCode.UNSUPPORTED, - 'SafetySettings.method is not supported in requests to the Gemini Developer API', + 'AI: SafetySetting.method is not supported in the the Gemini Developer API. Please remove this property. (AI/unsupported)', ); expect(() => mapGenerateContentRequest(request)).toThrowError(error); }); @@ -90,7 +90,7 @@ describe('Google AI Mappers', () => { }; const mappedRequest = mapGenerateContentRequest(request); expect(loggerWarnSpy).toHaveBeenCalledWith( - 'topK in GenerationConfig has been rounded to the nearest integer to match the format for requests to the Gemini Developer API.', + expect.stringContaining('topK in GenerationConfig has been rounded to the nearest integer'), ); expect(mappedRequest.generationConfig?.topK).toBe(16); }); From 410257402f7e875cd9a11fea37b6c4813f76f0f4 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 15:29:47 +0100 Subject: [PATCH 35/85] test: fixed another unit test suite --- .../ai/__tests__/googleai-mappers.test.ts | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/packages/ai/__tests__/googleai-mappers.test.ts b/packages/ai/__tests__/googleai-mappers.test.ts index 15cb297cd0..00ecd39635 100644 --- a/packages/ai/__tests__/googleai-mappers.test.ts +++ b/packages/ai/__tests__/googleai-mappers.test.ts @@ -15,7 +15,6 @@ * limitations under the License. */ import { describe, it, expect, beforeEach, afterEach, jest } from '@jest/globals'; -import { AIError } from '../lib/errors'; import { mapCountTokensRequest, mapGenerateContentCandidates, @@ -74,11 +73,14 @@ describe('Google AI Mappers', () => { }, ], }; - const error = new AIError( - AIErrorCode.UNSUPPORTED, - 'AI: SafetySetting.method is not supported in the the Gemini Developer API. Please remove this property. (AI/unsupported)', + + expect(() => mapGenerateContentRequest(request)).toThrowError( + expect.objectContaining({ + code: AIErrorCode.UNSUPPORTED, + message: + 'AI: SafetySetting.method is not supported in the the Gemini Developer API. Please remove this property. (AI/unsupported)', + }), ); - expect(() => mapGenerateContentRequest(request)).toThrowError(error); }); it('should warn and round topK if present', () => { @@ -90,6 +92,7 @@ describe('Google AI Mappers', () => { }; const mappedRequest = mapGenerateContentRequest(request); expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.any(String), // First argument (timestamp) expect.stringContaining('topK in GenerationConfig has been rounded to the nearest integer'), ); expect(mappedRequest.generationConfig?.topK).toBe(16); @@ -299,7 +302,11 @@ describe('Google AI Mappers', () => { }, ]; expect(() => mapGenerateContentCandidates(candidates)).toThrowError( - new AIError(AIErrorCode.UNSUPPORTED, 'Part.videoMetadata is not supported'), + expect.objectContaining({ + code: AIErrorCode.UNSUPPORTED, + message: + 'AI: Part.videoMetadata is not supported in the Gemini Developer API. Please remove this property. (AI/unsupported)', + }), ); }); From 6fb7bda0593fbae6f1db418f8e0b6f96d0cfbe4b Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Wed, 4 Jun 2025 15:30:11 +0100 Subject: [PATCH 36/85] test: add TS no-check to generated mocks --- packages/ai/__tests__/test-utils/convert-mocks.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/ai/__tests__/test-utils/convert-mocks.ts b/packages/ai/__tests__/test-utils/convert-mocks.ts index 76bfd6a8aa..87e18a478e 100644 --- a/packages/ai/__tests__/test-utils/convert-mocks.ts +++ b/packages/ai/__tests__/test-utils/convert-mocks.ts @@ -53,9 +53,10 @@ async function main(): Promise { backendLookup[fileName] = fullText; } } - let fileText = `// Generated from mocks text files.`; + let fileText = `// Generated from mocks text files. Do not edit.`; fileText += '\n\n'; + fileText += `// @ts-nocheck\n`; fileText += `export const mocksLookup: Record = ${JSON.stringify( lookup, null, From f210ac449ed39c7888953166ddbe855a72ba303c Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 10:28:03 +0100 Subject: [PATCH 37/85] refactor(vertexai): initial setup to use ai package --- packages/ai/lib/index.ts | 12 +++++++++--- packages/vertexai/lib/index.ts | 20 +++++++++++++++----- packages/vertexai/tsconfig.json | 5 +++-- tsconfig-jest.json | 1 + 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/packages/ai/lib/index.ts b/packages/ai/lib/index.ts index 0d5281fbc4..88622d850c 100644 --- a/packages/ai/lib/index.ts +++ b/packages/ai/lib/index.ts @@ -25,9 +25,15 @@ import { GenerativeModel } from './models/generative-model'; export { ChatSession } from './methods/chat-session'; export * from './requests/schema-builder'; -export { GenerativeModel }; - -export { AIError }; +export { + GenerativeModel, + AIError, + AIErrorCode, + ModelParams, + RequestOptions, + GoogleAIBackend, + VertexAIBackend, +}; /** * Returns the default {@link AI} instance that is associated with the provided diff --git a/packages/vertexai/lib/index.ts b/packages/vertexai/lib/index.ts index 580f1bc86b..dfd3308e40 100644 --- a/packages/vertexai/lib/index.ts +++ b/packages/vertexai/lib/index.ts @@ -17,17 +17,22 @@ import './polyfills'; import { getApp, ReactNativeFirebase } from '@react-native-firebase/app'; -import { ModelParams, RequestOptions, VertexAIErrorCode } from './types'; +import { + getGenerativeModel as getGenerativeModelFromAI, + getAI, + VertexAIBackend, + GenerativeModel, + RequestOptions, + ModelParams, +} from '@react-native-firebase/ai'; +import { VertexAIErrorCode } from './types'; import { DEFAULT_LOCATION } from './constants'; import { VertexAI, VertexAIOptions } from './public-types'; import { VertexAIError } from './errors'; -import { GenerativeModel } from './models/generative-model'; import { VertexAIService } from './service'; export { ChatSession } from './methods/chat-session'; export * from './requests/schema-builder'; -export { GenerativeModel }; - export { VertexAIError }; /** @@ -69,5 +74,10 @@ export function getGenerativeModel( `Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`, ); } - return new GenerativeModel(vertexAI, modelParams, requestOptions); + + const ai = getAI(vertexAI.app, { + backend: new VertexAIBackend(vertexAI.location), + }); + + return getGenerativeModelFromAI(ai, modelParams, requestOptions); } diff --git a/packages/vertexai/tsconfig.json b/packages/vertexai/tsconfig.json index f1d9865812..0eace61ca4 100644 --- a/packages/vertexai/tsconfig.json +++ b/packages/vertexai/tsconfig.json @@ -1,6 +1,5 @@ { "compilerOptions": { - "rootDir": ".", "allowUnreachableCode": false, "allowUnusedLabels": false, "esModuleInterop": true, @@ -27,6 +26,8 @@ "@react-native-firebase/app": ["../app/lib"], "@react-native-firebase/auth": ["../auth/lib"], "@react-native-firebase/app-check": ["../app-check/lib"], + "@react-native-firebase/ai": ["../ai/lib"], } - } + }, + "include": ["lib/**/*", "../ai/lib/**/*"] } diff --git a/tsconfig-jest.json b/tsconfig-jest.json index 0149111b06..8a42f66917 100644 --- a/tsconfig-jest.json +++ b/tsconfig-jest.json @@ -6,6 +6,7 @@ "@react-native-firebase/app": ["packages/app/lib"], "@react-native-firebase/auth": ["packages/auth/lib"], "@react-native-firebase/app-check": ["packages/app-check/lib"], + "@react-native-firebase/ai": ["packages/ai/lib"], } } } From 01f9da0483a2c93b65bfd2926436f2956ad8fd01 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 11:07:36 +0100 Subject: [PATCH 38/85] refactor: remove vertexai code and wrap around firebase ai package --- packages/ai/lib/index.ts | 15 +- packages/ai/lib/public-types.ts | 24 ++ packages/ai/lib/types/error.ts | 72 +++++ packages/vertexai/lib/constants.ts | 2 - packages/vertexai/lib/errors.ts | 52 ---- packages/vertexai/lib/index.ts | 11 +- .../lib/methods/chat-session-helpers.ts | 116 -------- packages/vertexai/lib/methods/chat-session.ts | 182 ------------ packages/vertexai/lib/methods/count-tokens.ts | 37 --- .../vertexai/lib/methods/generate-content.ts | 66 ---- .../vertexai/lib/models/generative-model.ts | 180 ----------- packages/vertexai/lib/public-types.ts | 46 --- .../vertexai/lib/requests/request-helpers.ts | 116 -------- packages/vertexai/lib/requests/request.ts | 242 --------------- .../vertexai/lib/requests/response-helpers.ts | 186 ------------ .../vertexai/lib/requests/schema-builder.ts | 281 ------------------ .../vertexai/lib/requests/stream-reader.ts | 213 ------------- packages/vertexai/lib/service.ts | 2 +- packages/vertexai/lib/types/content.ts | 162 ---------- packages/vertexai/lib/types/enums.ts | 149 ---------- packages/vertexai/lib/types/error.ts | 98 ------ packages/vertexai/lib/types/index.ts | 23 -- packages/vertexai/lib/types/internal.ts | 25 -- packages/vertexai/lib/types/polyfills.d.ts | 15 - packages/vertexai/lib/types/requests.ts | 198 ------------ packages/vertexai/lib/types/responses.ts | 209 ------------- packages/vertexai/lib/types/schema.ts | 104 ------- 27 files changed, 107 insertions(+), 2719 deletions(-) delete mode 100644 packages/vertexai/lib/errors.ts delete mode 100644 packages/vertexai/lib/methods/chat-session-helpers.ts delete mode 100644 packages/vertexai/lib/methods/chat-session.ts delete mode 100644 packages/vertexai/lib/methods/count-tokens.ts delete mode 100644 packages/vertexai/lib/methods/generate-content.ts delete mode 100644 packages/vertexai/lib/models/generative-model.ts delete mode 100644 packages/vertexai/lib/public-types.ts delete mode 100644 packages/vertexai/lib/requests/request-helpers.ts delete mode 100644 packages/vertexai/lib/requests/request.ts delete mode 100644 packages/vertexai/lib/requests/response-helpers.ts delete mode 100644 packages/vertexai/lib/requests/schema-builder.ts delete mode 100644 packages/vertexai/lib/requests/stream-reader.ts delete mode 100644 packages/vertexai/lib/types/content.ts delete mode 100644 packages/vertexai/lib/types/enums.ts delete mode 100644 packages/vertexai/lib/types/error.ts delete mode 100644 packages/vertexai/lib/types/index.ts delete mode 100644 packages/vertexai/lib/types/internal.ts delete mode 100644 packages/vertexai/lib/types/polyfills.d.ts delete mode 100644 packages/vertexai/lib/types/requests.ts delete mode 100644 packages/vertexai/lib/types/responses.ts delete mode 100644 packages/vertexai/lib/types/schema.ts diff --git a/packages/ai/lib/index.ts b/packages/ai/lib/index.ts index 88622d850c..04a318a3c9 100644 --- a/packages/ai/lib/index.ts +++ b/packages/ai/lib/index.ts @@ -19,21 +19,16 @@ import './polyfills'; import { getApp, ReactNativeFirebase } from '@react-native-firebase/app'; import { GoogleAIBackend, VertexAIBackend } from './backend'; import { AIErrorCode, ModelParams, RequestOptions } from './types'; -import { AI, AIOptions } from './public-types'; +import { AI, AIOptions, VertexAI, VertexAIOptions } from './public-types'; import { AIError } from './errors'; import { GenerativeModel } from './models/generative-model'; export { ChatSession } from './methods/chat-session'; + export * from './requests/schema-builder'; +export * from './types'; +export * from './backend'; -export { - GenerativeModel, - AIError, - AIErrorCode, - ModelParams, - RequestOptions, - GoogleAIBackend, - VertexAIBackend, -}; +export { GenerativeModel, AIError, VertexAI, VertexAIOptions }; /** * Returns the default {@link AI} instance that is associated with the provided diff --git a/packages/ai/lib/public-types.ts b/packages/ai/lib/public-types.ts index 95d9a1dc63..bd58c018c0 100644 --- a/packages/ai/lib/public-types.ts +++ b/packages/ai/lib/public-types.ts @@ -155,3 +155,27 @@ export interface AI { */ location: string; } + +/** + * An instance of the Vertex AI in Firebase SDK. + * @public + */ +export interface VertexAI { + /** + * The {@link @firebase/app#FirebaseApp} this {@link VertexAI} instance is associated with. + */ + app: ReactNativeFirebase.FirebaseApp; + location: string; + appCheck?: FirebaseAppCheckTypes.Module | null; + auth?: FirebaseAuthTypes.Module | null; +} + +/** + * Options when initializing the Vertex AI in Firebase SDK. + * @public + */ +export interface VertexAIOptions { + location?: string; + appCheck?: FirebaseAppCheckTypes.Module | null; + auth?: FirebaseAuthTypes.Module | null; +} diff --git a/packages/ai/lib/types/error.ts b/packages/ai/lib/types/error.ts index 4fcc1ac483..4f976f0901 100644 --- a/packages/ai/lib/types/error.ts +++ b/packages/ai/lib/types/error.ts @@ -102,3 +102,75 @@ export const enum AIErrorCode { /** An error occurred due an attempt to use an unsupported feature. */ UNSUPPORTED = 'unsupported', } + +/** + * Standardized error codes that {@link VertexAIError} can have. + * + * @public + */ +export const enum VertexAIErrorCode { + /** A generic error occurred. */ + ERROR = 'error', + + /** An error occurred in a request. */ + REQUEST_ERROR = 'request-error', + + /** An error occurred in a response. */ + RESPONSE_ERROR = 'response-error', + + /** An error occurred while performing a fetch. */ + FETCH_ERROR = 'fetch-error', + + /** An error associated with a Content object. */ + INVALID_CONTENT = 'invalid-content', + + /** An error due to the Firebase API not being enabled in the Console. */ + API_NOT_ENABLED = 'api-not-enabled', + + /** An error due to invalid Schema input. */ + INVALID_SCHEMA = 'invalid-schema', + + /** An error occurred due to a missing Firebase API key. */ + NO_API_KEY = 'no-api-key', + + /** An error occurred due to a model name not being specified during initialization. */ + NO_MODEL = 'no-model', + + /** An error occurred due to a missing project ID. */ + NO_PROJECT_ID = 'no-project-id', + + /** An error occurred while parsing. */ + PARSE_FAILED = 'parse-failed', +} + +/** + * Error class for the Vertex AI in Firebase SDK. + * + * @public + */ +export class VertexAIError extends FirebaseError { + /** + * Constructs a new instance of the `VertexAIError` class. + * + * @param code - The error code from {@link VertexAIErrorCode}. + * @param message - A human-readable message describing the error. + * @param customErrorData - Optional error data. + */ + constructor( + readonly code: VertexAIErrorCode, + message: string, + readonly customErrorData?: CustomErrorData, + ) { + // Match error format used by FirebaseError from ErrorFactory + const service = VERTEX_TYPE; + const serviceName = 'VertexAI'; + const fullCode = `${service}/${code}`; + const fullMessage = `${serviceName}: ${message} (${fullCode})`; + super(code, fullMessage); + + Object.setPrototypeOf(this, VertexAIError.prototype); + + // Since Error is an interface, we don't inherit toString and so we define it ourselves. + this.toString = () => fullMessage; + } +} diff --git a/packages/vertexai/lib/constants.ts b/packages/vertexai/lib/constants.ts index 816f5194a2..338ed8b80e 100644 --- a/packages/vertexai/lib/constants.ts +++ b/packages/vertexai/lib/constants.ts @@ -17,8 +17,6 @@ import { version } from './version'; -export const VERTEX_TYPE = 'vertexAI'; - export const DEFAULT_LOCATION = 'us-central1'; export const DEFAULT_BASE_URL = 'https://firebasevertexai.googleapis.com'; diff --git a/packages/vertexai/lib/errors.ts b/packages/vertexai/lib/errors.ts deleted file mode 100644 index 370c19aeb0..0000000000 --- a/packages/vertexai/lib/errors.ts +++ /dev/null @@ -1,52 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { FirebaseError } from '@firebase/util'; -import { VertexAIErrorCode, CustomErrorData } from './types'; -import { VERTEX_TYPE } from './constants'; - -/** - * Error class for the Vertex AI in Firebase SDK. - * - * @public - */ -export class VertexAIError extends FirebaseError { - /** - * Constructs a new instance of the `VertexAIError` class. - * - * @param code - The error code from {@link VertexAIErrorCode}. - * @param message - A human-readable message describing the error. - * @param customErrorData - Optional error data. - */ - constructor( - readonly code: VertexAIErrorCode, - message: string, - readonly customErrorData?: CustomErrorData, - ) { - // Match error format used by FirebaseError from ErrorFactory - const service = VERTEX_TYPE; - const serviceName = 'VertexAI'; - const fullCode = `${service}/${code}`; - const fullMessage = `${serviceName}: ${message} (${fullCode})`; - super(code, fullMessage); - - Object.setPrototypeOf(this, VertexAIError.prototype); - - // Since Error is an interface, we don't inherit toString and so we define it ourselves. - this.toString = () => fullMessage; - } -} diff --git a/packages/vertexai/lib/index.ts b/packages/vertexai/lib/index.ts index dfd3308e40..6c01fe00ba 100644 --- a/packages/vertexai/lib/index.ts +++ b/packages/vertexai/lib/index.ts @@ -24,16 +24,15 @@ import { GenerativeModel, RequestOptions, ModelParams, + VertexAIErrorCode, + VertexAIError, + VertexAI, + VertexAIOptions, } from '@react-native-firebase/ai'; -import { VertexAIErrorCode } from './types'; import { DEFAULT_LOCATION } from './constants'; -import { VertexAI, VertexAIOptions } from './public-types'; -import { VertexAIError } from './errors'; import { VertexAIService } from './service'; -export { ChatSession } from './methods/chat-session'; -export * from './requests/schema-builder'; -export { VertexAIError }; +export * from '@react-native-firebase/ai'; /** * Returns a {@link VertexAI} instance for the given app. diff --git a/packages/vertexai/lib/methods/chat-session-helpers.ts b/packages/vertexai/lib/methods/chat-session-helpers.ts deleted file mode 100644 index 4b9bb56db0..0000000000 --- a/packages/vertexai/lib/methods/chat-session-helpers.ts +++ /dev/null @@ -1,116 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Content, POSSIBLE_ROLES, Part, Role, VertexAIErrorCode } from '../types'; -import { VertexAIError } from '../errors'; - -// https://ai.google.dev/api/rest/v1beta/Content#part - -const VALID_PART_FIELDS: Array = [ - 'text', - 'inlineData', - 'functionCall', - 'functionResponse', -]; - -const VALID_PARTS_PER_ROLE: { [key in Role]: Array } = { - user: ['text', 'inlineData'], - function: ['functionResponse'], - model: ['text', 'functionCall'], - // System instructions shouldn't be in history anyway. - system: ['text'], -}; - -const VALID_PREVIOUS_CONTENT_ROLES: { [key in Role]: Role[] } = { - user: ['model'], - function: ['model'], - model: ['user', 'function'], - // System instructions shouldn't be in history. - system: [], -}; - -export function validateChatHistory(history: Content[]): void { - let prevContent: Content | null = null; - for (const currContent of history) { - const { role, parts } = currContent; - if (!prevContent && role !== 'user') { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - `First Content should be with role 'user', got ${role}`, - ); - } - if (!POSSIBLE_ROLES.includes(role)) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - `Each item should include role field. Got ${role} but valid roles are: ${JSON.stringify( - POSSIBLE_ROLES, - )}`, - ); - } - - if (!Array.isArray(parts)) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - `Content should have 'parts' but property with an array of Parts`, - ); - } - - if (parts.length === 0) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - `Each Content should have at least one part`, - ); - } - - const countFields: Record = { - text: 0, - inlineData: 0, - functionCall: 0, - functionResponse: 0, - }; - - for (const part of parts) { - for (const key of VALID_PART_FIELDS) { - if (key in part) { - countFields[key] += 1; - } - } - } - const validParts = VALID_PARTS_PER_ROLE[role]; - for (const key of VALID_PART_FIELDS) { - if (!validParts.includes(key) && countFields[key] > 0) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - `Content with role '${role}' can't contain '${key}' part`, - ); - } - } - - if (prevContent) { - const validPreviousContentRoles = VALID_PREVIOUS_CONTENT_ROLES[role]; - if (!validPreviousContentRoles.includes(prevContent.role)) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - `Content with role '${role} can't follow '${ - prevContent.role - }'. Valid previous roles: ${JSON.stringify(VALID_PREVIOUS_CONTENT_ROLES)}`, - ); - } - } - prevContent = currContent; - } -} diff --git a/packages/vertexai/lib/methods/chat-session.ts b/packages/vertexai/lib/methods/chat-session.ts deleted file mode 100644 index e3e9cf905f..0000000000 --- a/packages/vertexai/lib/methods/chat-session.ts +++ /dev/null @@ -1,182 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { - Content, - GenerateContentRequest, - GenerateContentResult, - GenerateContentStreamResult, - Part, - RequestOptions, - StartChatParams, - EnhancedGenerateContentResponse, -} from '../types'; -import { formatNewContent } from '../requests/request-helpers'; -import { formatBlockErrorMessage } from '../requests/response-helpers'; -import { validateChatHistory } from './chat-session-helpers'; -import { generateContent, generateContentStream } from './generate-content'; -import { ApiSettings } from '../types/internal'; -import { logger } from '../logger'; - -/** - * Do not log a message for this error. - */ -const SILENT_ERROR = 'SILENT_ERROR'; - -/** - * ChatSession class that enables sending chat messages and stores - * history of sent and received messages so far. - * - * @public - */ -export class ChatSession { - private _apiSettings: ApiSettings; - private _history: Content[] = []; - private _sendPromise: Promise = Promise.resolve(); - - constructor( - apiSettings: ApiSettings, - public model: string, - public params?: StartChatParams, - public requestOptions?: RequestOptions, - ) { - this._apiSettings = apiSettings; - if (params?.history) { - validateChatHistory(params.history); - this._history = params.history; - } - } - - /** - * Gets the chat history so far. Blocked prompts are not added to history. - * Neither blocked candidates nor the prompts that generated them are added - * to history. - */ - async getHistory(): Promise { - await this._sendPromise; - return this._history; - } - - /** - * Sends a chat message and receives a non-streaming - * {@link GenerateContentResult} - */ - async sendMessage(request: string | Array): Promise { - await this._sendPromise; - const newContent = formatNewContent(request); - const generateContentRequest: GenerateContentRequest = { - safetySettings: this.params?.safetySettings, - generationConfig: this.params?.generationConfig, - tools: this.params?.tools, - toolConfig: this.params?.toolConfig, - systemInstruction: this.params?.systemInstruction, - contents: [...this._history, newContent], - }; - let finalResult = {} as GenerateContentResult; - // Add onto the chain. - this._sendPromise = this._sendPromise - .then(() => - generateContent(this._apiSettings, this.model, generateContentRequest, this.requestOptions), - ) - .then((result: GenerateContentResult) => { - if (result.response.candidates && result.response.candidates.length > 0) { - this._history.push(newContent); - const responseContent: Content = { - parts: result.response.candidates?.[0]?.content.parts || [], - // Response seems to come back without a role set. - role: result.response.candidates?.[0]?.content.role || 'model', - }; - this._history.push(responseContent); - } else { - const blockErrorMessage = formatBlockErrorMessage(result.response); - if (blockErrorMessage) { - logger.warn( - `sendMessage() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.`, - ); - } - } - finalResult = result; - }); - await this._sendPromise; - return finalResult; - } - - /** - * Sends a chat message and receives the response as a - * {@link GenerateContentStreamResult} containing an iterable stream - * and a response promise. - */ - async sendMessageStream( - request: string | Array, - ): Promise { - await this._sendPromise; - const newContent = formatNewContent(request); - const generateContentRequest: GenerateContentRequest = { - safetySettings: this.params?.safetySettings, - generationConfig: this.params?.generationConfig, - tools: this.params?.tools, - toolConfig: this.params?.toolConfig, - systemInstruction: this.params?.systemInstruction, - contents: [...this._history, newContent], - }; - const streamPromise = generateContentStream( - this._apiSettings, - this.model, - generateContentRequest, - this.requestOptions, - ); - - // Add onto the chain. - this._sendPromise = this._sendPromise - .then(() => streamPromise) - // This must be handled to avoid unhandled rejection, but jump - // to the final catch block with a label to not log this error. - .catch(_ignored => { - throw new Error(SILENT_ERROR); - }) - .then(streamResult => streamResult.response) - .then((response: EnhancedGenerateContentResponse) => { - if (response.candidates && response.candidates.length > 0) { - this._history.push(newContent); - const responseContent = { ...response.candidates[0]?.content }; - // Response seems to come back without a role set. - if (!responseContent.role) { - responseContent.role = 'model'; - } - this._history.push(responseContent as Content); - } else { - const blockErrorMessage = formatBlockErrorMessage(response); - if (blockErrorMessage) { - logger.warn( - `sendMessageStream() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.`, - ); - } - } - }) - .catch(e => { - // Errors in streamPromise are already catchable by the user as - // streamPromise is returned. - // Avoid duplicating the error message in logs. - if (e.message !== SILENT_ERROR) { - // Users do not have access to _sendPromise to catch errors - // downstream from streamPromise, so they should not throw. - logger.error(e); - } - }); - return streamPromise; - } -} diff --git a/packages/vertexai/lib/methods/count-tokens.ts b/packages/vertexai/lib/methods/count-tokens.ts deleted file mode 100644 index 10d41cffa8..0000000000 --- a/packages/vertexai/lib/methods/count-tokens.ts +++ /dev/null @@ -1,37 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { CountTokensRequest, CountTokensResponse, RequestOptions } from '../types'; -import { Task, makeRequest } from '../requests/request'; -import { ApiSettings } from '../types/internal'; - -export async function countTokens( - apiSettings: ApiSettings, - model: string, - params: CountTokensRequest, - requestOptions?: RequestOptions, -): Promise { - const response = await makeRequest( - model, - Task.COUNT_TOKENS, - apiSettings, - false, - JSON.stringify(params), - requestOptions, - ); - return response.json(); -} diff --git a/packages/vertexai/lib/methods/generate-content.ts b/packages/vertexai/lib/methods/generate-content.ts deleted file mode 100644 index 6d1a6ecb27..0000000000 --- a/packages/vertexai/lib/methods/generate-content.ts +++ /dev/null @@ -1,66 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { - GenerateContentRequest, - GenerateContentResponse, - GenerateContentResult, - GenerateContentStreamResult, - RequestOptions, -} from '../types'; -import { Task, makeRequest } from '../requests/request'; -import { createEnhancedContentResponse } from '../requests/response-helpers'; -import { processStream } from '../requests/stream-reader'; -import { ApiSettings } from '../types/internal'; - -export async function generateContentStream( - apiSettings: ApiSettings, - model: string, - params: GenerateContentRequest, - requestOptions?: RequestOptions, -): Promise { - const response = await makeRequest( - model, - Task.STREAM_GENERATE_CONTENT, - apiSettings, - /* stream */ true, - JSON.stringify(params), - requestOptions, - ); - return processStream(response); -} - -export async function generateContent( - apiSettings: ApiSettings, - model: string, - params: GenerateContentRequest, - requestOptions?: RequestOptions, -): Promise { - const response = await makeRequest( - model, - Task.GENERATE_CONTENT, - apiSettings, - /* stream */ false, - JSON.stringify(params), - requestOptions, - ); - const responseJson: GenerateContentResponse = await response.json(); - const enhancedResponse = createEnhancedContentResponse(responseJson); - return { - response: enhancedResponse, - }; -} diff --git a/packages/vertexai/lib/models/generative-model.ts b/packages/vertexai/lib/models/generative-model.ts deleted file mode 100644 index 111cefa427..0000000000 --- a/packages/vertexai/lib/models/generative-model.ts +++ /dev/null @@ -1,180 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { generateContent, generateContentStream } from '../methods/generate-content'; -import { - Content, - CountTokensRequest, - CountTokensResponse, - GenerateContentRequest, - GenerateContentResult, - GenerateContentStreamResult, - GenerationConfig, - ModelParams, - Part, - RequestOptions, - SafetySetting, - StartChatParams, - Tool, - ToolConfig, - VertexAIErrorCode, -} from '../types'; -import { VertexAIError } from '../errors'; -import { ChatSession } from '../methods/chat-session'; -import { countTokens } from '../methods/count-tokens'; -import { formatGenerateContentInput, formatSystemInstruction } from '../requests/request-helpers'; -import { VertexAI } from '../public-types'; -import { ApiSettings } from '../types/internal'; -import { VertexAIService } from '../service'; - -/** - * Class for generative model APIs. - * @public - */ -export class GenerativeModel { - private _apiSettings: ApiSettings; - model: string; - generationConfig: GenerationConfig; - safetySettings: SafetySetting[]; - requestOptions?: RequestOptions; - tools?: Tool[]; - toolConfig?: ToolConfig; - systemInstruction?: Content; - - constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions) { - if (!vertexAI.app?.options?.apiKey) { - throw new VertexAIError( - VertexAIErrorCode.NO_API_KEY, - `The "apiKey" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid API key.`, - ); - } else if (!vertexAI.app?.options?.projectId) { - throw new VertexAIError( - VertexAIErrorCode.NO_PROJECT_ID, - `The "projectId" field is empty in the local Firebase config. Firebase VertexAI requires this field to contain a valid project ID.`, - ); - } else { - this._apiSettings = { - apiKey: vertexAI.app.options.apiKey, - project: vertexAI.app.options.projectId, - location: vertexAI.location, - }; - if ((vertexAI as VertexAIService).appCheck) { - this._apiSettings.getAppCheckToken = () => - (vertexAI as VertexAIService).appCheck!.getToken(); - } - - if ((vertexAI as VertexAIService).auth?.currentUser) { - this._apiSettings.getAuthToken = () => - (vertexAI as VertexAIService).auth!.currentUser!.getIdToken(); - } - } - if (modelParams.model.includes('/')) { - if (modelParams.model.startsWith('models/')) { - // Add "publishers/google" if the user is only passing in 'models/model-name'. - this.model = `publishers/google/${modelParams.model}`; - } else { - // Any other custom format (e.g. tuned models) must be passed in correctly. - this.model = modelParams.model; - } - } else { - // If path is not included, assume it's a non-tuned model. - this.model = `publishers/google/models/${modelParams.model}`; - } - this.generationConfig = modelParams.generationConfig || {}; - this.safetySettings = modelParams.safetySettings || []; - this.tools = modelParams.tools; - this.toolConfig = modelParams.toolConfig; - this.systemInstruction = formatSystemInstruction(modelParams.systemInstruction); - this.requestOptions = requestOptions || {}; - } - - /** - * Makes a single non-streaming call to the model - * and returns an object containing a single {@link GenerateContentResponse}. - */ - async generateContent( - request: GenerateContentRequest | string | Array, - ): Promise { - const formattedParams = formatGenerateContentInput(request); - return generateContent( - this._apiSettings, - this.model, - { - generationConfig: this.generationConfig, - safetySettings: this.safetySettings, - tools: this.tools, - toolConfig: this.toolConfig, - systemInstruction: this.systemInstruction, - ...formattedParams, - }, - this.requestOptions, - ); - } - - /** - * Makes a single streaming call to the model - * and returns an object containing an iterable stream that iterates - * over all chunks in the streaming response as well as - * a promise that returns the final aggregated response. - */ - async generateContentStream( - request: GenerateContentRequest | string | Array, - ): Promise { - const formattedParams = formatGenerateContentInput(request); - return generateContentStream( - this._apiSettings, - this.model, - { - generationConfig: this.generationConfig, - safetySettings: this.safetySettings, - tools: this.tools, - toolConfig: this.toolConfig, - systemInstruction: this.systemInstruction, - ...formattedParams, - }, - this.requestOptions, - ); - } - - /** - * Gets a new {@link ChatSession} instance which can be used for - * multi-turn chats. - */ - startChat(startChatParams?: StartChatParams): ChatSession { - return new ChatSession( - this._apiSettings, - this.model, - { - tools: this.tools, - toolConfig: this.toolConfig, - systemInstruction: this.systemInstruction, - ...startChatParams, - }, - this.requestOptions, - ); - } - - /** - * Counts the tokens in the provided request. - */ - async countTokens( - request: CountTokensRequest | string | Array, - ): Promise { - const formattedParams = formatGenerateContentInput(request); - return countTokens(this._apiSettings, this.model, formattedParams); - } -} diff --git a/packages/vertexai/lib/public-types.ts b/packages/vertexai/lib/public-types.ts deleted file mode 100644 index 24c6be6efa..0000000000 --- a/packages/vertexai/lib/public-types.ts +++ /dev/null @@ -1,46 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { ReactNativeFirebase } from '@react-native-firebase/app'; -import { FirebaseAuthTypes } from '@react-native-firebase/auth'; -import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; - -export * from './types'; - -/** - * An instance of the Vertex AI in Firebase SDK. - * @public - */ -export interface VertexAI { - /** - * The {@link @firebase/app#FirebaseApp} this {@link VertexAI} instance is associated with. - */ - app: ReactNativeFirebase.FirebaseApp; - location: string; - appCheck?: FirebaseAppCheckTypes.Module | null; - auth?: FirebaseAuthTypes.Module | null; -} - -/** - * Options when initializing the Vertex AI in Firebase SDK. - * @public - */ -export interface VertexAIOptions { - location?: string; - appCheck?: FirebaseAppCheckTypes.Module | null; - auth?: FirebaseAuthTypes.Module | null; -} diff --git a/packages/vertexai/lib/requests/request-helpers.ts b/packages/vertexai/lib/requests/request-helpers.ts deleted file mode 100644 index 9de045a4ee..0000000000 --- a/packages/vertexai/lib/requests/request-helpers.ts +++ /dev/null @@ -1,116 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Content, GenerateContentRequest, Part, VertexAIErrorCode } from '../types'; -import { VertexAIError } from '../errors'; - -export function formatSystemInstruction(input?: string | Part | Content): Content | undefined { - if (input == null) { - return undefined; - } else if (typeof input === 'string') { - return { role: 'system', parts: [{ text: input }] } as Content; - } else if ((input as Part).text) { - return { role: 'system', parts: [input as Part] }; - } else if ((input as Content).parts) { - if (!(input as Content).role) { - return { role: 'system', parts: (input as Content).parts }; - } else { - return input as Content; - } - } - - return undefined; -} - -export function formatNewContent(request: string | Array): Content { - let newParts: Part[] = []; - if (typeof request === 'string') { - newParts = [{ text: request }]; - } else { - for (const partOrString of request) { - if (typeof partOrString === 'string') { - newParts.push({ text: partOrString }); - } else { - newParts.push(partOrString); - } - } - } - return assignRoleToPartsAndValidateSendMessageRequest(newParts); -} - -/** - * When multiple Part types (i.e. FunctionResponsePart and TextPart) are - * passed in a single Part array, we may need to assign different roles to each - * part. Currently only FunctionResponsePart requires a role other than 'user'. - * @private - * @param parts Array of parts to pass to the model - * @returns Array of content items - */ -function assignRoleToPartsAndValidateSendMessageRequest(parts: Part[]): Content { - const userContent: Content = { role: 'user', parts: [] }; - const functionContent: Content = { role: 'function', parts: [] }; - let hasUserContent = false; - let hasFunctionContent = false; - for (const part of parts) { - if ('functionResponse' in part) { - functionContent.parts.push(part); - hasFunctionContent = true; - } else { - userContent.parts.push(part); - hasUserContent = true; - } - } - - if (hasUserContent && hasFunctionContent) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - 'Within a single message, FunctionResponse cannot be mixed with other type of Part in the request for sending chat message.', - ); - } - - if (!hasUserContent && !hasFunctionContent) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - 'No Content is provided for sending chat message.', - ); - } - - if (hasUserContent) { - return userContent; - } - - return functionContent; -} - -export function formatGenerateContentInput( - params: GenerateContentRequest | string | Array, -): GenerateContentRequest { - let formattedRequest: GenerateContentRequest; - if ((params as GenerateContentRequest).contents) { - formattedRequest = params as GenerateContentRequest; - } else { - // Array or string - const content = formatNewContent(params as string | Array); - formattedRequest = { contents: [content] }; - } - if ((params as GenerateContentRequest).systemInstruction) { - formattedRequest.systemInstruction = formatSystemInstruction( - (params as GenerateContentRequest).systemInstruction, - ); - } - return formattedRequest; -} diff --git a/packages/vertexai/lib/requests/request.ts b/packages/vertexai/lib/requests/request.ts deleted file mode 100644 index e055094f90..0000000000 --- a/packages/vertexai/lib/requests/request.ts +++ /dev/null @@ -1,242 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { Platform } from 'react-native'; -import { ErrorDetails, RequestOptions, VertexAIErrorCode } from '../types'; -import { VertexAIError } from '../errors'; -import { ApiSettings } from '../types/internal'; -import { - DEFAULT_API_VERSION, - DEFAULT_BASE_URL, - DEFAULT_FETCH_TIMEOUT_MS, - LANGUAGE_TAG, - PACKAGE_VERSION, -} from '../constants'; -import { logger } from '../logger'; - -export enum Task { - GENERATE_CONTENT = 'generateContent', - STREAM_GENERATE_CONTENT = 'streamGenerateContent', - COUNT_TOKENS = 'countTokens', -} - -export class RequestUrl { - constructor( - public model: string, - public task: Task, - public apiSettings: ApiSettings, - public stream: boolean, - public requestOptions?: RequestOptions, - ) {} - toString(): string { - // @ts-ignore - const isTestEnvironment = globalThis.RNFB_VERTEXAI_EMULATOR_URL; - if (isTestEnvironment) { - let emulatorUrl; - logger.info( - 'Running VertexAI in test environment, pointing to Firebase Functions emulator URL', - ); - const isAndroid = Platform.OS === 'android'; - - if (this.stream) { - emulatorUrl = `http://${isAndroid ? '10.0.2.2' : '127.0.0.1'}:5001/react-native-firebase-testing/us-central1/testFetchStream`; - } else { - emulatorUrl = `http://${isAndroid ? '10.0.2.2' : '127.0.0.1'}:5001/react-native-firebase-testing/us-central1/testFetch`; - } - return emulatorUrl; - } - - const apiVersion = DEFAULT_API_VERSION; - const baseUrl = this.requestOptions?.baseUrl || DEFAULT_BASE_URL; - let url = `${baseUrl}/${apiVersion}`; - url += `/projects/${this.apiSettings.project}`; - url += `/locations/${this.apiSettings.location}`; - url += `/${this.model}`; - url += `:${this.task}`; - if (this.stream) { - url += '?alt=sse'; - } - return url; - } - - /** - * If the model needs to be passed to the backend, it needs to - * include project and location path. - */ - get fullModelString(): string { - let modelString = `projects/${this.apiSettings.project}`; - modelString += `/locations/${this.apiSettings.location}`; - modelString += `/${this.model}`; - return modelString; - } -} - -/** - * Log language and "fire/version" to x-goog-api-client - */ -function getClientHeaders(): string { - const loggingTags = []; - loggingTags.push(`${LANGUAGE_TAG}/${PACKAGE_VERSION}`); - loggingTags.push(`fire/${PACKAGE_VERSION}`); - return loggingTags.join(' '); -} - -export async function getHeaders(url: RequestUrl): Promise { - const headers = new Headers(); - headers.append('Content-Type', 'application/json'); - headers.append('x-goog-api-client', getClientHeaders()); - headers.append('x-goog-api-key', url.apiSettings.apiKey); - if (url.apiSettings.getAppCheckToken) { - let appCheckToken; - - try { - appCheckToken = await url.apiSettings.getAppCheckToken(); - } catch (e) { - logger.warn(`Unable to obtain a valid App Check token: ${e}`); - } - if (appCheckToken) { - headers.append('X-Firebase-AppCheck', appCheckToken.token); - } - } - - if (url.apiSettings.getAuthToken) { - const authToken = await url.apiSettings.getAuthToken(); - if (authToken) { - headers.append('Authorization', `Firebase ${authToken}`); - } - } - - return headers; -} - -export async function constructRequest( - model: string, - task: Task, - apiSettings: ApiSettings, - stream: boolean, - body: string, - requestOptions?: RequestOptions, -): Promise<{ url: string; fetchOptions: RequestInit }> { - const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); - return { - url: url.toString(), - fetchOptions: { - method: 'POST', - headers: await getHeaders(url), - body, - }, - }; -} - -export async function makeRequest( - model: string, - task: Task, - apiSettings: ApiSettings, - stream: boolean, - body: string, - requestOptions?: RequestOptions, -): Promise { - const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); - let response; - let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; - try { - const request = await constructRequest(model, task, apiSettings, stream, body, requestOptions); - const timeoutMillis = - requestOptions?.timeout != null && requestOptions.timeout >= 0 - ? requestOptions.timeout - : DEFAULT_FETCH_TIMEOUT_MS; - const abortController = new AbortController(); - fetchTimeoutId = setTimeout(() => abortController.abort(), timeoutMillis); - request.fetchOptions.signal = abortController.signal; - const fetchOptions = stream - ? { - ...request.fetchOptions, - reactNative: { - textStreaming: true, - }, - } - : request.fetchOptions; - response = await fetch(request.url, fetchOptions); - if (!response.ok) { - let message = ''; - let errorDetails; - try { - const json = await response.json(); - message = json.error.message; - if (json.error.details) { - message += ` ${JSON.stringify(json.error.details)}`; - errorDetails = json.error.details; - } - } catch (_) { - // ignored - } - if ( - response.status === 403 && - errorDetails.some((detail: ErrorDetails) => detail.reason === 'SERVICE_DISABLED') && - errorDetails.some((detail: ErrorDetails) => - (detail.links as Array>)?.[0]?.description?.includes( - 'Google developers console API activation', - ), - ) - ) { - throw new VertexAIError( - VertexAIErrorCode.API_NOT_ENABLED, - `The Vertex AI in Firebase SDK requires the Vertex AI in Firebase ` + - `API ('firebasevertexai.googleapis.com') to be enabled in your ` + - `Firebase project. Enable this API by visiting the Firebase Console ` + - `at https://console.firebase.google.com/project/${url.apiSettings.project}/genai/ ` + - `and clicking "Get started". If you enabled this API recently, ` + - `wait a few minutes for the action to propagate to our systems and ` + - `then retry.`, - { - status: response.status, - statusText: response.statusText, - errorDetails, - }, - ); - } - throw new VertexAIError( - VertexAIErrorCode.FETCH_ERROR, - `Error fetching from ${url}: [${response.status} ${response.statusText}] ${message}`, - { - status: response.status, - statusText: response.statusText, - errorDetails, - }, - ); - } - } catch (e) { - let err = e as Error; - if ( - (e as VertexAIError).code !== VertexAIErrorCode.FETCH_ERROR && - (e as VertexAIError).code !== VertexAIErrorCode.API_NOT_ENABLED && - e instanceof Error - ) { - err = new VertexAIError( - VertexAIErrorCode.ERROR, - `Error fetching from ${url.toString()}: ${e.message}`, - ); - err.stack = e.stack; - } - - throw err; - } finally { - if (fetchTimeoutId) { - clearTimeout(fetchTimeoutId); - } - } - return response; -} diff --git a/packages/vertexai/lib/requests/response-helpers.ts b/packages/vertexai/lib/requests/response-helpers.ts deleted file mode 100644 index c7abc9d923..0000000000 --- a/packages/vertexai/lib/requests/response-helpers.ts +++ /dev/null @@ -1,186 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { - EnhancedGenerateContentResponse, - FinishReason, - FunctionCall, - GenerateContentCandidate, - GenerateContentResponse, - VertexAIErrorCode, -} from '../types'; -import { VertexAIError } from '../errors'; -import { logger } from '../logger'; - -/** - * Creates an EnhancedGenerateContentResponse object that has helper functions and - * other modifications that improve usability. - */ -export function createEnhancedContentResponse( - response: GenerateContentResponse, -): EnhancedGenerateContentResponse { - /** - * The Vertex AI backend omits default values. - * This causes the `index` property to be omitted from the first candidate in the - * response, since it has index 0, and 0 is a default value. - * See: https://github.com/firebase/firebase-js-sdk/issues/8566 - */ - if (response.candidates && !response.candidates[0]?.hasOwnProperty('index')) { - response.candidates[0]!.index = 0; - } - - const responseWithHelpers = addHelpers(response); - return responseWithHelpers; -} - -/** - * Adds convenience helper methods to a response object, including stream - * chunks (as long as each chunk is a complete GenerateContentResponse JSON). - */ -export function addHelpers(response: GenerateContentResponse): EnhancedGenerateContentResponse { - (response as EnhancedGenerateContentResponse).text = () => { - if (response.candidates && response.candidates.length > 0) { - if (response.candidates.length > 1) { - logger.warn( - `This response had ${response.candidates.length} ` + - `candidates. Returning text from the first candidate only. ` + - `Access response.candidates directly to use the other candidates.`, - ); - } - if (hadBadFinishReason(response.candidates[0]!)) { - throw new VertexAIError( - VertexAIErrorCode.RESPONSE_ERROR, - `Response error: ${formatBlockErrorMessage( - response, - )}. Response body stored in error.response`, - { - response, - }, - ); - } - return getText(response); - } else if (response.promptFeedback) { - throw new VertexAIError( - VertexAIErrorCode.RESPONSE_ERROR, - `Text not available. ${formatBlockErrorMessage(response)}`, - { - response, - }, - ); - } - return ''; - }; - (response as EnhancedGenerateContentResponse).functionCalls = () => { - if (response.candidates && response.candidates.length > 0) { - if (response.candidates.length > 1) { - logger.warn( - `This response had ${response.candidates.length} ` + - `candidates. Returning function calls from the first candidate only. ` + - `Access response.candidates directly to use the other candidates.`, - ); - } - if (hadBadFinishReason(response.candidates[0]!)) { - throw new VertexAIError( - VertexAIErrorCode.RESPONSE_ERROR, - `Response error: ${formatBlockErrorMessage( - response, - )}. Response body stored in error.response`, - { - response, - }, - ); - } - return getFunctionCalls(response); - } else if (response.promptFeedback) { - throw new VertexAIError( - VertexAIErrorCode.RESPONSE_ERROR, - `Function call not available. ${formatBlockErrorMessage(response)}`, - { - response, - }, - ); - } - return undefined; - }; - return response as EnhancedGenerateContentResponse; -} - -/** - * Returns all text found in all parts of first candidate. - */ -export function getText(response: GenerateContentResponse): string { - const textStrings = []; - if (response.candidates?.[0]?.content?.parts) { - for (const part of response.candidates?.[0].content?.parts) { - if (part.text) { - textStrings.push(part.text); - } - } - } - if (textStrings.length > 0) { - return textStrings.join(''); - } else { - return ''; - } -} - -/** - * Returns {@link FunctionCall}s associated with first candidate. - */ -export function getFunctionCalls(response: GenerateContentResponse): FunctionCall[] | undefined { - const functionCalls: FunctionCall[] = []; - if (response.candidates?.[0]?.content?.parts) { - for (const part of response.candidates?.[0].content?.parts) { - if (part.functionCall) { - functionCalls.push(part.functionCall); - } - } - } - if (functionCalls.length > 0) { - return functionCalls; - } else { - return undefined; - } -} - -const badFinishReasons = [FinishReason.RECITATION, FinishReason.SAFETY]; - -function hadBadFinishReason(candidate: GenerateContentCandidate): boolean { - return !!candidate.finishReason && badFinishReasons.includes(candidate.finishReason); -} - -export function formatBlockErrorMessage(response: GenerateContentResponse): string { - let message = ''; - if ((!response.candidates || response.candidates.length === 0) && response.promptFeedback) { - message += 'Response was blocked'; - if (response.promptFeedback?.blockReason) { - message += ` due to ${response.promptFeedback.blockReason}`; - } - if (response.promptFeedback?.blockReasonMessage) { - message += `: ${response.promptFeedback.blockReasonMessage}`; - } - } else if (response.candidates?.[0]) { - const firstCandidate = response.candidates[0]; - if (hadBadFinishReason(firstCandidate)) { - message += `Candidate was blocked due to ${firstCandidate.finishReason}`; - if (firstCandidate.finishMessage) { - message += `: ${firstCandidate.finishMessage}`; - } - } - } - return message; -} diff --git a/packages/vertexai/lib/requests/schema-builder.ts b/packages/vertexai/lib/requests/schema-builder.ts deleted file mode 100644 index 92003a0950..0000000000 --- a/packages/vertexai/lib/requests/schema-builder.ts +++ /dev/null @@ -1,281 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { VertexAIError } from '../errors'; -import { VertexAIErrorCode } from '../types'; -import { - SchemaInterface, - SchemaType, - SchemaParams, - SchemaRequest, - ObjectSchemaInterface, -} from '../types/schema'; - -/** - * Parent class encompassing all Schema types, with static methods that - * allow building specific Schema types. This class can be converted with - * `JSON.stringify()` into a JSON string accepted by Vertex AI REST endpoints. - * (This string conversion is automatically done when calling SDK methods.) - * @public - */ -export abstract class Schema implements SchemaInterface { - /** - * Optional. The type of the property. {@link - * SchemaType}. - */ - type: SchemaType; - /** Optional. The format of the property. - * Supported formats:
- *
    - *
  • for NUMBER type: "float", "double"
  • - *
  • for INTEGER type: "int32", "int64"
  • - *
  • for STRING type: "email", "byte", etc
  • - *
- */ - format?: string; - /** Optional. The description of the property. */ - description?: string; - /** Optional. Whether the property is nullable. Defaults to false. */ - nullable: boolean; - /** Optional. The example of the property. */ - example?: unknown; - /** - * Allows user to add other schema properties that have not yet - * been officially added to the SDK. - */ - [key: string]: unknown; - - constructor(schemaParams: SchemaInterface) { - for (const paramKey in schemaParams) { - this[paramKey] = schemaParams[paramKey]; - } - // Ensure these are explicitly set to avoid TS errors. - this.type = schemaParams.type; - this.nullable = schemaParams.hasOwnProperty('nullable') ? !!schemaParams.nullable : false; - } - - /** - * Defines how this Schema should be serialized as JSON. - * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/JSON/stringify#tojson_behavior - * @internal - */ - toJSON(): SchemaRequest { - const obj: { type: SchemaType; [key: string]: unknown } = { - type: this.type, - }; - for (const prop in this) { - if (this.hasOwnProperty(prop) && this[prop] !== undefined) { - if (prop !== 'required' || this.type === SchemaType.OBJECT) { - obj[prop] = this[prop]; - } - } - } - return obj as SchemaRequest; - } - - static array(arrayParams: SchemaParams & { items: Schema }): ArraySchema { - return new ArraySchema(arrayParams, arrayParams.items); - } - - static object( - objectParams: SchemaParams & { - properties: { - [k: string]: Schema; - }; - optionalProperties?: string[]; - }, - ): ObjectSchema { - return new ObjectSchema(objectParams, objectParams.properties, objectParams.optionalProperties); - } - - static string(stringParams?: SchemaParams): StringSchema { - return new StringSchema(stringParams); - } - - static enumString(stringParams: SchemaParams & { enum: string[] }): StringSchema { - return new StringSchema(stringParams, stringParams.enum); - } - - static integer(integerParams?: SchemaParams): IntegerSchema { - return new IntegerSchema(integerParams); - } - - static number(numberParams?: SchemaParams): NumberSchema { - return new NumberSchema(numberParams); - } - - static boolean(booleanParams?: SchemaParams): BooleanSchema { - return new BooleanSchema(booleanParams); - } -} - -/** - * A type that includes all specific Schema types. - * @public - */ -export type TypedSchema = - | IntegerSchema - | NumberSchema - | StringSchema - | BooleanSchema - | ObjectSchema - | ArraySchema; - -/** - * Schema class for "integer" types. - * @public - */ -export class IntegerSchema extends Schema { - constructor(schemaParams?: SchemaParams) { - super({ - type: SchemaType.INTEGER, - ...schemaParams, - }); - } -} - -/** - * Schema class for "number" types. - * @public - */ -export class NumberSchema extends Schema { - constructor(schemaParams?: SchemaParams) { - super({ - type: SchemaType.NUMBER, - ...schemaParams, - }); - } -} - -/** - * Schema class for "boolean" types. - * @public - */ -export class BooleanSchema extends Schema { - constructor(schemaParams?: SchemaParams) { - super({ - type: SchemaType.BOOLEAN, - ...schemaParams, - }); - } -} - -/** - * Schema class for "string" types. Can be used with or without - * enum values. - * @public - */ -export class StringSchema extends Schema { - enum?: string[]; - constructor(schemaParams?: SchemaParams, enumValues?: string[]) { - super({ - type: SchemaType.STRING, - ...schemaParams, - }); - this.enum = enumValues; - } - - /** - * @internal - */ - toJSON(): SchemaRequest { - const obj = super.toJSON(); - if (this.enum) { - obj['enum'] = this.enum; - } - return obj as SchemaRequest; - } -} - -/** - * Schema class for "array" types. - * The `items` param should refer to the type of item that can be a member - * of the array. - * @public - */ -export class ArraySchema extends Schema { - constructor( - schemaParams: SchemaParams, - public items: TypedSchema, - ) { - super({ - type: SchemaType.ARRAY, - ...schemaParams, - }); - } - - /** - * @internal - */ - toJSON(): SchemaRequest { - const obj = super.toJSON(); - obj.items = this.items.toJSON(); - return obj; - } -} - -/** - * Schema class for "object" types. - * The `properties` param must be a map of `Schema` objects. - * @public - */ -export class ObjectSchema extends Schema { - constructor( - schemaParams: SchemaParams, - public properties: { - [k: string]: TypedSchema; - }, - public optionalProperties: string[] = [], - ) { - super({ - type: SchemaType.OBJECT, - ...schemaParams, - }); - } - - /** - * @internal - */ - toJSON(): SchemaRequest { - const obj = super.toJSON(); - obj.properties = { ...this.properties }; - const required = []; - if (this.optionalProperties) { - for (const propertyKey of this.optionalProperties) { - if (!this.properties.hasOwnProperty(propertyKey)) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_SCHEMA, - `Property "${propertyKey}" specified in "optionalProperties" does not exist.`, - ); - } - } - } - for (const propertyKey in this.properties) { - if (this.properties.hasOwnProperty(propertyKey)) { - obj.properties[propertyKey] = this.properties[propertyKey]!.toJSON() as SchemaRequest; - if (!this.optionalProperties.includes(propertyKey)) { - required.push(propertyKey); - } - } - } - if (required.length > 0) { - obj.required = required; - } - delete (obj as ObjectSchemaInterface).optionalProperties; - return obj as SchemaRequest; - } -} diff --git a/packages/vertexai/lib/requests/stream-reader.ts b/packages/vertexai/lib/requests/stream-reader.ts deleted file mode 100644 index d24f6d44bf..0000000000 --- a/packages/vertexai/lib/requests/stream-reader.ts +++ /dev/null @@ -1,213 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { ReadableStream } from 'web-streams-polyfill'; -import { - EnhancedGenerateContentResponse, - GenerateContentCandidate, - GenerateContentResponse, - GenerateContentStreamResult, - Part, - VertexAIErrorCode, -} from '../types'; -import { VertexAIError } from '../errors'; -import { createEnhancedContentResponse } from './response-helpers'; - -const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/; - -/** - * Process a response.body stream from the backend and return an - * iterator that provides one complete GenerateContentResponse at a time - * and a promise that resolves with a single aggregated - * GenerateContentResponse. - * - * @param response - Response from a fetch call - */ -export function processStream(response: Response): GenerateContentStreamResult { - const inputStream = new ReadableStream({ - async start(controller) { - const reader = response.body!.getReader(); - const decoder = new TextDecoder('utf-8'); - while (true) { - const { done, value } = await reader.read(); - if (done) { - controller.close(); - break; - } - const decodedValue = decoder.decode(value, { stream: true }); - controller.enqueue(decodedValue); - } - }, - }); - const responseStream = getResponseStream(inputStream); - const [stream1, stream2] = responseStream.tee(); - return { - stream: generateResponseSequence(stream1), - response: getResponsePromise(stream2), - }; -} - -async function getResponsePromise( - stream: ReadableStream, -): Promise { - const allResponses: GenerateContentResponse[] = []; - const reader = stream.getReader(); - while (true) { - const { done, value } = await reader.read(); - if (done) { - const enhancedResponse = createEnhancedContentResponse(aggregateResponses(allResponses)); - return enhancedResponse; - } - allResponses.push(value); - } -} - -async function* generateResponseSequence( - stream: ReadableStream, -): AsyncGenerator { - const reader = stream.getReader(); - while (true) { - const { value, done } = await reader.read(); - if (done) { - break; - } - - const enhancedResponse = createEnhancedContentResponse(value); - yield enhancedResponse; - } -} - -/** - * Reads a raw stream from the fetch response and join incomplete - * chunks, returning a new stream that provides a single complete - * GenerateContentResponse in each iteration. - */ -export function getResponseStream(inputStream: ReadableStream): ReadableStream { - const reader = inputStream.getReader(); - const stream = new ReadableStream({ - start(controller) { - let currentText = ''; - return pump().then(() => undefined); - function pump(): Promise<(() => Promise) | undefined> { - return reader.read().then(({ value, done }) => { - if (done) { - if (currentText.trim()) { - controller.error( - new VertexAIError(VertexAIErrorCode.PARSE_FAILED, 'Failed to parse stream'), - ); - return; - } - controller.close(); - return; - } - - currentText += value; - let match = currentText.match(responseLineRE); - let parsedResponse: T; - while (match) { - try { - parsedResponse = JSON.parse(match[1]!); - } catch (_) { - controller.error( - new VertexAIError( - VertexAIErrorCode.PARSE_FAILED, - `Error parsing JSON response: "${match[1]}`, - ), - ); - return; - } - controller.enqueue(parsedResponse); - currentText = currentText.substring(match[0].length); - match = currentText.match(responseLineRE); - } - return pump(); - }); - } - }, - }); - return stream; -} - -/** - * Aggregates an array of `GenerateContentResponse`s into a single - * GenerateContentResponse. - */ -export function aggregateResponses(responses: GenerateContentResponse[]): GenerateContentResponse { - const lastResponse = responses[responses.length - 1]; - const aggregatedResponse: GenerateContentResponse = { - promptFeedback: lastResponse?.promptFeedback, - }; - for (const response of responses) { - if (response.candidates) { - for (const candidate of response.candidates) { - // Index will be undefined if it's the first index (0), so we should use 0 if it's undefined. - // See: https://github.com/firebase/firebase-js-sdk/issues/8566 - const i = candidate.index || 0; - if (!aggregatedResponse.candidates) { - aggregatedResponse.candidates = []; - } - if (!aggregatedResponse.candidates[i]) { - aggregatedResponse.candidates[i] = { - index: candidate.index, - } as GenerateContentCandidate; - } - // Keep overwriting, the last one will be final - aggregatedResponse.candidates[i].citationMetadata = candidate.citationMetadata; - aggregatedResponse.candidates[i].finishReason = candidate.finishReason; - aggregatedResponse.candidates[i].finishMessage = candidate.finishMessage; - aggregatedResponse.candidates[i].safetyRatings = candidate.safetyRatings; - - /** - * Candidates should always have content and parts, but this handles - * possible malformed responses. - */ - if (candidate.content && candidate.content.parts) { - if (!aggregatedResponse.candidates[i].content) { - aggregatedResponse.candidates[i].content = { - role: candidate.content.role || 'user', - parts: [], - }; - } - const newPart: Partial = {}; - for (const part of candidate.content.parts) { - if (part.text !== undefined) { - // The backend can send empty text parts. If these are sent back - // (e.g. in chat history), the backend will respond with an error. - // To prevent this, ignore empty text parts. - if (part.text === '') { - continue; - } - newPart.text = part.text; - } - if (part.functionCall) { - newPart.functionCall = part.functionCall; - } - if (Object.keys(newPart).length === 0) { - throw new VertexAIError( - VertexAIErrorCode.INVALID_CONTENT, - 'Part should have at least one property, but there are none. This is likely caused ' + - 'by a malformed response from the backend.', - ); - } - aggregatedResponse.candidates[i].content.parts.push(newPart as Part); - } - } - } - } - } - return aggregatedResponse; -} diff --git a/packages/vertexai/lib/service.ts b/packages/vertexai/lib/service.ts index e90ffa9668..54599db3fa 100644 --- a/packages/vertexai/lib/service.ts +++ b/packages/vertexai/lib/service.ts @@ -16,7 +16,7 @@ */ import { ReactNativeFirebase } from '@react-native-firebase/app'; -import { VertexAI, VertexAIOptions } from './public-types'; +import { VertexAI, VertexAIOptions } from '@react-native-firebase/ai'; import { DEFAULT_LOCATION } from './constants'; import { FirebaseAuthTypes } from '@react-native-firebase/auth'; import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; diff --git a/packages/vertexai/lib/types/content.ts b/packages/vertexai/lib/types/content.ts deleted file mode 100644 index abf5d29222..0000000000 --- a/packages/vertexai/lib/types/content.ts +++ /dev/null @@ -1,162 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Role } from './enums'; - -/** - * Content type for both prompts and response candidates. - * @public - */ -export interface Content { - role: Role; - parts: Part[]; -} - -/** - * Content part - includes text, image/video, or function call/response - * part types. - * @public - */ -export type Part = - | TextPart - | InlineDataPart - | FunctionCallPart - | FunctionResponsePart - | FileDataPart; - -/** - * Content part interface if the part represents a text string. - * @public - */ -export interface TextPart { - text: string; - inlineData?: never; - functionCall?: never; - functionResponse?: never; -} - -/** - * Content part interface if the part represents an image. - * @public - */ -export interface InlineDataPart { - text?: never; - inlineData: GenerativeContentBlob; - functionCall?: never; - functionResponse?: never; - /** - * Applicable if `inlineData` is a video. - */ - videoMetadata?: VideoMetadata; -} - -/** - * Describes the input video content. - * @public - */ -export interface VideoMetadata { - /** - * The start offset of the video in - * protobuf {@link https://cloud.google.com/ruby/docs/reference/google-cloud-workflows-v1/latest/Google-Protobuf-Duration#json-mapping | Duration} format. - */ - startOffset: string; - /** - * The end offset of the video in - * protobuf {@link https://cloud.google.com/ruby/docs/reference/google-cloud-workflows-v1/latest/Google-Protobuf-Duration#json-mapping | Duration} format. - */ - endOffset: string; -} - -/** - * Content part interface if the part represents a {@link FunctionCall}. - * @public - */ -export interface FunctionCallPart { - text?: never; - inlineData?: never; - functionCall: FunctionCall; - functionResponse?: never; -} - -/** - * Content part interface if the part represents {@link FunctionResponse}. - * @public - */ -export interface FunctionResponsePart { - text?: never; - inlineData?: never; - functionCall?: never; - functionResponse: FunctionResponse; -} - -/** - * Content part interface if the part represents {@link FileData} - * @public - */ -export interface FileDataPart { - text?: never; - inlineData?: never; - functionCall?: never; - functionResponse?: never; - fileData: FileData; -} - -/** - * A predicted {@link FunctionCall} returned from the model - * that contains a string representing the {@link FunctionDeclaration.name} - * and a structured JSON object containing the parameters and their values. - * @public - */ -export interface FunctionCall { - name: string; - args: object; -} - -/** - * The result output from a {@link FunctionCall} that contains a string - * representing the {@link FunctionDeclaration.name} - * and a structured JSON object containing any output - * from the function is used as context to the model. - * This should contain the result of a {@link FunctionCall} - * made based on model prediction. - * @public - */ -export interface FunctionResponse { - name: string; - response: object; -} - -/** - * Interface for sending an image. - * @public - */ -export interface GenerativeContentBlob { - mimeType: string; - /** - * Image as a base64 string. - */ - data: string; -} - -/** - * Data pointing to a file uploaded on Google Cloud Storage. - * @public - */ -export interface FileData { - mimeType: string; - fileUri: string; -} diff --git a/packages/vertexai/lib/types/enums.ts b/packages/vertexai/lib/types/enums.ts deleted file mode 100644 index 010aff903a..0000000000 --- a/packages/vertexai/lib/types/enums.ts +++ /dev/null @@ -1,149 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Role is the producer of the content. - * @public - */ -export type Role = (typeof POSSIBLE_ROLES)[number]; - -/** - * Possible roles. - * @public - */ -export const POSSIBLE_ROLES = ['user', 'model', 'function', 'system'] as const; - -/** - * Harm categories that would cause prompts or candidates to be blocked. - * @public - */ -export enum HarmCategory { - HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH', - HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT', - HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT', -} - -/** - * Threshold above which a prompt or candidate will be blocked. - * @public - */ -export enum HarmBlockThreshold { - // Content with NEGLIGIBLE will be allowed. - BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE', - // Content with NEGLIGIBLE and LOW will be allowed. - BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE', - // Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed. - BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH', - // All content will be allowed. - BLOCK_NONE = 'BLOCK_NONE', -} - -/** - * @public - */ -export enum HarmBlockMethod { - // The harm block method uses both probability and severity scores. - SEVERITY = 'SEVERITY', - // The harm block method uses the probability score. - PROBABILITY = 'PROBABILITY', -} - -/** - * Probability that a prompt or candidate matches a harm category. - * @public - */ -export enum HarmProbability { - // Content has a negligible chance of being unsafe. - NEGLIGIBLE = 'NEGLIGIBLE', - // Content has a low chance of being unsafe. - LOW = 'LOW', - // Content has a medium chance of being unsafe. - MEDIUM = 'MEDIUM', - // Content has a high chance of being unsafe. - HIGH = 'HIGH', -} - -/** - * Harm severity levels. - * @public - */ -export enum HarmSeverity { - // Negligible level of harm severity. - HARM_SEVERITY_NEGLIGIBLE = 'HARM_SEVERITY_NEGLIGIBLE', - // Low level of harm severity. - HARM_SEVERITY_LOW = 'HARM_SEVERITY_LOW', - // Medium level of harm severity. - HARM_SEVERITY_MEDIUM = 'HARM_SEVERITY_MEDIUM', - // High level of harm severity. - HARM_SEVERITY_HIGH = 'HARM_SEVERITY_HIGH', -} - -/** - * Reason that a prompt was blocked. - * @public - */ -export enum BlockReason { - // The prompt was blocked because it contained terms from the terminology blocklist. - BLOCKLIST = 'BLOCKLIST', - // The prompt was blocked due to prohibited content. - PROHIBITED_CONTENT = 'PROHIBITED_CONTENT', - // Content was blocked by safety settings. - SAFETY = 'SAFETY', - // Content was blocked, but the reason is uncategorized. - OTHER = 'OTHER', -} - -/** - * Reason that a candidate finished. - * @public - */ -export enum FinishReason { - // Token generation was stopped because the response contained forbidden terms. - BLOCKLIST = 'BLOCKLIST', - // Token generation was stopped because the response contained potentially prohibited content. - PROHIBITED_CONTENT = 'PROHIBITED_CONTENT', - // Token generation was stopped because of Sensitive Personally Identifiable Information (SPII). - SPII = 'SPII', - // Natural stop point of the model or provided stop sequence. - STOP = 'STOP', - // The maximum number of tokens as specified in the request was reached. - MAX_TOKENS = 'MAX_TOKENS', - // The candidate content was flagged for safety reasons. - SAFETY = 'SAFETY', - // The candidate content was flagged for recitation reasons. - RECITATION = 'RECITATION', - // Unknown reason. - OTHER = 'OTHER', -} - -/** - * @public - */ -export enum FunctionCallingMode { - // Default model behavior, model decides to predict either a function call - // or a natural language response. - AUTO = 'AUTO', - // Model is constrained to always predicting a function call only. - // If "allowed_function_names" is set, the predicted function call will be - // limited to any one of "allowed_function_names", else the predicted - // function call will be any one of the provided "function_declarations". - ANY = 'ANY', - // Model will not predict any function call. Model behavior is same as when - // not passing any function declarations. - NONE = 'NONE', -} diff --git a/packages/vertexai/lib/types/error.ts b/packages/vertexai/lib/types/error.ts deleted file mode 100644 index c65e09c55f..0000000000 --- a/packages/vertexai/lib/types/error.ts +++ /dev/null @@ -1,98 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { GenerateContentResponse } from './responses'; - -/** - * Details object that may be included in an error response. - * - * @public - */ -export interface ErrorDetails { - '@type'?: string; - - /** The reason for the error. */ - reason?: string; - - /** The domain where the error occurred. */ - domain?: string; - - /** Additional metadata about the error. */ - metadata?: Record; - - /** Any other relevant information about the error. */ - [key: string]: unknown; -} - -/** - * Details object that contains data originating from a bad HTTP response. - * - * @public - */ -export interface CustomErrorData { - /** HTTP status code of the error response. */ - status?: number; - - /** HTTP status text of the error response. */ - statusText?: string; - - /** Response from a {@link GenerateContentRequest} */ - response?: GenerateContentResponse; - - /** Optional additional details about the error. */ - errorDetails?: ErrorDetails[]; -} - -/** - * Standardized error codes that {@link VertexAIError} can have. - * - * @public - */ -export const enum VertexAIErrorCode { - /** A generic error occurred. */ - ERROR = 'error', - - /** An error occurred in a request. */ - REQUEST_ERROR = 'request-error', - - /** An error occurred in a response. */ - RESPONSE_ERROR = 'response-error', - - /** An error occurred while performing a fetch. */ - FETCH_ERROR = 'fetch-error', - - /** An error associated with a Content object. */ - INVALID_CONTENT = 'invalid-content', - - /** An error due to the Firebase API not being enabled in the Console. */ - API_NOT_ENABLED = 'api-not-enabled', - - /** An error due to invalid Schema input. */ - INVALID_SCHEMA = 'invalid-schema', - - /** An error occurred due to a missing Firebase API key. */ - NO_API_KEY = 'no-api-key', - - /** An error occurred due to a model name not being specified during initialization. */ - NO_MODEL = 'no-model', - - /** An error occurred due to a missing project ID. */ - NO_PROJECT_ID = 'no-project-id', - - /** An error occurred while parsing. */ - PARSE_FAILED = 'parse-failed', -} diff --git a/packages/vertexai/lib/types/index.ts b/packages/vertexai/lib/types/index.ts deleted file mode 100644 index 85133aa07c..0000000000 --- a/packages/vertexai/lib/types/index.ts +++ /dev/null @@ -1,23 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export * from './content'; -export * from './enums'; -export * from './requests'; -export * from './responses'; -export * from './error'; -export * from './schema'; diff --git a/packages/vertexai/lib/types/internal.ts b/packages/vertexai/lib/types/internal.ts deleted file mode 100644 index ee60d476c9..0000000000 --- a/packages/vertexai/lib/types/internal.ts +++ /dev/null @@ -1,25 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; - -export interface ApiSettings { - apiKey: string; - project: string; - location: string; - getAuthToken?: () => Promise; - getAppCheckToken?: () => Promise; -} diff --git a/packages/vertexai/lib/types/polyfills.d.ts b/packages/vertexai/lib/types/polyfills.d.ts deleted file mode 100644 index 06fdf29b09..0000000000 --- a/packages/vertexai/lib/types/polyfills.d.ts +++ /dev/null @@ -1,15 +0,0 @@ -declare module 'react-native-fetch-api' { - export function fetch(input: RequestInfo, init?: RequestInit): Promise; -} - -declare global { - interface RequestInit { - /** - * @description Polyfilled to enable text ReadableStream for React Native: - * @link https://github.com/facebook/react-native/issues/27741#issuecomment-2362901032 - */ - reactNative?: { - textStreaming: boolean; - }; - } -} diff --git a/packages/vertexai/lib/types/requests.ts b/packages/vertexai/lib/types/requests.ts deleted file mode 100644 index 708a55a11c..0000000000 --- a/packages/vertexai/lib/types/requests.ts +++ /dev/null @@ -1,198 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { TypedSchema } from '../requests/schema-builder'; -import { Content, Part } from './content'; -import { FunctionCallingMode, HarmBlockMethod, HarmBlockThreshold, HarmCategory } from './enums'; -import { ObjectSchemaInterface, SchemaRequest } from './schema'; - -/** - * Base parameters for a number of methods. - * @public - */ -export interface BaseParams { - safetySettings?: SafetySetting[]; - generationConfig?: GenerationConfig; -} - -/** - * Params passed to {@link getGenerativeModel}. - * @public - */ -export interface ModelParams extends BaseParams { - model: string; - tools?: Tool[]; - toolConfig?: ToolConfig; - systemInstruction?: string | Part | Content; -} - -/** - * Request sent through {@link GenerativeModel.generateContent} - * @public - */ -export interface GenerateContentRequest extends BaseParams { - contents: Content[]; - tools?: Tool[]; - toolConfig?: ToolConfig; - systemInstruction?: string | Part | Content; -} - -/** - * Safety setting that can be sent as part of request parameters. - * @public - */ -export interface SafetySetting { - category: HarmCategory; - threshold: HarmBlockThreshold; - method?: HarmBlockMethod; -} - -/** - * Config options for content-related requests - * @public - */ -export interface GenerationConfig { - candidateCount?: number; - stopSequences?: string[]; - maxOutputTokens?: number; - temperature?: number; - topP?: number; - topK?: number; - presencePenalty?: number; - frequencyPenalty?: number; - /** - * Output response MIME type of the generated candidate text. - * Supported MIME types are `text/plain` (default, text output), - * `application/json` (JSON response in the candidates), and - * `text/x.enum`. - */ - responseMimeType?: string; - /** - * Output response schema of the generated candidate text. This - * value can be a class generated with a {@link Schema} static method - * like `Schema.string()` or `Schema.object()` or it can be a plain - * JS object matching the {@link SchemaRequest} interface. - *
Note: This only applies when the specified `responseMIMEType` supports a schema; currently - * this is limited to `application/json` and `text/x.enum`. - */ - responseSchema?: TypedSchema | SchemaRequest; -} - -/** - * Params for {@link GenerativeModel.startChat}. - * @public - */ -export interface StartChatParams extends BaseParams { - history?: Content[]; - tools?: Tool[]; - toolConfig?: ToolConfig; - systemInstruction?: string | Part | Content; -} - -/** - * Params for calling {@link GenerativeModel.countTokens} - * @public - */ -export interface CountTokensRequest { - contents: Content[]; -} - -/** - * Params passed to {@link getGenerativeModel}. - * @public - */ -export interface RequestOptions { - /** - * Request timeout in milliseconds. Defaults to 180 seconds (180000ms). - */ - timeout?: number; - /** - * Base url for endpoint. Defaults to https://firebasevertexai.googleapis.com - */ - baseUrl?: string; -} - -/** - * Defines a tool that model can call to access external knowledge. - * @public - */ -export declare type Tool = FunctionDeclarationsTool; - -/** - * Structured representation of a function declaration as defined by the - * {@link https://spec.openapis.org/oas/v3.0.3 | OpenAPI 3.0 specification}. - * Included - * in this declaration are the function name and parameters. This - * `FunctionDeclaration` is a representation of a block of code that can be used - * as a Tool by the model and executed by the client. - * @public - */ -export declare interface FunctionDeclaration { - /** - * The name of the function to call. Must start with a letter or an - * underscore. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with - * a max length of 64. - */ - name: string; - /** - * Description and purpose of the function. Model uses it to decide - * how and whether to call the function. - */ - description: string; - /** - * Optional. Describes the parameters to this function in JSON Schema Object - * format. Reflects the Open API 3.03 Parameter Object. Parameter names are - * case-sensitive. For a function with no parameters, this can be left unset. - */ - parameters?: ObjectSchemaInterface; -} - -/** - * A `FunctionDeclarationsTool` is a piece of code that enables the system to - * interact with external systems to perform an action, or set of actions, - * outside of knowledge and scope of the model. - * @public - */ -export declare interface FunctionDeclarationsTool { - /** - * Optional. One or more function declarations - * to be passed to the model along with the current user query. Model may - * decide to call a subset of these functions by populating - * {@link FunctionCall} in the response. User should - * provide a {@link FunctionResponse} for each - * function call in the next turn. Based on the function responses, the model will - * generate the final response back to the user. Maximum 64 function - * declarations can be provided. - */ - functionDeclarations?: FunctionDeclaration[]; -} - -/** - * Tool config. This config is shared for all tools provided in the request. - * @public - */ -export interface ToolConfig { - functionCallingConfig?: FunctionCallingConfig; -} - -/** - * @public - */ -export interface FunctionCallingConfig { - mode?: FunctionCallingMode; - allowedFunctionNames?: string[]; -} diff --git a/packages/vertexai/lib/types/responses.ts b/packages/vertexai/lib/types/responses.ts deleted file mode 100644 index 013391e98b..0000000000 --- a/packages/vertexai/lib/types/responses.ts +++ /dev/null @@ -1,209 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Content, FunctionCall } from './content'; -import { BlockReason, FinishReason, HarmCategory, HarmProbability, HarmSeverity } from './enums'; - -/** - * Result object returned from {@link GenerativeModel.generateContent} call. - * - * @public - */ -export interface GenerateContentResult { - response: EnhancedGenerateContentResponse; -} - -/** - * Result object returned from {@link GenerativeModel.generateContentStream} call. - * Iterate over `stream` to get chunks as they come in and/or - * use the `response` promise to get the aggregated response when - * the stream is done. - * - * @public - */ -export interface GenerateContentStreamResult { - stream: AsyncGenerator; - response: Promise; -} - -/** - * Response object wrapped with helper methods. - * - * @public - */ -export interface EnhancedGenerateContentResponse extends GenerateContentResponse { - /** - * Returns the text string from the response, if available. - * Throws if the prompt or candidate was blocked. - */ - text: () => string; - functionCalls: () => FunctionCall[] | undefined; -} - -/** - * Individual response from {@link GenerativeModel.generateContent} and - * {@link GenerativeModel.generateContentStream}. - * `generateContentStream()` will return one in each chunk until - * the stream is done. - * @public - */ -export interface GenerateContentResponse { - candidates?: GenerateContentCandidate[]; - promptFeedback?: PromptFeedback; - usageMetadata?: UsageMetadata; -} - -/** - * Usage metadata about a {@link GenerateContentResponse}. - * - * @public - */ -export interface UsageMetadata { - promptTokenCount: number; - candidatesTokenCount: number; - totalTokenCount: number; -} - -/** - * If the prompt was blocked, this will be populated with `blockReason` and - * the relevant `safetyRatings`. - * @public - */ -export interface PromptFeedback { - blockReason?: BlockReason; - safetyRatings: SafetyRating[]; - blockReasonMessage?: string; -} - -/** - * A candidate returned as part of a {@link GenerateContentResponse}. - * @public - */ -export interface GenerateContentCandidate { - index: number; - content: Content; - finishReason?: FinishReason; - finishMessage?: string; - safetyRatings?: SafetyRating[]; - citationMetadata?: CitationMetadata; - groundingMetadata?: GroundingMetadata; -} - -/** - * Citation metadata that may be found on a {@link GenerateContentCandidate}. - * @public - */ -export interface CitationMetadata { - citations: Citation[]; -} - -/** - * A single citation. - * @public - */ -export interface Citation { - startIndex?: number; - endIndex?: number; - uri?: string; - license?: string; - title?: string; - publicationDate?: Date; -} - -/** - * Metadata returned to client when grounding is enabled. - * @public - */ -export interface GroundingMetadata { - webSearchQueries?: string[]; - retrievalQueries?: string[]; - groundingAttributions: GroundingAttribution[]; -} - -/** - * @public - */ -export interface GroundingAttribution { - segment: Segment; - confidenceScore?: number; - web?: WebAttribution; - retrievedContext?: RetrievedContextAttribution; -} - -/** - * @public - */ -export interface Segment { - partIndex: number; - startIndex: number; - endIndex: number; -} - -/** - * @public - */ -export interface WebAttribution { - uri: string; - title: string; -} - -/** - * @public - */ -export interface RetrievedContextAttribution { - uri: string; - title: string; -} - -/** - * Protobuf google.type.Date - * @public - */ -export interface Date { - year: number; - month: number; - day: number; -} - -/** - * A safety rating associated with a {@link GenerateContentCandidate} - * @public - */ -export interface SafetyRating { - category: HarmCategory; - probability: HarmProbability; - severity: HarmSeverity; - probabilityScore: number; - severityScore: number; - blocked: boolean; -} - -/** - * Response from calling {@link GenerativeModel.countTokens}. - * @public - */ -export interface CountTokensResponse { - /** - * The total number of tokens counted across all instances from the request. - */ - totalTokens: number; - /** - * The total number of billable characters counted across all instances - * from the request. - */ - totalBillableCharacters?: number; -} diff --git a/packages/vertexai/lib/types/schema.ts b/packages/vertexai/lib/types/schema.ts deleted file mode 100644 index c1376b9aa1..0000000000 --- a/packages/vertexai/lib/types/schema.ts +++ /dev/null @@ -1,104 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * Contains the list of OpenAPI data types - * as defined by the - * {@link https://swagger.io/docs/specification/data-models/data-types/ | OpenAPI specification} - * @public - */ -export enum SchemaType { - /** String type. */ - STRING = 'string', - /** Number type. */ - NUMBER = 'number', - /** Integer type. */ - INTEGER = 'integer', - /** Boolean type. */ - BOOLEAN = 'boolean', - /** Array type. */ - ARRAY = 'array', - /** Object type. */ - OBJECT = 'object', -} - -/** - * Basic {@link Schema} properties shared across several Schema-related - * types. - * @public - */ -export interface SchemaShared { - /** Optional. The format of the property. */ - format?: string; - /** Optional. The description of the property. */ - description?: string; - /** Optional. The items of the property. */ - items?: T; - /** Optional. Map of `Schema` objects. */ - properties?: { - [k: string]: T; - }; - /** Optional. The enum of the property. */ - enum?: string[]; - /** Optional. The example of the property. */ - example?: unknown; - /** Optional. Whether the property is nullable. */ - nullable?: boolean; - [key: string]: unknown; -} - -/** - * Params passed to {@link Schema} static methods to create specific - * {@link Schema} classes. - * @public - */ -// eslint-disable-next-line @typescript-eslint/no-empty-object-type -export interface SchemaParams extends SchemaShared {} - -/** - * Final format for {@link Schema} params passed to backend requests. - * @public - */ -export interface SchemaRequest extends SchemaShared { - /** - * The type of the property. {@link - * SchemaType}. - */ - type: SchemaType; - /** Optional. Array of required property. */ - required?: string[]; -} - -/** - * Interface for {@link Schema} class. - * @public - */ -export interface SchemaInterface extends SchemaShared { - /** - * The type of the property. {@link - * SchemaType}. - */ - type: SchemaType; -} - -/** - * Interface for {@link ObjectSchema} class. - * @public - */ -export interface ObjectSchemaInterface extends SchemaInterface { - type: SchemaType.OBJECT; - optionalProperties?: string[]; -} From 680e9a70e72f9410117cddbfcc7aa7293c106d0e Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 11:50:44 +0100 Subject: [PATCH 39/85] chore: update script name --- packages/vertexai/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/vertexai/package.json b/packages/vertexai/package.json index 2cd1575e64..20bd946caa 100644 --- a/packages/vertexai/package.json +++ b/packages/vertexai/package.json @@ -10,7 +10,7 @@ "build": "genversion --esm --semi lib/version.ts", "build:clean": "rimraf dist", "compile": "bob build", - "prepare": "yarn tests:vertex:mocks && yarn run build && yarn compile" + "prepare": "yarn tests:ai:mocks && yarn run build && yarn compile" }, "repository": { "type": "git", From f26b73b5399ea368630afdb7c7d850c2bcf21794 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 14:19:51 +0100 Subject: [PATCH 40/85] chore: firebaseerror from utils --- packages/ai/lib/types/error.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/ai/lib/types/error.ts b/packages/ai/lib/types/error.ts index 4f976f0901..7b9432ef8a 100644 --- a/packages/ai/lib/types/error.ts +++ b/packages/ai/lib/types/error.ts @@ -15,6 +15,7 @@ * limitations under the License. */ +import { FirebaseError } from '@firebase/util'; import { GenerateContentResponse } from './responses'; /** From 4bd631baf57a5db6ca9cd9bea1200636527402b6 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 14:21:43 +0100 Subject: [PATCH 41/85] chore: import vertex constant --- packages/ai/lib/constants.ts | 2 ++ packages/ai/lib/types/error.ts | 1 + 2 files changed, 3 insertions(+) diff --git a/packages/ai/lib/constants.ts b/packages/ai/lib/constants.ts index a0cffa49ad..8af57ec0f6 100644 --- a/packages/ai/lib/constants.ts +++ b/packages/ai/lib/constants.ts @@ -19,6 +19,8 @@ import { version } from './version'; export const AI_TYPE = 'AI'; +export const VERTEX_TYPE = 'vertexAI'; + export const DEFAULT_LOCATION = 'us-central1'; export const DEFAULT_BASE_URL = 'https://firebasevertexai.googleapis.com'; diff --git a/packages/ai/lib/types/error.ts b/packages/ai/lib/types/error.ts index 7b9432ef8a..f7579ad539 100644 --- a/packages/ai/lib/types/error.ts +++ b/packages/ai/lib/types/error.ts @@ -17,6 +17,7 @@ import { FirebaseError } from '@firebase/util'; import { GenerateContentResponse } from './responses'; +import { VERTEX_TYPE } from '../constants'; /** * Details object that may be included in an error response. From daa488f94a5381512c78ec612a0a5bd579d9eb0f Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 14:24:57 +0100 Subject: [PATCH 42/85] chore(vertexai): rm testing from vertex package --- packages/vertexai/__tests__/api.test.ts | 103 ----- .../__tests__/chat-session-helpers.test.ts | 154 ------- .../vertexai/__tests__/chat-session.test.ts | 100 ----- .../vertexai/__tests__/count-tokens.test.ts | 88 ---- .../__tests__/generate-content.test.ts | 204 --------- .../__tests__/generative-model.test.ts | 266 ------------ .../__tests__/request-helpers.test.ts | 204 --------- packages/vertexai/__tests__/request.test.ts | 353 ---------------- .../__tests__/response-helpers.test.ts | 236 ----------- .../vertexai/__tests__/schema-builder.test.ts | 389 ------------------ packages/vertexai/__tests__/service.test.ts | 45 -- .../vertexai/__tests__/stream-reader.test.ts | 370 ----------------- .../__tests__/test-utils/convert-mocks.ts | 67 --- .../__tests__/test-utils/mock-response.ts | 69 ---- 14 files changed, 2648 deletions(-) delete mode 100644 packages/vertexai/__tests__/api.test.ts delete mode 100644 packages/vertexai/__tests__/chat-session-helpers.test.ts delete mode 100644 packages/vertexai/__tests__/chat-session.test.ts delete mode 100644 packages/vertexai/__tests__/count-tokens.test.ts delete mode 100644 packages/vertexai/__tests__/generate-content.test.ts delete mode 100644 packages/vertexai/__tests__/generative-model.test.ts delete mode 100644 packages/vertexai/__tests__/request-helpers.test.ts delete mode 100644 packages/vertexai/__tests__/request.test.ts delete mode 100644 packages/vertexai/__tests__/response-helpers.test.ts delete mode 100644 packages/vertexai/__tests__/schema-builder.test.ts delete mode 100644 packages/vertexai/__tests__/service.test.ts delete mode 100644 packages/vertexai/__tests__/stream-reader.test.ts delete mode 100644 packages/vertexai/__tests__/test-utils/convert-mocks.ts delete mode 100644 packages/vertexai/__tests__/test-utils/mock-response.ts diff --git a/packages/vertexai/__tests__/api.test.ts b/packages/vertexai/__tests__/api.test.ts deleted file mode 100644 index 3199157e76..0000000000 --- a/packages/vertexai/__tests__/api.test.ts +++ /dev/null @@ -1,103 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it } from '@jest/globals'; -import { getApp, type ReactNativeFirebase } from '../../app/lib'; - -import { ModelParams, VertexAIErrorCode } from '../lib/types'; -import { VertexAIError } from '../lib/errors'; -import { getGenerativeModel, getVertexAI } from '../lib/index'; - -import { VertexAI } from '../lib/public-types'; -import { GenerativeModel } from '../lib/models/generative-model'; - -import '../../auth/lib'; -import '../../app-check/lib'; -import { getAuth } from '../../auth/lib'; - -const fakeVertexAI: VertexAI = { - app: { - name: 'DEFAULT', - options: { - apiKey: 'key', - appId: 'appId', - projectId: 'my-project', - }, - } as ReactNativeFirebase.FirebaseApp, - location: 'us-central1', -}; - -describe('Top level API', () => { - it('should allow auth and app check instances to be passed in', () => { - const app = getApp(); - const auth = getAuth(); - const appCheck = app.appCheck(); - - getVertexAI(app, { appCheck, auth }); - }); - - it('getGenerativeModel throws if no model is provided', () => { - try { - getGenerativeModel(fakeVertexAI, {} as ModelParams); - } catch (e) { - expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_MODEL); - expect((e as VertexAIError).message).toContain( - `VertexAI: Must provide a model name. Example: ` + - `getGenerativeModel({ model: 'my-model-name' }) (vertexAI/${VertexAIErrorCode.NO_MODEL})`, - ); - } - }); - - it('getGenerativeModel throws if no apiKey is provided', () => { - const fakeVertexNoApiKey = { - ...fakeVertexAI, - app: { options: { projectId: 'my-project' } }, - } as VertexAI; - try { - getGenerativeModel(fakeVertexNoApiKey, { model: 'my-model' }); - } catch (e) { - expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_API_KEY); - expect((e as VertexAIError).message).toBe( - `VertexAI: The "apiKey" field is empty in the local ` + - `Firebase config. Firebase VertexAI requires this field to` + - ` contain a valid API key. (vertexAI/${VertexAIErrorCode.NO_API_KEY})`, - ); - } - }); - - it('getGenerativeModel throws if no projectId is provided', () => { - const fakeVertexNoProject = { - ...fakeVertexAI, - app: { options: { apiKey: 'my-key' } }, - } as VertexAI; - try { - getGenerativeModel(fakeVertexNoProject, { model: 'my-model' }); - } catch (e) { - expect((e as VertexAIError).code).toContain(VertexAIErrorCode.NO_PROJECT_ID); - expect((e as VertexAIError).message).toBe( - `VertexAI: The "projectId" field is empty in the local` + - ` Firebase config. Firebase VertexAI requires this field ` + - `to contain a valid project ID. (vertexAI/${VertexAIErrorCode.NO_PROJECT_ID})`, - ); - } - }); - - it('getGenerativeModel gets a GenerativeModel', () => { - const genModel = getGenerativeModel(fakeVertexAI, { model: 'my-model' }); - expect(genModel).toBeInstanceOf(GenerativeModel); - expect(genModel.model).toBe('publishers/google/models/my-model'); - }); -}); diff --git a/packages/vertexai/__tests__/chat-session-helpers.test.ts b/packages/vertexai/__tests__/chat-session-helpers.test.ts deleted file mode 100644 index 8bc81f4eab..0000000000 --- a/packages/vertexai/__tests__/chat-session-helpers.test.ts +++ /dev/null @@ -1,154 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it } from '@jest/globals'; -import { validateChatHistory } from '../lib/methods/chat-session-helpers'; -import { Content } from '../lib/types'; -import { FirebaseError } from '@firebase/util'; - -describe('chat-session-helpers', () => { - describe('validateChatHistory', () => { - it('check chat history', () => { - const TCS: Array<{ history: Content[]; isValid: boolean }> = [ - { - history: [{ role: 'user', parts: [{ text: 'hi' }] }], - isValid: true, - }, - { - history: [ - { - role: 'user', - parts: [{ text: 'hi' }, { inlineData: { mimeType: 'image/jpeg', data: 'base64==' } }], - }, - ], - isValid: true, - }, - { - history: [ - { role: 'user', parts: [{ text: 'hi' }] }, - { role: 'model', parts: [{ text: 'hi' }, { text: 'hi' }] }, - ], - isValid: true, - }, - { - history: [ - { role: 'user', parts: [{ text: 'hi' }] }, - { - role: 'model', - parts: [{ functionCall: { name: 'greet', args: { name: 'user' } } }], - }, - ], - isValid: true, - }, - { - history: [ - { role: 'user', parts: [{ text: 'hi' }] }, - { - role: 'model', - parts: [{ functionCall: { name: 'greet', args: { name: 'user' } } }], - }, - { - role: 'function', - parts: [ - { - functionResponse: { name: 'greet', response: { name: 'user' } }, - }, - ], - }, - ], - isValid: true, - }, - { - history: [ - { role: 'user', parts: [{ text: 'hi' }] }, - { - role: 'model', - parts: [{ functionCall: { name: 'greet', args: { name: 'user' } } }], - }, - { - role: 'function', - parts: [ - { - functionResponse: { name: 'greet', response: { name: 'user' } }, - }, - ], - }, - { - role: 'model', - parts: [{ text: 'hi name' }], - }, - ], - isValid: true, - }, - { - //@ts-expect-error - history: [{ role: 'user', parts: '' }], - isValid: false, - }, - { - //@ts-expect-error - history: [{ role: 'user' }], - isValid: false, - }, - { - history: [{ role: 'user', parts: [] }], - isValid: false, - }, - { - history: [{ role: 'model', parts: [{ text: 'hi' }] }], - isValid: false, - }, - { - history: [ - { - role: 'function', - parts: [ - { - functionResponse: { name: 'greet', response: { name: 'user' } }, - }, - ], - }, - ], - isValid: false, - }, - { - history: [ - { role: 'user', parts: [{ text: 'hi' }] }, - { role: 'user', parts: [{ text: 'hi' }] }, - ], - isValid: false, - }, - { - history: [ - { role: 'user', parts: [{ text: 'hi' }] }, - { role: 'model', parts: [{ text: 'hi' }] }, - { role: 'model', parts: [{ text: 'hi' }] }, - ], - isValid: false, - }, - ]; - - TCS.forEach(tc => { - const fn = (): void => validateChatHistory(tc.history); - if (tc.isValid) { - expect(fn).not.toThrow(); - } else { - expect(fn).toThrow(FirebaseError); - } - }); - }); - }); -}); diff --git a/packages/vertexai/__tests__/chat-session.test.ts b/packages/vertexai/__tests__/chat-session.test.ts deleted file mode 100644 index cd96aa32e6..0000000000 --- a/packages/vertexai/__tests__/chat-session.test.ts +++ /dev/null @@ -1,100 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it, afterEach, jest } from '@jest/globals'; - -import * as generateContentMethods from '../lib/methods/generate-content'; -import { GenerateContentStreamResult } from '../lib/types'; -import { ChatSession } from '../lib/methods/chat-session'; -import { ApiSettings } from '../lib/types/internal'; -import { RequestOptions } from '../lib/types/requests'; - -const fakeApiSettings: ApiSettings = { - apiKey: 'key', - project: 'my-project', - location: 'us-central1', -}; - -const requestOptions: RequestOptions = { - timeout: 1000, -}; - -describe('ChatSession', () => { - afterEach(() => { - jest.restoreAllMocks(); - }); - - describe('sendMessage()', () => { - it('generateContent errors should be catchable', async () => { - const generateContentStub = jest - .spyOn(generateContentMethods, 'generateContent') - .mockRejectedValue('generateContent failed'); - - const chatSession = new ChatSession(fakeApiSettings, 'a-model', {}, requestOptions); - - await expect(chatSession.sendMessage('hello')).rejects.toMatch(/generateContent failed/); - - expect(generateContentStub).toHaveBeenCalledWith( - fakeApiSettings, - 'a-model', - expect.anything(), - requestOptions, - ); - }); - }); - - describe('sendMessageStream()', () => { - it('generateContentStream errors should be catchable', async () => { - jest.useFakeTimers(); - const consoleStub = jest.spyOn(console, 'error').mockImplementation(() => {}); - const generateContentStreamStub = jest - .spyOn(generateContentMethods, 'generateContentStream') - .mockRejectedValue('generateContentStream failed'); - const chatSession = new ChatSession(fakeApiSettings, 'a-model', {}, requestOptions); - await expect(chatSession.sendMessageStream('hello')).rejects.toMatch( - /generateContentStream failed/, - ); - expect(generateContentStreamStub).toHaveBeenCalledWith( - fakeApiSettings, - 'a-model', - expect.anything(), - requestOptions, - ); - jest.runAllTimers(); - expect(consoleStub).not.toHaveBeenCalled(); - jest.useRealTimers(); - }); - - it('downstream sendPromise errors should log but not throw', async () => { - const consoleStub = jest.spyOn(console, 'error').mockImplementation(() => {}); - // make response undefined so that response.candidates errors - const generateContentStreamStub = jest - .spyOn(generateContentMethods, 'generateContentStream') - .mockResolvedValue({} as unknown as GenerateContentStreamResult); - const chatSession = new ChatSession(fakeApiSettings, 'a-model', {}, requestOptions); - await chatSession.sendMessageStream('hello'); - expect(generateContentStreamStub).toHaveBeenCalledWith( - fakeApiSettings, - 'a-model', - expect.anything(), - requestOptions, - ); - // wait for the console.error to be called, due to number of promises in the chain - await new Promise(resolve => setTimeout(resolve, 100)); - expect(consoleStub).toHaveBeenCalledTimes(1); - }); - }); -}); diff --git a/packages/vertexai/__tests__/count-tokens.test.ts b/packages/vertexai/__tests__/count-tokens.test.ts deleted file mode 100644 index 3cd7b78970..0000000000 --- a/packages/vertexai/__tests__/count-tokens.test.ts +++ /dev/null @@ -1,88 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it, afterEach, jest } from '@jest/globals'; -import { getMockResponse } from './test-utils/mock-response'; -import * as request from '../lib/requests/request'; -import { countTokens } from '../lib/methods/count-tokens'; -import { CountTokensRequest } from '../lib/types'; -import { ApiSettings } from '../lib/types/internal'; -import { Task } from '../lib/requests/request'; - -const fakeApiSettings: ApiSettings = { - apiKey: 'key', - project: 'my-project', - location: 'us-central1', -}; - -const fakeRequestParams: CountTokensRequest = { - contents: [{ parts: [{ text: 'hello' }], role: 'user' }], -}; - -describe('countTokens()', () => { - afterEach(() => { - jest.restoreAllMocks(); - }); - - it('total tokens', async () => { - const mockResponse = getMockResponse('unary-success-total-tokens.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams); - expect(result.totalTokens).toBe(6); - expect(result.totalBillableCharacters).toBe(16); - expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, - expect.stringContaining('contents'), - undefined, - ); - }); - - it('total tokens no billable characters', async () => { - const mockResponse = getMockResponse('unary-success-no-billable-characters.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - const result = await countTokens(fakeApiSettings, 'model', fakeRequestParams); - expect(result.totalTokens).toBe(258); - expect(result).not.toHaveProperty('totalBillableCharacters'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, - expect.stringContaining('contents'), - undefined, - ); - }); - - it('model not found', async () => { - const mockResponse = getMockResponse('unary-failure-model-not-found.json'); - const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ - ok: false, - status: 404, - json: mockResponse.json, - } as Response); - await expect(countTokens(fakeApiSettings, 'model', fakeRequestParams)).rejects.toThrow( - /404.*not found/, - ); - expect(mockFetch).toHaveBeenCalled(); - }); -}); diff --git a/packages/vertexai/__tests__/generate-content.test.ts b/packages/vertexai/__tests__/generate-content.test.ts deleted file mode 100644 index 3bc733e370..0000000000 --- a/packages/vertexai/__tests__/generate-content.test.ts +++ /dev/null @@ -1,204 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it, afterEach, jest } from '@jest/globals'; -import { getMockResponse } from './test-utils/mock-response'; -import * as request from '../lib/requests/request'; -import { generateContent } from '../lib/methods/generate-content'; -import { - GenerateContentRequest, - HarmBlockMethod, - HarmBlockThreshold, - HarmCategory, - // RequestOptions, -} from '../lib/types'; -import { ApiSettings } from '../lib/types/internal'; -import { Task } from '../lib/requests/request'; - -const fakeApiSettings: ApiSettings = { - apiKey: 'key', - project: 'my-project', - location: 'us-central1', -}; - -const fakeRequestParams: GenerateContentRequest = { - contents: [{ parts: [{ text: 'hello' }], role: 'user' }], - generationConfig: { - topK: 16, - }, - safetySettings: [ - { - category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - method: HarmBlockMethod.SEVERITY, - }, - ], -}; - -describe('generateContent()', () => { - afterEach(() => { - jest.restoreAllMocks(); - }); - - it('short response', async () => { - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); - expect(await result.response.text()).toContain('Mountain View, California'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - expect.stringContaining('contents'), - undefined, - ); - }); - - it('long response', async () => { - const mockResponse = getMockResponse('unary-success-basic-reply-long.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); - expect(result.response.text()).toContain('Use Freshly Ground Coffee'); - expect(result.response.text()).toContain('30 minutes of brewing'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - expect.anything(), - undefined, - ); - }); - - it('citations', async () => { - const mockResponse = getMockResponse('unary-success-citations.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); - expect(result.response.text()).toContain('Some information cited from an external source'); - expect(result.response.candidates?.[0]!.citationMetadata?.citations.length).toBe(3); - expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - expect.anything(), - undefined, - ); - }); - - it('blocked prompt', async () => { - const mockResponse = getMockResponse('unary-failure-prompt-blocked-safety.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - - const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); - - expect(() => result.response.text()).toThrowError('SAFETY'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - expect.anything(), - undefined, - ); - }); - - it('finishReason safety', async () => { - const mockResponse = getMockResponse('unary-failure-finish-reason-safety.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); - expect(() => result.response.text()).toThrow('SAFETY'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - expect.anything(), - undefined, - ); - }); - - it('empty content', async () => { - const mockResponse = getMockResponse('unary-failure-empty-content.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); - expect(result.response.text()).toBe(''); - expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - expect.anything(), - undefined, - ); - }); - - it('unknown enum - should ignore', async () => { - const mockResponse = getMockResponse('unary-success-unknown-enum-safety-ratings.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); - expect(result.response.text()).toContain('Some text'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - expect.anything(), - undefined, - ); - }); - - it('image rejected (400)', async () => { - const mockResponse = getMockResponse('unary-failure-image-rejected.json'); - const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ - ok: false, - status: 400, - json: mockResponse.json, - } as Response); - await expect(generateContent(fakeApiSettings, 'model', fakeRequestParams)).rejects.toThrow( - /400.*invalid argument/, - ); - expect(mockFetch).toHaveBeenCalled(); - }); - - it('api not enabled (403)', async () => { - const mockResponse = getMockResponse('unary-failure-firebasevertexai-api-not-enabled.json'); - const mockFetch = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ - ok: false, - status: 403, - json: mockResponse.json, - } as Response); - await expect(generateContent(fakeApiSettings, 'model', fakeRequestParams)).rejects.toThrow( - /firebasevertexai\.googleapis[\s\S]*my-project[\s\S]*api-not-enabled/, - ); - expect(mockFetch).toHaveBeenCalled(); - }); -}); diff --git a/packages/vertexai/__tests__/generative-model.test.ts b/packages/vertexai/__tests__/generative-model.test.ts deleted file mode 100644 index e62862b6aa..0000000000 --- a/packages/vertexai/__tests__/generative-model.test.ts +++ /dev/null @@ -1,266 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it, jest } from '@jest/globals'; -import { type ReactNativeFirebase } from '@react-native-firebase/app'; -import { GenerativeModel } from '../lib/models/generative-model'; -import { FunctionCallingMode, VertexAI } from '../lib/public-types'; -import * as request from '../lib/requests/request'; -import { getMockResponse } from './test-utils/mock-response'; - -const fakeVertexAI: VertexAI = { - app: { - name: 'DEFAULT', - options: { - apiKey: 'key', - projectId: 'my-project', - }, - } as ReactNativeFirebase.FirebaseApp, - location: 'us-central1', -}; - -describe('GenerativeModel', () => { - it('handles plain model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' }); - expect(genModel.model).toBe('publishers/google/models/my-model'); - }); - - it('handles models/ prefixed model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'models/my-model', - }); - expect(genModel.model).toBe('publishers/google/models/my-model'); - }); - - it('handles full model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'publishers/google/models/my-model', - }); - expect(genModel.model).toBe('publishers/google/models/my-model'); - }); - - it('handles prefixed tuned model name', () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'tunedModels/my-model', - }); - expect(genModel.model).toBe('tunedModels/my-model'); - }); - - it('passes params through to generateContent', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - tools: [ - { - functionDeclarations: [ - { - name: 'myfunc', - description: 'mydesc', - }, - ], - }, - ], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, - }); - expect(genModel.tools?.length).toBe(1); - expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); - expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - await genModel.generateContent('hello'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, - expect.stringMatching(new RegExp(`myfunc|be friendly|${FunctionCallingMode.NONE}`)), - {}, - ); - makeRequestStub.mockRestore(); - }); - - it('passes text-only systemInstruction through to generateContent', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - systemInstruction: 'be friendly', - }); - expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - await genModel.generateContent('hello'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, - expect.stringContaining('be friendly'), - {}, - ); - makeRequestStub.mockRestore(); - }); - - it('generateContent overrides model values', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - tools: [ - { - functionDeclarations: [ - { - name: 'myfunc', - description: 'mydesc', - }, - ], - }, - ], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, - }); - expect(genModel.tools?.length).toBe(1); - expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); - expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - await genModel.generateContent({ - contents: [{ role: 'user', parts: [{ text: 'hello' }] }], - tools: [ - { - functionDeclarations: [{ name: 'otherfunc', description: 'otherdesc' }], - }, - ], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } }, - systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }, - }); - expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, - expect.stringMatching(new RegExp(`be formal|otherfunc|${FunctionCallingMode.AUTO}`)), - {}, - ); - makeRequestStub.mockRestore(); - }); - - it('passes params through to chat.sendMessage', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - tools: [{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, - }); - expect(genModel.tools?.length).toBe(1); - expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); - expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - await genModel.startChat().sendMessage('hello'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, - expect.stringMatching(new RegExp(`myfunc|be friendly|${FunctionCallingMode.NONE}`)), - {}, - ); - makeRequestStub.mockRestore(); - }); - - it('passes text-only systemInstruction through to chat.sendMessage', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - systemInstruction: 'be friendly', - }); - expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - await genModel.startChat().sendMessage('hello'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, - expect.stringContaining('be friendly'), - {}, - ); - makeRequestStub.mockRestore(); - }); - - it('startChat overrides model values', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - tools: [{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, - }); - expect(genModel.tools?.length).toBe(1); - expect(genModel.toolConfig?.functionCallingConfig?.mode).toBe(FunctionCallingMode.NONE); - expect(genModel.systemInstruction?.parts[0]!.text).toBe('be friendly'); - const mockResponse = getMockResponse('unary-success-basic-reply-short.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - await genModel - .startChat({ - tools: [ - { - functionDeclarations: [{ name: 'otherfunc', description: 'otherdesc' }], - }, - ], - toolConfig: { - functionCallingConfig: { mode: FunctionCallingMode.AUTO }, - }, - systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }, - }) - .sendMessage('hello'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, - expect.stringMatching(new RegExp(`otherfunc|be formal|${FunctionCallingMode.AUTO}`)), - {}, - ); - makeRequestStub.mockRestore(); - }); - - it('calls countTokens', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' }); - const mockResponse = getMockResponse('unary-success-total-tokens.json'); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); - await genModel.countTokens('hello'); - expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.COUNT_TOKENS, - expect.anything(), - false, - expect.stringContaining('hello'), - undefined, - ); - makeRequestStub.mockRestore(); - }); -}); diff --git a/packages/vertexai/__tests__/request-helpers.test.ts b/packages/vertexai/__tests__/request-helpers.test.ts deleted file mode 100644 index 05433f5ba2..0000000000 --- a/packages/vertexai/__tests__/request-helpers.test.ts +++ /dev/null @@ -1,204 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it } from '@jest/globals'; -import { Content } from '../lib/types'; -import { formatGenerateContentInput } from '../lib/requests/request-helpers'; - -describe('request formatting methods', () => { - describe('formatGenerateContentInput', () => { - it('formats a text string into a request', () => { - const result = formatGenerateContentInput('some text content'); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [{ text: 'some text content' }], - }, - ], - }); - }); - - it('formats an array of strings into a request', () => { - const result = formatGenerateContentInput(['txt1', 'txt2']); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [{ text: 'txt1' }, { text: 'txt2' }], - }, - ], - }); - }); - - it('formats an array of Parts into a request', () => { - const result = formatGenerateContentInput([{ text: 'txt1' }, { text: 'txtB' }]); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [{ text: 'txt1' }, { text: 'txtB' }], - }, - ], - }); - }); - - it('formats a mixed array into a request', () => { - const result = formatGenerateContentInput(['txtA', { text: 'txtB' }]); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }, { text: 'txtB' }], - }, - ], - }); - }); - - it('preserves other properties of request', () => { - const result = formatGenerateContentInput({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - generationConfig: { topK: 100 }, - }); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - generationConfig: { topK: 100 }, - }); - }); - - it('formats systemInstructions if provided as text', () => { - const result = formatGenerateContentInput({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - systemInstruction: 'be excited', - }); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, - }); - }); - - it('formats systemInstructions if provided as Part', () => { - const result = formatGenerateContentInput({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - systemInstruction: { text: 'be excited' }, - }); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, - }); - }); - - it('formats systemInstructions if provided as Content (no role)', () => { - const result = formatGenerateContentInput({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - systemInstruction: { parts: [{ text: 'be excited' }] } as Content, - }); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, - }); - }); - - it('passes thru systemInstructions if provided as Content', () => { - const result = formatGenerateContentInput({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, - }); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [{ text: 'txtA' }], - }, - ], - systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }, - }); - }); - - it('formats fileData as part if provided as part', () => { - const result = formatGenerateContentInput([ - 'What is this?', - { - fileData: { - mimeType: 'image/jpeg', - fileUri: 'gs://sample.appspot.com/image.jpeg', - }, - }, - ]); - expect(result).toEqual({ - contents: [ - { - role: 'user', - parts: [ - { text: 'What is this?' }, - { - fileData: { - mimeType: 'image/jpeg', - fileUri: 'gs://sample.appspot.com/image.jpeg', - }, - }, - ], - }, - ], - }); - }); - }); -}); diff --git a/packages/vertexai/__tests__/request.test.ts b/packages/vertexai/__tests__/request.test.ts deleted file mode 100644 index c992b062e9..0000000000 --- a/packages/vertexai/__tests__/request.test.ts +++ /dev/null @@ -1,353 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it, jest, afterEach } from '@jest/globals'; -import { RequestUrl, Task, getHeaders, makeRequest } from '../lib/requests/request'; -import { ApiSettings } from '../lib/types/internal'; -import { DEFAULT_API_VERSION } from '../lib/constants'; -import { VertexAIErrorCode } from '../lib/types'; -import { VertexAIError } from '../lib/errors'; -import { getMockResponse } from './test-utils/mock-response'; - -const fakeApiSettings: ApiSettings = { - apiKey: 'key', - project: 'my-project', - location: 'us-central1', -}; - -describe('request methods', () => { - afterEach(() => { - jest.restoreAllMocks(); // Use Jest's restoreAllMocks - }); - - describe('RequestUrl', () => { - it('stream', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {}, - ); - const urlStr = url.toString(); - expect(urlStr).toContain('models/model-name:generateContent'); - expect(urlStr).toContain(fakeApiSettings.project); - expect(urlStr).toContain(fakeApiSettings.location); - expect(urlStr).toContain('alt=sse'); - }); - - it('non-stream', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - {}, - ); - const urlStr = url.toString(); - expect(urlStr).toContain('models/model-name:generateContent'); - expect(urlStr).toContain(fakeApiSettings.project); - expect(urlStr).toContain(fakeApiSettings.location); - expect(urlStr).not.toContain('alt=sse'); - }); - - it('default apiVersion', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - {}, - ); - expect(url.toString()).toContain(DEFAULT_API_VERSION); - }); - - it('custom baseUrl', async () => { - const url = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - { baseUrl: 'https://my.special.endpoint' }, - ); - expect(url.toString()).toContain('https://my.special.endpoint'); - }); - - it('non-stream - tunedModels/', async () => { - const url = new RequestUrl( - 'tunedModels/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - {}, - ); - const urlStr = url.toString(); - expect(urlStr).toContain('tunedModels/model-name:generateContent'); - expect(urlStr).toContain(fakeApiSettings.location); - expect(urlStr).toContain(fakeApiSettings.project); - expect(urlStr).not.toContain('alt=sse'); - }); - }); - - describe('getHeaders', () => { - const fakeApiSettings: ApiSettings = { - apiKey: 'key', - project: 'myproject', - location: 'moon', - getAuthToken: () => Promise.resolve('authtoken'), - getAppCheckToken: () => Promise.resolve({ token: 'appchecktoken' }), - }; - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - true, - {}, - ); - - it('adds client headers', async () => { - const headers = await getHeaders(fakeUrl); - expect(headers.get('x-goog-api-client')).toMatch(/gl-rn\/[0-9\.]+ fire\/[0-9\.]+/); - }); - - it('adds api key', async () => { - const headers = await getHeaders(fakeUrl); - expect(headers.get('x-goog-api-key')).toBe('key'); - }); - - it('adds app check token if it exists', async () => { - const headers = await getHeaders(fakeUrl); - expect(headers.get('X-Firebase-AppCheck')).toBe('appchecktoken'); - }); - - it('ignores app check token header if no appcheck service', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { - apiKey: 'key', - project: 'myproject', - location: 'moon', - }, - true, - {}, - ); - const headers = await getHeaders(fakeUrl); - expect(headers.has('X-Firebase-AppCheck')).toBe(false); - }); - - it('ignores app check token header if returned token was undefined', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { - apiKey: 'key', - project: 'myproject', - location: 'moon', - //@ts-ignore - getAppCheckToken: () => Promise.resolve(), - }, - true, - {}, - ); - const headers = await getHeaders(fakeUrl); - expect(headers.has('X-Firebase-AppCheck')).toBe(false); - }); - - it('ignores app check token header if returned token had error', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { - apiKey: 'key', - project: 'myproject', - location: 'moon', - getAppCheckToken: () => Promise.reject(new Error('oops')), - }, - true, - {}, - ); - - const warnSpy = jest.spyOn(console, 'warn').mockImplementation(() => {}); - await getHeaders(fakeUrl); - // NOTE - no app check header if there is no token, this is different to firebase-js-sdk - // See: https://github.com/firebase/firebase-js-sdk/blob/main/packages/vertexai/src/requests/request.test.ts#L172 - // expect(headers.get('X-Firebase-AppCheck')).toBe('dummytoken'); - expect(warnSpy).toHaveBeenCalledWith( - expect.stringMatching(/vertexai/), - expect.stringMatching(/App Check.*oops/), - ); - }); - - it('adds auth token if it exists', async () => { - const headers = await getHeaders(fakeUrl); - expect(headers.get('Authorization')).toBe('Firebase authtoken'); - }); - - it('ignores auth token header if no auth service', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { - apiKey: 'key', - project: 'myproject', - location: 'moon', - }, - true, - {}, - ); - const headers = await getHeaders(fakeUrl); - expect(headers.has('Authorization')).toBe(false); - }); - - it('ignores auth token header if returned token was undefined', async () => { - const fakeUrl = new RequestUrl( - 'models/model-name', - Task.GENERATE_CONTENT, - { - apiKey: 'key', - project: 'myproject', - location: 'moon', - //@ts-ignore - getAppCheckToken: () => Promise.resolve(), - }, - true, - {}, - ); - const headers = await getHeaders(fakeUrl); - expect(headers.has('Authorization')).toBe(false); - }); - }); - - describe('makeRequest', () => { - it('no error', async () => { - const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ - ok: true, - } as Response); - const response = await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, - '', - ); - expect(fetchMock).toHaveBeenCalledTimes(1); - expect(response.ok).toBe(true); - }); - - it('error with timeout', async () => { - const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ - ok: false, - status: 500, - statusText: 'AbortError', - } as Response); - - try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, '', { - timeout: 180000, - }); - } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); - expect((e as VertexAIError).customErrorData?.status).toBe(500); - expect((e as VertexAIError).customErrorData?.statusText).toBe('AbortError'); - expect((e as VertexAIError).message).toContain('500 AbortError'); - } - - expect(fetchMock).toHaveBeenCalledTimes(1); - }); - - it('Network error, no response.json()', async () => { - const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ - ok: false, - status: 500, - statusText: 'Server Error', - } as Response); - try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); - } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); - expect((e as VertexAIError).customErrorData?.status).toBe(500); - expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error'); - expect((e as VertexAIError).message).toContain('500 Server Error'); - } - expect(fetchMock).toHaveBeenCalledTimes(1); - }); - - it('Network error, includes response.json()', async () => { - const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ - ok: false, - status: 500, - statusText: 'Server Error', - json: () => Promise.resolve({ error: { message: 'extra info' } }), - } as Response); - try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); - } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); - expect((e as VertexAIError).customErrorData?.status).toBe(500); - expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error'); - expect((e as VertexAIError).message).toContain('500 Server Error'); - expect((e as VertexAIError).message).toContain('extra info'); - } - expect(fetchMock).toHaveBeenCalledTimes(1); - }); - - it('Network error, includes response.json() and details', async () => { - const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ - ok: false, - status: 500, - statusText: 'Server Error', - json: () => - Promise.resolve({ - error: { - message: 'extra info', - details: [ - { - '@type': 'type.googleapis.com/google.rpc.DebugInfo', - detail: - '[ORIGINAL ERROR] generic::invalid_argument: invalid status photos.thumbnailer.Status.Code::5: Source image 0 too short', - }, - ], - }, - }), - } as Response); - try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); - } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.FETCH_ERROR); - expect((e as VertexAIError).customErrorData?.status).toBe(500); - expect((e as VertexAIError).customErrorData?.statusText).toBe('Server Error'); - expect((e as VertexAIError).message).toContain('500 Server Error'); - expect((e as VertexAIError).message).toContain('extra info'); - expect((e as VertexAIError).message).toContain('generic::invalid_argument'); - } - expect(fetchMock).toHaveBeenCalledTimes(1); - }); - }); - - it('Network error, API not enabled', async () => { - const mockResponse = getMockResponse('unary-failure-firebasevertexai-api-not-enabled.json'); - const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue(mockResponse as Response); - try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); - } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.API_NOT_ENABLED); - expect((e as VertexAIError).message).toContain('my-project'); - expect((e as VertexAIError).message).toContain('googleapis.com'); - } - expect(fetchMock).toHaveBeenCalledTimes(1); - }); -}); diff --git a/packages/vertexai/__tests__/response-helpers.test.ts b/packages/vertexai/__tests__/response-helpers.test.ts deleted file mode 100644 index cc0fddc658..0000000000 --- a/packages/vertexai/__tests__/response-helpers.test.ts +++ /dev/null @@ -1,236 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it, jest, afterEach } from '@jest/globals'; -import { addHelpers, formatBlockErrorMessage } from '../lib/requests/response-helpers'; - -import { BlockReason, Content, FinishReason, GenerateContentResponse } from '../lib/types'; - -const fakeResponseText: GenerateContentResponse = { - candidates: [ - { - index: 0, - content: { - role: 'model', - parts: [{ text: 'Some text' }, { text: ' and some more text' }], - }, - }, - ], -}; - -const functionCallPart1 = { - functionCall: { - name: 'find_theaters', - args: { - location: 'Mountain View, CA', - movie: 'Barbie', - }, - }, -}; - -const functionCallPart2 = { - functionCall: { - name: 'find_times', - args: { - location: 'Mountain View, CA', - movie: 'Barbie', - time: '20:00', - }, - }, -}; - -const fakeResponseFunctionCall: GenerateContentResponse = { - candidates: [ - { - index: 0, - content: { - role: 'model', - parts: [functionCallPart1], - }, - }, - ], -}; - -const fakeResponseFunctionCalls: GenerateContentResponse = { - candidates: [ - { - index: 0, - content: { - role: 'model', - parts: [functionCallPart1, functionCallPart2], - }, - }, - ], -}; - -const fakeResponseMixed1: GenerateContentResponse = { - candidates: [ - { - index: 0, - content: { - role: 'model', - parts: [{ text: 'some text' }, functionCallPart2], - }, - }, - ], -}; - -const fakeResponseMixed2: GenerateContentResponse = { - candidates: [ - { - index: 0, - content: { - role: 'model', - parts: [functionCallPart1, { text: 'some text' }], - }, - }, - ], -}; - -const fakeResponseMixed3: GenerateContentResponse = { - candidates: [ - { - index: 0, - content: { - role: 'model', - parts: [{ text: 'some text' }, functionCallPart1, { text: ' and more text' }], - }, - }, - ], -}; - -const badFakeResponse: GenerateContentResponse = { - promptFeedback: { - blockReason: BlockReason.SAFETY, - safetyRatings: [], - }, -}; - -describe('response-helpers methods', () => { - afterEach(() => { - jest.restoreAllMocks(); // Use Jest's restore function - }); - - describe('addHelpers', () => { - it('good response text', () => { - const enhancedResponse = addHelpers(fakeResponseText); - expect(enhancedResponse.text()).toBe('Some text and some more text'); - expect(enhancedResponse.functionCalls()).toBeUndefined(); - }); - - it('good response functionCall', () => { - const enhancedResponse = addHelpers(fakeResponseFunctionCall); - expect(enhancedResponse.text()).toBe(''); - expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]); - }); - - it('good response functionCalls', () => { - const enhancedResponse = addHelpers(fakeResponseFunctionCalls); - expect(enhancedResponse.text()).toBe(''); - expect(enhancedResponse.functionCalls()).toEqual([ - functionCallPart1.functionCall, - functionCallPart2.functionCall, - ]); - }); - - it('good response text/functionCall', () => { - const enhancedResponse = addHelpers(fakeResponseMixed1); - expect(enhancedResponse.functionCalls()).toEqual([functionCallPart2.functionCall]); - expect(enhancedResponse.text()).toBe('some text'); - }); - - it('good response functionCall/text', () => { - const enhancedResponse = addHelpers(fakeResponseMixed2); - expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]); - expect(enhancedResponse.text()).toBe('some text'); - }); - - it('good response text/functionCall/text', () => { - const enhancedResponse = addHelpers(fakeResponseMixed3); - expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]); - expect(enhancedResponse.text()).toBe('some text and more text'); - }); - - it('bad response safety', () => { - const enhancedResponse = addHelpers(badFakeResponse); - expect(() => enhancedResponse.text()).toThrow('SAFETY'); - }); - }); - - describe('getBlockString', () => { - it('has no promptFeedback or bad finishReason', () => { - const message = formatBlockErrorMessage({ - candidates: [ - { - index: 0, - finishReason: FinishReason.STOP, - finishMessage: 'this was fine', - content: {} as Content, - }, - ], - }); - expect(message).toBe(''); - }); - - it('has promptFeedback and blockReason only', () => { - const message = formatBlockErrorMessage({ - promptFeedback: { - blockReason: BlockReason.SAFETY, - safetyRatings: [], - }, - }); - expect(message).toContain('Response was blocked due to SAFETY'); - }); - - it('has promptFeedback with blockReason and blockMessage', () => { - const message = formatBlockErrorMessage({ - promptFeedback: { - blockReason: BlockReason.SAFETY, - blockReasonMessage: 'safety reasons', - safetyRatings: [], - }, - }); - expect(message).toContain('Response was blocked due to SAFETY: safety reasons'); - }); - - it('has bad finishReason only', () => { - const message = formatBlockErrorMessage({ - candidates: [ - { - index: 0, - finishReason: FinishReason.SAFETY, - content: {} as Content, - }, - ], - }); - expect(message).toContain('Candidate was blocked due to SAFETY'); - }); - - it('has finishReason and finishMessage', () => { - const message = formatBlockErrorMessage({ - candidates: [ - { - index: 0, - finishReason: FinishReason.SAFETY, - finishMessage: 'unsafe candidate', - content: {} as Content, - }, - ], - }); - expect(message).toContain('Candidate was blocked due to SAFETY: unsafe candidate'); - }); - }); -}); diff --git a/packages/vertexai/__tests__/schema-builder.test.ts b/packages/vertexai/__tests__/schema-builder.test.ts deleted file mode 100644 index bec1f6a8d2..0000000000 --- a/packages/vertexai/__tests__/schema-builder.test.ts +++ /dev/null @@ -1,389 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it } from '@jest/globals'; -import { Schema } from '../lib/requests/schema-builder'; -import { VertexAIErrorCode } from '../lib/types'; - -describe('Schema builder', () => { - it('builds integer schema', () => { - const schema = Schema.integer(); - expect(schema.toJSON()).toEqual({ - type: 'integer', - nullable: false, - }); - }); - - it('builds integer schema with options and overrides', () => { - const schema = Schema.integer({ nullable: true, format: 'int32' }); - expect(schema.toJSON()).toEqual({ - type: 'integer', - format: 'int32', - nullable: true, - }); - }); - - it('builds number schema', () => { - const schema = Schema.number(); - expect(schema.toJSON()).toEqual({ - type: 'number', - nullable: false, - }); - }); - - it('builds number schema with options and unknown options', () => { - const schema = Schema.number({ format: 'float', futureOption: 'test' }); - expect(schema.toJSON()).toEqual({ - type: 'number', - format: 'float', - futureOption: 'test', - nullable: false, - }); - }); - - it('builds boolean schema', () => { - const schema = Schema.boolean(); - expect(schema.toJSON()).toEqual({ - type: 'boolean', - nullable: false, - }); - }); - - it('builds string schema', () => { - const schema = Schema.string({ description: 'hey' }); - expect(schema.toJSON()).toEqual({ - type: 'string', - description: 'hey', - nullable: false, - }); - }); - - it('builds enumString schema', () => { - const schema = Schema.enumString({ - example: 'east', - enum: ['east', 'west'], - }); - expect(schema.toJSON()).toEqual({ - type: 'string', - example: 'east', - enum: ['east', 'west'], - nullable: false, - }); - }); - - it('builds an object schema', () => { - const schema = Schema.object({ - properties: { - someInput: Schema.string(), - }, - }); - - expect(schema.toJSON()).toEqual({ - type: 'object', - nullable: false, - properties: { - someInput: { - type: 'string', - nullable: false, - }, - }, - required: ['someInput'], - }); - }); - - it('builds an object schema with optional properties', () => { - const schema = Schema.object({ - properties: { - someInput: Schema.string(), - someBool: Schema.boolean(), - }, - optionalProperties: ['someBool'], - }); - expect(schema.toJSON()).toEqual({ - type: 'object', - nullable: false, - properties: { - someInput: { - type: 'string', - nullable: false, - }, - someBool: { - type: 'boolean', - nullable: false, - }, - }, - required: ['someInput'], - }); - }); - - it('builds layered schema - partially filled out', () => { - const schema = Schema.array({ - items: Schema.object({ - properties: { - country: Schema.string({ - description: 'A country name', - }), - population: Schema.integer(), - coordinates: Schema.object({ - properties: { - latitude: Schema.number({ format: 'float' }), - longitude: Schema.number({ format: 'double' }), - }, - }), - hemisphere: Schema.object({ - properties: { - latitudinal: Schema.enumString({ enum: ['N', 'S'] }), - longitudinal: Schema.enumString({ enum: ['E', 'W'] }), - }, - }), - isCapital: Schema.boolean(), - }, - }), - }); - expect(schema.toJSON()).toEqual(layeredSchemaOutputPartial); - }); - - it('builds layered schema - fully filled out', () => { - const schema = Schema.array({ - items: Schema.object({ - description: 'A country profile', - nullable: false, - properties: { - country: Schema.string({ - nullable: false, - description: 'Country name', - format: undefined, - }), - population: Schema.integer({ - nullable: false, - description: 'Number of people in country', - format: 'int64', - }), - coordinates: Schema.object({ - nullable: false, - description: 'Latitude and longitude', - properties: { - latitude: Schema.number({ - nullable: false, - description: 'Latitude of capital', - format: 'float', - }), - longitude: Schema.number({ - nullable: false, - description: 'Longitude of capital', - format: 'double', - }), - }, - }), - hemisphere: Schema.object({ - nullable: false, - description: 'Hemisphere(s) country is in', - properties: { - latitudinal: Schema.enumString({ enum: ['N', 'S'] }), - longitudinal: Schema.enumString({ enum: ['E', 'W'] }), - }, - }), - isCapital: Schema.boolean({ - nullable: false, - description: "This doesn't make a lot of sense but it's a demo", - }), - elevation: Schema.integer({ - nullable: false, - description: 'Average elevation', - format: 'float', - }), - }, - optionalProperties: [], - }), - }); - - expect(schema.toJSON()).toEqual(layeredSchemaOutput); - }); - - it('can override "nullable" and set optional properties', () => { - const schema = Schema.object({ - properties: { - country: Schema.string(), - elevation: Schema.number(), - population: Schema.integer({ nullable: true }), - }, - optionalProperties: ['elevation'], - }); - expect(schema.toJSON()).toEqual({ - type: 'object', - nullable: false, - properties: { - country: { - type: 'string', - nullable: false, - }, - elevation: { - type: 'number', - nullable: false, - }, - population: { - type: 'integer', - nullable: true, - }, - }, - required: ['country', 'population'], - }); - }); - - it('throws if an optionalProperties item does not exist', () => { - const schema = Schema.object({ - properties: { - country: Schema.string(), - elevation: Schema.number(), - population: Schema.integer({ nullable: true }), - }, - optionalProperties: ['cat'], - }); - expect(() => schema.toJSON()).toThrow(VertexAIErrorCode.INVALID_SCHEMA); - }); -}); - -const layeredSchemaOutputPartial = { - type: 'array', - nullable: false, - items: { - type: 'object', - nullable: false, - properties: { - country: { - type: 'string', - description: 'A country name', - nullable: false, - }, - population: { - type: 'integer', - nullable: false, - }, - coordinates: { - type: 'object', - nullable: false, - properties: { - latitude: { - type: 'number', - format: 'float', - nullable: false, - }, - longitude: { - type: 'number', - format: 'double', - nullable: false, - }, - }, - required: ['latitude', 'longitude'], - }, - hemisphere: { - type: 'object', - nullable: false, - properties: { - latitudinal: { - type: 'string', - nullable: false, - enum: ['N', 'S'], - }, - longitudinal: { - type: 'string', - nullable: false, - enum: ['E', 'W'], - }, - }, - required: ['latitudinal', 'longitudinal'], - }, - isCapital: { - type: 'boolean', - nullable: false, - }, - }, - required: ['country', 'population', 'coordinates', 'hemisphere', 'isCapital'], - }, -}; - -const layeredSchemaOutput = { - type: 'array', - nullable: false, - items: { - type: 'object', - description: 'A country profile', - nullable: false, - required: ['country', 'population', 'coordinates', 'hemisphere', 'isCapital', 'elevation'], - properties: { - country: { - type: 'string', - description: 'Country name', - nullable: false, - }, - population: { - type: 'integer', - format: 'int64', - description: 'Number of people in country', - nullable: false, - }, - coordinates: { - type: 'object', - description: 'Latitude and longitude', - nullable: false, - required: ['latitude', 'longitude'], - properties: { - latitude: { - type: 'number', - format: 'float', - description: 'Latitude of capital', - nullable: false, - }, - longitude: { - type: 'number', - format: 'double', - description: 'Longitude of capital', - nullable: false, - }, - }, - }, - hemisphere: { - type: 'object', - description: 'Hemisphere(s) country is in', - nullable: false, - required: ['latitudinal', 'longitudinal'], - properties: { - latitudinal: { - type: 'string', - nullable: false, - enum: ['N', 'S'], - }, - longitudinal: { - type: 'string', - nullable: false, - enum: ['E', 'W'], - }, - }, - }, - isCapital: { - type: 'boolean', - description: "This doesn't make a lot of sense but it's a demo", - nullable: false, - }, - elevation: { - type: 'integer', - format: 'float', - description: 'Average elevation', - nullable: false, - }, - }, - }, -}; diff --git a/packages/vertexai/__tests__/service.test.ts b/packages/vertexai/__tests__/service.test.ts deleted file mode 100644 index 9f9503f2c9..0000000000 --- a/packages/vertexai/__tests__/service.test.ts +++ /dev/null @@ -1,45 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it } from '@jest/globals'; -import { type ReactNativeFirebase } from '@react-native-firebase/app'; -import { DEFAULT_LOCATION } from '../lib/constants'; -import { VertexAIService } from '../lib/service'; - -const fakeApp = { - name: 'DEFAULT', - options: { - apiKey: 'key', - projectId: 'my-project', - }, -} as ReactNativeFirebase.FirebaseApp; - -describe('VertexAIService', () => { - it('uses default location if not specified', () => { - const vertexAI = new VertexAIService(fakeApp); - expect(vertexAI.location).toBe(DEFAULT_LOCATION); - }); - - it('uses custom location if specified', () => { - const vertexAI = new VertexAIService( - fakeApp, - /* authProvider */ undefined, - /* appCheckProvider */ undefined, - { location: 'somewhere' }, - ); - expect(vertexAI.location).toBe('somewhere'); - }); -}); diff --git a/packages/vertexai/__tests__/stream-reader.test.ts b/packages/vertexai/__tests__/stream-reader.test.ts deleted file mode 100644 index 4a5ae8aef5..0000000000 --- a/packages/vertexai/__tests__/stream-reader.test.ts +++ /dev/null @@ -1,370 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { describe, expect, it, jest, afterEach, beforeAll } from '@jest/globals'; -import { ReadableStream } from 'web-streams-polyfill'; -import { - aggregateResponses, - getResponseStream, - processStream, -} from '../lib/requests/stream-reader'; - -import { getChunkedStream, getMockResponseStreaming } from './test-utils/mock-response'; -import { - BlockReason, - FinishReason, - GenerateContentResponse, - HarmCategory, - HarmProbability, - SafetyRating, - VertexAIErrorCode, -} from '../lib/types'; -import { VertexAIError } from '../lib/errors'; - -describe('stream-reader', () => { - describe('getResponseStream', () => { - afterEach(() => { - jest.restoreAllMocks(); - }); - - it('two lines', async () => { - const src = [{ text: 'A' }, { text: 'B' }]; - const inputStream = getChunkedStream( - src - .map(v => JSON.stringify(v)) - .map(v => 'data: ' + v + '\r\n\r\n') - .join(''), - ); - - const decodeStream = new ReadableStream({ - async start(controller) { - const reader = inputStream.getReader(); - const decoder = new TextDecoder('utf-8'); - while (true) { - const { done, value } = await reader.read(); - if (done) { - controller.close(); - break; - } - const decodedValue = decoder.decode(value, { stream: true }); - controller.enqueue(decodedValue); - } - }, - }); - - const responseStream = getResponseStream<{ text: string }>(decodeStream); - const reader = responseStream.getReader(); - const responses: Array<{ text: string }> = []; - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; - } - responses.push(value); - } - expect(responses).toEqual(src); - }); - }); - - describe('processStream', () => { - afterEach(() => { - jest.restoreAllMocks(); - }); - - it('streaming response - short', async () => { - const fakeResponse = getMockResponseStreaming('streaming-success-basic-reply-short.txt'); - const result = processStream(fakeResponse as Response); - for await (const response of result.stream) { - expect(response.text()).not.toBe(''); - } - const aggregatedResponse = await result.response; - expect(aggregatedResponse.text()).toContain('Cheyenne'); - }); - - it('streaming response - functioncall', async () => { - const fakeResponse = getMockResponseStreaming('streaming-success-function-call-short.txt'); - const result = processStream(fakeResponse as Response); - for await (const response of result.stream) { - expect(response.text()).toBe(''); - expect(response.functionCalls()).toEqual([ - { - name: 'getTemperature', - args: { city: 'San Jose' }, - }, - ]); - } - const aggregatedResponse = await result.response; - expect(aggregatedResponse.text()).toBe(''); - expect(aggregatedResponse.functionCalls()).toEqual([ - { - name: 'getTemperature', - args: { city: 'San Jose' }, - }, - ]); - }); - - it('handles citations', async () => { - const fakeResponse = getMockResponseStreaming('streaming-success-citations.txt'); - const result = processStream(fakeResponse as Response); - const aggregatedResponse = await result.response; - expect(aggregatedResponse.text()).toContain('Quantum mechanics is'); - expect(aggregatedResponse.candidates?.[0]!.citationMetadata?.citations.length).toBe(3); - let foundCitationMetadata = false; - for await (const response of result.stream) { - expect(response.text()).not.toBe(''); - if (response.candidates?.[0]?.citationMetadata) { - foundCitationMetadata = true; - } - } - expect(foundCitationMetadata).toBe(true); - }); - - it('removes empty text parts', async () => { - const fakeResponse = getMockResponseStreaming('streaming-success-empty-text-part.txt'); - const result = processStream(fakeResponse as Response); - const aggregatedResponse = await result.response; - expect(aggregatedResponse.text()).toBe('1'); - expect(aggregatedResponse.candidates?.length).toBe(1); - expect(aggregatedResponse.candidates?.[0]?.content.parts.length).toBe(1); - - // The chunk with the empty text part will still go through the stream - let numChunks = 0; - for await (const _ of result.stream) { - numChunks++; - } - expect(numChunks).toBe(2); - }); - }); - - describe('aggregateResponses', () => { - it('handles no candidates, and promptFeedback', () => { - const responsesToAggregate: GenerateContentResponse[] = [ - { - promptFeedback: { - blockReason: BlockReason.SAFETY, - safetyRatings: [ - { - category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - probability: HarmProbability.LOW, - } as SafetyRating, - ], - }, - }, - ]; - const response = aggregateResponses(responsesToAggregate); - expect(response.candidates).toBeUndefined(); - expect(response.promptFeedback?.blockReason).toBe(BlockReason.SAFETY); - }); - - describe('multiple responses, has candidates', () => { - let response: GenerateContentResponse; - beforeAll(() => { - const responsesToAggregate: GenerateContentResponse[] = [ - { - candidates: [ - { - index: 0, - content: { - role: 'user', - parts: [{ text: 'hello.' }], - }, - finishReason: FinishReason.STOP, - finishMessage: 'something', - safetyRatings: [ - { - category: HarmCategory.HARM_CATEGORY_HARASSMENT, - probability: HarmProbability.NEGLIGIBLE, - } as SafetyRating, - ], - }, - ], - promptFeedback: { - blockReason: BlockReason.SAFETY, - safetyRatings: [ - { - category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - probability: HarmProbability.LOW, - } as SafetyRating, - ], - }, - }, - { - candidates: [ - { - index: 0, - content: { - role: 'user', - parts: [{ text: 'angry stuff' }], - }, - finishReason: FinishReason.STOP, - finishMessage: 'something', - safetyRatings: [ - { - category: HarmCategory.HARM_CATEGORY_HARASSMENT, - probability: HarmProbability.NEGLIGIBLE, - } as SafetyRating, - ], - citationMetadata: { - citations: [ - { - startIndex: 0, - endIndex: 20, - uri: 'sourceurl', - license: '', - }, - ], - }, - }, - ], - promptFeedback: { - blockReason: BlockReason.OTHER, - safetyRatings: [ - { - category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - probability: HarmProbability.HIGH, - } as SafetyRating, - ], - }, - }, - { - candidates: [ - { - index: 0, - content: { - role: 'user', - parts: [{ text: '...more stuff' }], - }, - finishReason: FinishReason.MAX_TOKENS, - finishMessage: 'too many tokens', - safetyRatings: [ - { - category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - probability: HarmProbability.MEDIUM, - } as SafetyRating, - ], - citationMetadata: { - citations: [ - { - startIndex: 0, - endIndex: 20, - uri: 'sourceurl', - license: '', - }, - { - startIndex: 150, - endIndex: 155, - uri: 'sourceurl', - license: '', - }, - ], - }, - }, - ], - promptFeedback: { - blockReason: BlockReason.OTHER, - safetyRatings: [ - { - category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - probability: HarmProbability.HIGH, - } as SafetyRating, - ], - }, - }, - ]; - response = aggregateResponses(responsesToAggregate); - }); - - it('aggregates text across responses', () => { - expect(response.candidates?.length).toBe(1); - expect(response.candidates?.[0]!.content.parts.map(({ text }) => text)).toEqual([ - 'hello.', - 'angry stuff', - '...more stuff', - ]); - }); - - it("takes the last response's promptFeedback", () => { - expect(response.promptFeedback?.blockReason).toBe(BlockReason.OTHER); - }); - - it("takes the last response's finishReason", () => { - expect(response.candidates?.[0]!.finishReason).toBe(FinishReason.MAX_TOKENS); - }); - - it("takes the last response's finishMessage", () => { - expect(response.candidates?.[0]!.finishMessage).toBe('too many tokens'); - }); - - it("takes the last response's candidate safetyRatings", () => { - expect(response.candidates?.[0]!.safetyRatings?.[0]!.category).toBe( - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ); - expect(response.candidates?.[0]!.safetyRatings?.[0]!.probability).toBe( - HarmProbability.MEDIUM, - ); - }); - - it('collects all citations into one array', () => { - expect(response.candidates?.[0]!.citationMetadata?.citations.length).toBe(2); - expect(response.candidates?.[0]!.citationMetadata?.citations[0]!.startIndex).toBe(0); - expect(response.candidates?.[0]!.citationMetadata?.citations[1]!.startIndex).toBe(150); - }); - - it('throws if a part has no properties', () => { - const responsesToAggregate: GenerateContentResponse[] = [ - { - candidates: [ - { - index: 0, - content: { - role: 'user', - parts: [{} as any], // Empty - }, - finishReason: FinishReason.STOP, - finishMessage: 'something', - safetyRatings: [ - { - category: HarmCategory.HARM_CATEGORY_HARASSMENT, - probability: HarmProbability.NEGLIGIBLE, - } as SafetyRating, - ], - }, - ], - promptFeedback: { - blockReason: BlockReason.SAFETY, - safetyRatings: [ - { - category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - probability: HarmProbability.LOW, - } as SafetyRating, - ], - }, - }, - ]; - - try { - aggregateResponses(responsesToAggregate); - } catch (e) { - expect((e as VertexAIError).code).toBe(VertexAIErrorCode.INVALID_CONTENT); - expect((e as VertexAIError).message).toContain( - 'Part should have at least one property, but there are none. This is likely caused ' + - 'by a malformed response from the backend.', - ); - } - }); - }); - }); -}); diff --git a/packages/vertexai/__tests__/test-utils/convert-mocks.ts b/packages/vertexai/__tests__/test-utils/convert-mocks.ts deleted file mode 100644 index 97a5ed75df..0000000000 --- a/packages/vertexai/__tests__/test-utils/convert-mocks.ts +++ /dev/null @@ -1,67 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// eslint-disable-next-line @typescript-eslint/no-require-imports -const fs = require('fs'); -// eslint-disable-next-line @typescript-eslint/no-require-imports -const { join } = require('path'); - -function findMockResponseDir(): string { - const directories = fs - .readdirSync(__dirname, { withFileTypes: true }) - .filter( - (dirent: any) => dirent.isDirectory() && dirent.name.startsWith('vertexai-sdk-test-data'), - ) - .map((dirent: any) => dirent.name); - - if (directories.length === 0) { - throw new Error('No directory starting with "vertexai-sdk-test-data*" found.'); - } - - if (directories.length > 1) { - throw new Error('Multiple directories starting with "vertexai-sdk-test-data*" found'); - } - - return join(__dirname, directories[0], 'mock-responses', 'vertexai'); -} - -async function main(): Promise { - const mockResponseDir = findMockResponseDir(); - const list = fs.readdirSync(mockResponseDir); - const lookup: Record = {}; - // eslint-disable-next-line guard-for-in - for (const fileName of list) { - console.log(`attempting to read ${mockResponseDir}/${fileName}`) - const fullText = fs.readFileSync(join(mockResponseDir, fileName), 'utf-8'); - lookup[fileName] = fullText; - } - let fileText = `// Generated from mocks text files.`; - - fileText += '\n\n'; - fileText += `export const mocksLookup: Record = ${JSON.stringify( - lookup, - null, - 2, - )}`; - fileText += ';\n'; - fs.writeFileSync(join(__dirname, 'mocks-lookup.ts'), fileText, 'utf-8'); -} - -main().catch(e => { - console.error(e); - process.exit(1); -}); diff --git a/packages/vertexai/__tests__/test-utils/mock-response.ts b/packages/vertexai/__tests__/test-utils/mock-response.ts deleted file mode 100644 index 52eb0eb04e..0000000000 --- a/packages/vertexai/__tests__/test-utils/mock-response.ts +++ /dev/null @@ -1,69 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import { ReadableStream } from 'web-streams-polyfill'; -import { mocksLookup } from './mocks-lookup'; - -/** - * Mock native Response.body - * Streams contents of json file in 20 character chunks - */ -export function getChunkedStream(input: string, chunkLength = 20): ReadableStream { - const encoder = new TextEncoder(); - let currentChunkStart = 0; - - const stream = new ReadableStream({ - start(controller) { - while (currentChunkStart < input.length) { - const substring = input.slice(currentChunkStart, currentChunkStart + chunkLength); - currentChunkStart += chunkLength; - const chunk = encoder.encode(substring); - controller.enqueue(chunk); - } - controller.close(); - }, - }); - - return stream; -} -export function getMockResponseStreaming( - filename: string, - chunkLength: number = 20, -): Partial { - const fullText = mocksLookup[filename]; - - return { - - // Really tangled typescript error here from our transitive dependencies. - // Ignoring it now, but uncomment and run `yarn lerna:prepare` in top-level - // of the repo to see if you get it or if it has gone away. - // - // last stack frame of the error is from node_modules/undici-types/fetch.d.ts - // - // > Property 'value' is optional in type 'ReadableStreamReadDoneResult' but required in type '{ done: true; value: T | undefined; }'. - // - // @ts-ignore - body: getChunkedStream(fullText!, chunkLength), - }; -} - -export function getMockResponse(filename: string): Partial { - const fullText = mocksLookup[filename]; - return { - ok: true, - json: () => Promise.resolve(JSON.parse(fullText!)), - }; -} From 96de44a1b9b2985a985a726ef61fe7c5e8cc20bd Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 14:33:08 +0100 Subject: [PATCH 43/85] chore: update logger to ai --- packages/ai/lib/logger.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/ai/lib/logger.ts b/packages/ai/lib/logger.ts index dbc3e84059..55c6e4658a 100644 --- a/packages/ai/lib/logger.ts +++ b/packages/ai/lib/logger.ts @@ -17,4 +17,4 @@ // @ts-ignore import { Logger } from '@react-native-firebase/app/lib/internal/logger'; -export const logger = new Logger('@firebase/vertexai'); +export const logger = new Logger('@firebase/ai'); From 1ef56fb3da4f43ec7114138721f667cd6f85bc29 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 14:40:34 +0100 Subject: [PATCH 44/85] chore(vertexai): rm obsolete files --- packages/vertexai/lib/constants.ts | 33 ---------------------------- packages/vertexai/lib/polyfills.ts | 35 ------------------------------ 2 files changed, 68 deletions(-) delete mode 100644 packages/vertexai/lib/constants.ts delete mode 100644 packages/vertexai/lib/polyfills.ts diff --git a/packages/vertexai/lib/constants.ts b/packages/vertexai/lib/constants.ts deleted file mode 100644 index 338ed8b80e..0000000000 --- a/packages/vertexai/lib/constants.ts +++ /dev/null @@ -1,33 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { version } from './version'; - -export const DEFAULT_LOCATION = 'us-central1'; - -export const DEFAULT_BASE_URL = 'https://firebasevertexai.googleapis.com'; - -// This is the default API version for the VertexAI API. At some point, should be able to change when the feature becomes available. -// `v1beta` & `stable` available: https://cloud.google.com/vertex-ai/docs/reference#versions -export const DEFAULT_API_VERSION = 'v1beta'; - -export const PACKAGE_VERSION = version; - -export const LANGUAGE_TAG = 'gl-rn'; - -// Timeout is 180s by default -export const DEFAULT_FETCH_TIMEOUT_MS = 180 * 1000; diff --git a/packages/vertexai/lib/polyfills.ts b/packages/vertexai/lib/polyfills.ts deleted file mode 100644 index cbe2cfecb0..0000000000 --- a/packages/vertexai/lib/polyfills.ts +++ /dev/null @@ -1,35 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// @ts-ignore -import { polyfillGlobal } from 'react-native/Libraries/Utilities/PolyfillFunctions'; -// @ts-ignore -import { ReadableStream } from 'web-streams-polyfill/dist/ponyfill'; -// @ts-ignore -import { fetch, Headers, Request, Response } from 'react-native-fetch-api'; - -polyfillGlobal( - 'fetch', - () => - (...args: any[]) => - fetch(args[0], { ...args[1], reactNative: { textStreaming: true } }), -); -polyfillGlobal('Headers', () => Headers); -polyfillGlobal('Request', () => Request); -polyfillGlobal('Response', () => Response); -polyfillGlobal('ReadableStream', () => ReadableStream); - -import 'text-encoding'; From a285a7c0875d4760e73055f689c6993a38ef98a2 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 14:40:54 +0100 Subject: [PATCH 45/85] chore(vertexai): revert back to prev. tsconfig --- packages/vertexai/tsconfig.json | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/vertexai/tsconfig.json b/packages/vertexai/tsconfig.json index 0eace61ca4..f371c6edc9 100644 --- a/packages/vertexai/tsconfig.json +++ b/packages/vertexai/tsconfig.json @@ -1,5 +1,6 @@ { "compilerOptions": { + "rootDir": ".", "allowUnreachableCode": false, "allowUnusedLabels": false, "esModuleInterop": true, @@ -25,9 +26,7 @@ "paths": { "@react-native-firebase/app": ["../app/lib"], "@react-native-firebase/auth": ["../auth/lib"], - "@react-native-firebase/app-check": ["../app-check/lib"], - "@react-native-firebase/ai": ["../ai/lib"], + "@react-native-firebase/app-check": ["../app-check/lib"] } - }, - "include": ["lib/**/*", "../ai/lib/**/*"] + } } From 70fb2fed82f3389b8645d39929d82199e898915f Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 14:43:46 +0100 Subject: [PATCH 46/85] chore(vertexai): symlink ai lib/ folder --- packages/vertexai/lib/ai-symlink | 1 + 1 file changed, 1 insertion(+) create mode 120000 packages/vertexai/lib/ai-symlink diff --git a/packages/vertexai/lib/ai-symlink b/packages/vertexai/lib/ai-symlink new file mode 120000 index 0000000000..e5413df506 --- /dev/null +++ b/packages/vertexai/lib/ai-symlink @@ -0,0 +1 @@ +../../ai/lib \ No newline at end of file From e520bf163ae2c59729e218c5d3f9266e114719ce Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 14:44:06 +0100 Subject: [PATCH 47/85] chore(vertexai): wrap around ai --- packages/vertexai/lib/index.ts | 6 ++---- packages/vertexai/lib/service.ts | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/packages/vertexai/lib/index.ts b/packages/vertexai/lib/index.ts index 6c01fe00ba..dd154b133f 100644 --- a/packages/vertexai/lib/index.ts +++ b/packages/vertexai/lib/index.ts @@ -28,12 +28,10 @@ import { VertexAIError, VertexAI, VertexAIOptions, -} from '@react-native-firebase/ai'; -import { DEFAULT_LOCATION } from './constants'; +} from './ai-symlink/index'; +import { DEFAULT_LOCATION } from './ai-symlink/constants'; import { VertexAIService } from './service'; -export * from '@react-native-firebase/ai'; - /** * Returns a {@link VertexAI} instance for the given app. * diff --git a/packages/vertexai/lib/service.ts b/packages/vertexai/lib/service.ts index 54599db3fa..ba4cf8e8fa 100644 --- a/packages/vertexai/lib/service.ts +++ b/packages/vertexai/lib/service.ts @@ -16,8 +16,8 @@ */ import { ReactNativeFirebase } from '@react-native-firebase/app'; -import { VertexAI, VertexAIOptions } from '@react-native-firebase/ai'; -import { DEFAULT_LOCATION } from './constants'; +import { VertexAI, VertexAIOptions } from './ai-symlink/index'; +import { DEFAULT_LOCATION } from './ai-symlink/constants'; import { FirebaseAuthTypes } from '@react-native-firebase/auth'; import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; From f4161fc8b615cdad94bdd3a6fab8beacef8b7ffb Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 14:49:58 +0100 Subject: [PATCH 48/85] chore(vertexai): remove logger --- packages/vertexai/lib/logger.ts | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 packages/vertexai/lib/logger.ts diff --git a/packages/vertexai/lib/logger.ts b/packages/vertexai/lib/logger.ts deleted file mode 100644 index dbc3e84059..0000000000 --- a/packages/vertexai/lib/logger.ts +++ /dev/null @@ -1,20 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// @ts-ignore -import { Logger } from '@react-native-firebase/app/lib/internal/logger'; - -export const logger = new Logger('@firebase/vertexai'); From 5a8382fd53eed852b70f12c9602931974de24136 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Thu, 5 Jun 2025 15:03:31 +0100 Subject: [PATCH 49/85] test(vertexai): rm e2e tests --- packages/vertexai/e2e/fetch.e2e.js | 67 ------------------------------ 1 file changed, 67 deletions(-) delete mode 100644 packages/vertexai/e2e/fetch.e2e.js diff --git a/packages/vertexai/e2e/fetch.e2e.js b/packages/vertexai/e2e/fetch.e2e.js deleted file mode 100644 index de8832ac56..0000000000 --- a/packages/vertexai/e2e/fetch.e2e.js +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2016-present Invertase Limited & Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this library except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -import { getGenerativeModel } from '../lib/index'; - -const fakeVertexAI = { - app: { - name: 'DEFAULT', - options: { - appId: 'appId', - projectId: 'my-project', - apiKey: 'key', - }, - }, - location: 'us-central1', -}; -// See emulator setup: packages/vertexai/lib/requests/request.ts -globalThis.RNFB_VERTEXAI_EMULATOR_URL = true; - -// It calls firebase functions emulator that mimics responses from VertexAI server -describe('fetch requests()', function () { - it('should fetch', async function () { - const model = getGenerativeModel(fakeVertexAI, { model: 'gemini-1.5-flash' }); - const result = await model.generateContent("What is google's mission statement?"); - const text = result.response.text(); - // See vertexAI function emulator for response - text.should.containEql( - 'Google\'s mission is to "organize the world\'s information and make it universally accessible and useful."', - ); - }); - - it('should fetch stream', async function () { - const model = getGenerativeModel(fakeVertexAI, { model: 'gemini-1.5-flash' }); - // See vertexAI function emulator for response - const poem = [ - 'The wind whispers secrets through the trees,', - 'Rustling leaves in a gentle breeze.', - 'Sunlight dances on the grass,', - 'A fleeting moment, sure to pass.', - 'Birdsong fills the air so bright,', - 'A symphony of pure delight.', - 'Time stands still, a peaceful pause,', - "In nature's beauty, no flaws.", - ]; - const result = await model.generateContentStream('Write me a short poem'); - - const text = []; - for await (const chunk of result.stream) { - const chunkText = chunk.text(); - text.push(chunkText); - } - text.should.deepEqual(poem); - }); -}); From c307142d2542f16fe3a4165b257af434676693da Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Fri, 6 Jun 2025 09:55:48 +0100 Subject: [PATCH 50/85] refactor: make ai dependency on vertexai --- packages/ai/lib/constants.ts | 2 - packages/ai/lib/errors.ts | 4 +- packages/ai/lib/index.ts | 12 +-- packages/ai/lib/public-types.ts | 50 +---------- packages/ai/lib/types/error.ts | 74 ----------------- .../__tests__/backwards-compatbility.test.ts | 83 +++++++++++++++++++ packages/vertexai/lib/ai-symlink | 1 - packages/vertexai/lib/index.ts | 68 +++++++-------- packages/vertexai/lib/public-types.ts | 42 ++++++++++ packages/vertexai/lib/service.ts | 39 --------- packages/vertexai/package.json | 13 +-- yarn.lock | 3 +- 12 files changed, 178 insertions(+), 213 deletions(-) create mode 100644 packages/vertexai/__tests__/backwards-compatbility.test.ts delete mode 120000 packages/vertexai/lib/ai-symlink create mode 100644 packages/vertexai/lib/public-types.ts delete mode 100644 packages/vertexai/lib/service.ts diff --git a/packages/ai/lib/constants.ts b/packages/ai/lib/constants.ts index 8af57ec0f6..a0cffa49ad 100644 --- a/packages/ai/lib/constants.ts +++ b/packages/ai/lib/constants.ts @@ -19,8 +19,6 @@ import { version } from './version'; export const AI_TYPE = 'AI'; -export const VERTEX_TYPE = 'vertexAI'; - export const DEFAULT_LOCATION = 'us-central1'; export const DEFAULT_BASE_URL = 'https://firebasevertexai.googleapis.com'; diff --git a/packages/ai/lib/errors.ts b/packages/ai/lib/errors.ts index ea09d2f162..3a7e18ec3a 100644 --- a/packages/ai/lib/errors.ts +++ b/packages/ai/lib/errors.ts @@ -26,9 +26,9 @@ import { AI_TYPE } from './constants'; */ export class AIError extends FirebaseError { /** - * Constructs a new instance of the `VertexAIError` class. + * Constructs a new instance of the `AIError` class. * - * @param code - The error code from {@link VertexAIErrorCode}. + * @param code - The error code from {@link AIErrorCode}. * @param message - A human-readable message describing the error. * @param customErrorData - Optional error data. */ diff --git a/packages/ai/lib/index.ts b/packages/ai/lib/index.ts index 04a318a3c9..e5e5078a0f 100644 --- a/packages/ai/lib/index.ts +++ b/packages/ai/lib/index.ts @@ -19,16 +19,16 @@ import './polyfills'; import { getApp, ReactNativeFirebase } from '@react-native-firebase/app'; import { GoogleAIBackend, VertexAIBackend } from './backend'; import { AIErrorCode, ModelParams, RequestOptions } from './types'; -import { AI, AIOptions, VertexAI, VertexAIOptions } from './public-types'; +import { AI, AIOptions } from './public-types'; import { AIError } from './errors'; import { GenerativeModel } from './models/generative-model'; -export { ChatSession } from './methods/chat-session'; +import { AIModel } from './models/ai-model'; +export * from './public-types'; +export { ChatSession } from './methods/chat-session'; export * from './requests/schema-builder'; -export * from './types'; -export * from './backend'; - -export { GenerativeModel, AIError, VertexAI, VertexAIOptions }; +export { GoogleAIBackend, VertexAIBackend } from './backend'; +export { GenerativeModel, AIError, AIModel }; /** * Returns the default {@link AI} instance that is associated with the provided diff --git a/packages/ai/lib/public-types.ts b/packages/ai/lib/public-types.ts index bd58c018c0..918eea1b68 100644 --- a/packages/ai/lib/public-types.ts +++ b/packages/ai/lib/public-types.ts @@ -21,30 +21,6 @@ import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; export * from './types'; -/** - * An instance of the Vertex AI in Firebase SDK. - * @public - */ -export interface VertexAI { - /** - * The {@link @firebase/app#FirebaseApp} this {@link VertexAI} instance is associated with. - */ - app: ReactNativeFirebase.FirebaseApp; - location: string; - appCheck?: FirebaseAppCheckTypes.Module | null; - auth?: FirebaseAuthTypes.Module | null; -} - -/** - * Options when initializing the Vertex AI in Firebase SDK. - * @public - */ -export interface VertexAIOptions { - location?: string; - appCheck?: FirebaseAppCheckTypes.Module | null; - auth?: FirebaseAuthTypes.Module | null; -} - /** * Options for initializing the AI service using {@link getAI | getAI()}. * This allows specifying which backend to use (Vertex AI Gemini API or Gemini Developer API) @@ -142,6 +118,8 @@ export interface AI { * The {@link @firebase/app#FirebaseApp} this {@link AI} instance is associated with. */ app: ReactNativeFirebase.FirebaseApp; + appCheck?: FirebaseAppCheckTypes.Module | null; + auth?: FirebaseAuthTypes.Module | null; /** * A {@link Backend} instance that specifies the configuration for the target backend, * either the Gemini Developer API (using {@link GoogleAIBackend}) or the @@ -155,27 +133,3 @@ export interface AI { */ location: string; } - -/** - * An instance of the Vertex AI in Firebase SDK. - * @public - */ -export interface VertexAI { - /** - * The {@link @firebase/app#FirebaseApp} this {@link VertexAI} instance is associated with. - */ - app: ReactNativeFirebase.FirebaseApp; - location: string; - appCheck?: FirebaseAppCheckTypes.Module | null; - auth?: FirebaseAuthTypes.Module | null; -} - -/** - * Options when initializing the Vertex AI in Firebase SDK. - * @public - */ -export interface VertexAIOptions { - location?: string; - appCheck?: FirebaseAppCheckTypes.Module | null; - auth?: FirebaseAuthTypes.Module | null; -} diff --git a/packages/ai/lib/types/error.ts b/packages/ai/lib/types/error.ts index f7579ad539..4fcc1ac483 100644 --- a/packages/ai/lib/types/error.ts +++ b/packages/ai/lib/types/error.ts @@ -15,9 +15,7 @@ * limitations under the License. */ -import { FirebaseError } from '@firebase/util'; import { GenerateContentResponse } from './responses'; -import { VERTEX_TYPE } from '../constants'; /** * Details object that may be included in an error response. @@ -104,75 +102,3 @@ export const enum AIErrorCode { /** An error occurred due an attempt to use an unsupported feature. */ UNSUPPORTED = 'unsupported', } - -/** - * Standardized error codes that {@link VertexAIError} can have. - * - * @public - */ -export const enum VertexAIErrorCode { - /** A generic error occurred. */ - ERROR = 'error', - - /** An error occurred in a request. */ - REQUEST_ERROR = 'request-error', - - /** An error occurred in a response. */ - RESPONSE_ERROR = 'response-error', - - /** An error occurred while performing a fetch. */ - FETCH_ERROR = 'fetch-error', - - /** An error associated with a Content object. */ - INVALID_CONTENT = 'invalid-content', - - /** An error due to the Firebase API not being enabled in the Console. */ - API_NOT_ENABLED = 'api-not-enabled', - - /** An error due to invalid Schema input. */ - INVALID_SCHEMA = 'invalid-schema', - - /** An error occurred due to a missing Firebase API key. */ - NO_API_KEY = 'no-api-key', - - /** An error occurred due to a model name not being specified during initialization. */ - NO_MODEL = 'no-model', - - /** An error occurred due to a missing project ID. */ - NO_PROJECT_ID = 'no-project-id', - - /** An error occurred while parsing. */ - PARSE_FAILED = 'parse-failed', -} - -/** - * Error class for the Vertex AI in Firebase SDK. - * - * @public - */ -export class VertexAIError extends FirebaseError { - /** - * Constructs a new instance of the `VertexAIError` class. - * - * @param code - The error code from {@link VertexAIErrorCode}. - * @param message - A human-readable message describing the error. - * @param customErrorData - Optional error data. - */ - constructor( - readonly code: VertexAIErrorCode, - message: string, - readonly customErrorData?: CustomErrorData, - ) { - // Match error format used by FirebaseError from ErrorFactory - const service = VERTEX_TYPE; - const serviceName = 'VertexAI'; - const fullCode = `${service}/${code}`; - const fullMessage = `${serviceName}: ${message} (${fullCode})`; - super(code, fullMessage); - - Object.setPrototypeOf(this, VertexAIError.prototype); - - // Since Error is an interface, we don't inherit toString and so we define it ourselves. - this.toString = () => fullMessage; - } -} diff --git a/packages/vertexai/__tests__/backwards-compatbility.test.ts b/packages/vertexai/__tests__/backwards-compatbility.test.ts new file mode 100644 index 0000000000..81681bbbc4 --- /dev/null +++ b/packages/vertexai/__tests__/backwards-compatbility.test.ts @@ -0,0 +1,83 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it } from '@jest/globals'; +import { + AIError, + AIModel, + GenerativeModel, + VertexAIError, + VertexAIErrorCode, + VertexAIModel, + VertexAI, + getGenerativeModel, +} from '../lib/index'; +import { AI, AIErrorCode } from '@react-native-firebase/ai'; +import { VertexAIBackend } from '@react-native-firebase/ai'; +import { ReactNativeFirebase } from '@react-native-firebase/app'; + +function assertAssignable(): void {} + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'app-id', + }, + } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1', +}; + +const fakeVertexAI: VertexAI = fakeAI; + +describe('backwards-compatible types', function () { + it('AI is backwards compatible with VertexAI', function () { + assertAssignable(); + }); + + it('AIError is backwards compatible with VertexAIError', function () { + assertAssignable(); + const err = new VertexAIError(VertexAIErrorCode.ERROR, ''); + expect(err).toBeInstanceOf(AIError); + expect(err).toBeInstanceOf(VertexAIError); + }); + + it('AIErrorCode is backwards compatible with VertexAIErrorCode', () => { + assertAssignable(); + const errCode = AIErrorCode.ERROR; + expect(errCode).toBe(VertexAIErrorCode.ERROR); + }); + + it('AIModel is backwards compatible with VertexAIModel', () => { + assertAssignable(); + + const model = new GenerativeModel(fakeAI, { model: 'model-name' }); + expect(model).toBeInstanceOf(AIModel); + expect(model).toBeInstanceOf(VertexAIModel); + }); + + describe('backward-compatible functions', () => { + it('getGenerativeModel', () => { + const model = getGenerativeModel(fakeVertexAI, { model: 'model-name' }); + expect(model).toBeInstanceOf(AIModel); + expect(model).toBeInstanceOf(VertexAIModel); + }); + }); +}); diff --git a/packages/vertexai/lib/ai-symlink b/packages/vertexai/lib/ai-symlink deleted file mode 120000 index e5413df506..0000000000 --- a/packages/vertexai/lib/ai-symlink +++ /dev/null @@ -1 +0,0 @@ -../../ai/lib \ No newline at end of file diff --git a/packages/vertexai/lib/index.ts b/packages/vertexai/lib/index.ts index dd154b133f..7efc7a2656 100644 --- a/packages/vertexai/lib/index.ts +++ b/packages/vertexai/lib/index.ts @@ -15,24 +15,23 @@ * limitations under the License. */ -import './polyfills'; import { getApp, ReactNativeFirebase } from '@react-native-firebase/app'; -import { - getGenerativeModel as getGenerativeModelFromAI, - getAI, - VertexAIBackend, - GenerativeModel, - RequestOptions, - ModelParams, - VertexAIErrorCode, - VertexAIError, - VertexAI, - VertexAIOptions, -} from './ai-symlink/index'; -import { DEFAULT_LOCATION } from './ai-symlink/constants'; -import { VertexAIService } from './service'; +import { VertexAIBackend, AIModel, AIError, AIErrorCode } from '@react-native-firebase/ai'; +import { VertexAIOptions, VertexAI } from './public-types'; +export * from './public-types'; +export * from '@react-native-firebase/ai'; + +const DEFAULT_LOCATION = 'us-central1'; /** + * @deprecated Use the new {@link getAI | getAI()} instead. The Vertex AI in Firebase SDK has been + * replaced with the Firebase AI SDK to accommodate the evolving set of supported features and + * services. For migration details, see the {@link https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk | migration guide}. + * + * Returns a {@link VertexAI} instance for the given app, configured to use the + * Vertex AI Gemini API. This instance will be + * configured to use the Vertex AI Gemini API. + * * Returns a {@link VertexAI} instance for the given app. * * @public @@ -51,30 +50,31 @@ export function getVertexAI( location: options?.location || DEFAULT_LOCATION, appCheck: options?.appCheck || null, auth: options?.auth || null, - } as VertexAIService; + backend: new VertexAIBackend(options?.location || DEFAULT_LOCATION), + }; } /** - * Returns a {@link GenerativeModel} class with methods for inference - * and other functionality. + * @deprecated Use the new {@link AIModel} instead. The Vertex AI in Firebase SDK has been + * replaced with the Firebase AI SDK to accommodate the evolving set of supported features and + * services. For migration details, see the {@link https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk | migration guide}. + * + * Base class for Firebase AI model APIs. * * @public */ -export function getGenerativeModel( - vertexAI: VertexAI, - modelParams: ModelParams, - requestOptions?: RequestOptions, -): GenerativeModel { - if (!modelParams.model) { - throw new VertexAIError( - VertexAIErrorCode.NO_MODEL, - `Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`, - ); - } +export const VertexAIModel = AIModel; - const ai = getAI(vertexAI.app, { - backend: new VertexAIBackend(vertexAI.location), - }); +/** + * @deprecated Use the new {@link AIError} instead. The Vertex AI in Firebase SDK has been + * replaced with the Firebase AI SDK to accommodate the evolving set of supported features and + * services. For migration details, see the {@link https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk | migration guide}. + * + * Error class for the Firebase AI SDK. + * + * @public + */ +export const VertexAIError = AIError; - return getGenerativeModelFromAI(ai, modelParams, requestOptions); -} +export { AIErrorCode as VertexAIErrorCode }; +export { VertexAIBackend, AIModel, AIError }; diff --git a/packages/vertexai/lib/public-types.ts b/packages/vertexai/lib/public-types.ts new file mode 100644 index 0000000000..1ce05549ae --- /dev/null +++ b/packages/vertexai/lib/public-types.ts @@ -0,0 +1,42 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { AI, AIErrorCode } from '@react-native-firebase/ai'; +import { FirebaseAuthTypes } from '@react-native-firebase/auth'; +import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; + +/** + * @deprecated Use the new {@link AI | AI} instead. The Vertex AI in Firebase SDK has been + * replaced with the Firebase AI SDK to accommodate the evolving set of supported features and + * services. For migration details, see the {@link https://firebase.google.com/docs/vertex-ai/migrate-to-latest-sdk | migration guide}. + * An instance of the Vertex AI in Firebase SDK. + * @public + */ +export type VertexAI = AI; + +/** + * Options when initializing the Vertex AI in Firebase SDK. + * @public + */ +export interface VertexAIOptions { + location?: string; + appCheck?: FirebaseAppCheckTypes.Module | null; + auth?: FirebaseAuthTypes.Module | null; +} + +export type VertexAIErrorCode = AIErrorCode; + +export const VERTEX_TYPE = 'vertexAI'; diff --git a/packages/vertexai/lib/service.ts b/packages/vertexai/lib/service.ts deleted file mode 100644 index ba4cf8e8fa..0000000000 --- a/packages/vertexai/lib/service.ts +++ /dev/null @@ -1,39 +0,0 @@ -/** - * @license - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { ReactNativeFirebase } from '@react-native-firebase/app'; -import { VertexAI, VertexAIOptions } from './ai-symlink/index'; -import { DEFAULT_LOCATION } from './ai-symlink/constants'; -import { FirebaseAuthTypes } from '@react-native-firebase/auth'; -import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; - -export class VertexAIService implements VertexAI { - auth: FirebaseAuthTypes.Module | null; - appCheck: FirebaseAppCheckTypes.Module | null; - location: string; - - constructor( - public app: ReactNativeFirebase.FirebaseApp, - auth?: FirebaseAuthTypes.Module, - appCheck?: FirebaseAppCheckTypes.Module, - public options?: VertexAIOptions, - ) { - this.auth = auth || null; - this.appCheck = appCheck || null; - this.location = this.options?.location || DEFAULT_LOCATION; - } -} diff --git a/packages/vertexai/package.json b/packages/vertexai/package.json index 20bd946caa..22cee02600 100644 --- a/packages/vertexai/package.json +++ b/packages/vertexai/package.json @@ -25,6 +25,12 @@ "gemini", "generative-ai" ], + "dependencies": { + "@react-native-firebase/ai": "22.2.0", + "react-native-fetch-api": "^3.0.0", + "text-encoding": "^0.7.0", + "web-streams-polyfill": "^4.1.0" + }, "peerDependencies": { "@react-native-firebase/app": "23.0.0" }, @@ -79,10 +85,5 @@ "eslintIgnore": [ "node_modules/", "dist/" - ], - "dependencies": { - "react-native-fetch-api": "^3.0.0", - "text-encoding": "^0.7.0", - "web-streams-polyfill": "^4.1.0" - } + ] } diff --git a/yarn.lock b/yarn.lock index 4f0de566e6..e09222d366 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5348,7 +5348,7 @@ __metadata: languageName: node linkType: hard -"@react-native-firebase/ai@workspace:packages/ai": +"@react-native-firebase/ai@npm:22.2.0, @react-native-firebase/ai@workspace:packages/ai": version: 0.0.0-use.local resolution: "@react-native-firebase/ai@workspace:packages/ai" dependencies: @@ -5565,6 +5565,7 @@ __metadata: version: 0.0.0-use.local resolution: "@react-native-firebase/vertexai@workspace:packages/vertexai" dependencies: + "@react-native-firebase/ai": "npm:22.2.0" "@types/text-encoding": "npm:^0.0.40" react-native-builder-bob: "npm:^0.40.6" react-native-fetch-api: "npm:^3.0.0" From 53ba822a18328e004bb056246a48ca9a9f095616 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Fri, 6 Jun 2025 10:13:12 +0100 Subject: [PATCH 51/85] test: fix test --- packages/ai/__tests__/request.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/ai/__tests__/request.test.ts b/packages/ai/__tests__/request.test.ts index 54b3378492..3e5e58e415 100644 --- a/packages/ai/__tests__/request.test.ts +++ b/packages/ai/__tests__/request.test.ts @@ -196,7 +196,7 @@ describe('request methods', () => { // See: https://github.com/firebase/firebase-js-sdk/blob/main/packages/vertexai/src/requests/request.test.ts#L172 // expect(headers.get('X-Firebase-AppCheck')).toBe('dummytoken'); expect(warnSpy).toHaveBeenCalledWith( - expect.stringMatching(/vertexai/), + expect.stringMatching(/firebase\/ai/), expect.stringMatching(/App Check.*oops/), ); }); From ac4a1ab5c601f875355fd1feddc5d95cf3763ff6 Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Fri, 6 Jun 2025 10:59:33 +0100 Subject: [PATCH 52/85] fix: do not use RN URL to construct url --- packages/ai/lib/requests/request.ts | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/packages/ai/lib/requests/request.ts b/packages/ai/lib/requests/request.ts index 77a1dd8691..c468de28bb 100644 --- a/packages/ai/lib/requests/request.ts +++ b/packages/ai/lib/requests/request.ts @@ -60,10 +60,18 @@ export class RequestUrl { } return emulatorUrl; } - const url = new URL(this.baseUrl); // Throws if the URL is invalid - url.pathname = `/${this.apiVersion}/${this.modelPath}:${this.task}`; - url.search = this.queryParams.toString(); - return url.toString(); + + // Manually construct URL to avoid React Native URL API issues + let baseUrl = this.baseUrl; + // Remove trailing slash if present + if (baseUrl.endsWith('/')) { + baseUrl = baseUrl.slice(0, -1); + } + + const pathname = `/${this.apiVersion}/${this.modelPath}:${this.task}`; + const queryString = this.queryParams; + + return `${baseUrl}${pathname}${queryString ? `?${queryString}` : ''}`; } private get baseUrl(): string { @@ -87,10 +95,10 @@ export class RequestUrl { } } - private get queryParams(): URLSearchParams { - const params = new URLSearchParams(); + private get queryParams(): string { + let params = ''; if (this.stream) { - params.set('alt', 'sse'); + params += 'alt=sse'; } return params; From 202bbcb2f1434a12e38def4600f0413d978459ca Mon Sep 17 00:00:00 2001 From: russellwheatley Date: Fri, 6 Jun 2025 11:06:33 +0100 Subject: [PATCH 53/85] chore(ai): write example app for ai package --- tests/test-app/examples/ai/ai.js | 328 +++++++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) create mode 100644 tests/test-app/examples/ai/ai.js diff --git a/tests/test-app/examples/ai/ai.js b/tests/test-app/examples/ai/ai.js new file mode 100644 index 0000000000..6a9d9a428e --- /dev/null +++ b/tests/test-app/examples/ai/ai.js @@ -0,0 +1,328 @@ +import React, { useState } from 'react'; +import { AppRegistry, Button, View, Text, Pressable } from 'react-native'; + +import { getApp } from '@react-native-firebase/app'; +import { getAI, getGenerativeModel, Schema } from '@react-native-firebase/ai'; +import { + PDF_BASE_64, + POEM_BASE_64, + VIDEO_BASE_64, + IMAGE_BASE_64, + EMOJI_BASE_64, +} from '../vertexai/base-64-media'; + +// eslint-disable-next-line react/prop-types +function OptionSelector({ selectedOption, setSelectedOption }) { + const options = ['image', 'pdf', 'video', 'audio', 'emoji']; + + return ( + + {options.map(option => { + const isSelected = selectedOption === option; + return ( + setSelectedOption(option)} + style={{ + paddingVertical: 10, + paddingHorizontal: 15, + margin: 5, + borderRadius: 8, + borderWidth: 1, + borderColor: isSelected ? '#007bff' : '#ccc', + backgroundColor: isSelected ? '#007bff' : '#fff', + }} + > + + {option.toUpperCase()} + + + ); + })} + + ); +} + +function App() { + const [selectedOption, setSelectedOption] = useState('image'); + const getMediaDetails = option => { + switch (option) { + case 'image': + return { data: IMAGE_BASE_64.trim(), mimeType: 'image/jpeg', prompt: 'What can you see?' }; + case 'pdf': + return { + data: PDF_BASE_64.trim(), + mimeType: 'application/pdf', + prompt: 'What can you see?', + }; + case 'video': + return { data: VIDEO_BASE_64.trim(), mimeType: 'video/mp4', prompt: 'What can you see?' }; + case 'audio': + return { data: POEM_BASE_64.trim(), mimeType: 'audio/mp3', prompt: 'What can you hear?' }; + case 'emoji': + return { data: EMOJI_BASE_64.trim(), mimeType: 'image/png', prompt: 'What can you see?' }; + default: + console.error('Invalid option selected'); + return null; + } + }; + + return ( + + +