diff --git a/js/ai/src/formats/types.ts b/js/ai/src/formats/types.ts index 0e8b57a880..377b936c96 100644 --- a/js/ai/src/formats/types.ts +++ b/js/ai/src/formats/types.ts @@ -15,7 +15,7 @@ */ import type { JSONSchema } from '@genkit-ai/core'; -import type { GenerateResponseChunk } from '../generate.js'; +import type { GenerateResponseChunk } from '../generate/chunk.js'; import type { Message } from '../message.js'; import type { ModelRequest } from '../model.js'; diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 6d9a19985d..fa3c46ad6c 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -33,10 +33,10 @@ import { import type { Formatter } from '../formats/types.js'; import { GenerateResponse, - GenerateResponseChunk, GenerationResponseError, tagAsPreamble, } from '../generate.js'; +import { GenerateResponseChunk } from '../generate/chunk.js'; import { GenerateActionOptionsSchema, GenerateResponseChunkSchema, diff --git a/js/plugins/google-genai/src/common/types.ts b/js/plugins/google-genai/src/common/types.ts index 338c87b40f..595ed6f2ff 100644 --- a/js/plugins/google-genai/src/common/types.ts +++ b/js/plugins/google-genai/src/common/types.ts @@ -40,7 +40,7 @@ export enum FunctionCallingMode { } /** - * The reason why the reponse is blocked. + * The reason why the response is blocked. */ export enum BlockReason { /** Unspecified block reason. */ @@ -156,7 +156,7 @@ export declare interface GroundingSupport { /** Optional. Segment of the content this support belongs to. */ segment?: GroundingSupportSegment; /** - * Optional. A arrau of indices (into {@link GroundingChunk}) specifying the + * Optional. A array of indices (into {@link GroundingChunk}) specifying the * citations associated with the claim. For instance [1,3,4] means * that grounding_chunk[1], grounding_chunk[3], * grounding_chunk[4] are the retrieved content attributed to the claim. @@ -441,7 +441,7 @@ export declare interface GoogleDate { year?: number; /** * Month of the date. Must be from 1 to 12, or 0 to specify a year without a - * monthi and day. + * month and day. */ month?: number; /** @@ -983,7 +983,7 @@ export declare interface GenerateContentRequest { /** * Result from calling generateContentStream. - * It constains both the stream and the final aggregated response. + * It contains both the stream and the final aggregated response. * @public */ export declare interface GenerateContentStreamResult { diff --git a/js/plugins/google-genai/src/common/utils.ts b/js/plugins/google-genai/src/common/utils.ts index a688fae450..3d65a1550a 100644 --- a/js/plugins/google-genai/src/common/utils.ts +++ b/js/plugins/google-genai/src/common/utils.ts @@ -18,11 +18,12 @@ import { EmbedderReference, GenkitError, JSONSchema, + MediaPart, ModelReference, + Part, z, } from 'genkit'; import { GenerateRequest } from 'genkit/model'; -import { ImagenInstance } from './types'; /** * Safely extracts the error message from the error. @@ -64,8 +65,9 @@ export function extractVersion( export function modelName(name?: string): string | undefined { if (!name) return name; - // Remove any of these prefixes: (but keep tunedModels e.g.) - const prefixesToRemove = /models\/|embedders\/|googleai\/|vertexai\//g; + // Remove any of these prefixes: + const prefixesToRemove = + /background-model\/|model\/|models\/|embedders\/|googleai\/|vertexai\//g; return name.replace(prefixesToRemove, ''); } @@ -95,20 +97,114 @@ export function extractText(request: GenerateRequest) { ); } -export function extractImagenImage( - request: GenerateRequest -): ImagenInstance['image'] | undefined { - const image = request.messages - .at(-1) - ?.content.find( - (p) => !!p.media && (!p.metadata?.type || p.metadata?.type === 'base') - ) - ?.media?.url.split(',')[1]; - - if (image) { - return { bytesBase64Encoded: image }; +const KNOWN_MIME_TYPES = { + jpg: 'image/jpeg', + jpeg: 'image/jpeg', + png: 'image/png', + mp4: 'video/mp4', + pdf: 'application/pdf', +}; + +export function extractMimeType(url?: string): string { + if (!url) { + return ''; + } + + const dataPrefix = 'data:'; + if (!url.startsWith(dataPrefix)) { + // Not a data url, try suffix + url.lastIndexOf('.'); + const key = url.substring(url.lastIndexOf('.') + 1); + if (Object.keys(KNOWN_MIME_TYPES).includes(key)) { + return KNOWN_MIME_TYPES[key]; + } + return ''; + } + + const commaIndex = url.indexOf(','); + if (commaIndex == -1) { + // Invalid - missing separator + return ''; + } + + // The part between 'data:' and the comma + let mimeType = url.substring(dataPrefix.length, commaIndex); + const base64Marker = ';base64'; + if (mimeType.endsWith(base64Marker)) { + mimeType = mimeType.substring(0, mimeType.length - base64Marker.length); + } + + return mimeType.trim(); +} + +export function checkSupportedMimeType( + media: MediaPart['media'], + supportedTypes: string[] +) { + if (!supportedTypes.includes(media.contentType ?? '')) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Invalid mimeType for ${displayUrl(media.url)}: "${media.contentType}". Supported mimeTypes: ${supportedTypes.join(', ')}`, + }); + } +} + +/** + * + * @param url The url to show (e.g. in an error message) + * @returns The appropriately sized url + */ +export function displayUrl(url: string): string { + if (url.length <= 50) { + return url; } - return undefined; + + return url.substring(0, 25) + '...' + url.substring(url.length - 25); +} + +/** + * + * @param request A generate request to extract from + * @param metadataType The media must have metadata matching this type if isDefault is false + * @param isDefault 'true' allows missing metadata type to match as well. + * @returns + */ +export function extractMedia( + request: GenerateRequest, + params: { + metadataType?: string; + /* Is there is no metadata type, it will match if isDefault is true */ + isDefault?: boolean; + } +): MediaPart['media'] | undefined { + const predicate = (part: Part) => { + const media = part.media; + if (!media) { + return false; + } + if (params.metadataType || params.isDefault) { + // We need to check the metadata type + const metadata = part.metadata; + if (!metadata?.type) { + return !!params.isDefault; + } else { + return metadata.type == params.metadataType; + } + } + return true; + }; + + const media = request.messages.at(-1)?.content.find(predicate)?.media; + + // Add the mimeType + if (media && !media?.contentType) { + return { + url: media.url, + contentType: extractMimeType(media.url), + }; + } + + return media; } /** diff --git a/js/plugins/google-genai/src/googleai/utils.ts b/js/plugins/google-genai/src/googleai/utils.ts index 7d2a55de71..4798eebd50 100644 --- a/js/plugins/google-genai/src/googleai/utils.ts +++ b/js/plugins/google-genai/src/googleai/utils.ts @@ -16,12 +16,12 @@ import { GenerateRequest, GenkitError } from 'genkit'; import process from 'process'; -import { VeoImage } from './types.js'; +import { extractMedia } from '../common/utils.js'; +import { ImagenInstance, VeoImage } from './types.js'; export { checkModelName, cleanSchema, - extractImagenImage, extractText, extractVersion, modelName, @@ -143,3 +143,17 @@ export function extractVeoImage( } return undefined; } + +export function extractImagenImage( + request: GenerateRequest +): ImagenInstance['image'] | undefined { + const image = extractMedia(request, { + metadataType: 'base', + isDefault: true, + })?.url.split(',')[1]; + + if (image) { + return { bytesBase64Encoded: image }; + } + return undefined; +} diff --git a/js/plugins/google-genai/src/vertexai/client.ts b/js/plugins/google-genai/src/vertexai/client.ts index d273533d1b..311f4604e1 100644 --- a/js/plugins/google-genai/src/vertexai/client.ts +++ b/js/plugins/google-genai/src/vertexai/client.ts @@ -31,9 +31,14 @@ import { ImagenPredictRequest, ImagenPredictResponse, ListModelsResponse, + LyriaPredictRequest, + LyriaPredictResponse, Model, + VeoOperation, + VeoOperationRequest, + VeoPredictRequest, } from './types'; -import { calculateApiKey, checkIsSupported } from './utils'; +import { calculateApiKey, checkSupportedResourceMethod } from './utils'; export async function listModels( clientOptions: ClientOptions @@ -94,25 +99,37 @@ export async function generateContentStream( return processStream(response); } -export async function embedContent( +async function internalPredict( model: string, - embedContentRequest: EmbedContentRequest, + body: string, clientOptions: ClientOptions -): Promise { +): Promise { const url = getVertexAIUrl({ includeProjectAndLocation: true, resourcePath: `publishers/google/models/${model}`, - resourceMethod: 'predict', // embedContent is a Vertex API predict call + resourceMethod: 'predict', clientOptions, }); const fetchOptions = await getFetchOptions({ method: 'POST', clientOptions, - body: JSON.stringify(embedContentRequest), + body, }); - const response = await makeRequest(url, fetchOptions); + return await makeRequest(url, fetchOptions); +} + +export async function embedContent( + model: string, + embedContentRequest: EmbedContentRequest, + clientOptions: ClientOptions +): Promise { + const response = await internalPredict( + model, + JSON.stringify(embedContentRequest), + clientOptions + ); return response.json() as Promise; } @@ -121,31 +138,78 @@ export async function imagenPredict( imagenPredictRequest: ImagenPredictRequest, clientOptions: ClientOptions ): Promise { + const response = await internalPredict( + model, + JSON.stringify(imagenPredictRequest), + clientOptions + ); + return response.json() as Promise; +} + +export async function lyriaPredict( + model: string, + lyriaPredictRequest: LyriaPredictRequest, + clientOptions: ClientOptions +): Promise { + const response = await internalPredict( + model, + JSON.stringify(lyriaPredictRequest), + clientOptions + ); + return response.json() as Promise; +} + +export async function veoPredict( + model: string, + veoPredictRequest: VeoPredictRequest, + clientOptions: ClientOptions +): Promise { const url = getVertexAIUrl({ includeProjectAndLocation: true, resourcePath: `publishers/google/models/${model}`, - resourceMethod: 'predict', + resourceMethod: 'predictLongRunning', clientOptions, }); const fetchOptions = await getFetchOptions({ method: 'POST', clientOptions, - body: JSON.stringify(imagenPredictRequest), + body: JSON.stringify(veoPredictRequest), }); const response = await makeRequest(url, fetchOptions); - return response.json() as Promise; + return response.json() as Promise; +} + +export async function veoCheckOperation( + model: string, + veoOperationRequest: VeoOperationRequest, + clientOptions: ClientOptions +): Promise { + const url = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: `publishers/google/models/${model}`, + resourceMethod: 'fetchPredictOperation', + clientOptions, + }); + const fetchOptions = await getFetchOptions({ + method: 'POST', + clientOptions, + body: JSON.stringify(veoOperationRequest), + }); + + const response = await makeRequest(url, fetchOptions); + return response.json() as Promise; } export function getVertexAIUrl(params: { includeProjectAndLocation: boolean; // False for listModels, true for most others resourcePath: string; - resourceMethod?: 'streamGenerateContent' | 'generateContent' | 'predict'; + resourceMethod?: string; queryParams?: string; clientOptions: ClientOptions; }): string { - checkIsSupported(params); + checkSupportedResourceMethod(params); const DEFAULT_API_VERSION = 'v1beta1'; const API_BASE_PATH = 'aiplatform.googleapis.com'; diff --git a/js/plugins/google-genai/src/vertexai/converters.ts b/js/plugins/google-genai/src/vertexai/converters.ts index 7e5aafb6b3..8cd1bd302b 100644 --- a/js/plugins/google-genai/src/vertexai/converters.ts +++ b/js/plugins/google-genai/src/vertexai/converters.ts @@ -14,13 +14,47 @@ * limitations under the License. */ -import { z } from 'genkit'; +import { + GenerateRequest, + GenerateResponseData, + GenkitError, + MediaPart, + Operation, + z, +} from 'genkit'; +import { CandidateData, getBasicUsageStats } from 'genkit/model'; import { HarmBlockThreshold, HarmCategory, + ImagenInstance, + ImagenParameters, + ImagenPredictRequest, + ImagenPredictResponse, + ImagenPrediction, SafetySetting, } from '../common/types'; import { SafetySettingsSchema } from './gemini'; +import { ImagenConfigSchemaType } from './imagen'; +import { LyriaConfigSchemaType } from './lyria'; +import { + LyriaInstance, + LyriaParameters, + LyriaPredictRequest, + LyriaPredictResponse, + LyriaPrediction, + VeoInstance, + VeoMedia, + VeoOperation, + VeoOperationRequest, + VeoPredictRequest, +} from './types'; +import { + checkSupportedMimeType, + extractMedia, + extractMimeType, + extractText, +} from './utils'; +import { VeoConfigSchemaType } from './veo'; export function toGeminiSafetySettings( genkitSettings?: z.infer[] @@ -55,3 +89,295 @@ export function toGeminiLabels( } return newLabels; } + +export function toImagenPredictRequest( + request: GenerateRequest +): ImagenPredictRequest { + return { + instances: toImagenInstances(request), + parameters: toImagenParameters(request), + }; +} + +function toImagenInstances( + request: GenerateRequest +): ImagenInstance[] { + let instance: ImagenInstance = { + prompt: extractText(request), + }; + + const imageMedia = extractMedia(request, { + metadataType: 'image', + isDefault: true, + }); + if (imageMedia) { + const image = imageMedia.url.split(',')[1]; + instance.image = { + bytesBase64Encoded: image, + }; + } + + const maskMedia = extractMedia(request, { metadataType: 'mask' }); + if (maskMedia) { + const mask = maskMedia.url.split(',')[1]; + instance.mask = { + image: { + bytesBase64Encoded: mask, + }, + }; + } + + return [instance]; +} + +function toImagenParameters( + request: GenerateRequest +): ImagenParameters { + const params = { + sampleCount: request.candidates ?? 1, + ...request?.config, + }; + + for (const k in params) { + if (!params[k]) delete params[k]; + } + + return params; +} + +function fromImagenPrediction(p: ImagenPrediction, i: number): CandidateData { + const b64data = p.bytesBase64Encoded; + const mimeType = p.mimeType; + return { + index: i, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + media: { + url: `data:${mimeType};base64,${b64data}`, + contentType: mimeType, + }, + }, + ], + }, + }; +} + +/** + * + * @param response The response to convert + * @param request The request (for usage stats) + * @returns The converted response + */ +export function fromImagenResponse( + response: ImagenPredictResponse, + request: GenerateRequest +): GenerateResponseData { + const candidates = response.predictions.map(fromImagenPrediction); + return { + candidates, + usage: { + ...getBasicUsageStats(request.messages, candidates), + custom: { generations: candidates.length }, + }, + custom: response, + }; +} + +export function toLyriaPredictRequest( + request: GenerateRequest +): LyriaPredictRequest { + return { + instances: toLyriaInstances(request), + parameters: toLyriaParameters(request), + }; +} + +function toLyriaInstances( + request: GenerateRequest +): LyriaInstance[] { + let config = { ...request.config }; + delete config.sampleCount; // Sample count goes in parameters, the rest go in instances + return [ + { + prompt: extractText(request), + ...config, + }, + ]; +} + +function toLyriaParameters( + request: GenerateRequest +): LyriaParameters { + return { + sampleCount: request.config?.sampleCount || 1, + }; +} + +function fromLyriaPrediction(p: LyriaPrediction, i: number): CandidateData { + const b64data = p.bytesBase64Encoded; + const mimeType = p.mimeType; + return { + index: i, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + media: { + url: `data:${mimeType};base64,${b64data}`, + contentType: mimeType, + }, + }, + ], + }, + }; +} + +export function fromLyriaResponse( + response: LyriaPredictResponse, + request: GenerateRequest +): GenerateResponseData { + const candidates: CandidateData[] = + response.predictions.map(fromLyriaPrediction); + return { + candidates, + usage: { + ...getBasicUsageStats(request.messages, candidates), + custom: { generations: candidates.length }, + }, + custom: response, + }; +} + +export function toVeoPredictRequest( + request: GenerateRequest +): VeoPredictRequest { + return { + instances: toVeoInstances(request), + parameters: { ...request.config }, + }; +} + +function toVeoInstances( + request: GenerateRequest +): VeoInstance[] { + let instance: VeoInstance = { + prompt: extractText(request), + }; + const supportedImageTypes = ['image/jpeg', 'image/png']; + const supportedVideoTypes = ['video/mp4']; + + const imageMedia = extractMedia(request, { metadataType: 'image' }); + if (imageMedia) { + checkSupportedMimeType(imageMedia, supportedImageTypes); + instance.image = toVeoMedia(imageMedia); + } + + const lastFrameMedia = extractMedia(request, { metadataType: 'lastFrame' }); + if (lastFrameMedia) { + checkSupportedMimeType(lastFrameMedia, supportedImageTypes); + instance.lastFrame = toVeoMedia(lastFrameMedia); + } + + const videoMedia = extractMedia(request, { metadataType: 'video' }); + if (videoMedia) { + checkSupportedMimeType(videoMedia, supportedVideoTypes); + instance.video = toVeoMedia(videoMedia); + } + return [instance]; +} + +export function toVeoMedia(media: MediaPart['media']): VeoMedia { + let mimeType = media.contentType; + if (!mimeType) { + mimeType = extractMimeType(media.url); + if (!mimeType) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: 'Content type is required.', + }); + } + } + if (media.url.startsWith('data:')) { + return { + bytesBase64Encoded: media.url?.split(',')[1], + mimeType, + }; + } else if (media.url.startsWith('gs://')) { + return { + gcsUri: media.url, + mimeType, + }; + } else if (media.url.startsWith('http')) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: + 'Veo does not support http(s) URIs. Please specify a Cloud Storage URI.', + }); + } else { + // Assume it's a non-prefixed data url + return { + bytesBase64Encoded: media.url, + mimeType, + }; + } +} + +export function fromVeoOperation( + fromOp: VeoOperation +): Operation { + const toOp: Operation = { id: fromOp.name }; + if (fromOp.done !== undefined) { + toOp.done = fromOp.done; + } + if (fromOp.error) { + toOp.error = { message: fromOp.error.message }; + } + + if (fromOp.response) { + toOp.output = { + finishReason: 'stop', + raw: fromOp.response, + message: { + role: 'model', + content: fromOp.response.videos.map((veoMedia) => { + if (veoMedia.bytesBase64Encoded) { + return { + media: { + url: `data:${veoMedia.mimeType}:base64,${veoMedia.bytesBase64Encoded}`, + contentType: veoMedia.mimeType, + }, + }; + } + + return { + media: { + url: veoMedia.gcsUri ?? '', + contentType: veoMedia.mimeType, + }, + }; + }), + }, + }; + } + + return toOp; +} + +export function toVeoModel(op: Operation): string { + return op.id.substring( + op.id.indexOf('models/') + 7, + op.id.indexOf('/operations/') + ); +} + +export function toVeoOperationRequest( + op: Operation +): VeoOperationRequest { + return { + operationName: op.id, + }; +} diff --git a/js/plugins/google-genai/src/vertexai/imagen.ts b/js/plugins/google-genai/src/vertexai/imagen.ts index 33f22d82d2..c6490be135 100644 --- a/js/plugins/google-genai/src/vertexai/imagen.ts +++ b/js/plugins/google-genai/src/vertexai/imagen.ts @@ -16,32 +16,16 @@ import { ActionMetadata, Genkit, modelActionMetadata, z } from 'genkit'; import { - CandidateData, - GenerateRequest, GenerationCommonConfigSchema, ModelAction, ModelInfo, ModelReference, - getBasicUsageStats, modelRef, } from 'genkit/model'; import { imagenPredict } from './client'; -import { - ClientOptions, - ImagenParameters, - ImagenPredictRequest, - ImagenPrediction, - Model, - VertexPluginOptions, -} from './types.js'; -import { - checkModelName, - extractImagenImage, - extractImagenMask, - extractText, - extractVersion, - modelName, -} from './utils'; +import { fromImagenResponse, toImagenPredictRequest } from './converters'; +import { ClientOptions, Model, VertexPluginOptions } from './types.js'; +import { checkModelName, extractVersion, modelName } from './utils'; /** * See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api. @@ -215,6 +199,8 @@ const GENERIC_MODEL = commonRef('imagen', { export const KNOWN_MODELS = { 'imagen-3.0-generate-002': commonRef('imagen-3.0-generate-002'), + 'imagen-3.0-generate-001': commonRef('imagen-3.0-generate-001'), + 'imagen-3.0-capability-001': commonRef('imagen-3.0-capability-001'), 'imagen-3.0-fast-generate-001': commonRef('imagen-3.0-fast-generate-001'), 'imagen-4.0-generate-preview-06-06': commonRef( 'imagen-4.0-generate-preview-06-06' @@ -287,16 +273,7 @@ export function defineModel( }, async (request, { abortSignal }) => { const clientOpt = { ...clientOptions, signal: abortSignal }; - const imagenPredictRequest: ImagenPredictRequest = { - instances: [ - { - prompt: extractText(request), - image: extractImagenImage(request), - mask: extractImagenMask(request), - }, - ], - parameters: toImagenParameters(request), - }; + const imagenPredictRequest = toImagenPredictRequest(request); const response = await imagenPredict( extractVersion(ref), @@ -310,58 +287,12 @@ export function defineModel( ); } - const candidates = response.predictions.map(fromImagenPrediction); - - return { - candidates, - usage: { - ...getBasicUsageStats(request.messages, candidates), - custom: { generations: candidates.length }, - }, - custom: response, - }; + return fromImagenResponse(response, request); } ); } -function toImagenParameters( - request: GenerateRequest -): ImagenParameters { - const params = { - sampleCount: request.candidates ?? 1, - ...request?.config, - }; - - for (const k in params) { - if (!params[k]) delete params[k]; - } - - return params; -} - -function fromImagenPrediction(p: ImagenPrediction, i: number): CandidateData { - const b64data = p.bytesBase64Encoded; - const mimeType = p.mimeType; - return { - index: i, - finishReason: 'stop', - message: { - role: 'model', - content: [ - { - media: { - url: `data:${mimeType};base64,${b64data}`, - contentType: mimeType, - }, - }, - ], - }, - }; -} - export const TEST_ONLY = { - toImagenParameters, - fromImagenPrediction, GENERIC_MODEL, KNOWN_MODELS, }; diff --git a/js/plugins/google-genai/src/vertexai/index.ts b/js/plugins/google-genai/src/vertexai/index.ts index 1a45571e74..6e07f9b593 100644 --- a/js/plugins/google-genai/src/vertexai/index.ts +++ b/js/plugins/google-genai/src/vertexai/index.ts @@ -28,6 +28,8 @@ import { listModels } from './client.js'; import * as embedder from './embedder.js'; import * as gemini from './gemini.js'; import * as imagen from './imagen.js'; +import * as lyria from './lyria.js'; +import * as veo from './veo.js'; import { VertexPluginOptions } from './types.js'; import { getDerivedOptions } from './utils.js'; @@ -35,11 +37,15 @@ import { getDerivedOptions } from './utils.js'; export { type EmbeddingConfig } from './embedder.js'; export { type GeminiConfig } from './gemini.js'; export { type ImagenConfig } from './imagen.js'; +export { type LyriaConfig } from './lyria.js'; export { type VertexPluginOptions } from './types.js'; +export { type VeoConfig } from './veo.js'; async function initializer(ai: Genkit, pluginOptions?: VertexPluginOptions) { const clientOptions = await getDerivedOptions(pluginOptions); + veo.defineKnownModels(ai, clientOptions, pluginOptions); imagen.defineKnownModels(ai, clientOptions, pluginOptions); + lyria.defineKnownModels(ai, clientOptions, pluginOptions); gemini.defineKnownModels(ai, clientOptions, pluginOptions); embedder.defineKnownModels(ai, clientOptions, pluginOptions); } @@ -53,12 +59,21 @@ async function resolver( const clientOptions = await getDerivedOptions(pluginOptions); switch (actionType) { case 'model': - if (imagen.isImagenModelName(actionName)) { + if (lyria.isLyriaModelName(actionName)) { + lyria.defineModel(ai, actionName, clientOptions, pluginOptions); + } else if (imagen.isImagenModelName(actionName)) { imagen.defineModel(ai, actionName, clientOptions, pluginOptions); + } else if (veo.isVeoModelName(actionName)) { + // no-op (not gemini) } else { gemini.defineModel(ai, actionName, clientOptions, pluginOptions); } break; + case 'background-model': + if (veo.isVeoModelName(actionName)) { + veo.defineModel(ai, actionName, clientOptions, pluginOptions); + } + break; case 'embedder': embedder.defineEmbedder(ai, actionName, clientOptions, pluginOptions); break; @@ -74,6 +89,8 @@ async function listActions(options?: VertexPluginOptions) { return [ ...gemini.listActions(models), ...imagen.listActions(models), + ...lyria.listActions(models), + ...veo.listActions(models), // We don't list embedders here ]; } catch (e: unknown) { @@ -110,6 +127,14 @@ export type VertexAIPlugin = { name: imagen.KnownModels | (imagen.ImagenModelName & {}), config?: imagen.ImagenConfig ): ModelReference; + model( + name: lyria.KnownModels | (lyria.LyriaModelName & {}), + config: lyria.LyriaConfig + ): ModelReference; + model( + name: veo.KnownModels | (veo.VeoModelName & {}), + config: veo.VeoConfig + ): ModelReference; model(name: string, config?: any): ModelReference; embedder( @@ -131,6 +156,12 @@ export const vertexAI = vertexAIPlugin as VertexAIPlugin; if (imagen.isImagenModelName(name)) { return imagen.model(name, config); } + if (lyria.isLyriaModelName(name)) { + return lyria.model(name, config); + } + if (veo.isVeoModelName(name)) { + return veo.model(name, config); + } // gemini and unknown model families return gemini.model(name, config); }; diff --git a/js/plugins/google-genai/src/vertexai/lyria.ts b/js/plugins/google-genai/src/vertexai/lyria.ts new file mode 100644 index 0000000000..aaa5436e33 --- /dev/null +++ b/js/plugins/google-genai/src/vertexai/lyria.ts @@ -0,0 +1,164 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + ActionMetadata, + Genkit, + modelActionMetadata, + modelRef, + ModelReference, + z, +} from 'genkit'; +import { ModelAction, ModelInfo } from 'genkit/model'; +import { lyriaPredict } from './client'; +import { fromLyriaResponse, toLyriaPredictRequest } from './converters'; +import { ClientOptions, Model, VertexPluginOptions } from './types'; +import { checkModelName, extractVersion } from './utils'; + +export const LyriaConfigSchema = z + .object({ + negativePrompt: z + .string() + .optional() + .describe( + 'Optional. A description of what to exclude from the generated audio.' + ), + seed: z + .number() + .optional() + .describe( + 'Optional. A seed for deterministic generation. If provided, the model will attempt to produce the same audio given the same prompt and other parameters. Cannot be used with sample_count in the same request.' + ), + sampleCount: z + .number() + .optional() + .describe( + 'Optional. The number of audio samples to generate. Default is 1 if not specified and seed is not used. Cannot be used with seed in the same request.' + ), + }) + .passthrough(); +export type LyriaConfigSchemaType = typeof LyriaConfigSchema; +export type LyriaConfig = z.infer; + +type ConfigSchemaType = LyriaConfigSchemaType; + +function commonRef( + name: string, + info?: ModelInfo, + configSchema: ConfigSchemaType = LyriaConfigSchema +): ModelReference { + return modelRef({ + name: `vertexai/${name}`, + configSchema, + info: info ?? { + supports: { + media: true, + multiturn: false, + tools: false, + systemRole: false, + output: ['media'], + }, + }, + }); +} + +const GENERIC_MODEL = commonRef('lyria'); + +const KNOWN_MODELS = { + 'lyria-002': commonRef('lyria-002'), +} as const; +export type KnownModels = keyof typeof KNOWN_MODELS; // For autocorrect +export type LyriaModelName = `lyria-${string}`; +export function isLyriaModelName(value?: string): value is LyriaModelName { + return !!value?.startsWith('lyria-'); +} + +export function model( + version: string, + config: LyriaConfig = {} +): ModelReference { + const name = checkModelName(version); + return modelRef({ + name: `vertexai/${name}`, + config, + configSchema: LyriaConfigSchema, + info: { ...GENERIC_MODEL.info }, + }); +} + +export function listActions(models: Model[]): ActionMetadata[] { + return models + .filter((m: Model) => isLyriaModelName(m.name)) + .map((m: Model) => { + const ref = model(m.name); + return modelActionMetadata({ + name: ref.name, + info: ref.info, + configSchema: ref.configSchema, + }); + }); +} + +export function defineKnownModels( + ai: Genkit, + clientOptions: ClientOptions, + pluginOptions?: VertexPluginOptions +) { + for (const name of Object.keys(KNOWN_MODELS)) { + defineModel(ai, name, clientOptions, pluginOptions); + } +} + +export function defineModel( + ai: Genkit, + name: string, + clientOptions: ClientOptions, + pluginOptions?: VertexPluginOptions +): ModelAction { + const ref = model(name); + + return ai.defineModel( + { + apiVersion: 'v2', + name: ref.name, + ...ref.info, + configSchema: ref.configSchema, + }, + async (request, { abortSignal }) => { + const clientOpt = { ...clientOptions, signal: abortSignal }; + const lyriaPredictRequest = toLyriaPredictRequest(request); + + const response = await lyriaPredict( + extractVersion(ref), + lyriaPredictRequest, + clientOpt + ); + + if (!response.predictions || response.predictions.length == 0) { + throw new Error( + 'Model returned no predictions. Possibly due to content filters.' + ); + } + + return fromLyriaResponse(response, request); + } + ); +} + +export const TEST_ONLY = { + GENERIC_MODEL, + KNOWN_MODELS, +}; diff --git a/js/plugins/google-genai/src/vertexai/types.ts b/js/plugins/google-genai/src/vertexai/types.ts index 3154c1cb62..e2283e565d 100644 --- a/js/plugins/google-genai/src/vertexai/types.ts +++ b/js/plugins/google-genai/src/vertexai/types.ts @@ -295,3 +295,79 @@ export declare type EmbeddingResult = { embedding: number[]; metadata?: Record; }; + +export declare interface VeoMedia { + bytesBase64Encoded?: string; + gcsUri?: string; + mimeType?: string; +} + +export declare interface VeoInstance { + prompt: string; + image?: VeoMedia; + lastFrame?: VeoMedia; + video?: VeoMedia; +} + +export declare interface VeoParameters { + aspectRatio?: string; + durationSeconds?: number; + enhancePrompt?: boolean; + generateAudio?: boolean; + negativePrompt?: string; + personGeneration?: string; + resolution?: string; // Veo 3 + sampleCount?: number; + seed?: number; + storageUri?: string; +} + +export declare interface VeoPredictRequest { + instances: VeoInstance[]; + parameters: VeoParameters; +} + +export declare interface Operation { + name: string; + done?: boolean; + error?: { + code: number; + message: string; + details?: unknown; + }; +} + +export declare interface VeoOperation extends Operation { + response?: { + raiMediaFilteredCount?: number; + videos: VeoMedia[]; + }; +} + +export declare interface VeoOperationRequest { + operationName: string; +} + +export declare interface LyriaParameters { + sampleCount?: number; +} + +export declare interface LyriaPredictRequest { + instances: LyriaInstance[]; + parameters: LyriaParameters; +} + +export declare interface LyriaPredictResponse { + predictions: LyriaPrediction[]; +} + +export declare interface LyriaPrediction { + bytesBase64Encoded: string; // Base64 encoded Wav string + mimeType: string; // autio/wav +} + +export declare interface LyriaInstance { + prompt: string; + negativePrompt?: string; + seed?: number; +} diff --git a/js/plugins/google-genai/src/vertexai/utils.ts b/js/plugins/google-genai/src/vertexai/utils.ts index 1361449baa..f3f573e8c3 100644 --- a/js/plugins/google-genai/src/vertexai/utils.ts +++ b/js/plugins/google-genai/src/vertexai/utils.ts @@ -15,21 +15,21 @@ */ import { GenkitError } from 'genkit'; -import { GenerateRequest } from 'genkit/model'; import { GoogleAuth } from 'google-auth-library'; import type { ClientOptions, ExpressClientOptions, GlobalClientOptions, - ImagenInstance, RegionalClientOptions, VertexPluginOptions, } from './types'; export { checkModelName, + checkSupportedMimeType, cleanSchema, - extractImagenImage, + extractMedia, + extractMimeType, extractText, extractVersion, modelName, @@ -327,7 +327,7 @@ export function calculateApiKey( } /** Vertex Express Mode lets you try a *subset* of Vertex AI features */ -export function checkIsSupported(params: { +export function checkSupportedResourceMethod(params: { clientOptions: ClientOptions; resourcePath?: string; resourceMethod?: string; @@ -350,17 +350,3 @@ export function checkIsSupported(params: { throw NOT_SUPPORTED_IN_EXPRESS_ERROR; } } - -export function extractImagenMask( - request: GenerateRequest -): ImagenInstance['mask'] | undefined { - const mask = request.messages - .at(-1) - ?.content.find((p) => !!p.media && p.metadata?.type === 'mask') - ?.media?.url.split(',')[1]; - - if (mask) { - return { image: { bytesBase64Encoded: mask } }; - } - return undefined; -} diff --git a/js/plugins/google-genai/src/vertexai/veo.ts b/js/plugins/google-genai/src/vertexai/veo.ts new file mode 100644 index 0000000000..5cfea28ddb --- /dev/null +++ b/js/plugins/google-genai/src/vertexai/veo.ts @@ -0,0 +1,221 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + ActionMetadata, + ModelReference, + modelActionMetadata, + modelRef, + z, + type Genkit, +} from 'genkit'; +import { BackgroundModelAction, ModelInfo } from 'genkit/model'; +import { veoCheckOperation, veoPredict } from './client'; +import { + fromVeoOperation, + toVeoModel, + toVeoOperationRequest, + toVeoPredictRequest, +} from './converters'; +import { ClientOptions, Model, VertexPluginOptions } from './types'; +import { checkModelName, extractVersion } from './utils'; + +export const VeoConfigSchema = z + .object({ + sampleCount: z.number().optional().describe('Number of output videos'), + storageUri: z + .string() + .optional() + .describe('The gcs bucket where to save the generated videos'), + fps: z + .number() + .optional() + .describe('Frames per second for video generation'), + durationSeconds: z + .number() + .optional() + .describe('Duration of the clip for video generation in seconds'), + seed: z + .number() + .optional() + .describe( + 'The RNG seed. If RNG seed is exactly same for each request with unchanged ' + + 'inputs, the prediction results will be consistent. Otherwise, a random RNG ' + + 'seed will be used each time to produce a different result. If the sample ' + + 'count is greater than 1, random seeds will be used for each sample.' + ), + aspectRatio: z + .enum(['9:16', '16:9']) + .optional() + .describe('The aspect ratio for the generated video'), + resolution: z + .enum(['720p', '1080p']) + .optional() + .describe('The resolution for the generated video'), + personGeneration: z + .enum(['dont_allow', 'allow_adult', 'allow_all']) + .optional() + .describe( + 'Specifies the policy for generating persons in videos, including age restrictions' + ), + pubsubTopic: z + .string() + .optional() + .describe('The pubsub topic to publish the video generation progress to'), + negativePrompt: z + .string() + .optional() + .describe( + 'In addition to the text context, negative prompts can be explicitly stated here to help generate the video' + ), + enhancePrompt: z + .boolean() + .optional() + .describe( + 'If true, the prompt will be improved before it is used to generate videos. ' + + 'The RNG seed, if provided, will not result in consistent results if prompts are enhanced.' + ), + generateAudio: z + .boolean() + .optional() + .describe('If true, audio will be generated along with the video'), + compressionQuality: z + .enum(['optimized', 'lossless']) + .default('optimized') + .optional() + .describe('Compression quality of the generated video'), + }) + .passthrough(); +export type VeoConfigSchemaType = typeof VeoConfigSchema; +export type VeoConfig = z.infer; + +// This contains all the Veo config schema types +type ConfigSchemaType = VeoConfigSchemaType; + +function commonRef( + name: string, + info?: ModelInfo, + configSchema: ConfigSchemaType = VeoConfigSchema +): ModelReference { + return modelRef({ + name: `vertexai/${name}`, + configSchema, + info: + info ?? + ({ + supports: { + media: true, + multiturn: false, + tools: false, + systemRole: false, + output: ['media'], + longRunning: true, + }, + } as ModelInfo), // TODO(ifielker): Remove this cast if we fix longRunning + }); +} + +const GENERIC_MODEL = commonRef('veo'); + +const KNOWN_MODELS = { + 'veo-2.0-generate-001': commonRef('veo-2.0-generate-001'), + 'veo-3.0-generate-001': commonRef('veo-3.0-generate-001'), + 'veo-3.0-fast-generate-001': commonRef('veo-3.0-fast-generate-001'), + 'veo-3.0-generate-preview': commonRef('veo-3.0-generate-preview'), + 'veo-3.0-fast-generate-preview': commonRef('veo-3.0-fast-generate-preview'), +} as const; +export type KnownModels = keyof typeof KNOWN_MODELS; // For autocomplete +export type VeoModelName = `veo-${string}`; +export function isVeoModelName(value?: string): value is VeoModelName { + return !!value?.startsWith('veo-'); +} + +export function model( + version: string, + config: VeoConfig = {} +): ModelReference { + const name = checkModelName(version); + return modelRef({ + name: `vertexai/${name}`, + config, + configSchema: VeoConfigSchema, + info: { ...GENERIC_MODEL.info }, + }); +} + +// Takes a full list of models, filters for current Veo models only +// and returns a modelActionMetadata for each. +export function listActions(models: Model[]): ActionMetadata[] { + return models + .filter((m: Model) => isVeoModelName(m.name)) + .map((m: Model) => { + const ref = model(m.name); + return modelActionMetadata({ + name: ref.name, + info: ref.info, + configSchema: ref.configSchema, + }); + }); +} + +export function defineKnownModels( + ai: Genkit, + clientOptions: ClientOptions, + pluginOptions?: VertexPluginOptions +) { + for (const name of Object.keys(KNOWN_MODELS)) { + defineModel(ai, name, clientOptions, pluginOptions); + } +} + +export function defineModel( + ai: Genkit, + name: string, + clientOptions: ClientOptions, + pluginOptions?: VertexPluginOptions +): BackgroundModelAction { + const ref = model(name); + + return ai.defineBackgroundModel({ + name: ref.name, + ...ref.info, + configSchema: ref.configSchema, + async start(request) { + const veoPredictRequest = toVeoPredictRequest(request); + + const response = await veoPredict( + extractVersion(ref), + veoPredictRequest, + clientOptions + ); + + return fromVeoOperation(response); + }, + async check(operation) { + const response = await veoCheckOperation( + toVeoModel(operation), + toVeoOperationRequest(operation), + clientOptions + ); + return fromVeoOperation(response); + }, + }); +} + +export const TEST_ONLY = { + GENERIC_MODEL, + KNOWN_MODELS, +}; diff --git a/js/plugins/google-genai/tests/common/converters_test.ts b/js/plugins/google-genai/tests/common/converters_test.ts index c48cc68a6e..cf570d4250 100644 --- a/js/plugins/google-genai/tests/common/converters_test.ts +++ b/js/plugins/google-genai/tests/common/converters_test.ts @@ -503,8 +503,6 @@ describe('fromGeminiCandidate', () => { }, }, }, - // NOTE: This test will fail until the bug in fromGeminiFileData is fixed. - // The code currently checks for .url instead of .fileUri. { should: 'should transform gemini candidate (fileData) correctly', geminiCandidate: { diff --git a/js/plugins/google-genai/tests/common/utils_test.ts b/js/plugins/google-genai/tests/common/utils_test.ts index 034a12881e..dcee880d93 100644 --- a/js/plugins/google-genai/tests/common/utils_test.ts +++ b/js/plugins/google-genai/tests/common/utils_test.ts @@ -5,7 +5,7 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,323 +15,497 @@ */ import * as assert from 'assert'; -import { GenkitError } from 'genkit'; +import { GenkitError, embedderRef, modelRef } from 'genkit'; import { GenerateRequest } from 'genkit/model'; import { describe, it } from 'node:test'; import { checkModelName, + checkSupportedMimeType, cleanSchema, + displayUrl, extractErrMsg, - extractImagenImage, + extractMedia, + extractMimeType, extractText, + extractVersion, modelName, } from '../../src/common/utils'; -describe('extractErrMsg', () => { - it('extracts message from an Error object', () => { - const error = new Error('This is a test error.'); - assert.strictEqual(extractErrMsg(error), 'This is a test error.'); - }); +describe('Common Utils', () => { + describe('extractErrMsg', () => { + it('extracts message from an Error object', () => { + const error = new Error('This is a test error.'); + assert.strictEqual(extractErrMsg(error), 'This is a test error.'); + }); - it('returns the string if error is a string', () => { - const error = 'A simple string error.'; - assert.strictEqual(extractErrMsg(error), 'A simple string error.'); - }); + it('returns the string if error is a string', () => { + const error = 'A simple string error.'; + assert.strictEqual(extractErrMsg(error), 'A simple string error.'); + }); - it('stringifies other error types', () => { - const error = { code: 500, message: 'Object error' }; - assert.strictEqual( - extractErrMsg(error), - '{"code":500,"message":"Object error"}' - ); + it('stringifies other error types', () => { + const error = { code: 500, message: 'Object error' }; + assert.strictEqual( + extractErrMsg(error), + '{"code":500,"message":"Object error"}' + ); + }); }); - it('provides a default message for unknown types', () => { - // Note: The function returns undefined for an undefined input because - // JSON.stringify(undefined) results in undefined. - assert.strictEqual(extractErrMsg(undefined), undefined); - assert.strictEqual(extractErrMsg(null), 'null'); - }); -}); + describe('extractVersion', () => { + it('should return version from modelRef if present', () => { + const ref = modelRef({ + name: 'vertexai/gemini-1.5-pro', + version: 'gemini-1.5-pro-001', + }); + assert.strictEqual(extractVersion(ref), 'gemini-1.5-pro-001'); + }); -describe('modelName', () => { - it('extracts model name from a full path', () => { - const name = 'models/googleai/gemini-2.5-pro'; - assert.strictEqual(modelName(name), 'gemini-2.5-pro'); - }); + it('should extract version from name if version field is missing', () => { + const ref = modelRef({ name: 'vertexai/gemini-1.5-flash' }); + assert.strictEqual(extractVersion(ref), 'gemini-1.5-flash'); + }); - it('returns the name if no path is present', () => { - const name = 'gemini-1.5-flash'; - assert.strictEqual(modelName(name), 'gemini-1.5-flash'); + it('should work with embedderRef', () => { + const ref = embedderRef({ name: 'vertexai/text-embedding-004' }); + assert.strictEqual(extractVersion(ref), 'text-embedding-004'); + }); }); - it('handles undefined input', () => { - assert.strictEqual(modelName(undefined), undefined); + describe('modelName', () => { + it('extracts model name from a full path', () => { + assert.strictEqual( + modelName('models/googleai/gemini-1.5-pro'), + 'gemini-1.5-pro' + ); + assert.strictEqual( + modelName('vertexai/gemini-1.5-flash'), + 'gemini-1.5-flash' + ); + assert.strictEqual(modelName('model/foo'), 'foo'); + assert.strictEqual(modelName('embedders/bar'), 'bar'); + assert.strictEqual(modelName('background-model/baz'), 'baz'); + }); + + it('returns the name if no known prefix is present', () => { + assert.strictEqual(modelName('gemini-1.0-ultra'), 'gemini-1.0-ultra'); + }); + + it('handles undefined input', () => { + assert.strictEqual(modelName(undefined), undefined); + }); + + it('handles empty string input', () => { + assert.strictEqual(modelName(''), ''); + }); }); - it('handles empty string input', () => { - assert.strictEqual(modelName(''), ''); + describe('checkModelName', () => { + it('extracts model name from a full path', () => { + const name = 'models/vertexai/gemini-1.5-pro'; + assert.strictEqual(checkModelName(name), 'gemini-1.5-pro'); + }); + + it('returns name if no prefix', () => { + assert.strictEqual(checkModelName('foo-bar'), 'foo-bar'); + }); + + it('throws an error for undefined input', () => { + assert.throws( + () => checkModelName(undefined), + (err: any) => { + assert.ok(err instanceof GenkitError, 'Expected GenkitError'); + assert.strictEqual(err.status, 'INVALID_ARGUMENT'); + assert.strictEqual( + err.message, + 'INVALID_ARGUMENT: Model name is required.' + ); + return true; + } + ); + }); + + it('throws an error for an empty string', () => { + assert.throws( + () => checkModelName(''), + (err: any) => { + assert.ok(err instanceof GenkitError, 'Expected GenkitError'); + assert.strictEqual(err.status, 'INVALID_ARGUMENT'); + assert.strictEqual( + err.message, + 'INVALID_ARGUMENT: Model name is required.' + ); + return true; + } + ); + }); }); - it('keeps prefixes like tunedModels', () => { - assert.strictEqual( - modelName('tunedModels/my-tuned-model'), - 'tunedModels/my-tuned-model' - ); + describe('extractText', () => { + it('extracts text from the last message', () => { + const request: GenerateRequest = { + messages: [ + { role: 'user', content: [{ text: 'Hello there.' }] }, + { role: 'model', content: [{ text: 'How can I help?' }] }, + { role: 'user', content: [{ text: 'Tell me a joke.' }] }, + ], + }; + assert.strictEqual(extractText(request), 'Tell me a joke.'); + }); + + it('concatenates multiple text parts in the last message', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [{ text: 'Part 1. ' }, { text: 'Part 2.' }], + }, + ], + }; + assert.strictEqual(extractText(request), 'Part 1. Part 2.'); + }); + + it('ignores non-text parts in the last message', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [ + { text: 'A ' }, + { media: { url: 'data:image/jpeg;base64,IMAGEDATA' } }, + { text: 'B' }, + ], + }, + ], + }; + assert.strictEqual(extractText(request), 'A B'); + }); + + it('returns an empty string if there are no text parts in the last message', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [{ media: { url: 'data:image/jpeg;base64,IMAGEDATA' } }], + }, + ], + }; + assert.strictEqual(extractText(request), ''); + }); + + it('returns an empty string if there are no messages', () => { + const request: GenerateRequest = { + messages: [], + }; + assert.strictEqual(extractText(request), ''); + }); }); -}); -describe('checkModelName', () => { - it('extracts model name from a full path', () => { - const name = 'models/vertexai/gemini-2.0-pro'; - assert.strictEqual(checkModelName(name), 'gemini-2.0-pro'); + describe('extractMimeType', () => { + it('extracts from data URL with base64', () => { + assert.strictEqual( + extractMimeType('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgA...'), + 'image/png' + ); + assert.strictEqual( + extractMimeType('data:application/pdf;base64,JVBERi0xLjQKJ...'), + 'application/pdf' + ); + }); + + it('returns empty string for invalid data URL format', () => { + assert.strictEqual(extractMimeType('data:image/png'), ''); + assert.strictEqual(extractMimeType('data:,text'), ''); + }); + + it('extracts from known file extensions', () => { + assert.strictEqual(extractMimeType('image.jpg'), 'image/jpeg'); + assert.strictEqual(extractMimeType('path/to/document.png'), 'image/png'); + assert.strictEqual(extractMimeType('video.mp4'), 'video/mp4'); + }); + + it('returns empty string for unknown file extensions', () => { + assert.strictEqual(extractMimeType('file.unknown'), ''); + assert.strictEqual(extractMimeType('archive.zip'), ''); + }); + + it('returns empty string for URL without extension', () => { + assert.strictEqual(extractMimeType('http://example.com/image'), ''); + }); + + it('returns empty string for undefined or empty input', () => { + assert.strictEqual(extractMimeType(undefined), ''); + assert.strictEqual(extractMimeType(''), ''); + }); }); - it('throws an error for undefined input', () => { - assert.throws( - () => checkModelName(undefined), - (err: GenkitError) => { - assert.strictEqual(err.status, 'INVALID_ARGUMENT'); - assert.strictEqual( - err.message, - 'INVALID_ARGUMENT: Model name is required.' + describe('checkSupportedMimeType', () => { + const supported = ['image/jpeg', 'image/png']; + it('should not throw for supported mime types', () => { + assert.doesNotThrow(() => + checkSupportedMimeType( + { url: 'test.jpg', contentType: 'image/jpeg' }, + supported + ) + ); + assert.doesNotThrow(() => + checkSupportedMimeType( + { url: 'test.png', contentType: 'image/png' }, + supported + ) + ); + }); + + it('should throw GenkitError for unsupported mime types', () => { + try { + checkSupportedMimeType( + { url: 'test.gif', contentType: 'image/gif' }, + supported + ); + assert.fail('Should have thrown'); + } catch (e: any) { + assert.ok(e instanceof GenkitError, 'Expected GenkitError'); + assert.strictEqual(e.status, 'INVALID_ARGUMENT'); + assert.ok( + e.message.includes('Invalid mimeType for test.gif: "image/gif"') + ); + assert.ok( + e.message.includes('Supported mimeTypes: image/jpeg, image/png') ); - return true; } - ); - }); + }); - it('throws an error for an empty string', () => { - assert.throws( - () => checkModelName(''), - (err: GenkitError) => { - assert.strictEqual(err.status, 'INVALID_ARGUMENT'); - assert.strictEqual( - err.message, - 'INVALID_ARGUMENT: Model name is required.' + it('should throw GenkitError if contentType is missing', () => { + try { + checkSupportedMimeType({ url: 'test.jpg' }, supported); + assert.fail('Should have thrown'); + } catch (e: any) { + assert.ok(e instanceof GenkitError, 'Expected GenkitError'); + assert.strictEqual(e.status, 'INVALID_ARGUMENT'); + assert.ok( + e.message.includes('Invalid mimeType for test.jpg: "undefined"') ); - return true; } - ); + }); }); -}); -describe('extractText', () => { - it('extracts text from the last message', () => { - const request: GenerateRequest = { - messages: [ - { role: 'user', content: [{ text: 'Hello there.' }] }, - { role: 'model', content: [{ text: 'How can I help?' }] }, - { role: 'user', content: [{ text: 'Tell me a joke.' }] }, - ], - config: {}, - }; - assert.strictEqual(extractText(request), 'Tell me a joke.'); - }); + describe('displayUrl', () => { + it('should return the full URL if short', () => { + const url = 'http://example.com/short'; + assert.strictEqual(displayUrl(url), url); + }); - it('concatenates multiple text parts', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [{ text: 'Part 1. ' }, { text: 'Part 2.' }], - }, - ], - config: {}, - }; - assert.strictEqual(extractText(request), 'Part 1. Part 2.'); - }); + it('should truncate long URLs', () => { + const longUrl = + 'http://example.com/this/is/a/very/long/url/that/needs/truncation/to/fit'; + const expected = 'http://example.com/this/i...t/needs/truncation/to/fit'; + assert.strictEqual(displayUrl(longUrl), expected); + }); - it('returns an empty string if there are no text parts', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'data:image/jpeg;base64,IMAGEDATA', - contentType: 'image/jpeg', - }, - }, - ], - }, - ], - config: {}, - }; - assert.strictEqual(extractText(request), ''); + it('should handle URLs exactly at the limit', () => { + const url = 'a'.repeat(50); + assert.strictEqual(displayUrl(url), url); + }); }); - it('returns an empty string if there are no messages', () => { - const request: GenerateRequest = { - messages: [], - config: {}, + describe('extractMedia', () => { + const imageMedia = { + url: 'data:image/png;base64,IMAGEDATA', + contentType: 'image/png', }; - assert.strictEqual(extractText(request), ''); - }); -}); - -describe('extractImagenImage', () => { - it('extracts a base64 encoded image', () => { - const base64Image = '/9j/4AAQSkZJRg...'; - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { text: 'Create an image.' }, - { - media: { - url: `data:image/jpeg;base64,${base64Image}`, - contentType: 'image/jpeg', - }, - }, - ], - }, - ], - config: {}, + const videoMedia = { + url: 'data:video/mp4;base64,VIDEODATA', + contentType: 'video/mp4', }; - const result = extractImagenImage(request); - assert.deepStrictEqual(result, { bytesBase64Encoded: base64Image }); - }); - it('returns undefined if no image part exists', () => { - const request: GenerateRequest = { - messages: [{ role: 'user', content: [{ text: 'Hello' }] }], - config: {}, - }; - assert.strictEqual(extractImagenImage(request), undefined); - }); + it('extracts any media from the last message if no params', () => { + const request: GenerateRequest = { + messages: [ + { role: 'user', content: [{ text: 'A ' }, { media: imageMedia }] }, + ], + }; + assert.deepStrictEqual(extractMedia(request, {}), imageMedia); + }); - it('returns undefined if the media part is not a base64 data URI', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'http://example.com/image.jpg', - contentType: 'image/jpeg', - }, - }, - ], - }, - ], - config: {}, - }; - assert.strictEqual(extractImagenImage(request), undefined); - }); + it('extracts media matching metadataType', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [ + { media: imageMedia, metadata: { type: 'image' } }, + { media: videoMedia, metadata: { type: 'video' } }, + ], + }, + ], + }; + assert.deepStrictEqual( + extractMedia(request, { metadataType: 'video' }), + videoMedia + ); + assert.deepStrictEqual( + extractMedia(request, { metadataType: 'image' }), + imageMedia + ); + }); - it('ignores parts with metadata type not equal to "base"', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'data:image/png;base64,MASKDATA', - contentType: 'image/png', - }, - metadata: { type: 'mask' }, - }, - ], - }, - ], - config: {}, - }; - assert.strictEqual(extractImagenImage(request), undefined); - }); + it('extracts media with no metadata type if isDefault is true', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [ + { media: imageMedia }, + { media: videoMedia, metadata: { type: 'video' } }, + ], + }, + ], + }; + assert.deepStrictEqual( + extractMedia(request, { metadataType: 'image', isDefault: true }), + imageMedia + ); + }); - it('returns undefined for an empty message list', () => { - const request: GenerateRequest = { - messages: [], - config: {}, - }; - assert.strictEqual(extractImagenImage(request), undefined); - }); -}); + it('does not extract media with different metadataType even if isDefault is true', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [{ media: videoMedia, metadata: { type: 'video' } }], + }, + ], + }; + assert.strictEqual( + extractMedia(request, { metadataType: 'image', isDefault: true }), + undefined + ); + }); -describe('cleanSchema', () => { - it('strips $schema and additionalProperties', () => { - const schema = { - type: 'object', - properties: { name: { type: 'string' } }, - $schema: 'http://json-schema.org/draft-07/schema#', - additionalProperties: false, - }; - const cleaned = cleanSchema(schema); - assert.deepStrictEqual(cleaned, { - type: 'object', - properties: { name: { type: 'string' } }, + it('returns undefined if no media matches metadataType', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [{ media: imageMedia, metadata: { type: 'image' } }], + }, + ], + }; + assert.strictEqual( + extractMedia(request, { metadataType: 'video' }), + undefined + ); + }); + + it('infers contentType if missing', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [{ media: { url: 'data:image/jpeg;base64,DATA' } }], + }, + ], + }; + const result = extractMedia(request, {}); + assert.deepStrictEqual(result, { + url: 'data:image/jpeg;base64,DATA', + contentType: 'image/jpeg', + }); + }); + + it('returns undefined if no media parts in the last message', () => { + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'No media' }] }], + }; + assert.strictEqual(extractMedia(request, {}), undefined); + }); + + it('returns undefined for empty messages array', () => { + const request: GenerateRequest = { messages: [] }; + assert.strictEqual(extractMedia(request, {}), undefined); }); }); - it('handles nested objects', () => { - const schema = { - type: 'object', - properties: { - user: { - type: 'object', - properties: { id: { type: 'number' } }, - additionalProperties: true, + describe('cleanSchema', () => { + it('strips $schema and additionalProperties', () => { + const schema = { + type: 'object', + properties: { name: { type: 'string' } }, + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: false, + }; + const cleaned = cleanSchema(schema); + assert.deepStrictEqual(cleaned, { + type: 'object', + properties: { name: { type: 'string' } }, + }); + }); + + it('handles nested objects', () => { + const schema = { + type: 'object', + properties: { + user: { + type: 'object', + properties: { id: { type: 'number' } }, + additionalProperties: true, + }, }, - }, - }; - const cleaned = cleanSchema(schema); - assert.deepStrictEqual(cleaned, { - type: 'object', - properties: { - user: { - type: 'object', - properties: { id: { type: 'number' } }, + }; + const cleaned = cleanSchema(schema); + assert.deepStrictEqual(cleaned, { + type: 'object', + properties: { + user: { + type: 'object', + properties: { id: { type: 'number' } }, + }, }, - }, + }); }); - }); - it('converts type ["string", "null"] to "string"', () => { - const schema = { - type: 'object', - properties: { - name: { type: ['string', 'null'] }, - age: { type: ['number', 'null'] }, - }, - }; - const cleaned = cleanSchema(schema); - assert.deepStrictEqual(cleaned, { - type: 'object', - properties: { - name: { type: 'string' }, - age: { type: 'number' }, - }, + it('converts type ["string", "null"] to "string"', () => { + const schema = { + type: 'object', + properties: { + name: { type: ['string', 'null'] }, + age: { type: ['number', 'null'] }, + }, + }; + const cleaned = cleanSchema(schema); + assert.deepStrictEqual(cleaned, { + type: 'object', + properties: { + name: { type: 'string' }, + age: { type: 'number' }, + }, + }); }); - }); - it('converts type ["null", "string"] to "string"', () => { - const schema = { - type: 'object', - properties: { - name: { type: ['null', 'string'] }, - }, - }; - const cleaned = cleanSchema(schema); - assert.deepStrictEqual(cleaned, { - type: 'object', - properties: { - name: { type: 'string' }, - }, + it('converts type ["null", "boolean"] to "boolean"', () => { + const schema = { + type: 'object', + properties: { + isActive: { type: ['null', 'boolean'] }, + }, + }; + const cleaned = cleanSchema(schema); + assert.deepStrictEqual(cleaned, { + type: 'object', + properties: { + isActive: { type: 'boolean' }, + }, + }); }); - }); - it('leaves other properties untouched', () => { - const schema = { - type: 'string', - description: 'A name', - maxLength: 100, - }; - const cleaned = cleanSchema(schema); - assert.deepStrictEqual(cleaned, schema); + it('leaves other properties untouched', () => { + const schema = { + type: 'string', + description: 'A name', + maxLength: 100, + }; + const cleaned = cleanSchema(schema); + assert.deepStrictEqual(cleaned, schema); + }); }); }); diff --git a/js/plugins/google-genai/tests/vertexai/client_test.ts b/js/plugins/google-genai/tests/vertexai/client_test.ts index ac66b6618c..ae69ba7a38 100644 --- a/js/plugins/google-genai/tests/vertexai/client_test.ts +++ b/js/plugins/google-genai/tests/vertexai/client_test.ts @@ -27,6 +27,9 @@ import { getVertexAIUrl, imagenPredict, listModels, + lyriaPredict, + veoCheckOperation, + veoPredict, } from '../../src/vertexai/client'; import { ClientOptions, @@ -38,7 +41,12 @@ import { GenerateContentStreamResult, ImagenPredictRequest, ImagenPredictResponse, + LyriaPredictRequest, + LyriaPredictResponse, Model, + VeoOperation, + VeoOperationRequest, + VeoPredictRequest, } from '../../src/vertexai/types'; describe('Vertex AI Client', () => { @@ -145,6 +153,45 @@ describe('Vertex AI Client', () => { ); }); + it('should build URL for predict', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: 'publishers/google/models/imagen-3.0', + resourceMethod: 'predict', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://us-east1-aiplatform.googleapis.com/v1beta1/projects/test-proj/locations/us-east1/publishers/google/models/imagen-3.0:predict' + ); + }); + + it('should build URL for predictLongRunning', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: 'publishers/google/models/veo-2.0', + resourceMethod: 'predictLongRunning', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://us-east1-aiplatform.googleapis.com/v1beta1/projects/test-proj/locations/us-east1/publishers/google/models/veo-2.0:predictLongRunning' + ); + }); + + it('should build URL for fetchPredictOperation', () => { + const url = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: 'publishers/google/models/veo-2.0', + resourceMethod: 'fetchPredictOperation', + clientOptions: opts, + }); + assert.strictEqual( + url, + 'https://us-east1-aiplatform.googleapis.com/v1beta1/projects/test-proj/locations/us-east1/publishers/google/models/veo-2.0:fetchPredictOperation' + ); + }); + it('should handle queryParams', () => { const url = getVertexAIUrl({ includeProjectAndLocation: false, @@ -206,7 +253,29 @@ describe('Vertex AI Client', () => { resourcePath: 'publishers/google/models', clientOptions: opts, }); - }, 'This method is not supported in Vertex AI Express Mode/'); + }, /This method is not supported in Vertex AI Express Mode/); + }); + + it('should not support predict', () => { + assert.throws(() => { + return getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: 'publishers/google/models/imagen-3.0', + resourceMethod: 'predict', + clientOptions: opts, + }); + }, /This method is not supported in Vertex AI Express Mode/); + }); + + it('should not support predictLongRunning', () => { + assert.throws(() => { + return getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: 'publishers/google/models/veo-2.0', + resourceMethod: 'predictLongRunning', + clientOptions: opts, + }); + }, /This method is not supported in Vertex AI Express Mode/); }); it('should build URL for generateContent', () => { @@ -309,7 +378,11 @@ describe('Vertex AI Client', () => { return `https://${domain}/v1beta1/${path}`; }; - const getResourceUrl = (model: string, method: string) => { + const getResourceUrl = ( + model: string, + method: string, + isLongRunning = false + ) => { const isStreaming = method.includes('streamGenerateContent'); let url; @@ -349,7 +422,6 @@ describe('Vertex AI Client', () => { headers: getExpectedHeaders(), }); - // Corrected assertions using sinon.assert: if (!isExpress) { sinon.assert.calledOnce(authMock.getAccessToken); } else { @@ -461,6 +533,103 @@ describe('Vertex AI Client', () => { } }); + describe('lyriaPredict', () => { + const request: LyriaPredictRequest = { + instances: [{ prompt: 'a song' }], + parameters: { sampleCount: 1 }, + }; + const model = 'lyria-002'; + if (!isExpress) { + it('should return LyriaPredictResponse', async () => { + const mockResponse: LyriaPredictResponse = { predictions: [] }; + mockFetchResponse(mockResponse); + await lyriaPredict(model, request, currentOptions); + + const expectedUrl = getResourceUrl(model, 'predict'); + sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { + method: 'POST', + headers: getExpectedHeaders(), + body: JSON.stringify(request), + }); + }); + } else { + it('should throw with unsupported for Express', async () => { + await assert.rejects( + lyriaPredict(model, request, currentOptions), + /This method is not supported in Vertex AI Express Mode/ + ); + }); + } + }); + + describe('veoPredict', () => { + const request: VeoPredictRequest = { + instances: [{ prompt: 'a video' }], + parameters: {}, + }; + const model = 'veo-2.0-generate-001'; + if (!isExpress) { + it('should return VeoOperation', async () => { + const mockResponse: VeoOperation = { name: 'operations/123' }; + mockFetchResponse(mockResponse); + await veoPredict(model, request, currentOptions); + + const expectedUrl = getResourceUrl( + model, + 'predictLongRunning', + true + ); + sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { + method: 'POST', + headers: getExpectedHeaders(), + body: JSON.stringify(request), + }); + }); + } else { + it('should throw with unsupported for Express', async () => { + await assert.rejects( + veoPredict(model, request, currentOptions), + /This method is not supported in Vertex AI Express Mode/ + ); + }); + } + }); + + describe('veoCheckOperation', () => { + const request: VeoOperationRequest = { + operationName: 'operations/123', + }; + const model = 'veo-2.0-generate-001'; + if (!isExpress) { + it('should return VeoOperation', async () => { + const mockResponse: VeoOperation = { + name: 'operations/123', + done: true, + }; + mockFetchResponse(mockResponse); + await veoCheckOperation(model, request, currentOptions); + + const expectedUrl = getResourceUrl( + model, + 'fetchPredictOperation', + true + ); + sinon.assert.calledOnceWithExactly(fetchSpy, expectedUrl, { + method: 'POST', + headers: getExpectedHeaders(), + body: JSON.stringify(request), + }); + }); + } else { + it('should throw with unsupported for Express', async () => { + await assert.rejects( + veoCheckOperation(model, request, currentOptions), + /This method is not supported in Vertex AI Express Mode/ + ); + }); + } + }); + describe('generateContentStream', () => { it('should process stream', async () => { const request: GenerateContentRequest = { diff --git a/js/plugins/google-genai/tests/vertexai/converters_test.ts b/js/plugins/google-genai/tests/vertexai/converters_test.ts index 3f9b9816e4..a769eb0b0c 100644 --- a/js/plugins/google-genai/tests/vertexai/converters_test.ts +++ b/js/plugins/google-genai/tests/vertexai/converters_test.ts @@ -15,14 +15,31 @@ */ import * as assert from 'assert'; -import { z } from 'genkit'; +import { GenerateRequest, z } from 'genkit'; import { describe, it } from 'node:test'; import { HarmBlockThreshold, HarmCategory } from '../../src/common/types'; import { + fromImagenResponse, + fromLyriaResponse, + fromVeoOperation, toGeminiLabels, toGeminiSafetySettings, + toImagenPredictRequest, + toLyriaPredictRequest, + toVeoMedia, + toVeoModel, + toVeoOperationRequest, + toVeoPredictRequest, } from '../../src/vertexai/converters'; import { SafetySettingsSchema } from '../../src/vertexai/gemini'; +import { ImagenConfigSchema } from '../../src/vertexai/imagen'; +import { LyriaConfigSchema } from '../../src/vertexai/lyria'; +import { + ImagenPredictResponse, + LyriaPredictResponse, + VeoOperation, +} from '../../src/vertexai/types'; +import { VeoConfigSchema } from '../../src/vertexai/veo'; describe('Vertex AI Converters', () => { describe('toGeminiSafetySettings', () => { @@ -101,13 +118,9 @@ describe('Vertex AI Converters', () => { it('returns undefined if all keys are empty strings', () => { const labels = { '': 'value1', - ' ': 'value2', // This key is not empty string, so it will be kept - }; - const expected = { - ' ': 'value2', }; const result = toGeminiLabels(labels); - assert.deepStrictEqual(result, expected); + assert.strictEqual(result, undefined); }); it('handles labels with empty values', () => { @@ -123,4 +136,424 @@ describe('Vertex AI Converters', () => { assert.deepStrictEqual(result, expected); }); }); + + describe('toImagenPredictRequest', () => { + const baseRequest: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'A cat on a mat' }] }], + }; + + it('should create a basic ImagenPredictRequest with default sampleCount', () => { + const result = toImagenPredictRequest(baseRequest); + assert.deepStrictEqual(result, { + instances: [{ prompt: 'A cat on a mat' }], + parameters: { sampleCount: 1 }, + }); + }); + + it('should handle candidates and config parameters', () => { + const request: GenerateRequest = { + ...baseRequest, + candidates: 2, + config: { + seed: 42, + negativePrompt: 'ugly', + aspectRatio: '16:9', + }, + }; + const result = toImagenPredictRequest(request); + assert.deepStrictEqual(result, { + instances: [{ prompt: 'A cat on a mat' }], + parameters: { + sampleCount: 2, + seed: 42, + negativePrompt: 'ugly', + aspectRatio: '16:9', + }, + }); + }); + + it('should omit undefined or null config parameters', () => { + const request: GenerateRequest = { + ...baseRequest, + config: { + negativePrompt: undefined, + seed: null as any, + aspectRatio: '1:1', + }, + }; + const result = toImagenPredictRequest(request); + assert.deepStrictEqual(result, { + instances: [{ prompt: 'A cat on a mat' }], + parameters: { + sampleCount: 1, + aspectRatio: '1:1', + }, + }); + }); + + it('should handle image and mask media', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [ + { text: 'A dog on a rug' }, + { + media: { + url: 'data:image/png;base64,IMAGEDATA', + contentType: 'image/png', + }, + metadata: { type: 'image' }, + }, + { + media: { + url: 'data:image/png;base64,MASKDATA', + contentType: 'image/png', + }, + metadata: { type: 'mask' }, + }, + ], + }, + ], + }; + const result = toImagenPredictRequest(request); + assert.deepStrictEqual(result, { + instances: [ + { + prompt: 'A dog on a rug', + image: { bytesBase64Encoded: 'IMAGEDATA' }, + mask: { image: { bytesBase64Encoded: 'MASKDATA' } }, + }, + ], + parameters: { sampleCount: 1 }, + }); + }); + }); + + describe('fromImagenResponse', () => { + it('should convert ImagenPredictResponse to GenerateResponseData', () => { + const response: ImagenPredictResponse = { + predictions: [ + { bytesBase64Encoded: 'IMAGE1', mimeType: 'image/jpeg' }, + { bytesBase64Encoded: 'IMAGE2', mimeType: 'image/png' }, + ], + }; + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'test' }] }], + }; + const result = fromImagenResponse(response, request); + + assert.strictEqual(result.candidates?.length, 2); + // Test structure from fromImagenPrediction logic + assert.deepStrictEqual(result.candidates?.[0], { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + media: { + url: 'data:image/jpeg;base64,IMAGE1', + contentType: 'image/jpeg', + }, + }, + ], + }, + }); + assert.deepStrictEqual(result.candidates?.[1], { + index: 1, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + media: { + url: 'data:image/png;base64,IMAGE2', + contentType: 'image/png', + }, + }, + ], + }, + }); + assert.strictEqual(result.custom, response); + assert.ok(result.usage); + }); + }); + + describe('toLyriaPredictRequest', () => { + const baseRequest: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'A happy song' }] }], + }; + + it('should create a basic LyriaPredictRequest', () => { + const result = toLyriaPredictRequest(baseRequest); + assert.deepStrictEqual(result, { + instances: [{ prompt: 'A happy song' }], + parameters: { sampleCount: 1 }, + }); + }); + + it('should handle config parameters', () => { + const request: GenerateRequest = { + ...baseRequest, + config: { + negativePrompt: 'sad', + seed: 123, + sampleCount: 3, + }, + }; + const result = toLyriaPredictRequest(request); + assert.deepStrictEqual(result, { + instances: [ + { prompt: 'A happy song', negativePrompt: 'sad', seed: 123 }, + ], + parameters: { sampleCount: 3 }, + }); + }); + }); + + describe('fromLyriaResponse', () => { + it('should convert LyriaPredictResponse to GenerateResponseData', () => { + const response: LyriaPredictResponse = { + predictions: [{ bytesBase64Encoded: 'AUDIO1', mimeType: 'audio/wav' }], + }; + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'test' }] }], + }; + const result = fromLyriaResponse(response, request); + + assert.strictEqual(result.candidates?.length, 1); + + assert.deepStrictEqual(result.candidates?.[0], { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + media: { + url: 'data:audio/wav;base64,AUDIO1', + contentType: 'audio/wav', + }, + }, + ], + }, + }); + assert.strictEqual(result.custom, response); + assert.ok(result.usage); + }); + }); + + describe('toVeoMedia', () => { + it('should convert data URL', () => { + const mediaPart = { + url: 'data:image/png;base64,VEODATA', + contentType: 'image/png', + }; + const result = toVeoMedia(mediaPart); + assert.deepStrictEqual(result, { + bytesBase64Encoded: 'VEODATA', + mimeType: 'image/png', + }); + }); + + it('should convert gs URL', () => { + const mediaPart = { + url: 'gs://bucket/object', + contentType: 'video/mp4', + }; + const result = toVeoMedia(mediaPart); + assert.deepStrictEqual(result, { + gcsUri: 'gs://bucket/object', + mimeType: 'video/mp4', + }); + }); + + it('should throw on http URL', () => { + const mediaPart = { + url: 'http://example.com/image.jpg', + contentType: 'image/jpeg', + }; + assert.throws(() => toVeoMedia(mediaPart), /Veo does not support http/); + }); + + it('should infer mimeType if missing', () => { + const mediaPart = { url: 'data:image/jpeg;base64,VEODATA' }; + const result = toVeoMedia(mediaPart as any); + assert.deepStrictEqual(result, { + bytesBase64Encoded: 'VEODATA', + mimeType: 'image/jpeg', + }); + }); + }); + + describe('toVeoPredictRequest', () => { + const baseRequest: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: 'A video of a sunset' }] }], + }; + + it('should create a basic VeoPredictRequest', () => { + const result = toVeoPredictRequest(baseRequest); + assert.deepStrictEqual(result, { + instances: [{ prompt: 'A video of a sunset' }], + parameters: {}, + }); + }); + + it('should handle config parameters', () => { + const request: GenerateRequest = { + ...baseRequest, + config: { + durationSeconds: 5, + fps: 24, + aspectRatio: '16:9', + }, + }; + const result = toVeoPredictRequest(request); + assert.deepStrictEqual(result, { + instances: [{ prompt: 'A video of a sunset' }], + parameters: { + durationSeconds: 5, + fps: 24, + aspectRatio: '16:9', + }, + }); + }); + + it('should handle media parts', () => { + const request: GenerateRequest = { + messages: [ + { + role: 'user', + content: [ + { text: 'A video of a sunrise' }, + { + media: { + url: 'data:image/jpeg;base64,IMAGEDATA', + contentType: 'image/jpeg', + }, + metadata: { type: 'image' }, + }, + { + media: { + url: 'gs://bucket/video.mp4', + contentType: 'video/mp4', + }, + metadata: { type: 'video' }, + }, + ], + }, + ], + }; + const result = toVeoPredictRequest(request); + assert.deepStrictEqual(result, { + instances: [ + { + prompt: 'A video of a sunrise', + image: { + bytesBase64Encoded: 'IMAGEDATA', + mimeType: 'image/jpeg', + }, + video: { gcsUri: 'gs://bucket/video.mp4', mimeType: 'video/mp4' }, + }, + ], + parameters: {}, + }); + }); + }); + + describe('fromVeoOperation', () => { + it('should convert basic pending operation', () => { + const veoOp: VeoOperation = { + name: 'operations/123', + done: false, + }; + const result = fromVeoOperation(veoOp); + assert.deepStrictEqual(result, { + id: 'operations/123', + done: false, + }); + }); + + it('should convert done operation with videos', () => { + const veoOp: VeoOperation = { + name: 'operations/456', + done: true, + response: { + videos: [ + { + gcsUri: 'gs://bucket/vid1.mp4', + mimeType: 'video/mp4', + }, + { + bytesBase64Encoded: 'VID2DATA', + mimeType: 'video/webm', + }, + ], + }, + }; + const result = fromVeoOperation(veoOp); + assert.deepStrictEqual(result, { + id: 'operations/456', + done: true, + output: { + finishReason: 'stop', + raw: veoOp.response, + message: { + role: 'model', + content: [ + { + media: { + url: 'gs://bucket/vid1.mp4', + contentType: 'video/mp4', + }, + }, + { + media: { + url: 'data:video/webm:base64,VID2DATA', + contentType: 'video/webm', + }, + }, + ], + }, + }, + }); + }); + + it('should convert operation with error', () => { + const veoOp: VeoOperation = { + name: 'operations/789', + done: true, + error: { code: 3, message: 'Invalid argument' }, + }; + const result = fromVeoOperation(veoOp); + assert.deepStrictEqual(result, { + id: 'operations/789', + done: true, + error: { message: 'Invalid argument' }, + }); + }); + }); + + describe('toVeoModel', () => { + it('should extract model name from operation id', () => { + const op = { + id: 'projects/test-project/locations/us-central1/models/veo-1.0/operations/12345', + }; + const result = toVeoModel(op); + assert.strictEqual(result, 'veo-1.0'); + }); + }); + + describe('toVeoOperationRequest', () => { + it('should create VeoOperationRequest from Operation', () => { + const op = { + id: 'operations/abcdef', + }; + const result = toVeoOperationRequest(op); + assert.deepStrictEqual(result, { + operationName: 'operations/abcdef', + }); + }); + }); }); diff --git a/js/plugins/google-genai/tests/vertexai/imagen_test.ts b/js/plugins/google-genai/tests/vertexai/imagen_test.ts index e93dbe7dcd..5ecd610f89 100644 --- a/js/plugins/google-genai/tests/vertexai/imagen_test.ts +++ b/js/plugins/google-genai/tests/vertexai/imagen_test.ts @@ -21,6 +21,10 @@ import { GoogleAuth } from 'google-auth-library'; import { afterEach, beforeEach, describe, it } from 'node:test'; import * as sinon from 'sinon'; import { getVertexAIUrl } from '../../src/vertexai/client'; +import { + fromImagenResponse, + toImagenPredictRequest, +} from '../../src/vertexai/converters'; import { ImagenConfig, ImagenConfigSchema, @@ -34,9 +38,6 @@ import { ImagenPredictResponse, ImagenPrediction, } from '../../src/vertexai/types.js'; -import * as utils from '../../src/vertexai/utils'; - -const { toImagenParameters, fromImagenPrediction } = TEST_ONLY; // Helper function to escape special characters for use in a RegExp function escapeRegExp(string) { @@ -88,85 +89,6 @@ describe('Vertex AI Imagen', () => { }); }); - describe('toImagenParameters', () => { - const baseRequest: GenerateRequest = { - messages: [], - }; - - it('should set default sampleCount to 1 if candidates is not provided', () => { - const result = toImagenParameters(baseRequest); - assert.strictEqual(result.sampleCount, 1); - }); - - it('should use request.candidates for sampleCount', () => { - const request: GenerateRequest = { - ...baseRequest, - candidates: 3, - }; - const result = toImagenParameters(request); - assert.strictEqual(result.sampleCount, 3); - }); - - it('should include config parameters', () => { - const request: GenerateRequest = { - ...baseRequest, - config: { - seed: 12345, - aspectRatio: '16:9', - negativePrompt: 'No red colors', - }, - }; - const result = toImagenParameters(request); - assert.strictEqual(result.sampleCount, 1); - assert.strictEqual(result.negativePrompt, 'No red colors'); - assert.strictEqual(result.seed, 12345); - assert.strictEqual(result.aspectRatio, '16:9'); - }); - - it('should omit undefined or null config parameters', () => { - const request: GenerateRequest = { - ...baseRequest, - config: { - negativePrompt: undefined, - seed: null as any, - aspectRatio: '1:1', - }, - }; - const result = toImagenParameters(request); - assert.strictEqual(result.sampleCount, 1); - assert.strictEqual(result.hasOwnProperty('negativePrompt'), false); - assert.strictEqual(result.hasOwnProperty('seed'), false); - assert.strictEqual(result.aspectRatio, '1:1'); - }); - }); - - describe('fromImagenPrediction', () => { - it('should convert ImagenPrediction to CandidateData', () => { - const prediction: ImagenPrediction = { - bytesBase64Encoded: 'dGVzdGJ5dGVz', - mimeType: 'image/png', - }; - const index = 2; - const result = fromImagenPrediction(prediction, index); - - assert.deepStrictEqual(result, { - index: 2, - finishReason: 'stop', - message: { - role: 'model', - content: [ - { - media: { - url: 'data:image/png;base64,dGVzdGJ5dGVz', - contentType: 'image/png', - }, - }, - ], - }, - }); - }); - }); - describe('defineImagenModel()', () => { let mockAi: sinon.SinonStubbedInstance; let fetchStub: sinon.SinonStub; @@ -275,28 +197,19 @@ describe('Vertex AI Imagen', () => { getExpectedHeaders(clientOptions) ); - const prompt = utils.extractText(request); - // extractImagenImage and extractImagenMask return undefined, - // so JSON.stringify will omit these keys. - const expectedInstance: any = { - prompt, - }; - const expectedImagenPredictRequest: ImagenPredictRequest = { - instances: [expectedInstance], - parameters: toImagenParameters(request), - }; + const expectedImagenPredictRequest: ImagenPredictRequest = + toImagenPredictRequest(request); assert.deepStrictEqual( JSON.parse(fetchArgs[1].body), expectedImagenPredictRequest ); - const expectedCandidates = mockResponse.predictions!.map((p, i) => - fromImagenPrediction(p, i) - ); + const expectedResponse = fromImagenResponse(mockResponse, request); + const expectedCandidates = expectedResponse.candidates; assert.deepStrictEqual(result.candidates, expectedCandidates); assert.deepStrictEqual(result.usage, { - ...getBasicUsageStats(request.messages, expectedCandidates), + ...getBasicUsageStats(request.messages, expectedCandidates as any), custom: { generations: 2 }, }); assert.deepStrictEqual(result.custom, mockResponse); diff --git a/js/plugins/google-genai/tests/vertexai/lyria_test.ts b/js/plugins/google-genai/tests/vertexai/lyria_test.ts new file mode 100644 index 0000000000..959aa1ec05 --- /dev/null +++ b/js/plugins/google-genai/tests/vertexai/lyria_test.ts @@ -0,0 +1,229 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { Genkit, GENKIT_CLIENT_HEADER } from 'genkit'; +import { GenerateRequest } from 'genkit/model'; +import { GoogleAuth } from 'google-auth-library'; +import { afterEach, beforeEach, describe, it } from 'node:test'; +import * as sinon from 'sinon'; +import { getVertexAIUrl } from '../../src/vertexai/client'; +import { + fromLyriaResponse, + toLyriaPredictRequest, +} from '../../src/vertexai/converters'; +import { + defineModel, + LyriaConfigSchema, + model, + TEST_ONLY, +} from '../../src/vertexai/lyria'; +import { + LyriaPredictResponse, + RegionalClientOptions, +} from '../../src/vertexai/types'; + +const { GENERIC_MODEL, KNOWN_MODELS } = TEST_ONLY; + +describe('Vertex AI Lyria', () => { + let mockGenkit: sinon.SinonStubbedInstance; + let fetchStub: sinon.SinonStub; + let authMock: sinon.SinonStubbedInstance; + let modelActionCallback: ( + request: GenerateRequest, + options: { + abortSignal?: AbortSignal; + } + ) => Promise; + + const modelName = 'lyria-test-model'; + + const defaultRegionalClientOptions: RegionalClientOptions = { + kind: 'regional', + projectId: 'test-project', + location: 'us-central1', + authClient: {} as any, + }; + + beforeEach(() => { + mockGenkit = sinon.createStubInstance(Genkit); + fetchStub = sinon.stub(global, 'fetch'); + authMock = sinon.createStubInstance(GoogleAuth); + + authMock.getAccessToken.resolves('test-token'); + defaultRegionalClientOptions.authClient = authMock as unknown as GoogleAuth; + + mockGenkit.defineModel.callsFake((config: any, func: any) => { + modelActionCallback = func; + return { name: config.name } as any; + }); + }); + + afterEach(() => { + sinon.restore(); + }); + + function mockFetchResponse(body: any, status = 200) { + const response = new Response(JSON.stringify(body), { + status: status, + statusText: status === 200 ? 'OK' : 'Error', + headers: { 'Content-Type': 'application/json' }, + }); + fetchStub.resolves(Promise.resolve(response)); + } + + function getExpectedHeaders(): Record { + return { + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + 'User-Agent': GENKIT_CLIENT_HEADER, + Authorization: 'Bearer test-token', + 'x-goog-user-project': defaultRegionalClientOptions.projectId, + }; + } + + describe('model()', () => { + it('should return a ModelReference for a known model', () => { + const knownModelName = Object.keys(KNOWN_MODELS)[0]; + const ref = model(knownModelName); + assert.strictEqual(ref.name, `vertexai/${knownModelName}`); + assert.ok(ref.info?.supports?.media); + assert.deepStrictEqual(ref.info?.supports?.output, ['media']); + }); + + it('should return a ModelReference for an unknown model using generic info', () => { + const unknownModelName = 'lyria-unknown-model'; + const ref = model(unknownModelName); + assert.strictEqual(ref.name, `vertexai/${unknownModelName}`); + assert.deepStrictEqual(ref.info, GENERIC_MODEL.info); + }); + + it('should apply config to a known model', () => { + const knownModelName = Object.keys(KNOWN_MODELS)[0]; + const config = { negativePrompt: 'noisy' }; + const ref = model(knownModelName, config); + assert.strictEqual(ref.name, `vertexai/${knownModelName}`); + assert.deepStrictEqual(ref.config, config); + }); + }); + + describe('defineModel()', () => { + beforeEach(() => { + defineModel(mockGenkit, modelName, defaultRegionalClientOptions); + sinon.assert.calledOnce(mockGenkit.defineModel); + const args = mockGenkit.defineModel.lastCall.args[0]; + assert.strictEqual(args.name, `vertexai/${modelName}`); + }); + + const prompt = 'A funky bass line'; + const minimalRequest: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: prompt }] }], + config: { sampleCount: 2 }, + }; + + const mockPrediction: LyriaPredictResponse = { + predictions: [ + { + bytesBase64Encoded: 'base64audio1', + mimeType: 'audio/wav', + }, + { + bytesBase64Encoded: 'base64audio2', + mimeType: 'audio/wav', + }, + ], + }; + + it('should call fetch with correct params and return lyria response', async () => { + mockFetchResponse(mockPrediction); + + const result = await modelActionCallback(minimalRequest, {}); + + sinon.assert.calledOnce(fetchStub); + const fetchArgs = fetchStub.lastCall.args; + const url = fetchArgs[0]; + const options = fetchArgs[1]; + + const expectedUrl = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: `publishers/google/models/${modelName}`, + resourceMethod: 'predict', + clientOptions: defaultRegionalClientOptions, + }); + assert.strictEqual(url, expectedUrl); + assert.strictEqual(options.method, 'POST'); + assert.deepStrictEqual(options.headers, getExpectedHeaders()); + + const expectedPredictRequest = toLyriaPredictRequest(minimalRequest); + assert.deepStrictEqual(JSON.parse(options.body), expectedPredictRequest); + + const expectedResponse = fromLyriaResponse( + mockPrediction, + minimalRequest + ); + assert.deepStrictEqual(result, expectedResponse); + assert.strictEqual(result.candidates?.length, 2); + assert.strictEqual( + result.candidates[0].message.content[0].media?.url, + 'data:audio/wav;base64,base64audio1' + ); + }); + + it('should throw if no predictions are returned', async () => { + mockFetchResponse({ predictions: [] }); + await assert.rejects( + modelActionCallback(minimalRequest, {}), + /Model returned no predictions/ + ); + }); + + it('should propagate API errors', async () => { + const errorBody = { error: { message: 'Quota exceeded', code: 429 } }; + mockFetchResponse(errorBody, 429); + + await assert.rejects( + modelActionCallback(minimalRequest, {}), + /Error fetching from .*predict.* Quota exceeded/ + ); + }); + + it('should pass AbortSignal to fetch', async () => { + mockFetchResponse(mockPrediction); + const controller = new AbortController(); + const abortSignal = controller.signal; + + // We need to re-register to pass the clientOptions with the signal + const clientOptionsWithSignal = { + ...defaultRegionalClientOptions, + signal: abortSignal, + }; + defineModel(mockGenkit, modelName, clientOptionsWithSignal); + + await modelActionCallback(minimalRequest, { abortSignal }); + + sinon.assert.calledOnce(fetchStub); + const fetchOptions = fetchStub.lastCall.args[1]; + assert.ok(fetchOptions.signal, 'Fetch options should have a signal'); + + const fetchSignal = fetchOptions.signal; + const abortSpy = sinon.spy(); + fetchSignal.addEventListener('abort', abortSpy); + + controller.abort(); + sinon.assert.calledOnce(abortSpy); + }); + }); +}); diff --git a/js/plugins/google-genai/tests/vertexai/utils_test.ts b/js/plugins/google-genai/tests/vertexai/utils_test.ts index 3911c6ba3d..005d38e72d 100644 --- a/js/plugins/google-genai/tests/vertexai/utils_test.ts +++ b/js/plugins/google-genai/tests/vertexai/utils_test.ts @@ -16,7 +16,6 @@ import * as assert from 'assert'; import { GenkitError } from 'genkit'; -import { GenerateRequest } from 'genkit/model'; import { GoogleAuth } from 'google-auth-library'; import { afterEach, beforeEach, describe, it } from 'node:test'; import * as sinon from 'sinon'; @@ -29,17 +28,34 @@ import { import { API_KEY_FALSE_ERROR, MISSING_API_KEY_ERROR, + NOT_SUPPORTED_IN_EXPRESS_ERROR, calculateApiKey, - extractImagenImage, - extractImagenMask, - extractText, + checkApiKey, + checkSupportedResourceMethod, + getApiKeyFromEnvVar, getDerivedOptions, } from '../../src/vertexai/utils'; -describe('getDerivedOptions', () => { +// Helper to assert GenkitError properties +function assertGenkitError(error: any, expectedError: GenkitError) { + assert.ok( + error instanceof GenkitError, + `Expected GenkitError, got ${error.name}` + ); + assert.strictEqual( + error.status, + expectedError.status, + 'Error status mismatch' + ); + assert.strictEqual( + error.message, + expectedError.message, + 'Error message mismatch' + ); +} + +describe('Vertex AI Utils', () => { const originalEnv = { ...process.env }; - let authInstance: sinon.SinonStubbedInstance; - let mockAuthClass: sinon.SinonStub; beforeEach(() => { // Reset env @@ -59,686 +75,495 @@ describe('getDerivedOptions', () => { delete process.env.VERTEX_API_KEY; delete process.env.GOOGLE_API_KEY; delete process.env.GOOGLE_GENAI_API_KEY; - - authInstance = sinon.createStubInstance(GoogleAuth); - authInstance.getAccessToken.resolves('test-token'); - // Default to simulating project ID not found, tests that need it should override. - authInstance.getProjectId.resolves(undefined); - - mockAuthClass = sinon.stub().returns(authInstance); }); afterEach(() => { sinon.restore(); }); - describe('Regional Options', () => { - it('should use options for projectId and location', async () => { - const pluginOptions: VertexPluginOptions = { - projectId: 'options-project', - location: 'options-location', - }; - const options = (await getDerivedOptions( - pluginOptions, - mockAuthClass as any - )) as RegionalClientOptions; - assert.strictEqual(options.kind, 'regional'); - assert.strictEqual(options.projectId, 'options-project'); - assert.strictEqual(options.location, 'options-location'); - assert.ok(options.authClient); - sinon.assert.calledOnce(mockAuthClass); - sinon.assert.notCalled(authInstance.getProjectId); - }); + describe('getDerivedOptions', () => { + let authInstance: sinon.SinonStubbedInstance; + let mockAuthClass: sinon.SinonStub; - it('should use GCLOUD_PROJECT and GCLOUD_LOCATION env vars', async () => { - process.env.GCLOUD_PROJECT = 'env-project'; - process.env.GCLOUD_LOCATION = 'env-location'; - const options = (await getDerivedOptions( - undefined, - mockAuthClass as any - )) as RegionalClientOptions; - assert.strictEqual(options.kind, 'regional'); - assert.strictEqual(options.projectId, 'env-project'); - assert.strictEqual(options.location, 'env-location'); - sinon.assert.calledOnce(mockAuthClass); - const authOptions = mockAuthClass.lastCall.args[0]; - assert.strictEqual(authOptions.projectId, 'env-project'); - sinon.assert.notCalled(authInstance.getProjectId); + beforeEach(() => { + authInstance = sinon.createStubInstance(GoogleAuth); + authInstance.getAccessToken.resolves('test-token'); + authInstance.getProjectId.resolves(undefined); // Default + mockAuthClass = sinon.stub().returns(authInstance); }); - it('should use default location when only projectId is available', async () => { - authInstance.getProjectId.resolves('default-project'); - const options = (await getDerivedOptions( - undefined, - mockAuthClass as any - )) as RegionalClientOptions; - assert.strictEqual(options.kind, 'regional'); - assert.strictEqual(options.projectId, 'default-project'); - assert.strictEqual(options.location, 'us-central1'); - sinon.assert.calledOnce(mockAuthClass); - sinon.assert.calledOnce(authInstance.getProjectId); - }); + describe('Regional Options', () => { + it('should use options for projectId and location', async () => { + const pluginOptions: VertexPluginOptions = { + projectId: 'options-project', + location: 'options-location', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'options-project'); + assert.strictEqual(options.location, 'options-location'); + assert.ok(options.authClient); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.notCalled(authInstance.getProjectId); + }); - it('should use FIREBASE_CONFIG for GoogleAuth constructor, but final projectId from getProjectId', async () => { - process.env.FIREBASE_CONFIG = JSON.stringify({ - projectId: 'firebase-project', + it('should use GCLOUD_PROJECT and GCLOUD_LOCATION env vars', async () => { + process.env.GCLOUD_PROJECT = 'env-project'; + process.env.GCLOUD_LOCATION = 'env-location'; + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'env-project'); + assert.strictEqual(options.location, 'env-location'); + sinon.assert.calledOnce(mockAuthClass); + const authOptions = mockAuthClass.lastCall.args[0]; + assert.strictEqual(authOptions.projectId, 'env-project'); + sinon.assert.notCalled(authInstance.getProjectId); }); - authInstance.getProjectId.resolves('auth-client-project'); - - const options = (await getDerivedOptions( - { location: 'fb-location' }, - mockAuthClass as any - )) as RegionalClientOptions; - - assert.strictEqual(options.kind, 'regional'); - sinon.assert.calledOnce(mockAuthClass); - const authOptions = mockAuthClass.lastCall.args[0]; - assert.strictEqual(authOptions.projectId, 'firebase-project'); - sinon.assert.calledOnce(authInstance.getProjectId); - assert.strictEqual(options.projectId, 'auth-client-project'); - assert.strictEqual(options.location, 'fb-location'); - }); - it('should prioritize plugin options over env vars', async () => { - process.env.GCLOUD_PROJECT = 'env-project'; - process.env.GCLOUD_LOCATION = 'env-location'; - const pluginOptions: VertexPluginOptions = { - projectId: 'options-project', - location: 'options-location', - }; - const options = (await getDerivedOptions( - pluginOptions, - mockAuthClass as any - )) as RegionalClientOptions; - assert.strictEqual(options.kind, 'regional'); - assert.strictEqual(options.projectId, 'options-project'); - assert.strictEqual(options.location, 'options-location'); - }); + it('should use default location when only projectId is available', async () => { + authInstance.getProjectId.resolves('default-project'); + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'default-project'); + assert.strictEqual(options.location, 'us-central1'); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.calledOnce(authInstance.getProjectId); + }); - it('should use GCLOUD_SERVICE_ACCOUNT_CREDS for auth', async () => { - const creds = { - client_email: '', - private_key: 'private_key', - }; - process.env.GCLOUD_SERVICE_ACCOUNT_CREDS = JSON.stringify(creds); - authInstance.getProjectId.resolves('creds-project'); - - const options = (await getDerivedOptions( - { location: 'creds-location' }, - mockAuthClass as any - )) as RegionalClientOptions; - - assert.strictEqual(options.kind, 'regional'); - assert.strictEqual(options.projectId, 'creds-project'); - assert.strictEqual(options.location, 'creds-location'); - sinon.assert.calledOnce(mockAuthClass); - const authOptions = mockAuthClass.lastCall.args[0]; - assert.deepStrictEqual(authOptions.credentials, creds); - assert.strictEqual(authOptions.projectId, undefined); - sinon.assert.calledOnce(authInstance.getProjectId); - }); + it('should use FIREBASE_CONFIG for GoogleAuth constructor, but final projectId from getProjectId', async () => { + process.env.FIREBASE_CONFIG = JSON.stringify({ + projectId: 'firebase-project', + }); + authInstance.getProjectId.resolves('auth-client-project'); + + const options = (await getDerivedOptions( + { location: 'fb-location' }, + mockAuthClass as any + )) as RegionalClientOptions; + + assert.strictEqual(options.kind, 'regional'); + sinon.assert.calledOnce(mockAuthClass); + const authOptions = mockAuthClass.lastCall.args[0]; + assert.strictEqual(authOptions.projectId, 'firebase-project'); + sinon.assert.calledOnce(authInstance.getProjectId); + assert.strictEqual(options.projectId, 'auth-client-project'); + assert.strictEqual(options.location, 'fb-location'); + }); - it('should throw error if projectId cannot be determined for regional', async () => { - authInstance.getProjectId.resolves(undefined); - await assert.rejects( - getDerivedOptions({ location: 'some-location' }, mockAuthClass as any), - /VertexAI Plugin is missing the 'project' configuration/ - ); - sinon.assert.calledOnce(mockAuthClass); - sinon.assert.calledOnce(authInstance.getProjectId); - }); + it('should prioritize plugin options over env vars', async () => { + process.env.GCLOUD_PROJECT = 'env-project'; + process.env.GCLOUD_LOCATION = 'env-location'; + const pluginOptions: VertexPluginOptions = { + projectId: 'options-project', + location: 'options-location', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'options-project'); + assert.strictEqual(options.location, 'options-location'); + }); - it('should prefer regional if location is specified, even with apiKey', async () => { - const pluginOptions: VertexPluginOptions = { - location: 'us-central1', - apiKey: 'test-api-key', - projectId: 'options-project', - }; - const options = (await getDerivedOptions( - pluginOptions, - mockAuthClass as any - )) as RegionalClientOptions; - assert.strictEqual(options.kind, 'regional'); - assert.strictEqual(options.projectId, 'options-project'); - assert.strictEqual(options.location, 'us-central1'); - assert.ok(options.authClient); - sinon.assert.calledOnce(mockAuthClass); - }); - }); + it('should use GCLOUD_SERVICE_ACCOUNT_CREDS for auth', async () => { + const creds = { + client_email: 'clientEmail', + private_key: 'private_key', + }; + process.env.GCLOUD_SERVICE_ACCOUNT_CREDS = JSON.stringify(creds); + authInstance.getProjectId.resolves('creds-project'); + + const options = (await getDerivedOptions( + { location: 'creds-location' }, + mockAuthClass as any + )) as RegionalClientOptions; + + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'creds-project'); + assert.strictEqual(options.location, 'creds-location'); + sinon.assert.calledOnce(mockAuthClass); + const authOptions = mockAuthClass.lastCall.args[0]; + assert.deepStrictEqual(authOptions.credentials, creds); + assert.strictEqual(authOptions.projectId, undefined); + sinon.assert.calledOnce(authInstance.getProjectId); + }); - describe('Global Options', () => { - it('should use global options when location is global', async () => { - const pluginOptions: VertexPluginOptions = { - location: 'global', - projectId: 'options-project', - }; - const options = (await getDerivedOptions( - pluginOptions, - mockAuthClass as any - )) as GlobalClientOptions; - assert.strictEqual(options.kind, 'global'); - assert.strictEqual(options.location, 'global'); - assert.strictEqual(options.projectId, 'options-project'); - assert.ok(options.authClient); - sinon.assert.calledOnce(mockAuthClass); - sinon.assert.notCalled(authInstance.getProjectId); - }); + it('should throw error if projectId cannot be determined for regional', async () => { + authInstance.getProjectId.resolves(undefined); + await assert.rejects( + getDerivedOptions( + { location: 'some-location' }, + mockAuthClass as any + ), + /VertexAI Plugin is missing the 'project' configuration/ + ); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.calledOnce(authInstance.getProjectId); + }); - it('should use env project for global options', async () => { - process.env.GCLOUD_PROJECT = 'env-project'; - const pluginOptions: VertexPluginOptions = { location: 'global' }; - const options = (await getDerivedOptions( - pluginOptions, - mockAuthClass as any - )) as GlobalClientOptions; - assert.strictEqual(options.kind, 'global'); - assert.strictEqual(options.projectId, 'env-project'); - sinon.assert.calledOnce(mockAuthClass); + it('should prefer regional if location is specified, even with apiKey', async () => { + const pluginOptions: VertexPluginOptions = { + location: 'us-central1', + apiKey: 'test-api-key', + projectId: 'options-project', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'options-project'); + assert.strictEqual(options.location, 'us-central1'); + assert.ok(options.authClient); + assert.strictEqual(options.apiKey, 'test-api-key'); + sinon.assert.calledOnce(mockAuthClass); + }); }); - it('should use auth project for global options', async () => { - authInstance.getProjectId.resolves('auth-project'); - const pluginOptions: VertexPluginOptions = { location: 'global' }; - const options = (await getDerivedOptions( - pluginOptions, - mockAuthClass as any - )) as GlobalClientOptions; - assert.strictEqual(options.kind, 'global'); - assert.strictEqual(options.projectId, 'auth-project'); - sinon.assert.calledOnce(mockAuthClass); - sinon.assert.calledOnce(authInstance.getProjectId); - }); + describe('Global Options', () => { + it('should use global options when location is global', async () => { + const pluginOptions: VertexPluginOptions = { + location: 'global', + projectId: 'options-project', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as GlobalClientOptions; + assert.strictEqual(options.kind, 'global'); + assert.strictEqual(options.location, 'global'); + assert.strictEqual(options.projectId, 'options-project'); + assert.ok(options.authClient); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.notCalled(authInstance.getProjectId); + }); - it('should throw error if projectId cannot be determined for global', async () => { - authInstance.getProjectId.resolves(undefined); - await assert.rejects( - getDerivedOptions({ location: 'global' }, mockAuthClass as any), - /VertexAI Plugin is missing the 'project' configuration/ - ); - sinon.assert.calledOnce(mockAuthClass); - sinon.assert.calledOnce(authInstance.getProjectId); - }); + it('should use env project for global options', async () => { + process.env.GCLOUD_PROJECT = 'env-project'; + const pluginOptions: VertexPluginOptions = { location: 'global' }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as GlobalClientOptions; + assert.strictEqual(options.kind, 'global'); + assert.strictEqual(options.projectId, 'env-project'); + sinon.assert.calledOnce(mockAuthClass); + }); - it('should prefer global if location is global, even with apiKey', async () => { - const pluginOptions: VertexPluginOptions = { - location: 'global', - apiKey: 'test-api-key', - projectId: 'options-project', - }; - const options = (await getDerivedOptions( - pluginOptions, - mockAuthClass as any - )) as GlobalClientOptions; - assert.strictEqual(options.kind, 'global'); - assert.strictEqual(options.projectId, 'options-project'); - assert.ok(options.authClient); - sinon.assert.calledOnce(mockAuthClass); - }); - }); + it('should use auth project for global options', async () => { + authInstance.getProjectId.resolves('auth-project'); + const pluginOptions: VertexPluginOptions = { location: 'global' }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as GlobalClientOptions; + assert.strictEqual(options.kind, 'global'); + assert.strictEqual(options.projectId, 'auth-project'); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.calledOnce(authInstance.getProjectId); + }); - describe('Express Options', () => { - it('should use express options with apiKey in options', async () => { - const pluginOptions: VertexPluginOptions = { apiKey: 'key1' }; - const options = (await getDerivedOptions( - pluginOptions, - mockAuthClass as any - )) as ExpressClientOptions; - assert.strictEqual(options.kind, 'express'); - assert.strictEqual(options.apiKey, 'key1'); - sinon.assert.notCalled(mockAuthClass); - }); + it('should throw error if projectId cannot be determined for global', async () => { + authInstance.getProjectId.resolves(undefined); + await assert.rejects( + getDerivedOptions({ location: 'global' }, mockAuthClass as any), + /VertexAI Plugin is missing the 'project' configuration/ + ); + sinon.assert.calledOnce(mockAuthClass); + sinon.assert.calledOnce(authInstance.getProjectId); + }); - it('should use express options with apiKey false in options', async () => { - const pluginOptions: VertexPluginOptions = { apiKey: false }; - const options = (await getDerivedOptions( - pluginOptions, - mockAuthClass as any - )) as ExpressClientOptions; - assert.strictEqual(options.kind, 'express'); - assert.strictEqual(options.apiKey, undefined); - sinon.assert.notCalled(mockAuthClass); + it('should prefer global if location is global, even with apiKey', async () => { + const pluginOptions: VertexPluginOptions = { + location: 'global', + apiKey: 'test-api-key', + projectId: 'options-project', + }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as GlobalClientOptions; + assert.strictEqual(options.kind, 'global'); + assert.strictEqual(options.projectId, 'options-project'); + assert.ok(options.authClient); + assert.strictEqual(options.apiKey, 'test-api-key'); + sinon.assert.calledOnce(mockAuthClass); + }); }); - it('should use VERTEX_API_KEY env var for express', async () => { - process.env.VERTEX_API_KEY = 'key2'; - const options = (await getDerivedOptions( - undefined, - mockAuthClass as any - )) as ExpressClientOptions; - assert.strictEqual(options.kind, 'express'); - assert.strictEqual(options.apiKey, 'key2'); - // mockAuthClass is called during regional/global fallbacks - sinon.assert.calledTwice(mockAuthClass); - }); + describe('Express Options', () => { + it('should use express options with apiKey in options', async () => { + const pluginOptions: VertexPluginOptions = { apiKey: 'key1' }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as ExpressClientOptions; + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, 'key1'); + sinon.assert.notCalled(mockAuthClass); + }); - it('should use GOOGLE_API_KEY env var for express', async () => { - process.env.GOOGLE_API_KEY = 'key3'; - const options = (await getDerivedOptions( - undefined, - mockAuthClass as any - )) as ExpressClientOptions; - assert.strictEqual(options.kind, 'express'); - assert.strictEqual(options.apiKey, 'key3'); - // mockAuthClass is called during regional/global fallbacks - sinon.assert.calledTwice(mockAuthClass); - }); + it('should use express options with apiKey false in options', async () => { + const pluginOptions: VertexPluginOptions = { apiKey: false }; + const options = (await getDerivedOptions( + pluginOptions, + mockAuthClass as any + )) as ExpressClientOptions; + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, undefined); + sinon.assert.notCalled(mockAuthClass); + }); - it('should prioritize VERTEX_API_KEY over GOOGLE_API_KEY for express', async () => { - process.env.VERTEX_API_KEY = 'keyV'; - process.env.GOOGLE_API_KEY = 'keyG'; - const options = (await getDerivedOptions( - undefined, - mockAuthClass as any - )) as ExpressClientOptions; - assert.strictEqual(options.kind, 'express'); - assert.strictEqual(options.apiKey, 'keyV'); - // mockAuthClass is called during regional/global fallbacks - sinon.assert.calledTwice(mockAuthClass); - }); - }); + it('should use VERTEX_API_KEY env var for express', async () => { + process.env.VERTEX_API_KEY = 'key2'; + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as ExpressClientOptions; + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, 'key2'); + sinon.assert.calledTwice(mockAuthClass); // Fallback attempts + }); - describe('Fallback Determination (No Options)', () => { - it('should default to regional if project can be determined', async () => { - authInstance.getProjectId.resolves('fallback-project'); - const options = (await getDerivedOptions( - undefined, - mockAuthClass as any - )) as RegionalClientOptions; - assert.strictEqual(options.kind, 'regional'); - assert.strictEqual(options.projectId, 'fallback-project'); - assert.strictEqual(options.location, 'us-central1'); - sinon.assert.calledOnce(mockAuthClass); + it('should use GOOGLE_API_KEY env var for express', async () => { + process.env.GOOGLE_API_KEY = 'key3'; + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as ExpressClientOptions; + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, 'key3'); + sinon.assert.calledTwice(mockAuthClass); // Fallback attempts + }); }); - it('should fallback to express if regional/global fail and API key env exists', async () => { - authInstance.getProjectId.resolves(undefined); // Fail regional/global project lookup - process.env.GOOGLE_API_KEY = 'fallback-api-key'; - - const options = (await getDerivedOptions( - undefined, - mockAuthClass as any - )) as ExpressClientOptions; + describe('Fallback Determination (No Options)', () => { + it('should default to regional if project can be determined', async () => { + authInstance.getProjectId.resolves('fallback-project'); + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as RegionalClientOptions; + assert.strictEqual(options.kind, 'regional'); + assert.strictEqual(options.projectId, 'fallback-project'); + assert.strictEqual(options.location, 'us-central1'); + sinon.assert.calledOnce(mockAuthClass); + }); - assert.strictEqual(options.kind, 'express'); - assert.strictEqual(options.apiKey, 'fallback-api-key'); - // getRegionalDerivedOptions, getGlobalDerivedOptions are called first - sinon.assert.calledTwice(mockAuthClass); - }); - }); + it('should fallback to express if regional/global fail and API key env exists', async () => { + authInstance.getProjectId.resolves(undefined); + process.env.GOOGLE_API_KEY = 'fallback-api-key'; - describe('Error Scenarios', () => { - it('should throw error if no options or env vars provide sufficient info', async () => { - authInstance.getProjectId.resolves(undefined); // Simulate failure to get project ID - // No API key env vars set + const options = (await getDerivedOptions( + undefined, + mockAuthClass as any + )) as ExpressClientOptions; - await assert.rejects(getDerivedOptions(undefined, mockAuthClass as any), { - name: 'GenkitError', - status: 'INVALID_ARGUMENT', - message: - 'INVALID_ARGUMENT: Unable to determine client options. Please set either apiKey or projectId and location', + assert.strictEqual(options.kind, 'express'); + assert.strictEqual(options.apiKey, 'fallback-api-key'); + sinon.assert.calledTwice(mockAuthClass); }); - // Tries Regional, Global, then Express paths. Regional and Global attempts create AuthClient. - sinon.assert.calledTwice(mockAuthClass); }); - }); -}); - -describe('calculateApiKey', () => { - const originalEnv = { ...process.env }; - beforeEach(() => { - // Reset env - for (const key in process.env) { - if (!originalEnv.hasOwnProperty(key)) { - delete process.env[key]; - } - } - for (const key in originalEnv) { - process.env[key] = originalEnv[key]; - } - delete process.env.VERTEX_API_KEY; - delete process.env.GOOGLE_API_KEY; - delete process.env.GOOGLE_GENAI_API_KEY; + describe('Error Scenarios', () => { + it('should throw error if no options or env vars provide sufficient info', async () => { + authInstance.getProjectId.resolves(undefined); + await assert.rejects( + getDerivedOptions(undefined, mockAuthClass as any), + /Unable to determine client options/ + ); + sinon.assert.calledTwice(mockAuthClass); + }); + }); }); - function assertThrowsGenkitError( - block: () => void, - expectedError: GenkitError - ) { - let caughtError: any; - try { - block(); - } catch (e: any) { - caughtError = e; - } - - if (!caughtError) { - assert.fail('Should have thrown an error, but nothing was caught.'); - } + describe('getApiKeyFromEnvVar', () => { + it('should return VERTEX_API_KEY if set', () => { + process.env.VERTEX_API_KEY = 'vertexKey'; + process.env.GOOGLE_API_KEY = 'googleKey'; + assert.strictEqual(getApiKeyFromEnvVar(), 'vertexKey'); + }); - assert.strictEqual( - caughtError.name, - 'GenkitError', - `Caught error is not a GenkitError. Got: ${caughtError.name}, Message: ${caughtError.message}` - ); - assert.strictEqual(caughtError.status, expectedError.status); - assert.strictEqual(caughtError.message, expectedError.message); - } - - it('should use requestApiKey when provided', () => { - assert.strictEqual(calculateApiKey(undefined, 'reqKey'), 'reqKey'); - assert.strictEqual(calculateApiKey('pluginKey', 'reqKey'), 'reqKey'); - assert.strictEqual(calculateApiKey(false, 'reqKey'), 'reqKey'); - }); + it('should return GOOGLE_API_KEY if VERTEX_API_KEY is not set', () => { + process.env.GOOGLE_API_KEY = 'googleKey'; + process.env.GOOGLE_GENAI_API_KEY = 'genaiKey'; + assert.strictEqual(getApiKeyFromEnvVar(), 'googleKey'); + }); - it('should use pluginApiKey if requestApiKey is undefined', () => { - assert.strictEqual(calculateApiKey('pluginKey', undefined), 'pluginKey'); - }); + it('should return GOOGLE_GENAI_API_KEY if others are not set', () => { + process.env.GOOGLE_GENAI_API_KEY = 'genaiKey'; + assert.strictEqual(getApiKeyFromEnvVar(), 'genaiKey'); + }); - it('should use VERTEX_API_KEY from env if keys are undefined', () => { - process.env.VERTEX_API_KEY = 'vertexEnvKey'; - assert.strictEqual(calculateApiKey(undefined, undefined), 'vertexEnvKey'); + it('should return undefined if no key env vars are set', () => { + assert.strictEqual(getApiKeyFromEnvVar(), undefined); + }); }); - it('should use GOOGLE_API_KEY from env if VERTEX_API_KEY is not set', () => { - process.env.GOOGLE_API_KEY = 'googleEnvKey'; - assert.strictEqual(calculateApiKey(undefined, undefined), 'googleEnvKey'); - }); + describe('checkApiKey', () => { + it('should return pluginApiKey if it is a string', () => { + assert.strictEqual(checkApiKey('pluginKey'), 'pluginKey'); + }); - it('should prioritize pluginApiKey over env keys', () => { - process.env.VERTEX_API_KEY = 'vertexEnvKey'; - assert.strictEqual(calculateApiKey('pluginKey', undefined), 'pluginKey'); - }); + it('should return undefined if pluginApiKey is false', () => { + assert.strictEqual(checkApiKey(false), undefined); + }); - it('should throw MISSING_API_KEY_ERROR if no key is found', () => { - assert.strictEqual( - process.env.VERTEX_API_KEY, - undefined, - 'VERTEX_API_KEY should be undefined' - ); - assert.strictEqual( - process.env.GOOGLE_API_KEY, - undefined, - 'GOOGLE_API_KEY should be undefined' - ); - assert.strictEqual( - process.env.GOOGLE_GENAI_API_KEY, - undefined, - 'GOOGLE_GENAI_API_KEY should be undefined' - ); - - assertThrowsGenkitError( - () => calculateApiKey(undefined, undefined), - MISSING_API_KEY_ERROR - ); - }); + it('should return env var key if pluginApiKey is undefined', () => { + process.env.VERTEX_API_KEY = 'envKey'; + assert.strictEqual(checkApiKey(undefined), 'envKey'); + }); - it('should throw API_KEY_FALSE_ERROR if pluginApiKey is false and requestApiKey is undefined', () => { - assertThrowsGenkitError( - () => calculateApiKey(false, undefined), - API_KEY_FALSE_ERROR - ); - }); + it('should throw MISSING_API_KEY_ERROR if no key found', () => { + try { + checkApiKey(undefined); + assert.fail('Should have thrown'); + } catch (e) { + assertGenkitError(e, MISSING_API_KEY_ERROR); + } + }); - it('should not use env keys if pluginApiKey is false', () => { - process.env.VERTEX_API_KEY = 'vertexEnvKey'; - assertThrowsGenkitError( - () => calculateApiKey(false, undefined), - API_KEY_FALSE_ERROR - ); + it('should not throw if pluginApiKey is false, even if no env var', () => { + assert.doesNotThrow(() => { + checkApiKey(false); + }); + }); }); -}); -describe('extractText', () => { - it('should extract text from the last message', () => { - const request: GenerateRequest = { - messages: [ - { role: 'user', content: [{ text: 'ignore this' }] }, - { role: 'user', content: [{ text: 'Hello ' }, { text: 'World' }] }, - ], - }; - assert.strictEqual(extractText(request), 'Hello World'); - }); + describe('calculateApiKey', () => { + it('should use requestApiKey when provided', () => { + assert.strictEqual(calculateApiKey(undefined, 'reqKey'), 'reqKey'); + assert.strictEqual(calculateApiKey('pluginKey', 'reqKey'), 'reqKey'); + assert.strictEqual(calculateApiKey(false, 'reqKey'), 'reqKey'); + }); - it('should return empty string if last message has no text parts', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'data:image/png;base64,abc', - contentType: 'image/png', - }, - }, - ], - }, - ], - }; - assert.strictEqual(extractText(request), ''); - }); + it('should use pluginApiKey if requestApiKey is undefined', () => { + assert.strictEqual(calculateApiKey('pluginKey', undefined), 'pluginKey'); + }); - it('should handle messages with mixed content', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { text: 'A ' }, - { - media: { - url: 'data:image/png;base64,abc', - contentType: 'image/png', - }, - }, - { text: 'B' }, - ], - }, - ], - }; - assert.strictEqual(extractText(request), 'A B'); - }); + it('should use env key if plugin and request keys are undefined', () => { + process.env.VERTEX_API_KEY = 'envKey'; + assert.strictEqual(calculateApiKey(undefined, undefined), 'envKey'); + }); - it('should return empty string for empty content array', () => { - const request: GenerateRequest = { - messages: [{ role: 'user', content: [] }], - }; - assert.strictEqual(extractText(request), ''); - }); -}); + it('should prioritize pluginApiKey over env keys', () => { + process.env.VERTEX_API_KEY = 'envKey'; + assert.strictEqual(calculateApiKey('pluginKey', undefined), 'pluginKey'); + }); -describe('extractImagenImage', () => { - it('should extract base image from last message', () => { - const request: GenerateRequest = { - messages: [ - { role: 'user', content: [{ text: 'test' }] }, - { - role: 'user', - content: [ - { text: 'An image' }, - { - media: { - url: 'data:image/jpeg;base64,base64imagedata', - contentType: 'image/jpeg', - }, - metadata: { type: 'base' }, - }, - ], - }, - ], - }; - assert.deepStrictEqual(extractImagenImage(request), { - bytesBase64Encoded: 'base64imagedata', + it('should throw MISSING_API_KEY_ERROR if no key is found', () => { + try { + calculateApiKey(undefined, undefined); + assert.fail('Should have thrown'); + } catch (e) { + assertGenkitError(e, MISSING_API_KEY_ERROR); + } }); - }); - it('should extract image if metadata type is missing', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'data:image/png;base64,anotherimage', - contentType: 'image/png', - }, - }, - ], - }, - ], - }; - assert.deepStrictEqual(extractImagenImage(request), { - bytesBase64Encoded: 'anotherimage', + it('should throw API_KEY_FALSE_ERROR if pluginApiKey is false and requestApiKey is undefined', () => { + try { + calculateApiKey(false, undefined); + assert.fail('Should have thrown'); + } catch (e) { + assertGenkitError(e, API_KEY_FALSE_ERROR); + } }); - }); - it('should ignore images with metadata type mask', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'data:image/png;base64,maskdata', - contentType: 'image/png', - }, - metadata: { type: 'mask' }, - }, - ], - }, - ], - }; - assert.strictEqual(extractImagenImage(request), undefined); + it('should not use env keys if pluginApiKey is false', () => { + process.env.VERTEX_API_KEY = 'envKey'; + try { + calculateApiKey(false, undefined); + assert.fail('Should have thrown'); + } catch (e) { + assertGenkitError(e, API_KEY_FALSE_ERROR); + } + }); }); - it('should return undefined if no media in last message', () => { - const request: GenerateRequest = { - messages: [{ role: 'user', content: [{ text: 'No image here' }] }], + describe('checkSupportedResourceMethod', () => { + const expressOptions: ExpressClientOptions = { + kind: 'express', + apiKey: 'testKey', }; - assert.strictEqual(extractImagenImage(request), undefined); - }); - - it('should return undefined if media url is not base64 data URL', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'http://example.com/image.png', - contentType: 'image/png', - }, - }, - ], - }, - ], + const regionalOptions: RegionalClientOptions = { + kind: 'regional', + location: 'us-central1', + projectId: 'testProject', + authClient: {} as any, }; - assert.strictEqual(extractImagenImage(request), undefined); - }); - - it('should return undefined for empty messages array', () => { - const request: GenerateRequest = { messages: [] }; - assert.strictEqual(extractImagenImage(request), undefined); - }); -}); -describe('extractImagenMask', () => { - it('should extract mask image from last message', () => { - const request: GenerateRequest = { - messages: [ - { role: 'user', content: [{ text: 'test' }] }, - { - role: 'user', - content: [ - { text: 'A mask' }, - { - media: { - url: 'data:image/png;base64,maskbytes', - contentType: 'image/png', - }, - metadata: { type: 'mask' }, - }, - ], - }, - ], - }; - assert.deepStrictEqual(extractImagenMask(request), { - image: { bytesBase64Encoded: 'maskbytes' }, + it('should allow empty resourcePath', () => { + assert.doesNotThrow(() => { + checkSupportedResourceMethod({ + clientOptions: expressOptions, + resourcePath: '', + }); + }); }); - }); - it('should ignore images with metadata type base', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'data:image/jpeg;base64,basedata', - contentType: 'image/jpeg', - }, - metadata: { type: 'base' }, - }, - ], - }, - ], - }; - assert.strictEqual(extractImagenMask(request), undefined); - }); - - it('should ignore images with no metadata type', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'data:image/jpeg;base64,imagedata', - contentType: 'image/jpeg', - }, - }, - ], - }, - ], - }; - assert.strictEqual(extractImagenMask(request), undefined); - }); - - it('should return undefined if no media in last message', () => { - const request: GenerateRequest = { - messages: [{ role: 'user', content: [{ text: 'No mask here' }] }], - }; - assert.strictEqual(extractImagenMask(request), undefined); - }); + it('should allow supported methods for Express', () => { + const supported = [ + 'countTokens', + 'generateContent', + 'streamGenerateContent', + ]; + supported.forEach((method) => { + assert.doesNotThrow(() => { + checkSupportedResourceMethod({ + clientOptions: expressOptions, + resourceMethod: method, + }); + }, `Express should support ${method}`); + }); + }); - it('should return undefined if media url is not base64 data URL', () => { - const request: GenerateRequest = { - messages: [ - { - role: 'user', - content: [ - { - media: { - url: 'http://example.com/mask.png', - contentType: 'image/png', - }, - metadata: { type: 'mask' }, - }, - ], - }, - ], - }; - assert.strictEqual(extractImagenMask(request), undefined); - }); + it('should throw NOT_SUPPORTED_IN_EXPRESS_ERROR for unsupported methods in Express', () => { + const unsupported = [ + 'predict', + 'predictLongRunning', + 'fetchPredictOperation', + 'listModels', + ]; + unsupported.forEach((method) => { + try { + checkSupportedResourceMethod({ + clientOptions: expressOptions, + resourceMethod: method, + }); + assert.fail(`Should have thrown for Express method ${method}`); + } catch (e) { + assertGenkitError(e, NOT_SUPPORTED_IN_EXPRESS_ERROR); + } + }); + }); - it('should return undefined for empty messages array', () => { - const request: GenerateRequest = { messages: [] }; - assert.strictEqual(extractImagenMask(request), undefined); + it('should allow any method for non-Express options', () => { + const methods = [ + 'countTokens', + 'generateContent', + 'streamGenerateContent', + 'predict', + 'predictLongRunning', + 'fetchPredictOperation', + ]; + methods.forEach((method) => { + assert.doesNotThrow(() => { + checkSupportedResourceMethod({ + clientOptions: regionalOptions, + resourceMethod: method, + }); + }, `Regional should support ${method}`); + }); + }); }); }); diff --git a/js/plugins/google-genai/tests/vertexai/veo_test.ts b/js/plugins/google-genai/tests/vertexai/veo_test.ts new file mode 100644 index 0000000000..854fa3a907 --- /dev/null +++ b/js/plugins/google-genai/tests/vertexai/veo_test.ts @@ -0,0 +1,270 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { Genkit, GENKIT_CLIENT_HEADER, Operation } from 'genkit'; +import { GenerateRequest } from 'genkit/model'; +import { GoogleAuth } from 'google-auth-library'; +import { afterEach, beforeEach, describe, it } from 'node:test'; +import * as sinon from 'sinon'; +import { getVertexAIUrl } from '../../src/vertexai/client'; +import { + fromVeoOperation, + toVeoOperationRequest, + toVeoPredictRequest, +} from '../../src/vertexai/converters'; +import { + ClientOptions, + RegionalClientOptions, + VeoOperation, + VeoOperationRequest, + VeoPredictRequest, +} from '../../src/vertexai/types'; +import { + defineModel, + model, + TEST_ONLY, + VeoConfigSchema, +} from '../../src/vertexai/veo'; + +const { GENERIC_MODEL, KNOWN_MODELS } = TEST_ONLY; + +describe('Vertex AI Veo', () => { + let mockGenkit: sinon.SinonStubbedInstance; + let fetchStub: sinon.SinonStub; + let authMock: sinon.SinonStubbedInstance; + + const modelName = 'veo-test-model'; + + const defaultRegionalClientOptions: RegionalClientOptions = { + kind: 'regional', + projectId: 'test-project', + location: 'us-central1', + authClient: {} as any, + }; + + beforeEach(() => { + mockGenkit = sinon.createStubInstance(Genkit); + fetchStub = sinon.stub(global, 'fetch'); + authMock = sinon.createStubInstance(GoogleAuth); + + authMock.getAccessToken.resolves('test-token'); + defaultRegionalClientOptions.authClient = authMock as unknown as GoogleAuth; + + // Mock Genkit registry methods if needed, though defineBackgroundModel is the key + (mockGenkit as any).registry = { + lookupAction: () => undefined, + generateTraceId: () => 'test-trace-id', + }; + }); + + afterEach(() => { + sinon.restore(); + }); + + function mockFetchResponse(body: any, status = 200) { + const response = new Response(JSON.stringify(body), { + status: status, + statusText: status === 200 ? 'OK' : 'Error', + headers: { 'Content-Type': 'application/json' }, + }); + fetchStub.resolves(Promise.resolve(response)); + } + + function getExpectedHeaders(): Record { + return { + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + 'User-Agent': GENKIT_CLIENT_HEADER, + Authorization: 'Bearer test-token', + 'x-goog-user-project': defaultRegionalClientOptions.projectId, + }; + } + + describe('model()', () => { + it('should return a ModelReference for a known model', () => { + const knownModelName = Object.keys(KNOWN_MODELS)[0]; + const ref = model(knownModelName); + assert.strictEqual(ref.name, `vertexai/${knownModelName}`); + assert.ok(ref.info?.supports?.media); + assert.ok(ref.info?.supports?.longRunning); + }); + + it('should return a ModelReference for an unknown model using generic info', () => { + const unknownModelName = 'veo-unknown-model'; + const ref = model(unknownModelName); + assert.strictEqual(ref.name, `vertexai/${unknownModelName}`); + assert.deepStrictEqual(ref.info, GENERIC_MODEL.info); + }); + }); + + describe('defineModel()', () => { + function captureModelRunner(clientOptions: ClientOptions): { + start: ( + request: GenerateRequest + ) => Promise; + check: (operation: Operation) => Promise; + } { + defineModel(mockGenkit, modelName, clientOptions); + sinon.assert.calledOnce(mockGenkit.defineBackgroundModel); + const callArgs = mockGenkit.defineBackgroundModel.firstCall.args; + assert.strictEqual(callArgs[0].name, `vertexai/${modelName}`); + assert.strictEqual(callArgs[0].configSchema, VeoConfigSchema); + return { + start: callArgs[0].start, + check: callArgs[0].check, + }; + } + + describe('start()', () => { + const prompt = 'A unicycle on the moon'; + const request: GenerateRequest = { + messages: [{ role: 'user', content: [{ text: prompt }] }], + config: { + aspectRatio: '16:9', + durationSeconds: 5, + }, + }; + + it('should call fetch for veoPredict and return operation', async () => { + const mockOp: VeoOperation = { + name: `projects/test-project/locations/us-central1/publishers/google/models/${modelName}/operations/start123`, + done: false, + }; + mockFetchResponse(mockOp); + + const { start } = captureModelRunner(defaultRegionalClientOptions); + const result = await start(request); + + sinon.assert.calledOnce(fetchStub); + const fetchArgs = fetchStub.lastCall.args; + const url = fetchArgs[0]; + const options = fetchArgs[1]; + + const expectedUrl = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: `publishers/google/models/${modelName}`, + resourceMethod: 'predictLongRunning', + clientOptions: defaultRegionalClientOptions, + }); + assert.strictEqual(url, expectedUrl); + assert.strictEqual(options.method, 'POST'); + assert.deepStrictEqual(options.headers, getExpectedHeaders()); + + const expectedPredictRequest: VeoPredictRequest = + toVeoPredictRequest(request); + assert.deepStrictEqual( + JSON.parse(options.body), + expectedPredictRequest + ); + + assert.deepStrictEqual(result, fromVeoOperation(mockOp)); + }); + + it('should propagate API errors', async () => { + const errorBody = { error: { message: 'Invalid arg', code: 400 } }; + mockFetchResponse(errorBody, 400); + + const { start } = captureModelRunner(defaultRegionalClientOptions); + await assert.rejects( + start(request), + /Error fetching from .*predictLongRunning.* Invalid arg/ + ); + }); + + it('should pass AbortSignal to fetch', async () => { + mockFetchResponse({ name: 'operations/abort', done: false }); + const controller = new AbortController(); + const abortSignal = controller.signal; + + const clientOptionsWithSignal = { + ...defaultRegionalClientOptions, + signal: abortSignal, + }; + const { start } = captureModelRunner(clientOptionsWithSignal); + await start(request); + + sinon.assert.calledOnce(fetchStub); + const fetchOptions = fetchStub.lastCall.args[1]; + assert.ok(fetchOptions.signal, 'Fetch options should have a signal'); + + const fetchSignal = fetchOptions.signal; + const abortSpy = sinon.spy(); + fetchSignal.addEventListener('abort', abortSpy); + + // verify that aborting the original signal aborts the fetch signal + controller.abort(); + sinon.assert.calledOnce(abortSpy); + }); + }); + + describe('check()', () => { + const operationId = `projects/test-project/locations/us-central1/publishers/google/models/${modelName}/operations/check123`; + const pendingOp: Operation = { id: operationId, done: false }; + + it('should call fetch for veoCheckOperation and return updated operation', async () => { + const mockResponse: VeoOperation = { + name: operationId, + done: true, + response: { + videos: [ + { + gcsUri: 'gs://test-bucket/video.mp4', + mimeType: 'video/mp4', + }, + ], + }, + }; + mockFetchResponse(mockResponse); + + const { check } = captureModelRunner(defaultRegionalClientOptions); + const result = await check(pendingOp); + + sinon.assert.calledOnce(fetchStub); + const fetchArgs = fetchStub.lastCall.args; + const url = fetchArgs[0]; + const options = fetchArgs[1]; + + const expectedUrl = getVertexAIUrl({ + includeProjectAndLocation: true, + resourcePath: `publishers/google/models/${modelName}`, + resourceMethod: 'fetchPredictOperation', + clientOptions: defaultRegionalClientOptions, + }); + assert.strictEqual(url, expectedUrl); + assert.strictEqual(options.method, 'POST'); + assert.deepStrictEqual(options.headers, getExpectedHeaders()); + + const expectedCheckRequest: VeoOperationRequest = + toVeoOperationRequest(pendingOp); + assert.deepStrictEqual(JSON.parse(options.body), expectedCheckRequest); + + assert.deepStrictEqual(result, fromVeoOperation(mockResponse)); + }); + + it('should propagate API errors for check', async () => { + const errorBody = { error: { message: 'Not found', code: 404 } }; + mockFetchResponse(errorBody, 404); + + const { check } = captureModelRunner(defaultRegionalClientOptions); + await assert.rejects( + check(pendingOp), + /Error fetching from .*fetchPredictOperation.* Not found/ + ); + }); + }); + }); +});