Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion js/ai/src/formats/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

import type { JSONSchema } from '@genkit-ai/core';
import type { GenerateResponseChunk } from '../generate.js';
import type { GenerateResponseChunk } from '../generate/chunk.js';
import type { Message } from '../message.js';
import type { ModelRequest } from '../model.js';

Expand Down
2 changes: 1 addition & 1 deletion js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ import {
import type { Formatter } from '../formats/types.js';
import {
GenerateResponse,
GenerateResponseChunk,
GenerationResponseError,
tagAsPreamble,
} from '../generate.js';
import { GenerateResponseChunk } from '../generate/chunk.js';
import {
GenerateActionOptionsSchema,
GenerateResponseChunkSchema,
Expand Down
8 changes: 4 additions & 4 deletions js/plugins/google-genai/src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export enum FunctionCallingMode {
}

/**
* The reason why the reponse is blocked.
* The reason why the response is blocked.
*/
export enum BlockReason {
/** Unspecified block reason. */
Expand Down Expand Up @@ -156,7 +156,7 @@ export declare interface GroundingSupport {
/** Optional. Segment of the content this support belongs to. */
segment?: GroundingSupportSegment;
/**
* Optional. A arrau of indices (into {@link GroundingChunk}) specifying the
* Optional. A array of indices (into {@link GroundingChunk}) specifying the
* citations associated with the claim. For instance [1,3,4] means
* that grounding_chunk[1], grounding_chunk[3],
* grounding_chunk[4] are the retrieved content attributed to the claim.
Expand Down Expand Up @@ -441,7 +441,7 @@ export declare interface GoogleDate {
year?: number;
/**
* Month of the date. Must be from 1 to 12, or 0 to specify a year without a
* monthi and day.
* month and day.
*/
month?: number;
/**
Expand Down Expand Up @@ -983,7 +983,7 @@ export declare interface GenerateContentRequest {

/**
* Result from calling generateContentStream.
* It constains both the stream and the final aggregated response.
* It contains both the stream and the final aggregated response.
* @public
*/
export declare interface GenerateContentStreamResult {
Expand Down
128 changes: 112 additions & 16 deletions js/plugins/google-genai/src/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ import {
EmbedderReference,
GenkitError,
JSONSchema,
MediaPart,
ModelReference,
Part,
z,
} from 'genkit';
import { GenerateRequest } from 'genkit/model';
import { ImagenInstance } from './types';

/**
* Safely extracts the error message from the error.
Expand Down Expand Up @@ -64,8 +65,9 @@ export function extractVersion(
export function modelName(name?: string): string | undefined {
if (!name) return name;

// Remove any of these prefixes: (but keep tunedModels e.g.)
const prefixesToRemove = /models\/|embedders\/|googleai\/|vertexai\//g;
// Remove any of these prefixes:
const prefixesToRemove =
/background-model\/|model\/|models\/|embedders\/|googleai\/|vertexai\//g;
return name.replace(prefixesToRemove, '');
}

Expand Down Expand Up @@ -95,20 +97,114 @@ export function extractText(request: GenerateRequest) {
);
}

export function extractImagenImage(
request: GenerateRequest
): ImagenInstance['image'] | undefined {
const image = request.messages
.at(-1)
?.content.find(
(p) => !!p.media && (!p.metadata?.type || p.metadata?.type === 'base')
)
?.media?.url.split(',')[1];

if (image) {
return { bytesBase64Encoded: image };
const KNOWN_MIME_TYPES = {
jpg: 'image/jpeg',
jpeg: 'image/jpeg',
png: 'image/png',
mp4: 'video/mp4',
pdf: 'application/pdf',
};

export function extractMimeType(url?: string): string {
if (!url) {
return '';
}

const dataPrefix = 'data:';
if (!url.startsWith(dataPrefix)) {
// Not a data url, try suffix
url.lastIndexOf('.');
const key = url.substring(url.lastIndexOf('.') + 1);
if (Object.keys(KNOWN_MIME_TYPES).includes(key)) {
return KNOWN_MIME_TYPES[key];
}
return '';
}

const commaIndex = url.indexOf(',');
if (commaIndex == -1) {
// Invalid - missing separator
return '';
}

// The part between 'data:' and the comma
let mimeType = url.substring(dataPrefix.length, commaIndex);
const base64Marker = ';base64';
if (mimeType.endsWith(base64Marker)) {
mimeType = mimeType.substring(0, mimeType.length - base64Marker.length);
}

return mimeType.trim();
}

export function checkSupportedMimeType(
media: MediaPart['media'],
supportedTypes: string[]
) {
if (!supportedTypes.includes(media.contentType ?? '')) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: `Invalid mimeType for ${displayUrl(media.url)}: "${media.contentType}". Supported mimeTypes: ${supportedTypes.join(', ')}`,
});
}
}

/**
*
* @param url The url to show (e.g. in an error message)
* @returns The appropriately sized url
*/
export function displayUrl(url: string): string {
if (url.length <= 50) {
return url;
}
return undefined;

return url.substring(0, 25) + '...' + url.substring(url.length - 25);
}

/**
*
* @param request A generate request to extract from
* @param metadataType The media must have metadata matching this type if isDefault is false
* @param isDefault 'true' allows missing metadata type to match as well.
* @returns
*/
export function extractMedia(
request: GenerateRequest,
params: {
metadataType?: string;
/* Is there is no metadata type, it will match if isDefault is true */
isDefault?: boolean;
}
): MediaPart['media'] | undefined {
const predicate = (part: Part) => {
const media = part.media;
if (!media) {
return false;
}
if (params.metadataType || params.isDefault) {
// We need to check the metadata type
const metadata = part.metadata;
if (!metadata?.type) {
return !!params.isDefault;
} else {
return metadata.type == params.metadataType;
}
}
return true;
};

const media = request.messages.at(-1)?.content.find(predicate)?.media;

// Add the mimeType
if (media && !media?.contentType) {
return {
url: media.url,
contentType: extractMimeType(media.url),
};
}

return media;
}

/**
Expand Down
18 changes: 16 additions & 2 deletions js/plugins/google-genai/src/googleai/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

import { GenerateRequest, GenkitError } from 'genkit';
import process from 'process';
import { VeoImage } from './types.js';
import { extractMedia } from '../common/utils.js';
import { ImagenInstance, VeoImage } from './types.js';

export {
checkModelName,
cleanSchema,
extractImagenImage,
extractText,
extractVersion,
modelName,
Expand Down Expand Up @@ -143,3 +143,17 @@ export function extractVeoImage(
}
return undefined;
}

export function extractImagenImage(
request: GenerateRequest
): ImagenInstance['image'] | undefined {
const image = extractMedia(request, {
metadataType: 'base',
isDefault: true,
})?.url.split(',')[1];

if (image) {
return { bytesBase64Encoded: image };
}
return undefined;
}
88 changes: 76 additions & 12 deletions js/plugins/google-genai/src/vertexai/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@ import {
ImagenPredictRequest,
ImagenPredictResponse,
ListModelsResponse,
LyriaPredictRequest,
LyriaPredictResponse,
Model,
VeoOperation,
VeoOperationRequest,
VeoPredictRequest,
} from './types';
import { calculateApiKey, checkIsSupported } from './utils';
import { calculateApiKey, checkSupportedResourceMethod } from './utils';

export async function listModels(
clientOptions: ClientOptions
Expand Down Expand Up @@ -94,25 +99,37 @@ export async function generateContentStream(
return processStream(response);
}

export async function embedContent(
async function internalPredict(
model: string,
embedContentRequest: EmbedContentRequest,
body: string,
clientOptions: ClientOptions
): Promise<EmbedContentResponse> {
): Promise<Response> {
const url = getVertexAIUrl({
includeProjectAndLocation: true,
resourcePath: `publishers/google/models/${model}`,
resourceMethod: 'predict', // embedContent is a Vertex API predict call
resourceMethod: 'predict',
clientOptions,
});

const fetchOptions = await getFetchOptions({
method: 'POST',
clientOptions,
body: JSON.stringify(embedContentRequest),
body,
});

const response = await makeRequest(url, fetchOptions);
return await makeRequest(url, fetchOptions);
}

export async function embedContent(
model: string,
embedContentRequest: EmbedContentRequest,
clientOptions: ClientOptions
): Promise<EmbedContentResponse> {
const response = await internalPredict(
model,
JSON.stringify(embedContentRequest),
clientOptions
);
return response.json() as Promise<EmbedContentResponse>;
}

Expand All @@ -121,31 +138,78 @@ export async function imagenPredict(
imagenPredictRequest: ImagenPredictRequest,
clientOptions: ClientOptions
): Promise<ImagenPredictResponse> {
const response = await internalPredict(
model,
JSON.stringify(imagenPredictRequest),
clientOptions
);
return response.json() as Promise<ImagenPredictResponse>;
}

export async function lyriaPredict(
model: string,
lyriaPredictRequest: LyriaPredictRequest,
clientOptions: ClientOptions
): Promise<LyriaPredictResponse> {
const response = await internalPredict(
model,
JSON.stringify(lyriaPredictRequest),
clientOptions
);
return response.json() as Promise<LyriaPredictResponse>;
}

export async function veoPredict(
model: string,
veoPredictRequest: VeoPredictRequest,
clientOptions: ClientOptions
): Promise<VeoOperation> {
const url = getVertexAIUrl({
includeProjectAndLocation: true,
resourcePath: `publishers/google/models/${model}`,
resourceMethod: 'predict',
resourceMethod: 'predictLongRunning',
clientOptions,
});

const fetchOptions = await getFetchOptions({
method: 'POST',
clientOptions,
body: JSON.stringify(imagenPredictRequest),
body: JSON.stringify(veoPredictRequest),
});

const response = await makeRequest(url, fetchOptions);
return response.json() as Promise<ImagenPredictResponse>;
return response.json() as Promise<VeoOperation>;
}

export async function veoCheckOperation(
model: string,
veoOperationRequest: VeoOperationRequest,
clientOptions: ClientOptions
): Promise<VeoOperation> {
const url = getVertexAIUrl({
includeProjectAndLocation: true,
resourcePath: `publishers/google/models/${model}`,
resourceMethod: 'fetchPredictOperation',
clientOptions,
});
const fetchOptions = await getFetchOptions({
method: 'POST',
clientOptions,
body: JSON.stringify(veoOperationRequest),
});

const response = await makeRequest(url, fetchOptions);
return response.json() as Promise<VeoOperation>;
}

export function getVertexAIUrl(params: {
includeProjectAndLocation: boolean; // False for listModels, true for most others
resourcePath: string;
resourceMethod?: 'streamGenerateContent' | 'generateContent' | 'predict';
resourceMethod?: string;
queryParams?: string;
clientOptions: ClientOptions;
}): string {
checkIsSupported(params);
checkSupportedResourceMethod(params);

const DEFAULT_API_VERSION = 'v1beta1';
const API_BASE_PATH = 'aiplatform.googleapis.com';
Expand Down
Loading