Skip to content

feat(js/plugins/google-genai): Added support for global Vertex endpoint #3249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions js/plugins/google-genai/src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ export declare interface GoogleSearchRetrievalTool {
googleSearchRetrieval?: GoogleSearchRetrieval;
googleSearch?: GoogleSearchRetrieval;
}
export function isGoogleSearchRetrievalTool(
tool: Tool
): tool is GoogleSearchRetrievalTool {
return (
(tool as GoogleSearchRetrievalTool).googleSearchRetrieval !== undefined ||
(tool as GoogleSearchRetrievalTool).googleSearch !== undefined
);
}

/**
* Grounding support.
Expand Down Expand Up @@ -737,6 +745,11 @@ export declare interface FunctionDeclarationsTool {
*/
functionDeclarations?: FunctionDeclaration[];
}
export function isFunctionDeclarationsTool(
tool: Tool
): tool is FunctionDeclarationsTool {
return (tool as FunctionDeclarationsTool).functionDeclarations !== undefined;
}

/**
* Google AI Only. Enables the model to execute code as part of generation.
Expand All @@ -749,6 +762,9 @@ export declare interface CodeExecutionTool {
*/
codeExecution: {};
}
export function isCodeExecutionTool(tool: Tool): tool is CodeExecutionTool {
return (tool as CodeExecutionTool).codeExecution !== undefined;
}

/**
* Vertex AI Only. Retrieve from Vertex AI Search datastore for grounding.
Expand Down Expand Up @@ -830,6 +846,9 @@ export declare interface RetrievalTool {
/** Optional. {@link Retrieval}. */
retrieval?: Retrieval;
}
export function isRetrievalTool(tool: Tool): tool is RetrievalTool {
return (tool as RetrievalTool).retrieval !== undefined;
}

/**
* Tool to retrieve public web data for grounding, powered by Google.
Expand Down
10 changes: 6 additions & 4 deletions js/plugins/google-genai/src/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ export function checkModelName(name?: string): string {
}

export function extractText(request: GenerateRequest) {
return request.messages
.at(-1)!
.content.map((c) => c.text || '')
.join('');
return (
request.messages
.at(-1)
?.content.map((c) => c.text || '')
.join('') ?? ''
);
}

export function extractImagenImage(
Expand Down
2 changes: 1 addition & 1 deletion js/plugins/google-genai/src/googleai/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ export async function veoPredict(
return response.json() as Promise<VeoOperation>;
}

export async function checkVeoOperation(
export async function veoCheckOperation(
apiKey: string,
operation: string,
clientOptions?: ClientOptions
Expand Down
3 changes: 0 additions & 3 deletions js/plugins/google-genai/src/googleai/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,6 @@ const KNOWN_GEMINI_MODELS = {
'gemini-2.0-flash-preview-image-generation'
),
'gemini-2.0-flash-lite': commonRef('gemini-2.0-flash-lite'),
'gemini-1.5-flash': commonRef('gemini-1.5-flash'),
'gemini-1.5-flash-8b': commonRef('gemini-1.5-flash-8b'),
'gemini-1.5-pro': commonRef('gemini-1.5-pro'),
};
export type KnownGeminiModels = keyof typeof KNOWN_GEMINI_MODELS;
export type GeminiModelName = `gemini-${string}`;
Expand Down
4 changes: 2 additions & 2 deletions js/plugins/google-genai/src/googleai/veo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {
type ModelInfo,
type ModelReference,
} from 'genkit/model';
import { checkVeoOperation, veoPredict } from './client.js';
import { veoCheckOperation, veoPredict } from './client.js';
import {
ClientOptions,
GoogleAIPluginOptions,
Expand Down Expand Up @@ -201,7 +201,7 @@ export function defineModel(
},
async check(operation) {
const apiKey = calculateApiKey(pluginOptions?.apiKey, undefined);
const response = await checkVeoOperation(
const response = await veoCheckOperation(
apiKey,
operation.id,
clientOptions
Expand Down
75 changes: 48 additions & 27 deletions js/plugins/google-genai/src/vertexai/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { GENKIT_CLIENT_HEADER } from 'genkit';
import { GENKIT_CLIENT_HEADER, GenkitError } from 'genkit';
import { GoogleAuth } from 'google-auth-library';
import { extractErrMsg } from '../common/utils';
import {
Expand Down Expand Up @@ -137,19 +137,6 @@ export async function imagenPredict(
return response.json() as Promise<ImagenPredictResponse>;
}

// TODO(ifielker): update with 'global' and APIKey in the options.
// See genai SDK for how to handle the apiKey
// if (
// this.clientOptions.project &&
// this.clientOptions.location &&
// this.clientOptions.location !== 'global'
// ) {
// // Regional endpoint
// return `https://${this.clientOptions.location}-aiplatform.googleapis.com/`;
// }
// // Global endpoint (covers 'global' location and API key usage)
// return `https://aiplatform.googleapis.com/`;

export function getVertexAIUrl(params: {
includeProjectAndLocation: boolean; // False for listModels, true for most others
resourcePath: string;
Expand All @@ -160,11 +147,19 @@ export function getVertexAIUrl(params: {
const DEFAULT_API_VERSION = 'v1beta1';
const API_BASE_PATH = 'aiplatform.googleapis.com';

const region = params.clientOptions.location || 'us-central1';
const basePath = `${region}-${API_BASE_PATH}`;
let basePath: string;

if (params.clientOptions.kind == 'regional') {
basePath = `${params.clientOptions.location}-${API_BASE_PATH}`;
} else {
basePath = API_BASE_PATH;
}

let resourcePath = params.resourcePath;
if (params.includeProjectAndLocation) {
if (
params.clientOptions.kind != 'express' &&
params.includeProjectAndLocation
) {
const parent = `projects/${params.clientOptions.projectId}/locations/${params.clientOptions.location}`;
resourcePath = `${parent}/${params.resourcePath}`;
}
Expand All @@ -173,24 +168,50 @@ export function getVertexAIUrl(params: {
if (params.resourceMethod) {
url += `:${params.resourceMethod}`;
}

let joiner = '?';
if (params.queryParams) {
url += `?${params.queryParams}`;
url += `${joiner}${params.queryParams}`;
joiner = '&';
}
if (params.resourceMethod === 'streamGenerateContent') {
url += `${params.queryParams ? '&' : '?'}alt=sse`;
url += `${joiner}alt=sse`;
joiner = '&';
}
if (params.clientOptions.kind == 'express' && !params.clientOptions.apiKey) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: 'missing api key',
});
}
if (params.clientOptions.apiKey) {
// Required for express, optional for others
url += `${joiner}key=${params.clientOptions.apiKey}`;
joiner = '&';
}
return url;
}

async function getHeaders(clientOptions: ClientOptions): Promise<HeadersInit> {
const token = await getToken(clientOptions.authClient);
const headers: HeadersInit = {
Authorization: `Bearer ${token}`,
'x-goog-user-project': clientOptions.projectId,
'Content-Type': 'application/json',
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
};
return headers;
if (clientOptions.kind == 'express') {
const headers: HeadersInit = {
// ApiKey is in the url query params
'Content-Type': 'application/json',
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
'User-Agent': GENKIT_CLIENT_HEADER,
};
return headers;
} else {
const token = await getToken(clientOptions.authClient);
const headers: HeadersInit = {
Authorization: `Bearer ${token}`,
'x-goog-user-project': clientOptions.projectId,
'Content-Type': 'application/json',
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
'User-Agent': GENKIT_CLIENT_HEADER,
};
return headers;
}
}

async function getToken(authClient: GoogleAuth): Promise<string> {
Expand Down
22 changes: 22 additions & 0 deletions js/plugins/google-genai/src/vertexai/converters.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,25 @@ export function toGeminiSafetySettings(
};
});
}

export function toGeminiLabels(
labels?: Record<string, string>
): Record<string, string> | undefined {
if (!labels) {
return undefined;
}
const keys = Object.keys(labels);
const newLabels: Record<string, string> = {};
for (const key of keys) {
const value = labels[key];
if (!key) {
continue;
}
newLabels[key] = value;
}

if (Object.keys(newLabels).length == 0) {
return undefined;
}
return newLabels;
}
77 changes: 70 additions & 7 deletions js/plugins/google-genai/src/vertexai/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import {
generateContentStream,
getVertexAIUrl,
} from './client';
import { toGeminiSafetySettings } from './converters';
import { toGeminiLabels, toGeminiSafetySettings } from './converters';
import {
ClientOptions,
Content,
Expand All @@ -59,6 +59,7 @@ import {
ToolConfig,
VertexPluginOptions,
} from './types';
import { calculateApiKey } from './utils';

export const SafetySettingsSchema = z.object({
category: z.enum([
Expand Down Expand Up @@ -123,6 +124,14 @@ const GoogleSearchRetrievalSchema = z.object({
* Please refer to: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#generationconfig, for further information.
*/
export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
apiKey: z
.string()
.describe('Overrides the plugin-configured API key, if specified.')
.optional(),
labels: z
.record(z.string())
.optional()
.describe('Key-value labels to attach to the request for cost tracking.'),
temperature: z
.number()
.min(0.0)
Expand Down Expand Up @@ -419,6 +428,8 @@ export function defineModel(
use: middlewares,
},
async (request, sendChunk) => {
let clientOpt = { ...clientOptions };

// Make a copy of messages to avoid side-effects
const messages = structuredClone(request.messages);
if (messages.length === 0) throw new Error('No messages provided.');
Expand All @@ -431,19 +442,57 @@ export function defineModel(
systemInstruction = toGeminiSystemInstruction(systemMessage);
}

const requestConfig = request.config as GeminiConfig;
const requestConfig = { ...request.config };

const {
apiKey: apiKeyFromConfig,
functionCallingConfig,
version: versionFromConfig,
googleSearchRetrieval,
tools: toolsFromConfig,
vertexRetrieval,
location, // location can be overridden via config, take it out.
location,
safetySettings,
labels: labelsFromConfig,
...restOfConfig
} = requestConfig;

if (
location &&
clientOptions.kind != 'express' &&
clientOptions.location != location
) {
// Override the location if it's specified in the request
if (location == 'global') {
clientOpt = {
kind: 'global',
location: 'global',
projectId: clientOptions.projectId,
authClient: clientOptions.authClient,
apiKey: clientOptions.apiKey,
};
} else {
clientOpt = {
kind: 'regional',
location,
projectId: clientOptions.projectId,
authClient: clientOptions.authClient,
apiKey: clientOptions.apiKey,
};
}
}
if (clientOptions.kind == 'express') {
clientOpt.apiKey = calculateApiKey(
clientOptions.apiKey,
apiKeyFromConfig
);
} else if (apiKeyFromConfig) {
// Regional or Global can still use APIKey for billing (not auth)
clientOpt.apiKey = apiKeyFromConfig;
}

const labels = toGeminiLabels(labelsFromConfig);

const tools: Tool[] = [];
if (request.tools?.length) {
tools.push({
Expand Down Expand Up @@ -492,10 +541,23 @@ export function defineModel(

if (vertexRetrieval) {
const _projectId =
vertexRetrieval.datastore.projectId || clientOptions.projectId;
vertexRetrieval.datastore.projectId ||
(clientOptions.kind != 'express'
? clientOptions.projectId
: undefined);
const _location =
vertexRetrieval.datastore.location || clientOptions.location;
vertexRetrieval.datastore.location ||
(clientOptions.kind == 'regional'
? clientOptions.location
: undefined);
const _dataStoreId = vertexRetrieval.datastore.dataStoreId;
if (!_projectId || !_location || !_dataStoreId) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message:
'projectId, location and datastoreId are required for vertexRetrieval and could not be determined from configuration',
});
}
const datastore = `projects/${_projectId}/locations/${_location}/collections/default_collection/dataStores/${_dataStoreId}`;
tools.push({
retrieval: {
Expand All @@ -518,6 +580,7 @@ export function defineModel(
toolConfig,
safetySettings: toGeminiSafetySettings(safetySettings),
contents: messages.map((message) => toGeminiMessage(message, ref)),
labels,
};

const modelVersion = versionFromConfig || (ref.version as string);
Expand All @@ -536,7 +599,7 @@ export function defineModel(
const result = await generateContentStream(
modelVersion,
generateContentRequest,
clientOptions
clientOpt
);

for await (const item of result.stream) {
Expand All @@ -555,7 +618,7 @@ export function defineModel(
response = await generateContent(
modelVersion,
generateContentRequest,
clientOptions
clientOpt
);
}

Expand Down
Loading