diff --git a/js/plugins/google-genai/src/common/types.ts b/js/plugins/google-genai/src/common/types.ts index 454948fedd..d8a023e0f6 100644 --- a/js/plugins/google-genai/src/common/types.ts +++ b/js/plugins/google-genai/src/common/types.ts @@ -140,6 +140,14 @@ export declare interface GoogleSearchRetrievalTool { googleSearchRetrieval?: GoogleSearchRetrieval; googleSearch?: GoogleSearchRetrieval; } +export function isGoogleSearchRetrievalTool( + tool: Tool +): tool is GoogleSearchRetrievalTool { + return ( + (tool as GoogleSearchRetrievalTool).googleSearchRetrieval !== undefined || + (tool as GoogleSearchRetrievalTool).googleSearch !== undefined + ); +} /** * Grounding support. @@ -737,6 +745,11 @@ export declare interface FunctionDeclarationsTool { */ functionDeclarations?: FunctionDeclaration[]; } +export function isFunctionDeclarationsTool( + tool: Tool +): tool is FunctionDeclarationsTool { + return (tool as FunctionDeclarationsTool).functionDeclarations !== undefined; +} /** * Google AI Only. Enables the model to execute code as part of generation. @@ -749,6 +762,9 @@ export declare interface CodeExecutionTool { */ codeExecution: {}; } +export function isCodeExecutionTool(tool: Tool): tool is CodeExecutionTool { + return (tool as CodeExecutionTool).codeExecution !== undefined; +} /** * Vertex AI Only. Retrieve from Vertex AI Search datastore for grounding. @@ -830,6 +846,9 @@ export declare interface RetrievalTool { /** Optional. {@link Retrieval}. */ retrieval?: Retrieval; } +export function isRetrievalTool(tool: Tool): tool is RetrievalTool { + return (tool as RetrievalTool).retrieval !== undefined; +} /** * Tool to retrieve public web data for grounding, powered by Google. diff --git a/js/plugins/google-genai/src/common/utils.ts b/js/plugins/google-genai/src/common/utils.ts index 3f8ce5e732..50b9bbf915 100644 --- a/js/plugins/google-genai/src/common/utils.ts +++ b/js/plugins/google-genai/src/common/utils.ts @@ -68,10 +68,12 @@ export function checkModelName(name?: string): string { } export function extractText(request: GenerateRequest) { - return request.messages - .at(-1)! - .content.map((c) => c.text || '') - .join(''); + return ( + request.messages + .at(-1) + ?.content.map((c) => c.text || '') + .join('') ?? '' + ); } export function extractImagenImage( diff --git a/js/plugins/google-genai/src/googleai/client.ts b/js/plugins/google-genai/src/googleai/client.ts index f8917331e4..1fdc8b5cb8 100644 --- a/js/plugins/google-genai/src/googleai/client.ts +++ b/js/plugins/google-genai/src/googleai/client.ts @@ -198,7 +198,7 @@ export async function veoPredict( return response.json() as Promise; } -export async function checkVeoOperation( +export async function veoCheckOperation( apiKey: string, operation: string, clientOptions?: ClientOptions diff --git a/js/plugins/google-genai/src/googleai/gemini.ts b/js/plugins/google-genai/src/googleai/gemini.ts index bf0ddaa73c..1021dba3af 100644 --- a/js/plugins/google-genai/src/googleai/gemini.ts +++ b/js/plugins/google-genai/src/googleai/gemini.ts @@ -309,9 +309,6 @@ const KNOWN_GEMINI_MODELS = { 'gemini-2.0-flash-preview-image-generation' ), 'gemini-2.0-flash-lite': commonRef('gemini-2.0-flash-lite'), - 'gemini-1.5-flash': commonRef('gemini-1.5-flash'), - 'gemini-1.5-flash-8b': commonRef('gemini-1.5-flash-8b'), - 'gemini-1.5-pro': commonRef('gemini-1.5-pro'), }; export type KnownGeminiModels = keyof typeof KNOWN_GEMINI_MODELS; export type GeminiModelName = `gemini-${string}`; diff --git a/js/plugins/google-genai/src/googleai/veo.ts b/js/plugins/google-genai/src/googleai/veo.ts index b6ff99abf7..0d2d6b5fe6 100644 --- a/js/plugins/google-genai/src/googleai/veo.ts +++ b/js/plugins/google-genai/src/googleai/veo.ts @@ -29,7 +29,7 @@ import { type ModelInfo, type ModelReference, } from 'genkit/model'; -import { checkVeoOperation, veoPredict } from './client.js'; +import { veoCheckOperation, veoPredict } from './client.js'; import { ClientOptions, GoogleAIPluginOptions, @@ -201,7 +201,7 @@ export function defineModel( }, async check(operation) { const apiKey = calculateApiKey(pluginOptions?.apiKey, undefined); - const response = await checkVeoOperation( + const response = await veoCheckOperation( apiKey, operation.id, clientOptions diff --git a/js/plugins/google-genai/src/vertexai/client.ts b/js/plugins/google-genai/src/vertexai/client.ts index 2683b57d99..a13e1728bb 100644 --- a/js/plugins/google-genai/src/vertexai/client.ts +++ b/js/plugins/google-genai/src/vertexai/client.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { GENKIT_CLIENT_HEADER } from 'genkit'; +import { GENKIT_CLIENT_HEADER, GenkitError } from 'genkit'; import { GoogleAuth } from 'google-auth-library'; import { extractErrMsg } from '../common/utils'; import { @@ -137,19 +137,6 @@ export async function imagenPredict( return response.json() as Promise; } -// TODO(ifielker): update with 'global' and APIKey in the options. -// See genai SDK for how to handle the apiKey -// if ( -// this.clientOptions.project && -// this.clientOptions.location && -// this.clientOptions.location !== 'global' -// ) { -// // Regional endpoint -// return `https://${this.clientOptions.location}-aiplatform.googleapis.com/`; -// } -// // Global endpoint (covers 'global' location and API key usage) -// return `https://aiplatform.googleapis.com/`; - export function getVertexAIUrl(params: { includeProjectAndLocation: boolean; // False for listModels, true for most others resourcePath: string; @@ -160,11 +147,19 @@ export function getVertexAIUrl(params: { const DEFAULT_API_VERSION = 'v1beta1'; const API_BASE_PATH = 'aiplatform.googleapis.com'; - const region = params.clientOptions.location || 'us-central1'; - const basePath = `${region}-${API_BASE_PATH}`; + let basePath: string; + + if (params.clientOptions.kind == 'regional') { + basePath = `${params.clientOptions.location}-${API_BASE_PATH}`; + } else { + basePath = API_BASE_PATH; + } let resourcePath = params.resourcePath; - if (params.includeProjectAndLocation) { + if ( + params.clientOptions.kind != 'express' && + params.includeProjectAndLocation + ) { const parent = `projects/${params.clientOptions.projectId}/locations/${params.clientOptions.location}`; resourcePath = `${parent}/${params.resourcePath}`; } @@ -173,24 +168,50 @@ export function getVertexAIUrl(params: { if (params.resourceMethod) { url += `:${params.resourceMethod}`; } + + let joiner = '?'; if (params.queryParams) { - url += `?${params.queryParams}`; + url += `${joiner}${params.queryParams}`; + joiner = '&'; } if (params.resourceMethod === 'streamGenerateContent') { - url += `${params.queryParams ? '&' : '?'}alt=sse`; + url += `${joiner}alt=sse`; + joiner = '&'; + } + if (params.clientOptions.kind == 'express' && !params.clientOptions.apiKey) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: 'missing api key', + }); + } + if (params.clientOptions.apiKey) { + // Required for express, optional for others + url += `${joiner}key=${params.clientOptions.apiKey}`; + joiner = '&'; } return url; } async function getHeaders(clientOptions: ClientOptions): Promise { - const token = await getToken(clientOptions.authClient); - const headers: HeadersInit = { - Authorization: `Bearer ${token}`, - 'x-goog-user-project': clientOptions.projectId, - 'Content-Type': 'application/json', - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }; - return headers; + if (clientOptions.kind == 'express') { + const headers: HeadersInit = { + // ApiKey is in the url query params + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + 'User-Agent': GENKIT_CLIENT_HEADER, + }; + return headers; + } else { + const token = await getToken(clientOptions.authClient); + const headers: HeadersInit = { + Authorization: `Bearer ${token}`, + 'x-goog-user-project': clientOptions.projectId, + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + 'User-Agent': GENKIT_CLIENT_HEADER, + }; + return headers; + } } async function getToken(authClient: GoogleAuth): Promise { diff --git a/js/plugins/google-genai/src/vertexai/converters.ts b/js/plugins/google-genai/src/vertexai/converters.ts index 72283d8ada..7e5aafb6b3 100644 --- a/js/plugins/google-genai/src/vertexai/converters.ts +++ b/js/plugins/google-genai/src/vertexai/converters.ts @@ -33,3 +33,25 @@ export function toGeminiSafetySettings( }; }); } + +export function toGeminiLabels( + labels?: Record +): Record | undefined { + if (!labels) { + return undefined; + } + const keys = Object.keys(labels); + const newLabels: Record = {}; + for (const key of keys) { + const value = labels[key]; + if (!key) { + continue; + } + newLabels[key] = value; + } + + if (Object.keys(newLabels).length == 0) { + return undefined; + } + return newLabels; +} diff --git a/js/plugins/google-genai/src/vertexai/gemini.ts b/js/plugins/google-genai/src/vertexai/gemini.ts index 1887a08155..444bf75c09 100644 --- a/js/plugins/google-genai/src/vertexai/gemini.ts +++ b/js/plugins/google-genai/src/vertexai/gemini.ts @@ -46,7 +46,7 @@ import { generateContentStream, getVertexAIUrl, } from './client'; -import { toGeminiSafetySettings } from './converters'; +import { toGeminiLabels, toGeminiSafetySettings } from './converters'; import { ClientOptions, Content, @@ -59,6 +59,7 @@ import { ToolConfig, VertexPluginOptions, } from './types'; +import { calculateApiKey } from './utils'; export const SafetySettingsSchema = z.object({ category: z.enum([ @@ -123,6 +124,14 @@ const GoogleSearchRetrievalSchema = z.object({ * Please refer to: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#generationconfig, for further information. */ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ + apiKey: z + .string() + .describe('Overrides the plugin-configured API key, if specified.') + .optional(), + labels: z + .record(z.string()) + .optional() + .describe('Key-value labels to attach to the request for cost tracking.'), temperature: z .number() .min(0.0) @@ -419,6 +428,8 @@ export function defineModel( use: middlewares, }, async (request, sendChunk) => { + let clientOpt = { ...clientOptions }; + // Make a copy of messages to avoid side-effects const messages = structuredClone(request.messages); if (messages.length === 0) throw new Error('No messages provided.'); @@ -431,19 +442,57 @@ export function defineModel( systemInstruction = toGeminiSystemInstruction(systemMessage); } - const requestConfig = request.config as GeminiConfig; + const requestConfig = { ...request.config }; const { + apiKey: apiKeyFromConfig, functionCallingConfig, version: versionFromConfig, googleSearchRetrieval, tools: toolsFromConfig, vertexRetrieval, - location, // location can be overridden via config, take it out. + location, safetySettings, + labels: labelsFromConfig, ...restOfConfig } = requestConfig; + if ( + location && + clientOptions.kind != 'express' && + clientOptions.location != location + ) { + // Override the location if it's specified in the request + if (location == 'global') { + clientOpt = { + kind: 'global', + location: 'global', + projectId: clientOptions.projectId, + authClient: clientOptions.authClient, + apiKey: clientOptions.apiKey, + }; + } else { + clientOpt = { + kind: 'regional', + location, + projectId: clientOptions.projectId, + authClient: clientOptions.authClient, + apiKey: clientOptions.apiKey, + }; + } + } + if (clientOptions.kind == 'express') { + clientOpt.apiKey = calculateApiKey( + clientOptions.apiKey, + apiKeyFromConfig + ); + } else if (apiKeyFromConfig) { + // Regional or Global can still use APIKey for billing (not auth) + clientOpt.apiKey = apiKeyFromConfig; + } + + const labels = toGeminiLabels(labelsFromConfig); + const tools: Tool[] = []; if (request.tools?.length) { tools.push({ @@ -492,10 +541,23 @@ export function defineModel( if (vertexRetrieval) { const _projectId = - vertexRetrieval.datastore.projectId || clientOptions.projectId; + vertexRetrieval.datastore.projectId || + (clientOptions.kind != 'express' + ? clientOptions.projectId + : undefined); const _location = - vertexRetrieval.datastore.location || clientOptions.location; + vertexRetrieval.datastore.location || + (clientOptions.kind == 'regional' + ? clientOptions.location + : undefined); const _dataStoreId = vertexRetrieval.datastore.dataStoreId; + if (!_projectId || !_location || !_dataStoreId) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: + 'projectId, location and datastoreId are required for vertexRetrieval and could not be determined from configuration', + }); + } const datastore = `projects/${_projectId}/locations/${_location}/collections/default_collection/dataStores/${_dataStoreId}`; tools.push({ retrieval: { @@ -518,6 +580,7 @@ export function defineModel( toolConfig, safetySettings: toGeminiSafetySettings(safetySettings), contents: messages.map((message) => toGeminiMessage(message, ref)), + labels, }; const modelVersion = versionFromConfig || (ref.version as string); @@ -536,7 +599,7 @@ export function defineModel( const result = await generateContentStream( modelVersion, generateContentRequest, - clientOptions + clientOpt ); for await (const item of result.stream) { @@ -555,7 +618,7 @@ export function defineModel( response = await generateContent( modelVersion, generateContentRequest, - clientOptions + clientOpt ); } diff --git a/js/plugins/google-genai/src/vertexai/index.ts b/js/plugins/google-genai/src/vertexai/index.ts index d2dec7671f..1a45571e74 100644 --- a/js/plugins/google-genai/src/vertexai/index.ts +++ b/js/plugins/google-genai/src/vertexai/index.ts @@ -68,8 +68,8 @@ async function resolver( } async function listActions(options?: VertexPluginOptions) { - const clientOptions = await getDerivedOptions(options); try { + const clientOptions = await getDerivedOptions(options); const models = await listModels(clientOptions); return [ ...gemini.listActions(models), diff --git a/js/plugins/google-genai/src/vertexai/types.ts b/js/plugins/google-genai/src/vertexai/types.ts index 56adcb2af1..5d492929a8 100644 --- a/js/plugins/google-genai/src/vertexai/types.ts +++ b/js/plugins/google-genai/src/vertexai/types.ts @@ -17,8 +17,10 @@ import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; import { CitationMetadata, + CodeExecutionTool, Content, FunctionCallingMode, + FunctionDeclarationsTool, GenerateContentCandidate, GenerateContentRequest, GenerateContentResponse, @@ -33,10 +35,15 @@ import { ImagenPredictRequest, ImagenPredictResponse, ImagenPrediction, + RetrievalTool, TaskType, TaskTypeSchema, Tool, ToolConfig, + isCodeExecutionTool, + isFunctionDeclarationsTool, + isGoogleSearchRetrievalTool, + isRetrievalTool, } from '../common/types'; // This makes it easier to import all types from one place @@ -45,8 +52,14 @@ export { HarmBlockThreshold, HarmCategory, TaskTypeSchema, + isCodeExecutionTool, + isFunctionDeclarationsTool, + isGoogleSearchRetrievalTool, + isRetrievalTool, type CitationMetadata, + type CodeExecutionTool, type Content, + type FunctionDeclarationsTool, type GenerateContentCandidate, type GenerateContentRequest, type GenerateContentResponse, @@ -59,33 +72,58 @@ export { type ImagenPredictRequest, type ImagenPredictResponse, type ImagenPrediction, + type RetrievalTool, type Tool, type ToolConfig, }; /** Options for Vertex AI plugin configuration */ export interface VertexPluginOptions { + /** The Vertex API key for express mode */ + apiKey?: string | false; /** The Google Cloud project id to call. */ projectId?: string; /** The Google Cloud region to call. */ - location: string; + location?: string; /** Provide custom authentication configuration for connecting to Vertex AI. */ googleAuth?: GoogleAuthOptions; /** Enables additional debug traces (e.g. raw model API call details). */ experimental_debugTraces?: boolean; } -/** Resolved options for use with the client */ -export interface ClientOptions { +export interface RegionalClientOptions { + kind: 'regional'; location: string; projectId: string; authClient: GoogleAuth; + apiKey?: string; // In addition to regular auth +} + +export interface GlobalClientOptions { + kind: 'global'; + location: 'global'; + projectId: string; + authClient: GoogleAuth; + apiKey?: string; // In addition to regular auth +} + +export interface ExpressClientOptions { + kind: 'express'; + apiKey: string | false | undefined; // Instead of regular auth } +/** Resolved options for use with the client */ +export type ClientOptions = + | RegionalClientOptions + | GlobalClientOptions + | ExpressClientOptions; + /** * Request options params. */ export interface RequestOptions { + /** an apiKey to use for this request if applicable */ + apiKey?: string | false | undefined; /** timeout in milli seconds. time out value needs to be non negative. */ timeout?: number; /** diff --git a/js/plugins/google-genai/src/vertexai/utils.ts b/js/plugins/google-genai/src/vertexai/utils.ts index 3b5d419b62..e12798a788 100644 --- a/js/plugins/google-genai/src/vertexai/utils.ts +++ b/js/plugins/google-genai/src/vertexai/utils.ts @@ -14,11 +14,15 @@ * limitations under the License. */ +import { GenkitError } from 'genkit'; import { GenerateRequest } from 'genkit/model'; import { GoogleAuth } from 'google-auth-library'; import type { ClientOptions, + ExpressClientOptions, + GlobalClientOptions, ImagenInstance, + RegionalClientOptions, VertexPluginOptions, } from './types'; @@ -54,6 +58,115 @@ export async function getDerivedOptions( if (__mockDerivedOptions) { return Promise.resolve(__mockDerivedOptions); } + + // Figure out the type of preferred options if possible + // The order of the if statements is important. + if (options?.location == 'global') { + return await getGlobalDerivedOptions(AuthClass, options); + } else if (options?.location) { + return await getRegionalDerivedOptions(AuthClass, options); + } else if (options?.apiKey !== undefined) { + // apiKey = false still indicates apiKey expectation + return getExpressDerivedOptions(options); + } + + // If we got here then we're relying on environment variables. + // Try regional first, it's the most common usage. + try { + const regionalOptions = await getRegionalDerivedOptions(AuthClass, options); + return regionalOptions; + } catch (e: unknown) { + /* no-op - try global next */ + } + try { + const globalOptions = await getGlobalDerivedOptions(AuthClass, options); + return globalOptions; + } catch (e: unknown) { + /* no-op - try express last */ + } + try { + const expressOptions = getExpressDerivedOptions(options); + return expressOptions; + } catch (e: unknown) { + /* no-op */ + } + + // We did not have enough information in the options or in environment variables + // to properly determine client options. + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: + 'Unable to determine client options. Please set either apiKey or projectId and location', + }); +} + +async function getGlobalDerivedOptions( + AuthClass: typeof GoogleAuth, + options?: VertexPluginOptions +): Promise { + let authOptions = options?.googleAuth; + let authClient: GoogleAuth; + const providedProjectId = + options?.projectId || + process.env.GCLOUD_PROJECT || + parseFirebaseProjectId(); + if (process.env.GCLOUD_SERVICE_ACCOUNT_CREDS) { + const serviceAccountCreds = JSON.parse( + process.env.GCLOUD_SERVICE_ACCOUNT_CREDS + ); + authOptions = { + credentials: serviceAccountCreds, + scopes: [CLOUD_PLATFORM_OAUTH_SCOPE], + projectId: providedProjectId, + }; + authClient = new AuthClass(authOptions); + } else { + authClient = new AuthClass( + authOptions ?? { + scopes: [CLOUD_PLATFORM_OAUTH_SCOPE], + projectId: providedProjectId, + } + ); + } + + const projectId = + options?.projectId || + process.env.GCLOUD_PROJECT || + (await authClient.getProjectId()); + + if (!projectId) { + throw new Error( + `VertexAI Plugin is missing the 'project' configuration. Please set the 'GCLOUD_PROJECT' environment variable or explicitly pass 'project' into genkit config.` + ); + } + + const clientOpt: GlobalClientOptions = { + kind: 'global', + location: 'global', + projectId, + authClient, + }; + if (options?.apiKey) { + clientOpt.apiKey = options.apiKey; + } + + return clientOpt; +} + +function getExpressDerivedOptions( + options?: VertexPluginOptions +): ExpressClientOptions { + const apiKey = checkApiKey(options?.apiKey); + return { + kind: 'express', + apiKey, + }; +} + +async function getRegionalDerivedOptions( + AuthClass: typeof GoogleAuth, + options?: VertexPluginOptions +): Promise { let authOptions = options?.googleAuth; let authClient: GoogleAuth; const providedProjectId = @@ -97,11 +210,111 @@ export async function getDerivedOptions( ); } - return { + const clientOpt: RegionalClientOptions = { + kind: 'regional', location, projectId, authClient, }; + if (options?.apiKey) { + clientOpt.apiKey = options.apiKey; + } + return clientOpt; +} + +/** + * Retrieves an API key from environment variables. + * + * @returns The API key as a string, or `undefined` if none of the specified + * environment variables are set. + */ +export function getApiKeyFromEnvVar(): string | undefined { + return ( + process.env.VERTEX_API_KEY || + process.env.GOOGLE_API_KEY || + process.env.GOOGLE_GENAI_API_KEY + ); +} + +export const MISSING_API_KEY_ERROR = new GenkitError({ + status: 'FAILED_PRECONDITION', + message: + 'Please pass in the API key or set the VERTEX_API_KEY or GOOGLE_API_KEY environment variable.\n' + + 'For more details see https://firebase.google.com/docs/genkit/plugins/google-genai', +}); + +export const API_KEY_FALSE_ERROR = new GenkitError({ + status: 'INVALID_ARGUMENT', + message: + 'VertexAI plugin was initialized with {apiKey: false} but no apiKey configuration was passed at call time.', +}); + +/** + * Checks and retrieves an API key based on the provided argument and environment variables. + * + * - If `pluginApiKey` is a non-empty string, it's used as the API key. + * - If `pluginApiKey` is `undefined` or an empty string, it attempts to fetch the API key from environment + * - If `pluginApiKey` is `false`, key retrieval from the environment is skipped, and the function + * will return `undefined`. This mode indicates that the API key is expected to be provided + * at a later stage or in a different context. + * + * @param pluginApiKey - An optional API key string, `undefined` to check the environment, or `false` to bypass all checks in this function. + * @returns The resolved API key as a string, or `undefined` if `pluginApiKey` is `false`. + * @throws {Error} MISSING_API_KEY_ERROR - Thrown if `pluginApiKey` is not `false` and no API key + * can be found either in the `pluginApiKey` argument or from the environment. + */ +export function checkApiKey( + pluginApiKey: string | false | undefined +): string | undefined { + let apiKey: string | undefined; + + // Don't get the key from the environment if pluginApiKey is false + if (pluginApiKey !== false) { + apiKey = pluginApiKey || getApiKeyFromEnvVar(); + } + + // If pluginApiKey is false, then we don't throw because we are waiting for + // the apiKey passed into the individual call + if (pluginApiKey !== false && !apiKey) { + throw MISSING_API_KEY_ERROR; + } + return apiKey; +} + +/** + * Calculates and returns the effective API key based on multiple potential sources. + * The order of precedence for determining the API key is: + * 1. `requestApiKey` (if provided) + * 2. `pluginApiKey` (if provided and not `false`) + * 3. Environment variable (if `pluginApiKey` is not `false` and `pluginApiKey` is not provided) + * + * @param pluginApiKey - The apiKey value provided during plugin initialization. + * @param requestApiKey - The apiKey provided to an individual generate call. + * @returns The resolved API key as a string. + * @throws {Error} API_KEY_FALSE_ERROR - Thrown if `pluginApiKey` is `false` and `requestApiKey` is not provided + * @throws {Error} MISSING_API_KEY_ERROR - Thrown if no API key can be resolved from any source + */ +export function calculateApiKey( + pluginApiKey: string | false | undefined, + requestApiKey: string | undefined +): string { + let apiKey: string | undefined; + + // Don't get the key from the environment if pluginApiKey is false + if (pluginApiKey !== false) { + apiKey = pluginApiKey || getApiKeyFromEnvVar(); + } + + apiKey = requestApiKey || apiKey; + + if (pluginApiKey === false && !requestApiKey) { + throw API_KEY_FALSE_ERROR; + } + + if (!apiKey) { + throw MISSING_API_KEY_ERROR; + } + return apiKey; } export function extractImagenMask( diff --git a/js/plugins/google-genai/tests/common/utils_test.ts b/js/plugins/google-genai/tests/common/utils_test.ts index 262d2476b6..529fe2af24 100644 --- a/js/plugins/google-genai/tests/common/utils_test.ts +++ b/js/plugins/google-genai/tests/common/utils_test.ts @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,28 +15,236 @@ */ import * as assert from 'assert'; -import { ModelReference } from 'genkit'; +import { GenkitError } from 'genkit'; +import { GenerateRequest } from 'genkit/model'; import { describe, it } from 'node:test'; -import { cleanSchema } from '../../src/common/utils'; - -// Mock ModelReference for testing nearestModelRef -const createMockModelRef = ( - name: string, - version?: string -): ModelReference => { - return { - name, - info: { label: `Model ${name}`, supports: {} }, - config: { version }, - withConfig: function (newConfig) { - // Return a new object to mimic immutability - return { - ...this, - config: { ...this.config, ...newConfig }, - }; - }, - } as any; -}; +import { + checkModelName, + cleanSchema, + extractErrMsg, + extractImagenImage, + extractText, + modelName, +} from '../../src/common/utils'; + +describe('extractErrMsg', () => { + it('extracts message from an Error object', () => { + const error = new Error('This is a test error.'); + assert.strictEqual(extractErrMsg(error), 'This is a test error.'); + }); + + it('returns the string if error is a string', () => { + const error = 'A simple string error.'; + assert.strictEqual(extractErrMsg(error), 'A simple string error.'); + }); + + it('stringifies other error types', () => { + const error = { code: 500, message: 'Object error' }; + assert.strictEqual( + extractErrMsg(error), + '{"code":500,"message":"Object error"}' + ); + }); + + it('provides a default message for unknown types', () => { + // Note: The function returns undefined for an undefined input because + // JSON.stringify(undefined) results in undefined. + assert.strictEqual(extractErrMsg(undefined), undefined); + assert.strictEqual(extractErrMsg(null), 'null'); + }); +}); + +describe('modelName', () => { + it('extracts model name from a full path', () => { + const name = 'models/googleai/gemini-2.5-pro'; + assert.strictEqual(modelName(name), 'gemini-2.5-pro'); + }); + + it('returns the name if no path is present', () => { + const name = 'gemini-1.5-flash'; + assert.strictEqual(modelName(name), 'gemini-1.5-flash'); + }); + + it('handles undefined input', () => { + assert.strictEqual(modelName(undefined), undefined); + }); + + it('handles empty string input', () => { + assert.strictEqual(modelName(''), ''); + }); +}); + +describe('checkModelName', () => { + it('extracts model name from a full path', () => { + const name = 'models/vertexai/gemini-2.0-pro'; + assert.strictEqual(checkModelName(name), 'gemini-2.0-pro'); + }); + + it('throws an error for undefined input', () => { + assert.throws( + () => checkModelName(undefined), + (err: GenkitError) => { + assert.strictEqual(err.status, 'INVALID_ARGUMENT'); + assert.strictEqual( + err.message, + 'INVALID_ARGUMENT: Model name is required.' + ); + return true; + } + ); + }); + + it('throws an error for an empty string', () => { + assert.throws( + () => checkModelName(''), + (err: GenkitError) => { + assert.strictEqual(err.status, 'INVALID_ARGUMENT'); + assert.strictEqual( + err.message, + 'INVALID_ARGUMENT: Model name is required.' + ); + return true; + } + ); + }); +}); + +describe('extractText', () => { + it('extracts text from the last message', () => { + const request: GenerateRequest = { + messages: [ + { role: 'user', content: [{ text: 'Hello there.' }] }, + { role: 'model', content: [{ text: 'How can I help?' }] }, + { role: 'user', content: [{ text: 'Tell me a joke.' }] }, + ], + config: {}, + }; + assert.strictEqual(extractText(request), 'Tell me a joke.'); + }); + + it('concatenates multiple text parts', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [{ text: 'Part 1. ' }, { text: 'Part 2.' }], + }, + ], + config: {}, + }; + assert.strictEqual(extractText(request), 'Part 1. Part 2.'); + }); + + it('returns an empty string if there are no text parts', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [ + { + media: { + url: '', + contentType: 'image/jpeg', + }, + }, + ], + }, + ], + config: {}, + }; + assert.strictEqual(extractText(request), ''); + }); + + it('returns an empty string if there are no messages', () => { + const request: GenerateRequest = { + messages: [], + config: {}, + }; + assert.strictEqual(extractText(request), ''); + }); +}); + +describe('extractImagenImage', () => { + it('extracts a base64 encoded image', () => { + const base64Image = '/9j/4AAQSkZJRg...'; + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [ + { text: 'Create an image.' }, + { + media: { + url: `data:image/jpeg;base64,${base64Image}`, + contentType: 'image/jpeg', + }, + }, + ], + }, + ], + config: {}, + }; + const result = extractImagenImage(request); + assert.deepStrictEqual(result, { bytesBase64Encoded: base64Image }); + }); + + it('returns undefined if no image part exists', () => { + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'Hello' }] }], + config: {}, + }; + assert.strictEqual(extractImagenImage(request), undefined); + }); + + it('returns undefined if the media part is not a base64 data URI', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [ + { + media: { + url: 'http://example.com/image.jpg', + contentType: 'image/jpeg', + }, + }, + ], + }, + ], + config: {}, + }; + assert.strictEqual(extractImagenImage(request), undefined); + }); + + it('ignores parts with metadata type not equal to "base"', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [ + { + media: { + url: '', + contentType: 'image/png', + }, + metadata: { type: 'mask' }, + }, + ], + }, + ], + config: {}, + }; + assert.strictEqual(extractImagenImage(request), undefined); + }); + + it('returns undefined for an empty message list', () => { + const request: GenerateRequest = { + messages: [], + config: {}, + }; + assert.strictEqual(extractImagenImage(request), undefined); + }); +}); describe('cleanSchema', () => { it('strips $schema and additionalProperties', () => { diff --git a/js/plugins/google-genai/tests/googleai/veo_test.ts b/js/plugins/google-genai/tests/googleai/veo_test.ts index 54e2bdbab7..16027a412c 100644 --- a/js/plugins/google-genai/tests/googleai/veo_test.ts +++ b/js/plugins/google-genai/tests/googleai/veo_test.ts @@ -329,7 +329,7 @@ describe('Google AI Veo', () => { const operationId = 'operations/check123'; const pendingOp: Operation = { id: operationId, done: false }; - it('should call fetch for checkVeoOperation and return updated operation', async () => { + it('should call fetch for veoCheckOperation and return updated operation', async () => { const mockResponse: VeoOperation = { name: operationId, done: true, diff --git a/js/plugins/google-genai/tests/vertexai/client_test.ts b/js/plugins/google-genai/tests/vertexai/client_test.ts index d98ced0aca..ec3e76f1fc 100644 --- a/js/plugins/google-genai/tests/vertexai/client_test.ts +++ b/js/plugins/google-genai/tests/vertexai/client_test.ts @@ -45,17 +45,32 @@ describe('Vertex AI Client', () => { let fetchSpy: sinon.SinonStub; let authMock: sinon.SinonStubbedInstance; - const clientOptions: ClientOptions = { + const regionalClientOptions: ClientOptions = { + kind: 'regional', projectId: 'test-project', location: 'us-central1', authClient: {} as GoogleAuth, // Will be replaced by mock }; + const globalClientOptions: ClientOptions = { + kind: 'global', + projectId: 'test-project', + location: 'global', + authClient: {} as GoogleAuth, // Will be replaced by mock + }; + + const expressClientOptions: ClientOptions = { + kind: 'express', + apiKey: 'test-api-key', + }; + beforeEach(() => { fetchSpy = sinon.stub(global, 'fetch'); authMock = sinon.createStubInstance(GoogleAuth); authMock.getAccessToken.resolves('test-token'); - clientOptions.authClient = authMock as unknown as GoogleAuth; + (regionalClientOptions as any).authClient = + authMock as unknown as GoogleAuth; + (globalClientOptions as any).authClient = authMock as unknown as GoogleAuth; }); afterEach(() => { @@ -84,74 +99,156 @@ describe('Vertex AI Client', () => { } describe('getVertexAIUrl', () => { - const opts: ClientOptions = { - projectId: 'test-proj', - location: 'us-east1', - authClient: {} as any, - }; - - it('should build URL for listModels', () => { - const url = getVertexAIUrl({ - includeProjectAndLocation: false, - resourcePath: 'publishers/google/models', - clientOptions: opts, + describe('Regional', () => { + const opts: ClientOptions = { + kind: 'regional', + projectId: 'test-proj', + location: 'us-east1', + authClient: {} as any, + }; + + it('should build URL for listModels', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: false, + resourcePath: 'publishers/google/models', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://us-east1-aiplatform.googleapis.com/v1beta1/publishers/google/models' + ); }); - assert.strictEqual( - url, - 'https://us-east1-aiplatform.googleapis.com/v1beta1/publishers/google/models' - ); - }); - it('should build URL for generateContent', () => { - const url = getVertexAIUrl({ - includeProjectAndLocation: true, - resourcePath: 'publishers/google/models/gemini-2.0-pro', - resourceMethod: 'generateContent', - clientOptions: opts, + it('should build URL for generateContent', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: 'publishers/google/models/gemini-2.0-pro', + resourceMethod: 'generateContent', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://us-east1-aiplatform.googleapis.com/v1beta1/projects/test-proj/locations/us-east1/publishers/google/models/gemini-2.0-pro:generateContent' + ); }); - assert.strictEqual( - url, - 'https://us-east1-aiplatform.googleapis.com/v1beta1/projects/test-proj/locations/us-east1/publishers/google/models/gemini-2.0-pro:generateContent' - ); - }); - it('should build URL for streamGenerateContent', () => { - const url = getVertexAIUrl({ - includeProjectAndLocation: true, - resourcePath: 'publishers/google/models/gemini-2.5-flash', - resourceMethod: 'streamGenerateContent', - clientOptions: opts, + it('should build URL for streamGenerateContent', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: 'publishers/google/models/gemini-2.5-flash', + resourceMethod: 'streamGenerateContent', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://us-east1-aiplatform.googleapis.com/v1beta1/projects/test-proj/locations/us-east1/publishers/google/models/gemini-2.5-flash:streamGenerateContent?alt=sse' + ); + }); + + it('should handle queryParams', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: false, + resourcePath: 'publishers/google/models', + clientOptions: opts, + queryParams: 'pageSize=10', + }); + assert.strictEqual( + url, + 'https://us-east1-aiplatform.googleapis.com/v1beta1/publishers/google/models?pageSize=10' + ); }); - assert.strictEqual( - url, - 'https://us-east1-aiplatform.googleapis.com/v1beta1/projects/test-proj/locations/us-east1/publishers/google/models/gemini-2.5-flash:streamGenerateContent?alt=sse' - ); }); - it('should build URL for predict (embedContent)', () => { - const url = getVertexAIUrl({ - includeProjectAndLocation: true, - resourcePath: 'publishers/google/models/text-embedding-005', - resourceMethod: 'predict', - clientOptions: opts, + describe('Global', () => { + const opts: ClientOptions = { + kind: 'global', + projectId: 'test-proj', + location: 'global', + authClient: {} as any, + }; + + it('should build URL for listModels', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: false, + resourcePath: 'publishers/google/models', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://aiplatform.googleapis.com/v1beta1/publishers/google/models' + ); + }); + + it('should build URL for generateContent', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: 'publishers/google/models/gemini-2.0-pro', + resourceMethod: 'generateContent', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://aiplatform.googleapis.com/v1beta1/projects/test-proj/locations/global/publishers/google/models/gemini-2.0-pro:generateContent' + ); }); - assert.strictEqual( - url, - 'https://us-east1-aiplatform.googleapis.com/v1beta1/projects/test-proj/locations/us-east1/publishers/google/models/text-embedding-005:predict' - ); }); - it('should handle queryParams', () => { - const url = getVertexAIUrl({ - includeProjectAndLocation: false, - resourcePath: 'publishers/google/models', - clientOptions: opts, - queryParams: 'pageSize=10', + describe('Express', () => { + const opts: ClientOptions = { + kind: 'express', + apiKey: 'test-api-key', + }; + + it('should build URL for listModels', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: false, + resourcePath: 'publishers/google/models', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://aiplatform.googleapis.com/v1beta1/publishers/google/models?key=test-api-key' + ); + }); + + it('should build URL for generateContent', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: true, // This is ignored for express + resourcePath: 'publishers/google/models/gemini-2.0-pro', + resourceMethod: 'generateContent', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://aiplatform.googleapis.com/v1beta1/publishers/google/models/gemini-2.0-pro:generateContent?key=test-api-key' + ); + }); + + it('should build URL for streamGenerateContent', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: true, // Ignored + resourcePath: 'publishers/google/models/gemini-2.5-flash', + resourceMethod: 'streamGenerateContent', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://aiplatform.googleapis.com/v1beta1/publishers/google/models/gemini-2.5-flash:streamGenerateContent?alt=sse&key=test-api-key' + ); + }); + + it('should handle queryParams', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: false, + resourcePath: 'publishers/google/models', + clientOptions: opts, + queryParams: 'pageSize=10', + }); + assert.strictEqual( + url, + 'https://aiplatform.googleapis.com/v1beta1/publishers/google/models?pageSize=10&key=test-api-key' + ); }); - assert.strictEqual( - url, - 'https://us-east1-aiplatform.googleapis.com/v1beta1/publishers/google/models?pageSize=10' - ); }); }); @@ -159,47 +256,231 @@ describe('Vertex AI Client', () => { it('should throw a specific error if getToken fails', async () => { authMock.getAccessToken.rejects(new Error('Auth failed')); await assert.rejects( - listModels(clientOptions), + listModels(regionalClientOptions), /Unable to authenticate your request/ ); }); }); - describe('listModels', () => { - it('should return a list of models', async () => { - const mockModels: Model[] = [ - { name: 'gemini-2.0-pro', launchStage: 'GA' }, - { name: 'gemini-2.5-flash', launchStage: 'GA' }, - ]; - mockFetchResponse({ publisherModels: mockModels }); - - const result = await listModels(clientOptions); - assert.deepStrictEqual(result, mockModels); - - const expectedUrl = - 'https://us-central1-aiplatform.googleapis.com/v1beta1/publishers/google/models'; - sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { - method: 'GET', - headers: { - Authorization: 'Bearer test-token', - 'x-goog-user-project': 'test-project', - 'Content-Type': 'application/json', - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }, + describe('API Calls', () => { + const testCases = [ + { name: 'Regional', options: regionalClientOptions }, + { name: 'Global', options: globalClientOptions }, + { name: 'Express', options: expressClientOptions }, + ]; + + for (const testCase of testCases) { + describe(`${testCase.name} Client - kind: ${testCase.options.kind}`, () => { + const currentOptions = testCase.options; + const isExpress = currentOptions.kind === 'express'; + const location = + currentOptions.kind === 'regional' + ? currentOptions.location + : 'global'; + const projectId = + currentOptions.kind !== 'express' ? currentOptions.projectId : ''; + + const getExpectedHeaders = () => { + const headers: Record = { + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + 'User-Agent': GENKIT_CLIENT_HEADER, + }; + if (isExpress) { + return headers; + } + return { + ...headers, + Authorization: 'Bearer test-token', + 'x-goog-user-project': projectId, + }; + }; + + const getBaseUrl = (path: string) => { + const keySuffix = isExpress ? `?key=${currentOptions.apiKey}` : ''; + if (isExpress) { + return `https://aiplatform.googleapis.com/v1beta1/${path}${keySuffix}`; + } + const domain = + currentOptions.kind === 'regional' + ? `${location}-aiplatform.googleapis.com` + : 'aiplatform.googleapis.com'; + return `https://${domain}/v1beta1/${path}`; + }; + + const getResourceUrl = (model: string, method: string) => { + const isStreaming = method.includes('streamGenerateContent'); + let url; + + if (isExpress) { + url = `https://aiplatform.googleapis.com/v1beta1/publishers/google/models/${model}:${method}`; + if (isStreaming) { + url += `?alt=sse&key=${currentOptions.apiKey}`; + } else { + url += `?key=${currentOptions.apiKey}`; + } + } else { + const domain = + currentOptions.kind === 'regional' + ? `${location}-aiplatform.googleapis.com` + : 'aiplatform.googleapis.com'; + const projectLocationPrefix = `projects/${projectId}/locations/${location}`; + url = `https://${domain}/v1beta1/${projectLocationPrefix}/publishers/google/models/${model}:${method}`; + if (isStreaming) { + url += `?alt=sse`; + } + } + return url; + }; + + describe('listModels', () => { + it('should return a list of models', async () => { + const mockModels: Model[] = [ + { name: 'gemini-2.0-pro', launchStage: 'GA' }, + ]; + mockFetchResponse({ publisherModels: mockModels }); + + const result = await listModels(currentOptions); + assert.deepStrictEqual(result, mockModels); + + const expectedUrl = getBaseUrl('publishers/google/models'); + sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { + method: 'GET', + headers: getExpectedHeaders(), + }); + + // Corrected assertions using sinon.assert: + if (!isExpress) { + sinon.assert.calledOnce(authMock.getAccessToken); + } else { + sinon.assert.notCalled(authMock.getAccessToken); + } + }); + }); + + describe('generateContent', () => { + const request: GenerateContentRequest = { + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + }; + const model = 'gemini-2.0-pro'; + + it('should return GenerateContentResponse', async () => { + const mockResponse: GenerateContentResponse = { candidates: [] }; + mockFetchResponse(mockResponse); + + const result = await generateContent( + model, + request, + currentOptions + ); + assert.deepStrictEqual(result, mockResponse); + + const expectedUrl = getResourceUrl(model, 'generateContent'); + sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { + method: 'POST', + headers: getExpectedHeaders(), + body: JSON.stringify(request), + }); + }); + + it('should throw on API error', async () => { + const errorResponse = { error: { message: 'Permission denied' } }; + mockFetchResponse(errorResponse, false, 403, 'Forbidden'); + + await assert.rejects( + generateContent(model, request, currentOptions), + /Failed to fetch from .* \[403 Forbidden\] Permission denied/ + ); + }); + }); + + describe('embedContent', () => { + const request: EmbedContentRequest = { + instances: [{ content: 'test content' }], + parameters: {}, + }; + const model = 'text-embedding-005'; + + it('should return EmbedContentResponse', async () => { + const mockResponse: EmbedContentResponse = { predictions: [] }; + mockFetchResponse(mockResponse); + + await embedContent(model, request, currentOptions); + const expectedUrl = getResourceUrl(model, 'predict'); + sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { + method: 'POST', + headers: getExpectedHeaders(), + body: JSON.stringify(request), + }); + }); + }); + + describe('imagenPredict', () => { + const request: ImagenPredictRequest = { + instances: [{ prompt: 'a cat' }], + parameters: { sampleCount: 1 }, + }; + const model = 'imagen-3.0-generate-002'; + + it('should return ImagenPredictResponse', async () => { + const mockResponse: ImagenPredictResponse = { predictions: [] }; + mockFetchResponse(mockResponse); + await imagenPredict(model, request, currentOptions); + + const expectedUrl = getResourceUrl(model, 'predict'); + sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { + method: 'POST', + headers: getExpectedHeaders(), + body: JSON.stringify(request), + }); + }); + }); + + describe('generateContentStream', () => { + it('should process stream', async () => { + const request: GenerateContentRequest = { + contents: [{ role: 'user', parts: [{ text: 'stream' }] }], + }; + const chunks = [ + 'data: {"candidates": [{"index": 0, "content": {"role": "model", "parts": [{"text": "Hello "}]}}]}\n\n', + ]; + const stream = new ReadableStream({ + start(controller) { + for (const chunk of chunks) { + controller.enqueue(new TextEncoder().encode(chunk)); + } + controller.close(); + }, + }); + fetchSpy.resolves( + new Response(stream, { + headers: { 'Content-Type': 'application/json' }, + }) + ); + + await generateContentStream( + 'gemini-2.5-flash', + request, + currentOptions + ); + + const expectedUrl = getResourceUrl( + 'gemini-2.5-flash', + 'streamGenerateContent' + ); + sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { + method: 'POST', + headers: getExpectedHeaders(), + body: JSON.stringify(request), + }); + }); + }); }); - }); - - it('should throw an error if fetch fails with JSON error', async () => { - const errorResponse = { error: { message: 'Internal Error' } }; - mockFetchResponse(errorResponse, false, 500, 'Internal Server Error'); - - await assert.rejects( - listModels(clientOptions), - /Failed to fetch from .* \[500 Internal Server Error\] Internal Error/ - ); - }); + } + }); - it('should throw an error if fetch fails with non-JSON error', async () => { + describe('Error Handling Extras', () => { + it('listModels should throw an error if fetch fails with non-JSON error', async () => { mockFetchResponse( '

Gateway Timeout

', false, @@ -209,166 +490,30 @@ describe('Vertex AI Client', () => { ); await assert.rejects( - listModels(clientOptions), + listModels(regionalClientOptions), /Failed to fetch from .* \[504 Gateway Timeout\]

Gateway Timeout<\/h1>/ ); }); - it('should throw an error if fetch fails with empty response body', async () => { + it('listModels should throw an error if fetch fails with empty response body', async () => { mockFetchResponse(null, false, 502, 'Bad Gateway'); await assert.rejects( - listModels(clientOptions), + listModels(regionalClientOptions), /Failed to fetch from .* \[502 Bad Gateway\] $/ ); }); - it('should throw an error on network failure', async () => { + it('listModels should throw an error on network failure', async () => { fetchSpy.rejects(new Error('Network Error')); await assert.rejects( - listModels(clientOptions), + listModels(regionalClientOptions), /Failed to fetch from .* Network Error/ ); }); }); - describe('generateContent', () => { - const request: GenerateContentRequest = { - contents: [{ role: 'user', parts: [{ text: 'hello' }] }], - }; - const model = 'gemini-2.0-pro'; - - it('should return GenerateContentResponse', async () => { - const mockResponse: GenerateContentResponse = { - candidates: [ - { index: 0, content: { role: 'model', parts: [{ text: 'world' }] } }, - ], - }; - mockFetchResponse(mockResponse); - - const result = await generateContent(model, request, clientOptions); - assert.deepStrictEqual(result, mockResponse); - - const expectedUrl = - 'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.0-pro:generateContent'; - sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { - method: 'POST', - headers: { - Authorization: 'Bearer test-token', - 'x-goog-user-project': 'test-project', - 'Content-Type': 'application/json', - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }, - body: JSON.stringify(request), - }); - }); - - it('should throw on API error with JSON body', async () => { - const errorResponse = { error: { message: 'Permission denied' } }; - mockFetchResponse(errorResponse, false, 403, 'Forbidden'); - - await assert.rejects( - generateContent(model, request, clientOptions), - /Failed to fetch from .* \[403 Forbidden\] Permission denied/ - ); - }); - - it('should throw on API error with non-JSON body', async () => { - mockFetchResponse('Forbidden', false, 403, 'Forbidden', 'text/plain'); - - await assert.rejects( - generateContent(model, request, clientOptions), - /Failed to fetch from .* \[403 Forbidden\] Forbidden/ - ); - }); - }); - - describe('embedContent', () => { - const request: EmbedContentRequest = { - instances: [{ content: 'test content' }], - parameters: {}, - }; - const model = 'text-embedding-005'; - - it('should return EmbedContentResponse', async () => { - const mockResponse: EmbedContentResponse = { - predictions: [ - { - embeddings: { - statistics: { truncated: false, token_count: 3 }, - values: [0.1, 0.2, 0.3], - }, - }, - ], - }; - mockFetchResponse(mockResponse); - - const result = await embedContent(model, request, clientOptions); - assert.deepStrictEqual(result, mockResponse); - - const expectedUrl = - 'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/text-embedding-005:predict'; - sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { - method: 'POST', - headers: { - Authorization: 'Bearer test-token', - 'x-goog-user-project': 'test-project', - 'Content-Type': 'application/json', - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }, - body: JSON.stringify(request), - }); - }); - - it('should throw on API error non-JSON', async () => { - mockFetchResponse('Not Found', false, 404, 'Not Found', 'text/plain'); - await assert.rejects( - embedContent(model, request, clientOptions), - /Failed to fetch from .* \[404 Not Found\] Not Found/ - ); - }); - }); - - describe('imagenPredict', () => { - const request: ImagenPredictRequest = { - instances: [{ prompt: 'a cat' }], - parameters: { sampleCount: 1 }, - }; - const model = 'imagen-3.0-generate-002'; - - it('should return ImagenPredictResponse', async () => { - const mockResponse: ImagenPredictResponse = { - predictions: [{ bytesBase64Encoded: 'abc', mimeType: 'image/png' }], - }; - mockFetchResponse(mockResponse); - - const result = await imagenPredict(model, request, clientOptions); - assert.deepStrictEqual(result, mockResponse); - - const expectedUrl = - 'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/imagen-3.0-generate-002:predict'; - sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { - method: 'POST', - headers: { - Authorization: 'Bearer test-token', - 'x-goog-user-project': 'test-project', - 'Content-Type': 'application/json', - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }, - body: JSON.stringify(request), - }); - }); - - it('should throw on API error non-JSON', async () => { - mockFetchResponse('Bad Request', false, 400, 'Bad Request', 'text/plain'); - await assert.rejects( - imagenPredict(model, request, clientOptions), - /Failed to fetch from .* \[400 Bad Request\] Bad Request/ - ); - }); - }); - - describe('generateContentStream', () => { + describe('generateContentStream full aggregation tests', () => { function createMockStream(chunks: string[]): Response { const stream = new ReadableStream({ start(controller) { @@ -396,7 +541,7 @@ describe('Vertex AI Client', () => { const result: GenerateContentStreamResult = await generateContentStream( 'gemini-2.5-flash', request, - clientOptions + regionalClientOptions ); const streamResults: GenerateContentResponse[] = []; @@ -440,19 +585,6 @@ describe('Vertex AI Client', () => { ], usageMetadata: { totalTokenCount: 10 }, }); - - const expectedUrl = - 'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.5-flash:streamGenerateContent?alt=sse'; - sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { - method: 'POST', - headers: { - Authorization: 'Bearer test-token', - 'x-goog-user-project': 'test-project', - 'Content-Type': 'application/json', - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }, - body: JSON.stringify(request), - }); }); it('should handle stream with malformed JSON', async () => { @@ -468,13 +600,12 @@ describe('Vertex AI Client', () => { const result: GenerateContentStreamResult = await generateContentStream( 'gemini-2.5-flash', request, - clientOptions + regionalClientOptions ); - const streamResults: GenerateContentResponse[] = []; try { for await (const item of result.stream) { - streamResults.push(item); + // Consume stream } assert.fail('Stream should have thrown an error'); } catch (e: any) { @@ -495,29 +626,19 @@ describe('Vertex AI Client', () => { } }); - it('should handle stream error in fetch', async () => { - const request: GenerateContentRequest = { contents: [] }; - fetchSpy.rejects(new Error('Network failure')); - - await assert.rejects( - generateContentStream('gemini-2.5-flash', request, clientOptions), - /Failed to fetch from .* Network failure/ - ); - }); - it('should aggregate parts for multiple candidates', async () => { const request: GenerateContentRequest = { contents: [{ role: 'user', parts: [{ text: 'stream' }] }], }; const chunks = [ - 'data: {"candidates": [{"index": 0, "content": {"role": "model", "parts": [{"text": "C0 A"}]}}, {"index": 1, "content": {"role": "model", "parts": [{"text": "C1 A"}]}}]}\n\n', - 'data: {"candidates": [{"index": 0, "content": {"role": "model", "parts": [{"text": " C0 B"}]}}, {"index": 1, "content": {"role": "model", "parts": [{"text": " C1 B"}]}}]}\n\n', + 'data: {"candidates": [{"index": 0, "content": { "role": "model", "parts": [{"text": "C0 A"}]}}, {"index": 1, "content": {"role": "model", "parts": [{"text": "C1 A"}]}}]}\n\n', + 'data: {"candidates": [{"index": 0, "content": { "role": "model", "parts": [{"text": " C0 B"}]}}, {"index": 1, "content": {"role": "model", "parts": [{"text": " C1 B"}]}}]}\n\n', ]; fetchSpy.resolves(createMockStream(chunks)); const result = await generateContentStream( 'gemini-2.5-flash', request, - clientOptions + regionalClientOptions ); const aggregated = await result.response; @@ -526,14 +647,12 @@ describe('Vertex AI Client', () => { (a, b) => a.index - b.index ); - assert.deepStrictEqual(sortedCandidates[0], { - index: 0, - content: { role: 'model', parts: [{ text: 'C0 A C0 B' }] }, - }); - assert.deepStrictEqual(sortedCandidates[1], { - index: 1, - content: { role: 'model', parts: [{ text: 'C1 A C1 B' }] }, - }); + assert.deepStrictEqual(sortedCandidates[0].content.parts, [ + { text: 'C0 A C0 B' }, + ]); + assert.deepStrictEqual(sortedCandidates[1].content.parts, [ + { text: 'C1 A C1 B' }, + ]); }); it('should aggregate functionCall parts, keeping text first', async () => { @@ -549,7 +668,7 @@ describe('Vertex AI Client', () => { const result = await generateContentStream( 'gemini-2.5-flash', request, - clientOptions + regionalClientOptions ); const aggregated = await result.response; assert.deepStrictEqual(aggregated.candidates![0].content.parts, [ @@ -569,7 +688,7 @@ describe('Vertex AI Client', () => { const result = await generateContentStream( 'gemini-2.5-flash', request, - clientOptions + regionalClientOptions ); const aggregated = await result.response; assert.deepStrictEqual(aggregated.candidates![0].content.parts, [ @@ -589,7 +708,7 @@ describe('Vertex AI Client', () => { const result = await generateContentStream( 'gemini-2.5-flash', request, - clientOptions + regionalClientOptions ); const aggregated = await result.response; assert.deepStrictEqual(aggregated.candidates![0].citationMetadata, { @@ -610,7 +729,7 @@ describe('Vertex AI Client', () => { const result = await generateContentStream( 'gemini-2.5-flash', request, - clientOptions + regionalClientOptions ); const aggregated = await result.response; assert.deepStrictEqual(aggregated.candidates![0].groundingMetadata, { @@ -621,6 +740,7 @@ describe('Vertex AI Client', () => { searchEntryPoint: { renderedContent: 'test' }, }); }); + it('should take last finishReason, finishMessage, and safetyRatings', async () => { const request: GenerateContentRequest = { contents: [{ role: 'user', parts: [{ text: 'stream' }] }], @@ -633,7 +753,7 @@ describe('Vertex AI Client', () => { const result = await generateContentStream( 'gemini-2.5-flash', request, - clientOptions + regionalClientOptions ); const aggregated = await result.response; assert.strictEqual(aggregated.candidates![0].finishReason, 'MAX_TOKENS'); @@ -642,25 +762,5 @@ describe('Vertex AI Client', () => { { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', probability: 'LOW' }, ]); }); - - it('handles candidates appearing in later chunks', async () => { - const request: GenerateContentRequest = { - contents: [{ role: 'user', parts: [{ text: 'stream' }] }], - }; - const chunks = [ - 'data: {"candidates": []}\n\n', - 'data: {"candidates": [{"index": 0, "content": { "role": "model", "parts": [{"text": "A"}]}}]}\n\n', - ]; - fetchSpy.resolves(createMockStream(chunks)); - const result = await generateContentStream( - 'gemini-2.5-flash', - request, - clientOptions - ); - const aggregated = await result.response; - assert.deepStrictEqual(aggregated.candidates![0].content.parts, [ - { text: 'A' }, - ]); - }); }); }); diff --git a/js/plugins/google-genai/tests/vertexai/converters_test.ts b/js/plugins/google-genai/tests/vertexai/converters_test.ts index f20bf9053b..3f9b9816e4 100644 --- a/js/plugins/google-genai/tests/vertexai/converters_test.ts +++ b/js/plugins/google-genai/tests/vertexai/converters_test.ts @@ -18,7 +18,10 @@ import * as assert from 'assert'; import { z } from 'genkit'; import { describe, it } from 'node:test'; import { HarmBlockThreshold, HarmCategory } from '../../src/common/types'; -import { toGeminiSafetySettings } from '../../src/vertexai/converters'; +import { + toGeminiLabels, + toGeminiSafetySettings, +} from '../../src/vertexai/converters'; import { SafetySettingsSchema } from '../../src/vertexai/gemini'; describe('Vertex AI Converters', () => { @@ -60,4 +63,64 @@ describe('Vertex AI Converters', () => { assert.deepStrictEqual(result, expected); }); }); + + describe('toGeminiLabels', () => { + it('returns undefined for undefined input', () => { + const result = toGeminiLabels(undefined); + assert.strictEqual(result, undefined); + }); + + it('returns undefined for an empty object input', () => { + const result = toGeminiLabels({}); + assert.strictEqual(result, undefined); + }); + + it('converts an object with valid labels', () => { + const labels = { + env: 'production', + 'my-label': 'my-value', + }; + const result = toGeminiLabels(labels); + assert.deepStrictEqual(result, labels); + }); + + it('filters out empty string keys', () => { + const labels = { + env: 'dev', + '': 'should-be-ignored', + 'valid-key': 'valid-value', + }; + const expected = { + env: 'dev', + 'valid-key': 'valid-value', + }; + const result = toGeminiLabels(labels); + assert.deepStrictEqual(result, expected); + }); + + it('returns undefined if all keys are empty strings', () => { + const labels = { + '': 'value1', + ' ': 'value2', // This key is not empty string, so it will be kept + }; + const expected = { + ' ': 'value2', + }; + const result = toGeminiLabels(labels); + assert.deepStrictEqual(result, expected); + }); + + it('handles labels with empty values', () => { + const labels = { + key1: '', + key2: 'value2', + }; + const expected = { + key1: '', + key2: 'value2', + }; + const result = toGeminiLabels(labels); + assert.deepStrictEqual(result, expected); + }); + }); }); diff --git a/js/plugins/google-genai/tests/vertexai/embedder_test.ts b/js/plugins/google-genai/tests/vertexai/embedder_test.ts index c6b4a45941..b17e78e000 100644 --- a/js/plugins/google-genai/tests/vertexai/embedder_test.ts +++ b/js/plugins/google-genai/tests/vertexai/embedder_test.ts @@ -15,11 +15,12 @@ */ import * as assert from 'assert'; -import { Document, Genkit } from 'genkit'; +import { Document, Genkit, GENKIT_CLIENT_HEADER } from 'genkit'; import { GoogleAuth } from 'google-auth-library'; import { afterEach, beforeEach, describe, it } from 'node:test'; import * as sinon from 'sinon'; -import { EmbeddingConfig, defineEmbedder } from '../../src/vertexai/embedder'; +import { getVertexAIUrl } from '../../src/vertexai/client'; +import { defineEmbedder, EmbeddingConfig } from '../../src/vertexai/embedder'; import { ClientOptions, EmbedContentResponse, @@ -31,11 +32,26 @@ describe('defineEmbedder', () => { let fetchStub: sinon.SinonStub; let authMock: sinon.SinonStubbedInstance; - const clientOptions: ClientOptions = { + const regionalClientOptions: ClientOptions = { + kind: 'regional', projectId: 'test-project', location: 'us-central1', - authClient: {} as GoogleAuth, // Will be replaced + authClient: {} as GoogleAuth, }; + + const globalClientOptions: ClientOptions = { + kind: 'global', + projectId: 'test-project', + location: 'global', + authClient: {} as GoogleAuth, + apiKey: 'test-global-api-key', + }; + + const expressClientOptions: ClientOptions = { + kind: 'express', + apiKey: 'test-express-api-key', + }; + let embedderFunc: ( input: Document[], options?: EmbeddingConfig @@ -47,7 +63,8 @@ describe('defineEmbedder', () => { authMock = sinon.createStubInstance(GoogleAuth); authMock.getAccessToken.resolves('test-token'); - clientOptions.authClient = authMock as unknown as GoogleAuth; + regionalClientOptions.authClient = authMock as unknown as GoogleAuth; + globalClientOptions.authClient = authMock as unknown as GoogleAuth; mockGenkit.defineEmbedder.callsFake((config, func) => { embedderFunc = func; @@ -70,127 +87,184 @@ describe('defineEmbedder', () => { fetchStub.resolves(response); } - it('defines an embedder with the correct name and info for known model', () => { - defineEmbedder(mockGenkit, 'text-embedding-005', clientOptions); - sinon.assert.calledOnce(mockGenkit.defineEmbedder); - const args = mockGenkit.defineEmbedder.lastCall.args[0]; - assert.strictEqual(args.name, 'vertexai/text-embedding-005'); - assert.strictEqual(args.info?.dimensions, 768); - }); - - it('defines an embedder with a custom name', () => { - defineEmbedder(mockGenkit, 'custom-model', clientOptions); - sinon.assert.calledOnce(mockGenkit.defineEmbedder); - const args = mockGenkit.defineEmbedder.lastCall.args[0]; - assert.strictEqual(args.name, 'vertexai/custom-model'); - }); - - describe('Embedder Functionality', () => { - const testDoc1: Document = new Document({ content: [{ text: 'Hello' }] }); - const testDoc2: Document = new Document({ content: [{ text: 'World' }] }); - - it('calls embedContent with text-only documents', async () => { - defineEmbedder(mockGenkit, 'text-embedding-005', clientOptions); - - const mockResponse: EmbedContentResponse = { - predictions: [ - { - embeddings: { - values: [0.1, 0.2], - statistics: { token_count: 1, truncated: false }, - }, - }, - { - embeddings: { - values: [0.3, 0.4], - statistics: { token_count: 1, truncated: false }, - }, - }, - ], - }; - mockFetchResponse(mockResponse); - - const result = await embedderFunc([testDoc1, testDoc2]); - - sinon.assert.calledOnce(fetchStub); - const fetchArgs = fetchStub.lastCall.args; - const expectedUrl = - 'https://us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/text-embedding-005:predict'; - assert.strictEqual(fetchArgs[0], expectedUrl); - - // Corrected expectedRequest: Keys with undefined values are omitted by JSON.stringify - const expectedRequest = { - instances: [{ content: 'Hello' }, { content: 'World' }], - parameters: {}, // outputDimensionality is undefined, so key is omitted - }; - assert.deepStrictEqual(JSON.parse(fetchArgs[1].body), expectedRequest); + function getExpectedHeaders( + clientOptions: ClientOptions + ): Record { + const headers: Record = { + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + 'User-Agent': GENKIT_CLIENT_HEADER, + }; + if (clientOptions.kind !== 'express') { + headers['Authorization'] = 'Bearer test-token'; + headers['x-goog-user-project'] = clientOptions.projectId; + } + return headers; + } - assert.deepStrictEqual(result, { - embeddings: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }], + function runTestsForClientOptions(clientOptions: ClientOptions) { + describe(`with ${clientOptions.kind} client options`, () => { + it('defines an embedder with the correct name and info for known model', () => { + defineEmbedder(mockGenkit, 'text-embedding-005', clientOptions); + sinon.assert.calledOnce(mockGenkit.defineEmbedder); + const args = mockGenkit.defineEmbedder.lastCall.args[0]; + assert.strictEqual(args.name, 'vertexai/text-embedding-005'); + assert.strictEqual(args.info?.dimensions, 768); }); - }); - - it('calls embedContent with taskType and title options', async () => { - defineEmbedder(mockGenkit, 'text-embedding-005', clientOptions); - mockFetchResponse({ predictions: [] }); - - const config: EmbeddingConfig = { - taskType: 'RETRIEVAL_DOCUMENT', - title: 'Doc Title', - }; - await embedderFunc([testDoc1], config); - - sinon.assert.calledOnce(fetchStub); - const fetchOptions = fetchStub.lastCall.args[1]; - const body = JSON.parse(fetchOptions.body); - assert.strictEqual(body.instances[0].task_type, 'RETRIEVAL_DOCUMENT'); - assert.strictEqual(body.instances[0].title, 'Doc Title'); - }); - it('handles multimodal embeddings for images (base64)', async () => { - defineEmbedder(mockGenkit, 'multimodalembedding@001', clientOptions); - const docWithImage: Document = new Document({ - content: [ - { text: 'A picture' }, - { media: { url: 'base64data', contentType: 'image/png' } }, - ], + it('defines an embedder with a custom name', () => { + defineEmbedder(mockGenkit, 'custom-model', clientOptions); + sinon.assert.calledOnce(mockGenkit.defineEmbedder); + const args = mockGenkit.defineEmbedder.lastCall.args[0]; + assert.strictEqual(args.name, 'vertexai/custom-model'); }); - const mockResponse: EmbedContentResponse = { - predictions: [{ textEmbedding: [0.1], imageEmbedding: [0.2] }], - }; - mockFetchResponse(mockResponse); - - const result = await embedderFunc([docWithImage]); - - const expectedInstance: EmbeddingInstance = { - text: 'A picture', - image: { bytesBase64Encoded: 'base64data', mimeType: 'image/png' }, - }; - const fetchBody = JSON.parse(fetchStub.lastCall.args[1].body); - assert.deepStrictEqual(fetchBody.instances[0], expectedInstance); - assert.deepStrictEqual(result.embeddings.length, 2); - }); - - it('handles multimodal embeddings for images (gcs)', async () => { - defineEmbedder(mockGenkit, 'multimodalembedding@001', clientOptions); - const docWithImage: Document = new Document({ - content: [ - { - media: { url: 'gs://bucket/image.jpg', contentType: 'image/jpeg' }, - }, - ], + describe('Embedder Functionality', () => { + const testDoc1: Document = new Document({ + content: [{ text: 'Hello' }], + }); + const testDoc2: Document = new Document({ + content: [{ text: 'World' }], + }); + + it('calls embedContent with text-only documents', async () => { + defineEmbedder(mockGenkit, 'text-embedding-005', clientOptions); + + const mockResponse: EmbedContentResponse = { + predictions: [ + { + embeddings: { + values: [0.1, 0.2], + statistics: { token_count: 1, truncated: false }, + }, + }, + { + embeddings: { + values: [0.3, 0.4], + statistics: { token_count: 1, truncated: false }, + }, + }, + ], + }; + mockFetchResponse(mockResponse); + + const result = await embedderFunc([testDoc1, testDoc2]); + + sinon.assert.calledOnce(fetchStub); + const fetchArgs = fetchStub.lastCall.args; + const expectedUrl = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: 'publishers/google/models/text-embedding-005', + resourceMethod: 'predict', + clientOptions, + }); + assert.strictEqual(fetchArgs[0], expectedUrl); + + const expectedRequest = { + instances: [{ content: 'Hello' }, { content: 'World' }], + parameters: {}, // Undefined properties are omitted + }; + assert.deepStrictEqual( + JSON.parse(fetchArgs[1].body), + expectedRequest + ); + assert.deepStrictEqual( + fetchArgs[1].headers, + getExpectedHeaders(clientOptions) + ); + + assert.deepStrictEqual(result, { + embeddings: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }], + }); + }); + + it('calls embedContent with taskType and title options', async () => { + defineEmbedder(mockGenkit, 'text-embedding-005', clientOptions); + mockFetchResponse({ predictions: [] }); + + const config: EmbeddingConfig = { + taskType: 'RETRIEVAL_DOCUMENT', + title: 'Doc Title', + }; + await embedderFunc([testDoc1], config); + + sinon.assert.calledOnce(fetchStub); + const fetchOptions = fetchStub.lastCall.args[1]; + const body = JSON.parse(fetchOptions.body); + assert.strictEqual(body.instances[0].task_type, 'RETRIEVAL_DOCUMENT'); + assert.strictEqual(body.instances[0].title, 'Doc Title'); + }); + + it('handles multimodal embeddings for images (base64)', async () => { + defineEmbedder(mockGenkit, 'multimodalembedding@001', clientOptions); + const docWithImage: Document = new Document({ + content: [ + { text: 'A picture' }, + { media: { url: 'base64data', contentType: 'image/png' } }, + ], + }); + + const mockResponse: EmbedContentResponse = { + predictions: [{ textEmbedding: [0.1], imageEmbedding: [0.2] }], + }; + mockFetchResponse(mockResponse); + + const result = await embedderFunc([docWithImage]); + + const expectedInstance: EmbeddingInstance = { + text: 'A picture', + image: { bytesBase64Encoded: 'base64data', mimeType: 'image/png' }, + }; + const fetchBody = JSON.parse(fetchStub.lastCall.args[1].body); + assert.deepStrictEqual(fetchBody.instances[0], expectedInstance); + assert.deepStrictEqual(result.embeddings.length, 2); + }); + + it('handles multimodal embeddings for images (gcs)', async () => { + defineEmbedder(mockGenkit, 'multimodalembedding@001', clientOptions); + const docWithImage: Document = new Document({ + content: [ + { + media: { + url: 'gs://bucket/image.jpg', + contentType: 'image/jpeg', + }, + }, + ], + }); + mockFetchResponse({ predictions: [] }); + await embedderFunc([docWithImage]); + + const expectedInstance: EmbeddingInstance = { + image: { gcsUri: 'gs://bucket/image.jpg', mimeType: 'image/jpeg' }, + }; + const fetchBody = JSON.parse(fetchStub.lastCall.args[1].body); + assert.deepStrictEqual(fetchBody.instances[0], expectedInstance); + }); + + it('passes outputDimensionality to the API call', async () => { + defineEmbedder(mockGenkit, 'text-embedding-005', clientOptions); + mockFetchResponse({ predictions: [] }); + + const config: EmbeddingConfig = { outputDimensionality: 256 }; + await embedderFunc([testDoc1], config); + + sinon.assert.calledOnce(fetchStub); + const fetchOptions = fetchStub.lastCall.args[1]; + const body = JSON.parse(fetchOptions.body); + assert.strictEqual(body.parameters.outputDimensionality, 256); + }); }); - mockFetchResponse({ predictions: [] }); - await embedderFunc([docWithImage]); - - const expectedInstance: EmbeddingInstance = { - image: { gcsUri: 'gs://bucket/image.jpg', mimeType: 'image/jpeg' }, - }; - const fetchBody = JSON.parse(fetchStub.lastCall.args[1].body); - assert.deepStrictEqual(fetchBody.instances[0], expectedInstance); }); + } + + runTestsForClientOptions(regionalClientOptions); + runTestsForClientOptions(globalClientOptions); + runTestsForClientOptions(expressClientOptions); + // Tests specific to regional (or not applicable to express) + describe('with regional client options only', () => { + const clientOptions = regionalClientOptions; it('handles multimodal embeddings for video', async () => { defineEmbedder(mockGenkit, 'multimodalembedding@001', clientOptions); const docWithVideo: Document = new Document({ @@ -259,19 +333,6 @@ describe('defineEmbedder', () => { }); }); - it('passes outputDimensionality to the API call', async () => { - defineEmbedder(mockGenkit, 'text-embedding-005', clientOptions); - mockFetchResponse({ predictions: [] }); - - const config: EmbeddingConfig = { outputDimensionality: 256 }; - await embedderFunc([testDoc1], config); - - sinon.assert.calledOnce(fetchStub); - const fetchOptions = fetchStub.lastCall.args[1]; - const body = JSON.parse(fetchOptions.body); - assert.strictEqual(body.parameters.outputDimensionality, 256); - }); - it('throws on unsupported media type', async () => { defineEmbedder(mockGenkit, 'multimodalembedding@001', clientOptions); const docWithInvalidMedia: Document = new Document({ diff --git a/js/plugins/google-genai/tests/vertexai/gemini_test.ts b/js/plugins/google-genai/tests/vertexai/gemini_test.ts index 86b542f7b7..9ce04463ae 100644 --- a/js/plugins/google-genai/tests/vertexai/gemini_test.ts +++ b/js/plugins/google-genai/tests/vertexai/gemini_test.ts @@ -15,15 +15,16 @@ */ import * as assert from 'assert'; -import { Genkit, z } from 'genkit'; +import { Genkit, GENKIT_CLIENT_HEADER, z } from 'genkit'; import { GenerateRequest, ModelReference } from 'genkit/model'; +import { GoogleAuth } from 'google-auth-library'; import { AsyncLocalStorage } from 'node:async_hooks'; import { afterEach, beforeEach, describe, it } from 'node:test'; import * as sinon from 'sinon'; import { FinishReason } from '../../src/common/types'; import { - GeminiConfigSchema, defineModel, + GeminiConfigSchema, model, } from '../../src/vertexai/gemini'; import { @@ -32,6 +33,9 @@ import { GenerateContentResponse, HarmBlockThreshold, HarmCategory, + isFunctionDeclarationsTool, + isGoogleSearchRetrievalTool, + isRetrievalTool, } from '../../src/vertexai/types'; describe('Vertex AI Gemini', () => { @@ -43,37 +47,45 @@ describe('Vertex AI Gemini', () => { let fetchStub: sinon.SinonStub; let mockAsyncStore: sinon.SinonStubbedInstance>; + let authMock: sinon.SinonStubbedInstance; - const defaultClientOptions: ClientOptions = { + const defaultRegionalClientOptions: ClientOptions = { + kind: 'regional', projectId: 'test-project', location: 'us-central1', - authClient: { - getAccessToken: async () => 'test-token', - } as any, // Mock auth client + authClient: {} as any, + }; + + const defaultGlobalClientOptions: ClientOptions = { + kind: 'global', + projectId: 'test-project', + location: 'global', + authClient: {} as any, + apiKey: 'test-api-key', + }; + + const defaultExpressClientOptions: ClientOptions = { + kind: 'express', + apiKey: 'test-express-api-key', }; beforeEach(() => { mockGenkit = sinon.createStubInstance(Genkit); mockAsyncStore = sinon.createStubInstance(AsyncLocalStorage); + authMock = sinon.createStubInstance(GoogleAuth); - // Setup mock registry and asyncStore - mockAsyncStore.getStore.returns(undefined); // Simulate no parent span + authMock.getAccessToken.resolves('test-token'); + defaultRegionalClientOptions.authClient = authMock as unknown as GoogleAuth; + defaultGlobalClientOptions.authClient = authMock as unknown as GoogleAuth; - mockAsyncStore.run.callsFake((arg1, arg2, callback) => { - if (typeof callback === 'function') { - return callback(); - } - // Fallback or error if the structure isn't as expected - throw new Error( - 'AsyncLocalStorage.run mock expected a function as the third argument' - ); - }); + mockAsyncStore.getStore.returns(undefined); + mockAsyncStore.run.callsFake((_, callback) => callback()); (mockGenkit as any).registry = { lookupAction: () => undefined, lookupFlow: () => undefined, generateTraceId: () => 'test-trace-id', - asyncStore: mockAsyncStore, // Provide the mock asyncStore + asyncStore: mockAsyncStore, }; fetchStub = sinon.stub(global, 'fetch'); @@ -88,7 +100,6 @@ describe('Vertex AI Gemini', () => { sinon.restore(); }); - // Mock fetch for non-streaming responses function mockFetchResponse(body: any, status = 200) { const response = new Response(JSON.stringify(body), { status: status, @@ -98,7 +109,6 @@ describe('Vertex AI Gemini', () => { fetchStub.resolves(Promise.resolve(response)); } - // Mock fetch for streaming responses (SSE) function mockFetchStreamResponse(responses: GenerateContentResponse[]) { const encoder = new TextEncoder(); const stream = new ReadableStream({ @@ -121,7 +131,7 @@ describe('Vertex AI Gemini', () => { const minimalRequest: GenerateRequest = { messages: [{ role: 'user', content: [{ text: 'Hello' }] }], - config: {}, // Add empty config + config: {}, }; const mockCandidate = { @@ -147,7 +157,7 @@ describe('Vertex AI Gemini', () => { const name = 'gemini-new-model'; const modelRef: ModelReference = model(name); assert.strictEqual(modelRef.name, `vertexai/${name}`); - assert.ok(modelRef.info?.supports?.multiturn); // Defaults to generic + assert.ok(modelRef.info?.supports?.multiturn); assert.strictEqual(modelRef.configSchema, GeminiConfigSchema); }); @@ -161,25 +171,67 @@ describe('Vertex AI Gemini', () => { }); }); - describe('defineGeminiModel', () => { - it('defines a model with the correct name', () => { - defineModel(mockGenkit, 'gemini-2.0-flash', defaultClientOptions); - sinon.assert.calledOnce(mockGenkit.defineModel); - const args = mockGenkit.defineModel.lastCall.args[0]; - assert.strictEqual(args.name, 'vertexai/gemini-2.0-flash'); - }); - - it('defines a model with a custom name', () => { - defineModel(mockGenkit, 'my-custom-gemini', defaultClientOptions); - const args = mockGenkit.defineModel.lastCall.args[0]; - assert.strictEqual(args.name, 'vertexai/my-custom-gemini'); - }); - - describe('Model Action Callback', () => { + function runCommonTests(clientOptions: ClientOptions) { + describe(`Model Action Callback ${clientOptions.kind}`, () => { beforeEach(() => { - defineModel(mockGenkit, 'gemini-2.5-flash', defaultClientOptions); + defineModel(mockGenkit, 'gemini-2.5-flash', clientOptions); }); + function getExpectedHeaders(): Record { + const headers: Record = { + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + 'User-Agent': GENKIT_CLIENT_HEADER, + }; + if (clientOptions.kind !== 'express') { + headers['Authorization'] = 'Bearer test-token'; + headers['x-goog-user-project'] = clientOptions.projectId; + } + return headers; + } + + function getExpectedUrl( + modelName: string, + method: string, + queryParams: string[] = [], + configApiKey?: string + ): string { + let baseUrl: string; + let projectAndLocation = ''; + if (clientOptions.kind != 'express') { + projectAndLocation = `projects/${clientOptions.projectId}/locations/${clientOptions.location}`; + } + + if (clientOptions.kind === 'regional') { + baseUrl = `https://${clientOptions.location}-aiplatform.googleapis.com/v1beta1/${projectAndLocation}`; + } else if (clientOptions.kind === 'global') { + baseUrl = `https://aiplatform.googleapis.com/v1beta1/${projectAndLocation}`; + } else { + // express + baseUrl = `https://aiplatform.googleapis.com/v1beta1`; + } + + let url = `${baseUrl}/publishers/google/models/${modelName}:${method}`; + const params = [...queryParams]; + let effectiveApiKey = configApiKey; + if (!effectiveApiKey) { + if (clientOptions.kind === 'express') { + effectiveApiKey = clientOptions.apiKey as string; + } else { + effectiveApiKey = clientOptions.apiKey; + } + } + + if (effectiveApiKey) { + params.push(`key=${effectiveApiKey}`); + } + + if (params.length > 0) { + url += '?' + params.join('&'); + } + return url; + } + it('throws if no messages are provided', async () => { await assert.rejects( modelActionCallback({ messages: [], config: {} }), @@ -196,16 +248,19 @@ describe('Vertex AI Gemini', () => { const url = fetchArgs[0]; const options = fetchArgs[1]; - assert.ok( - url.includes( - 'us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.5-flash:generateContent' - ) + const expectedUrl = getExpectedUrl( + 'gemini-2.5-flash', + 'generateContent' ); + assert.strictEqual(url, expectedUrl); assert.strictEqual(options.method, 'POST'); const body = JSON.parse(options.body); assert.deepStrictEqual(body.contents, [ { role: 'user', parts: [{ text: 'Hello' }] }, ]); + assert.strictEqual(body.labels, undefined); + + assert.deepStrictEqual(options.headers, getExpectedHeaders()); assert.strictEqual(result.candidates.length, 1); assert.strictEqual( @@ -223,12 +278,12 @@ describe('Vertex AI Gemini', () => { sinon.assert.calledOnce(fetchStub); const fetchArgs = fetchStub.lastCall.args; const url = fetchArgs[0]; - assert.ok( - url.includes( - 'us-central1-aiplatform.googleapis.com/v1beta1/projects/test-project/locations/us-central1/publishers/google/models/gemini-2.5-flash:streamGenerateContent' - ) + const expectedUrl = getExpectedUrl( + 'gemini-2.5-flash', + 'streamGenerateContent', + ['alt=sse'] ); - assert.ok(url.includes('alt=sse')); + assert.strictEqual(url, expectedUrl); await new Promise((resolve) => setTimeout(resolve, 10)); @@ -279,6 +334,21 @@ describe('Vertex AI Gemini', () => { assert.strictEqual(apiRequest.generationConfig?.maxOutputTokens, 100); }); + it('sends labels when provided in config', async () => { + mockFetchResponse(defaultApiResponse); + const myLabels = { env: 'test', version: '1' }; + const request: GenerateRequest = { + ...minimalRequest, + config: { labels: myLabels }, + }; + await modelActionCallback(request); + + const apiRequest: GenerateContentRequest = JSON.parse( + fetchStub.lastCall.args[1].body + ); + assert.deepStrictEqual(apiRequest.labels, myLabels); + }); + it('constructs tools array with functionDeclarations', async () => { mockFetchResponse(defaultApiResponse); const request: GenerateRequest = { @@ -300,38 +370,19 @@ describe('Vertex AI Gemini', () => { ); assert.ok(Array.isArray(apiRequest.tools)); assert.strictEqual(apiRequest.tools?.length, 1); - assert.ok(apiRequest.tools?.[0].functionDeclarations); - assert.strictEqual( - apiRequest.tools?.[0].functionDeclarations?.length, - 1 - ); - assert.strictEqual( - apiRequest.tools?.[0].functionDeclarations?.[0].name, - 'myFunc' - ); - }); - - it('handles googleSearchRetrieval tool for gemini-1.5', async () => { - defineModel(mockGenkit, 'gemini-1.5-pro', defaultClientOptions); - mockFetchResponse(defaultApiResponse); - const request: GenerateRequest = { - ...minimalRequest, - config: { - googleSearchRetrieval: {}, - }, - }; - await modelActionCallback(request); - const apiRequest: GenerateContentRequest = JSON.parse( - fetchStub.lastCall.args[1].body - ); - const searchTool = apiRequest.tools?.find( - (t) => t.googleSearchRetrieval + const tool = apiRequest.tools![0]; + assert.ok( + isFunctionDeclarationsTool(tool), + 'Expected FunctionDeclarationsTool' ); - assert.ok(searchTool, 'Expected googleSearchRetrieval tool'); - assert.deepStrictEqual(searchTool, { googleSearchRetrieval: {} }); + if (isFunctionDeclarationsTool(tool)) { + assert.ok(tool.functionDeclarations); + assert.strictEqual(tool.functionDeclarations?.length, 1); + assert.strictEqual(tool.functionDeclarations?.[0].name, 'myFunc'); + } }); - it('handles googleSearchRetrieval tool for other models (as googleSearch)', async () => { + it('handles googleSearchRetrieval tool (as googleSearch)', async () => { mockFetchResponse(defaultApiResponse); const request: GenerateRequest = { ...minimalRequest, @@ -343,38 +394,46 @@ describe('Vertex AI Gemini', () => { const apiRequest: GenerateContentRequest = JSON.parse( fetchStub.lastCall.args[1].body ); - const searchTool = apiRequest.tools?.find((t) => t.googleSearch); - assert.ok(searchTool, 'Expected googleSearch tool'); - assert.deepStrictEqual(searchTool, { googleSearch: {} }); + const searchTool = apiRequest.tools?.find(isGoogleSearchRetrievalTool); + assert.ok(searchTool, 'Expected GoogleSearchRetrievalTool'); + if (searchTool) { + assert.ok(searchTool.googleSearch, 'Expected googleSearch property'); + assert.deepStrictEqual(searchTool, { googleSearch: {} }); + } }); - it('handles vertexRetrieval tool', async () => { - mockFetchResponse(defaultApiResponse); - const request: GenerateRequest = { - ...minimalRequest, - config: { - vertexRetrieval: { - datastore: { dataStoreId: 'my-store' }, - disableAttribution: true, - }, - }, - }; - await modelActionCallback(request); - const apiRequest: GenerateContentRequest = JSON.parse( - fetchStub.lastCall.args[1].body - ); - const retrievalTool = apiRequest.tools?.find((t) => t.retrieval); - assert.ok(retrievalTool, 'Expected vertexRetrieval tool'); - assert.deepStrictEqual(retrievalTool, { - retrieval: { - vertexAiSearch: { - datastore: - 'projects/test-project/locations/us-central1/collections/default_collection/dataStores/my-store', + if (clientOptions.kind === 'regional') { + it('handles vertexRetrieval tool', async () => { + mockFetchResponse(defaultApiResponse); + const request: GenerateRequest = { + ...minimalRequest, + config: { + vertexRetrieval: { + datastore: { dataStoreId: 'my-store' }, + disableAttribution: true, + }, }, - disableAttribution: true, - }, + }; + await modelActionCallback(request); + const apiRequest: GenerateContentRequest = JSON.parse( + fetchStub.lastCall.args[1].body + ); + const retrievalTool = apiRequest.tools?.find(isRetrievalTool); + assert.ok(retrievalTool, 'Expected RetrievalTool'); + if (retrievalTool) { + assert.ok(retrievalTool.retrieval, 'Expected retrieval property'); + assert.deepStrictEqual(retrievalTool, { + retrieval: { + vertexAiSearch: { + datastore: + 'projects/test-project/locations/us-central1/collections/default_collection/dataStores/my-store', + }, + disableAttribution: true, + }, + }); + } }); - }); + } it('applies safetySettings', async () => { mockFetchResponse(defaultApiResponse); @@ -427,36 +486,91 @@ describe('Vertex AI Gemini', () => { }); it('handles API call error', async () => { - fetchStub.rejects(new Error('API Error')); + mockFetchResponse({ error: { message: 'API Error' } }, 400); await assert.rejects( modelActionCallback(minimalRequest), - /Failed to fetch from https:\/\/us-central1-aiplatform.googleapis.com\/v1beta1\/projects\/test-project\/locations\/us-central1\/publishers\/google\/models\/gemini-2.5-flash:generateContent: API Error/ + /Error fetching from .*?: \[400 Error\] API Error/ ); }); - }); - describe('Debug Traces', () => { - it('API call works with debugTraces: true', async () => { - defineModel(mockGenkit, 'gemini-2.5-flash', defaultClientOptions, { - location: 'us-central1', - experimental_debugTraces: true, - }); + it('handles config.apiKey override in URL', async () => { mockFetchResponse(defaultApiResponse); - - await assert.doesNotReject(modelActionCallback(minimalRequest)); + const overrideKey = 'override-api-key'; + const request: GenerateRequest = { + ...minimalRequest, + config: { apiKey: overrideKey }, + }; + await modelActionCallback(request); sinon.assert.calledOnce(fetchStub); + const fetchArgs = fetchStub.lastCall.args; + const url = fetchArgs[0]; + + const expectedUrl = getExpectedUrl( + 'gemini-2.5-flash', + 'generateContent', + [], + overrideKey + ); + assert.strictEqual(url, expectedUrl); + assert.deepStrictEqual(fetchArgs[1].headers, getExpectedHeaders()); }); + }); + } - it('API call works without extra logging with debugTraces: false', async () => { - defineModel(mockGenkit, 'gemini-2.0-flash', defaultClientOptions, { - location: 'us-central1', - experimental_debugTraces: false, - }); - mockFetchResponse(defaultApiResponse); + describe('defineModel - Regional Client', () => { + it('defines a model with the correct name', () => { + defineModel(mockGenkit, 'gemini-2.0-flash', defaultRegionalClientOptions); + sinon.assert.calledOnce(mockGenkit.defineModel); + const args = mockGenkit.defineModel.lastCall.args[0]; + assert.strictEqual(args.name, 'vertexai/gemini-2.0-flash'); + }); - await assert.doesNotReject(modelActionCallback(minimalRequest)); - sinon.assert.calledOnce(fetchStub); - }); + runCommonTests(defaultRegionalClientOptions); + + it('handles googleSearchRetrieval tool for gemini-1.5', async () => { + defineModel(mockGenkit, 'gemini-1.5-pro', defaultRegionalClientOptions); + mockFetchResponse(defaultApiResponse); + const request: GenerateRequest = { + ...minimalRequest, + config: { + googleSearchRetrieval: {}, + }, + }; + await modelActionCallback(request); + const apiRequest: GenerateContentRequest = JSON.parse( + fetchStub.lastCall.args[1].body + ); + const searchTool = apiRequest.tools?.find(isGoogleSearchRetrievalTool); + assert.ok(searchTool, 'Expected GoogleSearchRetrievalTool'); + if (searchTool) { + assert.ok( + searchTool.googleSearchRetrieval, + 'Expected googleSearchRetrieval property' + ); + assert.deepStrictEqual(searchTool, { googleSearchRetrieval: {} }); + } }); }); + + describe('defineModel - Global Client', () => { + it('defines a model with the correct name', () => { + defineModel(mockGenkit, 'gemini-2.0-flash', defaultGlobalClientOptions); + sinon.assert.calledOnce(mockGenkit.defineModel); + const args = mockGenkit.defineModel.lastCall.args[0]; + assert.strictEqual(args.name, 'vertexai/gemini-2.0-flash'); + }); + + runCommonTests(defaultGlobalClientOptions); + }); + + describe('defineModel - Express Client', () => { + it('defines a model with the correct name', () => { + defineModel(mockGenkit, 'gemini-2.0-flash', defaultExpressClientOptions); + sinon.assert.calledOnce(mockGenkit.defineModel); + const args = mockGenkit.defineModel.lastCall.args[0]; + assert.strictEqual(args.name, 'vertexai/gemini-2.0-flash'); + }); + + runCommonTests(defaultExpressClientOptions); + }); }); diff --git a/js/plugins/google-genai/tests/vertexai/imagen_test.ts b/js/plugins/google-genai/tests/vertexai/imagen_test.ts index 06e982b095..da6e56aaed 100644 --- a/js/plugins/google-genai/tests/vertexai/imagen_test.ts +++ b/js/plugins/google-genai/tests/vertexai/imagen_test.ts @@ -15,8 +15,9 @@ */ import * as assert from 'assert'; -import { Genkit } from 'genkit'; +import { GENKIT_CLIENT_HEADER, Genkit } from 'genkit'; import { GenerateRequest, getBasicUsageStats } from 'genkit/model'; +import { GoogleAuth } from 'google-auth-library'; import { afterEach, beforeEach, describe, it } from 'node:test'; import * as sinon from 'sinon'; import { getVertexAIUrl } from '../../src/vertexai/client'; @@ -37,6 +38,11 @@ import * as utils from '../../src/vertexai/utils'; const { toImagenParameters, fromImagenPrediction } = TEST_ONLY; +// Helper function to escape special characters for use in a RegExp +function escapeRegExp(string) { + return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string +} + describe('Vertex AI Imagen', () => { describe('KNOWN_IMAGEN_MODELS', () => { it('should contain non-zero number of models', () => { @@ -165,24 +171,36 @@ describe('Vertex AI Imagen', () => { describe('defineImagenModel()', () => { let mockAi: sinon.SinonStubbedInstance; let fetchStub: sinon.SinonStub; - const clientOptions: ClientOptions = { + const modelName = 'imagen-test-model'; + let authMock: sinon.SinonStubbedInstance; + + const regionalClientOptions: ClientOptions = { + kind: 'regional', projectId: 'test-project', location: 'us-central1', - authClient: { - getAccessToken: async () => 'test-token', - } as any, + authClient: {} as any, + }; + + const globalClientOptions: ClientOptions = { + kind: 'global', + projectId: 'test-project', + location: 'global', + authClient: {} as any, + apiKey: 'test-api-key', + }; + + const expressClientOptions: ClientOptions = { + kind: 'express', + apiKey: 'test-express-api-key', }; - const modelName = 'imagen-test-model'; - const expectedUrl = getVertexAIUrl({ - includeProjectAndLocation: true, - resourcePath: `publishers/google/models/${modelName}`, - resourceMethod: 'predict', - clientOptions, - }); beforeEach(() => { mockAi = sinon.createStubInstance(Genkit); fetchStub = sinon.stub(global, 'fetch'); + authMock = sinon.createStubInstance(GoogleAuth); + authMock.getAccessToken.resolves('test-token'); + regionalClientOptions.authClient = authMock as unknown as GoogleAuth; + globalClientOptions.authClient = authMock as unknown as GoogleAuth; }); afterEach(() => { @@ -198,7 +216,9 @@ describe('Vertex AI Imagen', () => { fetchStub.resolves(Promise.resolve(response)); } - function captureModelRunner(): (request: GenerateRequest) => Promise { + function captureModelRunner( + clientOptions: ClientOptions + ): (request: GenerateRequest) => Promise { defineModel(mockAi as any, modelName, clientOptions); assert.ok(mockAi.defineModel.calledOnce); const callArgs = mockAi.defineModel.firstCall.args; @@ -207,103 +227,155 @@ describe('Vertex AI Imagen', () => { return callArgs[1]; } - it('should define a model and call fetch successfully', async () => { - const request: GenerateRequest = { - messages: [{ role: 'user', content: [{ text: 'A cat' }] }], - candidates: 2, - config: { seed: 42 }, - }; - - const mockPrediction: ImagenPrediction = { - bytesBase64Encoded: 'abc', - mimeType: 'image/png', - }; - const mockResponse: ImagenPredictResponse = { - predictions: [mockPrediction, mockPrediction], + function getExpectedHeaders( + clientOptions: ClientOptions + ): Record { + const headers: Record = { + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + 'User-Agent': GENKIT_CLIENT_HEADER, }; - mockFetchResponse(mockResponse); + if (clientOptions.kind !== 'express') { + headers['Authorization'] = 'Bearer test-token'; + headers['x-goog-user-project'] = clientOptions.projectId; + } + return headers; + } - const modelRunner = captureModelRunner(); - const result = await modelRunner(request); + function runTestsForClientOptions(clientOptions: ClientOptions) { + const keySuffix = + clientOptions.kind === 'express' ? `?key=${clientOptions.apiKey}` : ''; - sinon.assert.calledOnce(fetchStub); - const fetchArgs = fetchStub.lastCall.args; - assert.strictEqual(fetchArgs[0], expectedUrl); - assert.strictEqual(fetchArgs[1].method, 'POST'); - assert.ok(fetchArgs[1].headers['Authorization'].startsWith('Bearer ')); + const expectedUrl = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: `publishers/google/models/${modelName}`, + resourceMethod: 'predict', + clientOptions, + }); - // Build the expected instance, only adding keys if they have values - const prompt = utils.extractText(request); - const image = utils.extractImagenImage(request); - const mask = utils.extractImagenMask(request); + it(`should define a model and call fetch successfully for ${clientOptions.kind}`, async () => { + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'A cat' }] }], + candidates: 2, + config: { seed: 42 }, + }; + + const mockPrediction: ImagenPrediction = { + bytesBase64Encoded: 'abc', + mimeType: 'image/png', + }; + const mockResponse: ImagenPredictResponse = { + predictions: [mockPrediction, mockPrediction], + }; + mockFetchResponse(mockResponse); + + const modelRunner = captureModelRunner(clientOptions); + const result = await modelRunner(request); + + sinon.assert.calledOnce(fetchStub); + const fetchArgs = fetchStub.lastCall.args; + let actualUrl = fetchArgs[0]; + if (clientOptions.kind === 'express') { + assert.ok(actualUrl.startsWith(expectedUrl.split('?')[0])); + assert.ok(actualUrl.endsWith(keySuffix)); + } else { + assert.strictEqual(actualUrl, expectedUrl); + } + assert.strictEqual(fetchArgs[1].method, 'POST'); + assert.deepStrictEqual( + fetchArgs[1].headers, + getExpectedHeaders(clientOptions) + ); + + const prompt = utils.extractText(request); + // extractImagenImage and extractImagenMask return undefined, + // so JSON.stringify will omit these keys. + const expectedInstance: any = { + prompt, + }; + const expectedImagenPredictRequest: ImagenPredictRequest = { + instances: [expectedInstance], + parameters: toImagenParameters(request), + }; + + assert.deepStrictEqual( + JSON.parse(fetchArgs[1].body), + expectedImagenPredictRequest + ); + + const expectedCandidates = mockResponse.predictions!.map((p, i) => + fromImagenPrediction(p, i) + ); + assert.deepStrictEqual(result.candidates, expectedCandidates); + assert.deepStrictEqual(result.usage, { + ...getBasicUsageStats(request.messages, expectedCandidates), + custom: { generations: 2 }, + }); + assert.deepStrictEqual(result.custom, mockResponse); + }); - const expectedInstance: any = { prompt }; - if (image !== undefined) expectedInstance.image = image; - if (mask !== undefined) expectedInstance.mask = mask; + it(`should throw an error if model returns no predictions for ${clientOptions.kind}`, async () => { + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'A dog' }] }], + }; + mockFetchResponse({ predictions: [] }); + + const modelRunner = captureModelRunner(clientOptions); + await assert.rejects( + modelRunner(request), + /Model returned no predictions/ + ); + sinon.assert.calledOnce(fetchStub); + }); - const expectedImagenPredictRequest: ImagenPredictRequest = { - instances: [expectedInstance], - parameters: toImagenParameters(request), - }; + it(`should propagate network errors from fetch for ${clientOptions.kind}`, async () => { + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'A fish' }] }], + }; + const error = new Error('Network Error'); + fetchStub.rejects(error); + + const modelRunner = captureModelRunner(clientOptions); + await assert.rejects( + modelRunner(request), + new RegExp( + `^Error: Failed to fetch from ${escapeRegExp(expectedUrl)}: Network Error` + ) + ); + }); - assert.deepStrictEqual( - JSON.parse(fetchArgs[1].body), - expectedImagenPredictRequest - ); - - const expectedCandidates = mockResponse.predictions!.map((p, i) => - fromImagenPrediction(p, i) - ); - assert.deepStrictEqual(result.candidates, expectedCandidates); - assert.deepStrictEqual(result.usage, { - ...getBasicUsageStats(request.messages, expectedCandidates), - custom: { generations: 2 }, + it(`should handle API error response for ${clientOptions.kind}`, async () => { + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'A bird' }] }], + }; + const errorMsg = 'Invalid argument'; + const errorBody = { error: { message: errorMsg, code: 400 } }; + mockFetchResponse(errorBody, 400); + + const modelRunner = captureModelRunner(clientOptions); + let expectedUrlRegex = escapeRegExp(expectedUrl); + if (clientOptions.kind === 'express') { + expectedUrlRegex = expectedUrl.split('?')[0] + '.*'; + } + await assert.rejects( + modelRunner(request), + new RegExp( + `^Error: Failed to fetch from ${expectedUrlRegex}: Error fetching from ${expectedUrlRegex}: \\[400 Error\\] ${errorMsg}` + ) + ); }); - assert.deepStrictEqual(result.custom, mockResponse); - }); + } - it('should throw an error if model returns no predictions', async () => { - const request: GenerateRequest = { - messages: [{ role: 'user', content: [{ text: 'A dog' }] }], - }; - mockFetchResponse({ predictions: [] }); - - const modelRunner = captureModelRunner(); - await assert.rejects( - modelRunner(request), - /Model returned no predictions/ - ); - sinon.assert.calledOnce(fetchStub); + describe('with RegionalClientOptions', () => { + runTestsForClientOptions(regionalClientOptions); }); - it('should propagate network errors from fetch', async () => { - const request: GenerateRequest = { - messages: [{ role: 'user', content: [{ text: 'A fish' }] }], - }; - const error = new Error('Network Error'); - fetchStub.rejects(error); - - const modelRunner = captureModelRunner(); - await assert.rejects( - modelRunner(request), - new RegExp(`Failed to fetch from ${expectedUrl}: Network Error`) - ); + describe('with GlobalClientOptions', () => { + runTestsForClientOptions(globalClientOptions); }); - it('should handle API error response', async () => { - const request: GenerateRequest = { - messages: [{ role: 'user', content: [{ text: 'A bird' }] }], - }; - const errorBody = { error: { message: 'Invalid argument', code: 400 } }; - mockFetchResponse(errorBody, 400); - - const modelRunner = captureModelRunner(); - await assert.rejects( - modelRunner(request), - new RegExp( - `Error fetching from ${expectedUrl}: \\[400 Error\\] Invalid argument` - ) - ); + describe('with ExpressClientOptions', () => { + runTestsForClientOptions(expressClientOptions); }); }); }); diff --git a/js/plugins/google-genai/tests/vertexai/index_test.ts b/js/plugins/google-genai/tests/vertexai/index_test.ts index 7a9f369b68..d022b98f83 100644 --- a/js/plugins/google-genai/tests/vertexai/index_test.ts +++ b/js/plugins/google-genai/tests/vertexai/index_test.ts @@ -16,16 +16,10 @@ import * as assert from 'assert'; import { genkit, type Genkit } from 'genkit'; +import { EmbedRequest } from 'genkit/embedder'; +import { GenerateRequest } from 'genkit/model'; import { GoogleAuth } from 'google-auth-library'; -import { - after, - afterEach, - before, - beforeEach, - describe, - it, - mock, -} from 'node:test'; +import { afterEach, beforeEach, describe, it, mock } from 'node:test'; import { TEST_ONLY as EMBEDDER_TEST_ONLY, EmbeddingConfigSchema, @@ -39,26 +33,49 @@ import { ImagenConfigSchema, } from '../../src/vertexai/imagen.js'; import { vertexAI } from '../../src/vertexai/index.js'; +import type { + ExpressClientOptions, + GlobalClientOptions, + RegionalClientOptions, +} from '../../src/vertexai/types.js'; import { TEST_ONLY as UTILS_TEST_ONLY } from '../../src/vertexai/utils.js'; describe('VertexAI Plugin', () => { - const originalMockDerivedOptions = { + const regionalMockDerivedOptions: RegionalClientOptions = { + kind: 'regional' as const, location: 'us-central1', projectId: 'test-project', - authClient: {} as GoogleAuth, + authClient: { + getAccessToken: async () => 'fake-test-token', + } as unknown as GoogleAuth, + }; + const globalMockDerivedOptions: GlobalClientOptions = { + kind: 'global' as const, + location: 'global', + projectId: 'test-project', + authClient: { + getAccessToken: async () => 'fake-test-token', + } as unknown as GoogleAuth, }; + const expressMockDerivedOptions: ExpressClientOptions = { + kind: 'express' as const, + apiKey: 'test-express-api-key', + }; + + let ai: Genkit; - before(() => { - UTILS_TEST_ONLY.setMockDerivedOptions(originalMockDerivedOptions); + // Default to regional options for most tests + beforeEach(() => { + UTILS_TEST_ONLY.setMockDerivedOptions(regionalMockDerivedOptions); + ai = genkit({ plugins: [vertexAI()] }); }); - after(() => { + afterEach(() => { UTILS_TEST_ONLY.setMockDerivedOptions(undefined as any); }); describe('Initializer', () => { it('should pre-register flagship Gemini models', async () => { - const ai = genkit({ plugins: [vertexAI()] }); const model1Name = Object.keys(GEMINI_TEST_ONLY.KNOWN_MODELS)[0]; const model1Path = `/model/vertexai/${model1Name}`; const expectedBaseName = `vertexai/${model1Name}`; @@ -68,7 +85,6 @@ describe('VertexAI Plugin', () => { }); it('should register all known Gemini models', async () => { - const ai = genkit({ plugins: [vertexAI()] }); for (const modelName in GEMINI_TEST_ONLY.KNOWN_MODELS) { const modelPath = `/model/vertexai/${modelName}`; const expectedBaseName = `vertexai/${modelName}`; @@ -79,7 +95,6 @@ describe('VertexAI Plugin', () => { }); it('should pre-register flagship Imagen models', async () => { - const ai = genkit({ plugins: [vertexAI()] }); const modelKeys = Object.keys(IMAGEN_TEST_ONLY.KNOWN_MODELS); if (modelKeys.length > 0) { const model1Name = modelKeys[0]; @@ -95,7 +110,6 @@ describe('VertexAI Plugin', () => { }); it('should register all known Imagen models', async () => { - const ai = genkit({ plugins: [vertexAI()] }); for (const modelName in IMAGEN_TEST_ONLY.KNOWN_MODELS) { const modelPath = `/model/vertexai/${modelName}`; const expectedBaseName = `vertexai/${modelName}`; @@ -106,7 +120,6 @@ describe('VertexAI Plugin', () => { }); it('should pre-register flagship Embedder models', async () => { - const ai = genkit({ plugins: [vertexAI()] }); const modelKeys = Object.keys(EMBEDDER_TEST_ONLY.KNOWN_MODELS); if (modelKeys.length > 0) { const model1Name = modelKeys[0]; @@ -122,7 +135,6 @@ describe('VertexAI Plugin', () => { }); it('should register all known Embedder models', async () => { - const ai = genkit({ plugins: [vertexAI()] }); for (const modelName in EMBEDDER_TEST_ONLY.KNOWN_MODELS) { const modelPath = `/embedder/vertexai/${modelName}`; const expectedBaseName = `vertexai/${modelName}`; @@ -142,21 +154,18 @@ describe('VertexAI Plugin', () => { const testEmbedderPath = `/embedder/vertexai/${testEmbedderName}`; it('should register a new Gemini model when looked up', async () => { - const ai = genkit({ plugins: [vertexAI()] }); const model = await ai.registry.lookupAction(testModelPath); assert.ok(model, `${testModelName} should be resolvable and registered`); assert.strictEqual(model?.__action.name, `vertexai/${testModelName}`); }); it('should register a new Imagen model when looked up', async () => { - const ai = genkit({ plugins: [vertexAI()] }); const model = await ai.registry.lookupAction(testImagenPath); assert.ok(model, `${testImagenName} should be resolvable and registered`); assert.strictEqual(model?.__action.name, `vertexai/${testImagenName}`); }); it('should register a new Embedder when looked up', async () => { - const ai = genkit({ plugins: [vertexAI()] }); const embedder = await ai.registry.lookupAction(testEmbedderPath); assert.ok( embedder, @@ -205,28 +214,6 @@ describe('VertexAI Plugin', () => { ); }); - it('vertexAI.model should handle custom names for Gemini family', () => { - const modelName = 'gemini-custom-model'; - const modelRef = vertexAI.model(modelName); - assert.strictEqual(modelRef.name, `vertexai/${modelName}`); - assert.strictEqual( - modelRef.configSchema, - GeminiConfigSchema, - 'Custom Gemini should still use GeminiConfigSchema' - ); - }); - - it('vertexAI.model should handle custom names for Imagen family', () => { - const modelName = 'imagen-custom-model'; - const modelRef = vertexAI.model(modelName); - assert.strictEqual(modelRef.name, `vertexai/${modelName}`); - assert.strictEqual( - modelRef.configSchema, - ImagenConfigSchema, - 'Custom Imagen should still use ImagenConfigSchema' - ); - }); - it('vertexAI.embedder should return an EmbedderReference with correct schema', () => { const embedderName = 'text-embedding-005'; const embedderRef = vertexAI.embedder(embedderName); @@ -242,25 +229,13 @@ describe('VertexAI Plugin', () => { describe('listActions Function', () => { let fetchMock: any; - let ai: Genkit; beforeEach(() => { - ai = genkit({ plugins: [vertexAI()] }); - const authClientMock = { - getAccessToken: async () => 'fake-test-token', - // Add other methods if needed, though getAccessToken is the key one for headers - }; - UTILS_TEST_ONLY.setMockDerivedOptions({ - location: 'us-central1', - projectId: 'test-project', - authClient: authClientMock as any, - }); fetchMock = mock.method(global, 'fetch'); }); afterEach(() => { fetchMock.mock.restore(); - UTILS_TEST_ONLY.setMockDerivedOptions(originalMockDerivedOptions); }); const createMockResponse = (models: Array<{ name: string }>) => { @@ -299,65 +274,355 @@ describe('VertexAI Plugin', () => { }); }); - it('should filter out known decommissioned models', async () => { - const mockModels = [ - { name: 'publishers/google/models/gemini-1.5-flash' }, - { name: 'publishers/google/models/gemini-pro' }, - ]; - fetchMock.mock.mockImplementation(async () => - createMockResponse(mockModels) - ); + it('should call fetch with auth token and location-specific URL for local options', async () => { + fetchMock.mock.mockImplementation(async () => createMockResponse([])); const pluginProvider = vertexAI()(ai); - const actions = await pluginProvider.listActions!(); - const actionNames = actions.map((a) => a.name); - assert.deepStrictEqual(actionNames, ['vertexai/gemini-1.5-flash']); - }); + await pluginProvider.listActions!(); - it('should filter out embedding models from gemini results', async () => { - const mockModels = [ - { name: 'publishers/google/models/gemini-2.0-pro' }, - { name: 'publishers/google/models/gemini-embedding-001' }, - ]; - fetchMock.mock.mockImplementation(async () => - createMockResponse(mockModels) + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + + assert.strictEqual(headers['Authorization'], 'Bearer fake-test-token'); + assert.strictEqual(headers['x-goog-user-project'], 'test-project'); + assert.ok( + url.startsWith('https://us-central1-aiplatform.googleapis.com') ); - const pluginProvider = vertexAI()(ai); - const actions = await pluginProvider.listActions!(); - const actionNames = actions.map((a) => a.name); - assert.deepStrictEqual(actionNames, ['vertexai/gemini-2.0-pro']); }); - it('should handle fetch errors gracefully', async () => { - fetchMock.mock.mockImplementation(async () => { - return Promise.resolve({ - ok: false, - status: 500, - statusText: 'Internal Error', - json: async () => ({ error: { message: 'API Error' } }), - }); - }); + it('should call fetch with API key and global URL for global options', async () => { + const globalWithOptions = { + ...globalMockDerivedOptions, + apiKey: 'test-api-key', + }; + UTILS_TEST_ONLY.setMockDerivedOptions(globalWithOptions); + ai = genkit({ plugins: [vertexAI()] }); // Re-init + fetchMock.mock.mockImplementation(async () => createMockResponse([])); const pluginProvider = vertexAI()(ai); - const actions = await pluginProvider.listActions!(); - assert.deepStrictEqual( - actions, - [], - 'Should return empty array on fetch error' - ); + await pluginProvider.listActions!(); + + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + + assert.strictEqual(headers['Authorization'], 'Bearer fake-test-token'); + assert.strictEqual(headers['x-goog-user-project'], 'test-project'); + assert.ok(url.startsWith('https://aiplatform.googleapis.com')); + assert.ok(url.includes('?key=test-api-key')); + assert.ok(!url.includes('us-central1-')); }); - it('should use listActions cache', async () => { - const mockModels = [{ name: 'publishers/google/models/gemini-1.0-pro' }]; - fetchMock.mock.mockImplementation(async () => - createMockResponse(mockModels) - ); + it('should call fetch with API key in URL and no Auth for express options', async () => { + UTILS_TEST_ONLY.setMockDerivedOptions(expressMockDerivedOptions); + ai = genkit({ plugins: [vertexAI()] }); // Re-init + fetchMock.mock.mockImplementation(async () => createMockResponse([])); const pluginProvider = vertexAI()(ai); await pluginProvider.listActions!(); - await pluginProvider.listActions!(); - assert.strictEqual( - fetchMock.mock.callCount(), - 1, - 'fetch should only be called once' - ); + + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + + assert.strictEqual(headers['x-goog-api-key'], undefined); + assert.strictEqual(headers['Authorization'], undefined); + assert.ok(url.startsWith('https://aiplatform.googleapis.com')); + assert.ok(url.includes('?key=test-express-api-key')); + }); + }); + + describe('API Calls', () => { + let fetchMock: any; + + beforeEach(() => { + fetchMock = mock.method(global, 'fetch'); + }); + + afterEach(() => { + fetchMock.mock.restore(); + }); + + const createMockApiResponse = (data: object) => { + return Promise.resolve({ + ok: true, + json: async () => data, + }); + }; + + describe('With Local Options', () => { + beforeEach(() => { + UTILS_TEST_ONLY.setMockDerivedOptions(regionalMockDerivedOptions); + ai = genkit({ plugins: [vertexAI()] }); + }); + + it('should use auth token for Gemini generateContent', async () => { + const modelRef = vertexAI.model('gemini-1.5-flash'); + const generateAction = await ai.registry.lookupAction( + '/model/' + modelRef.name + ); + assert.ok(generateAction, `/model/${modelRef.name} action not found`); + + fetchMock.mock.mockImplementation(async () => + createMockApiResponse({ + candidates: [ + { + index: 0, + content: { role: 'model', parts: [{ text: 'response' }] }, + }, + ], + }) + ); + + await generateAction({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + config: {}, + } as GenerateRequest); + + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + assert.strictEqual(headers['Authorization'], 'Bearer fake-test-token'); + assert.ok(url.includes('us-central1-aiplatform.googleapis.com')); + }); + + it('should use auth token for Embedder embedContent', async () => { + const embedderRef = vertexAI.embedder('text-embedding-004'); + const embedAction = await ai.registry.lookupAction( + '/embedder/' + embedderRef.name + ); + assert.ok( + embedAction, + `/embedder/${embedderRef.name} action not found` + ); + + fetchMock.mock.mockImplementation(async () => + createMockApiResponse({ + predictions: [{ embeddings: { values: [0.1] } }], + }) + ); + + await embedAction({ + input: [{ content: [{ text: 'test' }] }], + } as EmbedRequest); + + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + assert.strictEqual(headers['Authorization'], 'Bearer fake-test-token'); + }); + + it('should use auth token for Imagen predict', async () => { + const modelRef = vertexAI.model('imagen-3.0-generate-001'); + const generateAction = await ai.registry.lookupAction( + '/model/' + modelRef.name + ); + assert.ok(generateAction, `/model/${modelRef.name} action not found`); + + fetchMock.mock.mockImplementation(async () => + createMockApiResponse({ + predictions: [{ bytesBase64Encoded: 'abc', mimeType: 'image/png' }], + }) + ); + + await generateAction({ + messages: [{ role: 'user', content: [{ text: 'a cat' }] }], + config: {}, + } as GenerateRequest); + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + assert.strictEqual(headers['Authorization'], 'Bearer fake-test-token'); + }); + }); + + describe('With Global Options', () => { + beforeEach(() => { + const globalWithOptions = { + ...globalMockDerivedOptions, + apiKey: 'test-api-key', + }; + UTILS_TEST_ONLY.setMockDerivedOptions(globalWithOptions); + ai = genkit({ plugins: [vertexAI()] }); + }); + + it('should use API key for Gemini generateContent', async () => { + const modelRef = vertexAI.model('gemini-1.5-flash'); + const generateAction = await ai.registry.lookupAction( + '/model/' + modelRef.name + ); + assert.ok(generateAction, `/model/${modelRef.name} action not found`); + + fetchMock.mock.mockImplementation(async () => + createMockApiResponse({ + candidates: [ + { + index: 0, + content: { role: 'model', parts: [{ text: 'response' }] }, + }, + ], + }) + ); + + await generateAction({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + config: {}, + } as GenerateRequest); + + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + assert.strictEqual(headers['Authorization'], 'Bearer fake-test-token'); + assert.strictEqual(headers['x-goog-user-project'], 'test-project'); + assert.ok(url.includes('?key=test-api-key')); + assert.ok( + url.includes('aiplatform.googleapis.com') && + !url.includes('us-central1-') + ); + }); + + it('should use API key for Embedder embedContent', async () => { + const embedderRef = vertexAI.embedder('text-embedding-004'); + const embedAction = await ai.registry.lookupAction( + '/embedder/' + embedderRef.name + ); + assert.ok( + embedAction, + `/embedder/${embedderRef.name} action not found` + ); + + fetchMock.mock.mockImplementation(async () => + createMockApiResponse({ + predictions: [{ embeddings: { values: [0.1] } }], + }) + ); + + await embedAction({ + input: [{ content: [{ text: 'test' }] }], + } as EmbedRequest); + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + assert.strictEqual(headers['Authorization'], 'Bearer fake-test-token'); + assert.ok(url.includes('?key=test-api-key')); + }); + + it('should use API key for Imagen predict', async () => { + const modelRef = vertexAI.model('imagen-3.0-generate-001'); + const generateAction = await ai.registry.lookupAction( + '/model/' + modelRef.name + ); + assert.ok(generateAction, `/model/${modelRef.name} action not found`); + + fetchMock.mock.mockImplementation(async () => + createMockApiResponse({ + predictions: [{ bytesBase64Encoded: 'abc', mimeType: 'image/png' }], + }) + ); + + await generateAction({ + messages: [{ role: 'user', content: [{ text: 'a cat' }] }], + config: {}, + } as GenerateRequest); + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + assert.strictEqual(headers['Authorization'], 'Bearer fake-test-token'); + assert.ok(url.includes('?key=test-api-key')); + }); + }); + + describe('With Express Options', () => { + beforeEach(() => { + UTILS_TEST_ONLY.setMockDerivedOptions(expressMockDerivedOptions); + ai = genkit({ plugins: [vertexAI()] }); + }); + + it('should use API key in URL for Gemini generateContent', async () => { + const modelRef = vertexAI.model('gemini-1.5-flash'); + const generateAction = await ai.registry.lookupAction( + '/model/' + modelRef.name + ); + assert.ok(generateAction, `/model/${modelRef.name} action not found`); + + fetchMock.mock.mockImplementation(async () => + createMockApiResponse({ + candidates: [ + { + index: 0, + content: { role: 'model', parts: [{ text: 'response' }] }, + }, + ], + }) + ); + + await generateAction({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + config: {}, + } as GenerateRequest); + + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + + assert.strictEqual(headers['Authorization'], undefined); + assert.strictEqual(headers['x-goog-api-key'], undefined); + assert.ok(url.includes('?key=test-express-api-key')); + assert.ok( + url.includes('aiplatform.googleapis.com') && + !url.includes('us-central1-') + ); + }); + + it('should use API key in URL for Embedder embedContent', async () => { + const embedderRef = vertexAI.embedder('text-embedding-004'); + const embedAction = await ai.registry.lookupAction( + '/embedder/' + embedderRef.name + ); + assert.ok( + embedAction, + `/embedder/${embedderRef.name} action not found` + ); + + fetchMock.mock.mockImplementation(async () => + createMockApiResponse({ + predictions: [{ embeddings: { values: [0.1] } }], + }) + ); + + await embedAction({ + input: [{ content: [{ text: 'test' }] }], + } as EmbedRequest); + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + + assert.strictEqual(headers['Authorization'], undefined); + assert.strictEqual(headers['x-goog-api-key'], undefined); + assert.ok(url.includes('?key=test-express-api-key')); + }); + + it('should use API key in URL for Imagen predict', async () => { + const modelRef = vertexAI.model('imagen-3.0-generate-001'); + const generateAction = await ai.registry.lookupAction( + '/model/' + modelRef.name + ); + assert.ok(generateAction, `/model/${modelRef.name} action not found`); + + fetchMock.mock.mockImplementation(async () => + createMockApiResponse({ + predictions: [{ bytesBase64Encoded: 'abc', mimeType: 'image/png' }], + }) + ); + + await generateAction({ + messages: [{ role: 'user', content: [{ text: 'a cat' }] }], + config: {}, + } as GenerateRequest); + const fetchCall = fetchMock.mock.calls[0]; + const headers = fetchCall.arguments[1].headers; + const url = fetchCall.arguments[0]; + + assert.strictEqual(headers['Authorization'], undefined); + assert.strictEqual(headers['x-goog-api-key'], undefined); + assert.ok(url.includes('?key=test-express-api-key')); + }); }); }); }); diff --git a/js/plugins/google-genai/tests/vertexai/utils_test.ts b/js/plugins/google-genai/tests/vertexai/utils_test.ts index 0a6659e818..3911c6ba3d 100644 --- a/js/plugins/google-genai/tests/vertexai/utils_test.ts +++ b/js/plugins/google-genai/tests/vertexai/utils_test.ts @@ -15,12 +15,21 @@ */ import * as assert from 'assert'; +import { GenkitError } from 'genkit'; import { GenerateRequest } from 'genkit/model'; import { GoogleAuth } from 'google-auth-library'; import { afterEach, beforeEach, describe, it } from 'node:test'; import * as sinon from 'sinon'; -import { VertexPluginOptions } from '../../src/vertexai/types'; import { + ExpressClientOptions, + GlobalClientOptions, + RegionalClientOptions, + VertexPluginOptions, +} from '../../src/vertexai/types'; +import { + API_KEY_FALSE_ERROR, + MISSING_API_KEY_ERROR, + calculateApiKey, extractImagenImage, extractImagenMask, extractText, @@ -47,9 +56,14 @@ describe('getDerivedOptions', () => { delete process.env.GCLOUD_LOCATION; delete process.env.FIREBASE_CONFIG; delete process.env.GCLOUD_SERVICE_ACCOUNT_CREDS; + delete process.env.VERTEX_API_KEY; + delete process.env.GOOGLE_API_KEY; + delete process.env.GOOGLE_GENAI_API_KEY; authInstance = sinon.createStubInstance(GoogleAuth); authInstance.getAccessToken.resolves('test-token'); + // Default to simulating project ID not found, tests that need it should override. + authInstance.getProjectId.resolves(undefined); mockAuthClass = sinon.stub().returns(authInstance); }); @@ -58,108 +72,420 @@ describe('getDerivedOptions', () => { sinon.restore(); }); - it('should use defaults when no options or env vars are provided', async () => { - authInstance.getProjectId.resolves('default-project'); - const options = await getDerivedOptions(undefined, mockAuthClass as any); - assert.strictEqual(options.projectId, 'default-project'); - assert.strictEqual(options.location, 'us-central1'); - assert.ok(options.authClient); - sinon.assert.calledOnce(mockAuthClass); - sinon.assert.calledOnce(authInstance.getProjectId); + describe('Regional Options', () => { + it('should use options for projectId and location', async () => { + const pluginOptions: VertexPluginOptions = { + projectId: 'options-project', + location: 'options-location', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'options-project'); + assert.strictEqual(options.location, 'options-location'); + assert.ok(options.authClient); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.notCalled(authInstance.getProjectId); + }); + + it('should use GCLOUD_PROJECT and GCLOUD_LOCATION env vars', async () => { + process.env.GCLOUD_PROJECT = 'env-project'; + process.env.GCLOUD_LOCATION = 'env-location'; + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'env-project'); + assert.strictEqual(options.location, 'env-location'); + sinon.assert.calledOnce(mockAuthClass); + const authOptions = mockAuthClass.lastCall.args[0]; + assert.strictEqual(authOptions.projectId, 'env-project'); + sinon.assert.notCalled(authInstance.getProjectId); + }); + + it('should use default location when only projectId is available', async () => { + authInstance.getProjectId.resolves('default-project'); + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'default-project'); + assert.strictEqual(options.location, 'us-central1'); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.calledOnce(authInstance.getProjectId); + }); + + it('should use FIREBASE_CONFIG for GoogleAuth constructor, but final projectId from getProjectId', async () => { + process.env.FIREBASE_CONFIG = JSON.stringify({ + projectId: 'firebase-project', + }); + authInstance.getProjectId.resolves('auth-client-project'); + + const options = (await getDerivedOptions( + { location: 'fb-location' }, + mockAuthClass as any + )) as RegionalClientOptions; + + assert.strictEqual(options.kind, 'regional'); + sinon.assert.calledOnce(mockAuthClass); + const authOptions = mockAuthClass.lastCall.args[0]; + assert.strictEqual(authOptions.projectId, 'firebase-project'); + sinon.assert.calledOnce(authInstance.getProjectId); + assert.strictEqual(options.projectId, 'auth-client-project'); + assert.strictEqual(options.location, 'fb-location'); + }); + + it('should prioritize plugin options over env vars', async () => { + process.env.GCLOUD_PROJECT = 'env-project'; + process.env.GCLOUD_LOCATION = 'env-location'; + const pluginOptions: VertexPluginOptions = { + projectId: 'options-project', + location: 'options-location', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'options-project'); + assert.strictEqual(options.location, 'options-location'); + }); + + it('should use GCLOUD_SERVICE_ACCOUNT_CREDS for auth', async () => { + const creds = { + client_email: '', + private_key: 'private_key', + }; + process.env.GCLOUD_SERVICE_ACCOUNT_CREDS = JSON.stringify(creds); + authInstance.getProjectId.resolves('creds-project'); + + const options = (await getDerivedOptions( + { location: 'creds-location' }, + mockAuthClass as any + )) as RegionalClientOptions; + + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'creds-project'); + assert.strictEqual(options.location, 'creds-location'); + sinon.assert.calledOnce(mockAuthClass); + const authOptions = mockAuthClass.lastCall.args[0]; + assert.deepStrictEqual(authOptions.credentials, creds); + assert.strictEqual(authOptions.projectId, undefined); + sinon.assert.calledOnce(authInstance.getProjectId); + }); + + it('should throw error if projectId cannot be determined for regional', async () => { + authInstance.getProjectId.resolves(undefined); + await assert.rejects( + getDerivedOptions({ location: 'some-location' }, mockAuthClass as any), + /VertexAI Plugin is missing the 'project' configuration/ + ); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.calledOnce(authInstance.getProjectId); + }); + + it('should prefer regional if location is specified, even with apiKey', async () => { + const pluginOptions: VertexPluginOptions = { + location: 'us-central1', + apiKey: 'test-api-key', + projectId: 'options-project', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'options-project'); + assert.strictEqual(options.location, 'us-central1'); + assert.ok(options.authClient); + sinon.assert.calledOnce(mockAuthClass); + }); }); - it('should use options for projectId and location', async () => { - const pluginOptions: VertexPluginOptions = { - projectId: 'options-project', - location: 'options-location', - }; - const options = await getDerivedOptions( - pluginOptions, - mockAuthClass as any - ); - assert.strictEqual(options.projectId, 'options-project'); - assert.strictEqual(options.location, 'options-location'); - sinon.assert.calledOnce(mockAuthClass); - sinon.assert.notCalled(authInstance.getProjectId); - }); - - it('should use GCLOUD_PROJECT and GCLOUD_LOCATION env vars', async () => { - process.env.GCLOUD_PROJECT = 'env-project'; - process.env.GCLOUD_LOCATION = 'env-location'; - const options = await getDerivedOptions(undefined, mockAuthClass as any); - assert.strictEqual(options.projectId, 'env-project'); - assert.strictEqual(options.location, 'env-location'); - sinon.assert.calledOnce(mockAuthClass); - const authOptions = mockAuthClass.lastCall.args[0]; - assert.strictEqual(authOptions.projectId, 'env-project'); - sinon.assert.notCalled(authInstance.getProjectId); - }); - - it('should use FIREBASE_CONFIG for GoogleAuth constructor, but final projectId from getProjectId', async () => { - process.env.FIREBASE_CONFIG = JSON.stringify({ - projectId: 'firebase-project', + describe('Global Options', () => { + it('should use global options when location is global', async () => { + const pluginOptions: VertexPluginOptions = { + location: 'global', + projectId: 'options-project', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as GlobalClientOptions; + assert.strictEqual(options.kind, 'global'); + assert.strictEqual(options.location, 'global'); + assert.strictEqual(options.projectId, 'options-project'); + assert.ok(options.authClient); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.notCalled(authInstance.getProjectId); }); - // This will be called because options.projectId and GCLOUD_PROJECT are missing - authInstance.getProjectId.resolves('auth-client-project'); - const options = await getDerivedOptions(undefined, mockAuthClass as any); + it('should use env project for global options', async () => { + process.env.GCLOUD_PROJECT = 'env-project'; + const pluginOptions: VertexPluginOptions = { location: 'global' }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as GlobalClientOptions; + assert.strictEqual(options.kind, 'global'); + assert.strictEqual(options.projectId, 'env-project'); + sinon.assert.calledOnce(mockAuthClass); + }); + + it('should use auth project for global options', async () => { + authInstance.getProjectId.resolves('auth-project'); + const pluginOptions: VertexPluginOptions = { location: 'global' }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as GlobalClientOptions; + assert.strictEqual(options.kind, 'global'); + assert.strictEqual(options.projectId, 'auth-project'); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.calledOnce(authInstance.getProjectId); + }); - // Assert that the constructor received the project ID from FIREBASE_CONFIG - sinon.assert.calledOnce(mockAuthClass); - const authOptions = mockAuthClass.lastCall.args[0]; - assert.strictEqual(authOptions.projectId, 'firebase-project'); + it('should throw error if projectId cannot be determined for global', async () => { + authInstance.getProjectId.resolves(undefined); + await assert.rejects( + getDerivedOptions({ location: 'global' }, mockAuthClass as any), + /VertexAI Plugin is missing the 'project' configuration/ + ); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.calledOnce(authInstance.getProjectId); + }); - // Assert that getProjectId was called to determine the final projectId - sinon.assert.calledOnce(authInstance.getProjectId); - assert.strictEqual(options.projectId, 'auth-client-project'); // Final ID is from getProjectId - assert.strictEqual(options.location, 'us-central1'); + it('should prefer global if location is global, even with apiKey', async () => { + const pluginOptions: VertexPluginOptions = { + location: 'global', + apiKey: 'test-api-key', + projectId: 'options-project', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as GlobalClientOptions; + assert.strictEqual(options.kind, 'global'); + assert.strictEqual(options.projectId, 'options-project'); + assert.ok(options.authClient); + sinon.assert.calledOnce(mockAuthClass); + }); }); - it('should prioritize options over env vars', async () => { - process.env.GCLOUD_PROJECT = 'env-project'; - process.env.GCLOUD_LOCATION = 'env-location'; - const pluginOptions: VertexPluginOptions = { - projectId: 'options-project', - location: 'options-location', - }; - const options = await getDerivedOptions( - pluginOptions, - mockAuthClass as any + describe('Express Options', () => { + it('should use express options with apiKey in options', async () => { + const pluginOptions: VertexPluginOptions = { apiKey: 'key1' }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as ExpressClientOptions; + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, 'key1'); + sinon.assert.notCalled(mockAuthClass); + }); + + it('should use express options with apiKey false in options', async () => { + const pluginOptions: VertexPluginOptions = { apiKey: false }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as ExpressClientOptions; + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, undefined); + sinon.assert.notCalled(mockAuthClass); + }); + + it('should use VERTEX_API_KEY env var for express', async () => { + process.env.VERTEX_API_KEY = 'key2'; + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as ExpressClientOptions; + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, 'key2'); + // mockAuthClass is called during regional/global fallbacks + sinon.assert.calledTwice(mockAuthClass); + }); + + it('should use GOOGLE_API_KEY env var for express', async () => { + process.env.GOOGLE_API_KEY = 'key3'; + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as ExpressClientOptions; + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, 'key3'); + // mockAuthClass is called during regional/global fallbacks + sinon.assert.calledTwice(mockAuthClass); + }); + + it('should prioritize VERTEX_API_KEY over GOOGLE_API_KEY for express', async () => { + process.env.VERTEX_API_KEY = 'keyV'; + process.env.GOOGLE_API_KEY = 'keyG'; + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as ExpressClientOptions; + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, 'keyV'); + // mockAuthClass is called during regional/global fallbacks + sinon.assert.calledTwice(mockAuthClass); + }); + }); + + describe('Fallback Determination (No Options)', () => { + it('should default to regional if project can be determined', async () => { + authInstance.getProjectId.resolves('fallback-project'); + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'fallback-project'); + assert.strictEqual(options.location, 'us-central1'); + sinon.assert.calledOnce(mockAuthClass); + }); + + it('should fallback to express if regional/global fail and API key env exists', async () => { + authInstance.getProjectId.resolves(undefined); // Fail regional/global project lookup + process.env.GOOGLE_API_KEY = 'fallback-api-key'; + + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as ExpressClientOptions; + + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, 'fallback-api-key'); + // getRegionalDerivedOptions, getGlobalDerivedOptions are called first + sinon.assert.calledTwice(mockAuthClass); + }); + }); + + describe('Error Scenarios', () => { + it('should throw error if no options or env vars provide sufficient info', async () => { + authInstance.getProjectId.resolves(undefined); // Simulate failure to get project ID + // No API key env vars set + + await assert.rejects(getDerivedOptions(undefined, mockAuthClass as any), { + name: 'GenkitError', + status: 'INVALID_ARGUMENT', + message: + 'INVALID_ARGUMENT: Unable to determine client options. Please set either apiKey or projectId and location', + }); + // Tries Regional, Global, then Express paths. Regional and Global attempts create AuthClient. + sinon.assert.calledTwice(mockAuthClass); + }); + }); +}); + +describe('calculateApiKey', () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + // Reset env + for (const key in process.env) { + if (!originalEnv.hasOwnProperty(key)) { + delete process.env[key]; + } + } + for (const key in originalEnv) { + process.env[key] = originalEnv[key]; + } + delete process.env.VERTEX_API_KEY; + delete process.env.GOOGLE_API_KEY; + delete process.env.GOOGLE_GENAI_API_KEY; + }); + + function assertThrowsGenkitError( + block: () => void, + expectedError: GenkitError + ) { + let caughtError: any; + try { + block(); + } catch (e: any) { + caughtError = e; + } + + if (!caughtError) { + assert.fail('Should have thrown an error, but nothing was caught.'); + } + + assert.strictEqual( + caughtError.name, + 'GenkitError', + `Caught error is not a GenkitError. Got: ${caughtError.name}, Message: ${caughtError.message}` ); - assert.strictEqual(options.projectId, 'options-project'); - assert.strictEqual(options.location, 'options-location'); + assert.strictEqual(caughtError.status, expectedError.status); + assert.strictEqual(caughtError.message, expectedError.message); + } + + it('should use requestApiKey when provided', () => { + assert.strictEqual(calculateApiKey(undefined, 'reqKey'), 'reqKey'); + assert.strictEqual(calculateApiKey('pluginKey', 'reqKey'), 'reqKey'); + assert.strictEqual(calculateApiKey(false, 'reqKey'), 'reqKey'); }); - it('should use GCLOUD_SERVICE_ACCOUNT_CREDS', async () => { - const creds = { - client_email: '', - private_key: 'private_key', - }; - process.env.GCLOUD_SERVICE_ACCOUNT_CREDS = JSON.stringify(creds); - authInstance.getProjectId.resolves('creds-project'); + it('should use pluginApiKey if requestApiKey is undefined', () => { + assert.strictEqual(calculateApiKey('pluginKey', undefined), 'pluginKey'); + }); + + it('should use VERTEX_API_KEY from env if keys are undefined', () => { + process.env.VERTEX_API_KEY = 'vertexEnvKey'; + assert.strictEqual(calculateApiKey(undefined, undefined), 'vertexEnvKey'); + }); + + it('should use GOOGLE_API_KEY from env if VERTEX_API_KEY is not set', () => { + process.env.GOOGLE_API_KEY = 'googleEnvKey'; + assert.strictEqual(calculateApiKey(undefined, undefined), 'googleEnvKey'); + }); + + it('should prioritize pluginApiKey over env keys', () => { + process.env.VERTEX_API_KEY = 'vertexEnvKey'; + assert.strictEqual(calculateApiKey('pluginKey', undefined), 'pluginKey'); + }); - const options = await getDerivedOptions( - { location: 'creds-location' }, - mockAuthClass as any + it('should throw MISSING_API_KEY_ERROR if no key is found', () => { + assert.strictEqual( + process.env.VERTEX_API_KEY, + undefined, + 'VERTEX_API_KEY should be undefined' + ); + assert.strictEqual( + process.env.GOOGLE_API_KEY, + undefined, + 'GOOGLE_API_KEY should be undefined' + ); + assert.strictEqual( + process.env.GOOGLE_GENAI_API_KEY, + undefined, + 'GOOGLE_GENAI_API_KEY should be undefined' ); - assert.strictEqual(options.projectId, 'creds-project'); - assert.strictEqual(options.location, 'creds-location'); - sinon.assert.calledOnce(mockAuthClass); - const authOptions = mockAuthClass.lastCall.args[0]; - assert.deepStrictEqual(authOptions.credentials, creds); - assert.strictEqual(authOptions.projectId, undefined); - sinon.assert.calledOnce(authInstance.getProjectId); + assertThrowsGenkitError( + () => calculateApiKey(undefined, undefined), + MISSING_API_KEY_ERROR + ); }); - it('should throw error if projectId cannot be determined', async () => { - authInstance.getProjectId.resolves(undefined); - await assert.rejects( - getDerivedOptions({ location: 'some-location' }, mockAuthClass as any), - /VertexAI Plugin is missing the 'project' configuration/ + it('should throw API_KEY_FALSE_ERROR if pluginApiKey is false and requestApiKey is undefined', () => { + assertThrowsGenkitError( + () => calculateApiKey(false, undefined), + API_KEY_FALSE_ERROR + ); + }); + + it('should not use env keys if pluginApiKey is false', () => { + process.env.VERTEX_API_KEY = 'vertexEnvKey'; + assertThrowsGenkitError( + () => calculateApiKey(false, undefined), + API_KEY_FALSE_ERROR ); - sinon.assert.calledOnce(mockAuthClass); - sinon.assert.calledOnce(authInstance.getProjectId); }); }); @@ -179,7 +505,14 @@ describe('extractText', () => { messages: [ { role: 'user', - content: [{ media: { url: '' } }], + content: [ + { + media: { + url: '', + contentType: 'image/png', + }, + }, + ], }, ], }; @@ -193,7 +526,12 @@ describe('extractText', () => { role: 'user', content: [ { text: 'A ' }, - { media: { url: '' } }, + { + media: { + url: '', + contentType: 'image/png', + }, + }, { text: 'B' }, ], }, @@ -263,7 +601,10 @@ describe('extractImagenImage', () => { role: 'user', content: [ { - media: { url: '' }, + media: { + url: '', + contentType: 'image/png', + }, metadata: { type: 'mask' }, }, ], @@ -285,7 +626,14 @@ describe('extractImagenImage', () => { messages: [ { role: 'user', - content: [{ media: { url: 'http://example.com/image.png' } }], + content: [ + { + media: { + url: 'http://example.com/image.png', + contentType: 'image/png', + }, + }, + ], }, ], }; @@ -330,7 +678,10 @@ describe('extractImagenMask', () => { role: 'user', content: [ { - media: { url: '' }, + media: { + url: '', + contentType: 'image/jpeg', + }, metadata: { type: 'base' }, }, ], @@ -345,7 +696,14 @@ describe('extractImagenMask', () => { messages: [ { role: 'user', - content: [{ media: { url: '' } }], + content: [ + { + media: { + url: '', + contentType: 'image/jpeg', + }, + }, + ], }, ], }; @@ -366,7 +724,10 @@ describe('extractImagenMask', () => { role: 'user', content: [ { - media: { url: 'http://example.com/mask.png' }, + media: { + url: 'http://example.com/mask.png', + contentType: 'image/png', + }, metadata: { type: 'mask' }, }, ],