Skip to content

Commit 35a3d28

Browse files
authored
feat(js/plugins/google-genai): Added support for global Vertex endpoint (#3249)
1 parent 026bbf4 commit 35a3d28

File tree

20 files changed

+2541
-922
lines changed

20 files changed

+2541
-922
lines changed

js/plugins/google-genai/src/common/types.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,14 @@ export declare interface GoogleSearchRetrievalTool {
140140
googleSearchRetrieval?: GoogleSearchRetrieval;
141141
googleSearch?: GoogleSearchRetrieval;
142142
}
143+
export function isGoogleSearchRetrievalTool(
144+
tool: Tool
145+
): tool is GoogleSearchRetrievalTool {
146+
return (
147+
(tool as GoogleSearchRetrievalTool).googleSearchRetrieval !== undefined ||
148+
(tool as GoogleSearchRetrievalTool).googleSearch !== undefined
149+
);
150+
}
143151

144152
/**
145153
* Grounding support.
@@ -737,6 +745,11 @@ export declare interface FunctionDeclarationsTool {
737745
*/
738746
functionDeclarations?: FunctionDeclaration[];
739747
}
748+
export function isFunctionDeclarationsTool(
749+
tool: Tool
750+
): tool is FunctionDeclarationsTool {
751+
return (tool as FunctionDeclarationsTool).functionDeclarations !== undefined;
752+
}
740753

741754
/**
742755
* Google AI Only. Enables the model to execute code as part of generation.
@@ -749,6 +762,9 @@ export declare interface CodeExecutionTool {
749762
*/
750763
codeExecution: {};
751764
}
765+
export function isCodeExecutionTool(tool: Tool): tool is CodeExecutionTool {
766+
return (tool as CodeExecutionTool).codeExecution !== undefined;
767+
}
752768

753769
/**
754770
* Vertex AI Only. Retrieve from Vertex AI Search datastore for grounding.
@@ -830,6 +846,9 @@ export declare interface RetrievalTool {
830846
/** Optional. {@link Retrieval}. */
831847
retrieval?: Retrieval;
832848
}
849+
export function isRetrievalTool(tool: Tool): tool is RetrievalTool {
850+
return (tool as RetrievalTool).retrieval !== undefined;
851+
}
833852

834853
/**
835854
* Tool to retrieve public web data for grounding, powered by Google.

js/plugins/google-genai/src/common/utils.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,12 @@ export function checkModelName(name?: string): string {
6868
}
6969

7070
export function extractText(request: GenerateRequest) {
71-
return request.messages
72-
.at(-1)!
73-
.content.map((c) => c.text || '')
74-
.join('');
71+
return (
72+
request.messages
73+
.at(-1)
74+
?.content.map((c) => c.text || '')
75+
.join('') ?? ''
76+
);
7577
}
7678

7779
export function extractImagenImage(

js/plugins/google-genai/src/googleai/client.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ export async function veoPredict(
198198
return response.json() as Promise<VeoOperation>;
199199
}
200200

201-
export async function checkVeoOperation(
201+
export async function veoCheckOperation(
202202
apiKey: string,
203203
operation: string,
204204
clientOptions?: ClientOptions

js/plugins/google-genai/src/googleai/gemini.ts

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,6 @@ const KNOWN_GEMINI_MODELS = {
309309
'gemini-2.0-flash-preview-image-generation'
310310
),
311311
'gemini-2.0-flash-lite': commonRef('gemini-2.0-flash-lite'),
312-
'gemini-1.5-flash': commonRef('gemini-1.5-flash'),
313-
'gemini-1.5-flash-8b': commonRef('gemini-1.5-flash-8b'),
314-
'gemini-1.5-pro': commonRef('gemini-1.5-pro'),
315312
};
316313
export type KnownGeminiModels = keyof typeof KNOWN_GEMINI_MODELS;
317314
export type GeminiModelName = `gemini-${string}`;

js/plugins/google-genai/src/googleai/veo.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import {
2929
type ModelInfo,
3030
type ModelReference,
3131
} from 'genkit/model';
32-
import { checkVeoOperation, veoPredict } from './client.js';
32+
import { veoCheckOperation, veoPredict } from './client.js';
3333
import {
3434
ClientOptions,
3535
GoogleAIPluginOptions,
@@ -201,7 +201,7 @@ export function defineModel(
201201
},
202202
async check(operation) {
203203
const apiKey = calculateApiKey(pluginOptions?.apiKey, undefined);
204-
const response = await checkVeoOperation(
204+
const response = await veoCheckOperation(
205205
apiKey,
206206
operation.id,
207207
clientOptions

js/plugins/google-genai/src/vertexai/client.ts

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
import { GENKIT_CLIENT_HEADER } from 'genkit';
17+
import { GENKIT_CLIENT_HEADER, GenkitError } from 'genkit';
1818
import { GoogleAuth } from 'google-auth-library';
1919
import { extractErrMsg } from '../common/utils';
2020
import {
@@ -137,19 +137,6 @@ export async function imagenPredict(
137137
return response.json() as Promise<ImagenPredictResponse>;
138138
}
139139

140-
// TODO(ifielker): update with 'global' and APIKey in the options.
141-
// See genai SDK for how to handle the apiKey
142-
// if (
143-
// this.clientOptions.project &&
144-
// this.clientOptions.location &&
145-
// this.clientOptions.location !== 'global'
146-
// ) {
147-
// // Regional endpoint
148-
// return `https://${this.clientOptions.location}-aiplatform.googleapis.com/`;
149-
// }
150-
// // Global endpoint (covers 'global' location and API key usage)
151-
// return `https://aiplatform.googleapis.com/`;
152-
153140
export function getVertexAIUrl(params: {
154141
includeProjectAndLocation: boolean; // False for listModels, true for most others
155142
resourcePath: string;
@@ -160,11 +147,19 @@ export function getVertexAIUrl(params: {
160147
const DEFAULT_API_VERSION = 'v1beta1';
161148
const API_BASE_PATH = 'aiplatform.googleapis.com';
162149

163-
const region = params.clientOptions.location || 'us-central1';
164-
const basePath = `${region}-${API_BASE_PATH}`;
150+
let basePath: string;
151+
152+
if (params.clientOptions.kind == 'regional') {
153+
basePath = `${params.clientOptions.location}-${API_BASE_PATH}`;
154+
} else {
155+
basePath = API_BASE_PATH;
156+
}
165157

166158
let resourcePath = params.resourcePath;
167-
if (params.includeProjectAndLocation) {
159+
if (
160+
params.clientOptions.kind != 'express' &&
161+
params.includeProjectAndLocation
162+
) {
168163
const parent = `projects/${params.clientOptions.projectId}/locations/${params.clientOptions.location}`;
169164
resourcePath = `${parent}/${params.resourcePath}`;
170165
}
@@ -173,24 +168,50 @@ export function getVertexAIUrl(params: {
173168
if (params.resourceMethod) {
174169
url += `:${params.resourceMethod}`;
175170
}
171+
172+
let joiner = '?';
176173
if (params.queryParams) {
177-
url += `?${params.queryParams}`;
174+
url += `${joiner}${params.queryParams}`;
175+
joiner = '&';
178176
}
179177
if (params.resourceMethod === 'streamGenerateContent') {
180-
url += `${params.queryParams ? '&' : '?'}alt=sse`;
178+
url += `${joiner}alt=sse`;
179+
joiner = '&';
180+
}
181+
if (params.clientOptions.kind == 'express' && !params.clientOptions.apiKey) {
182+
throw new GenkitError({
183+
status: 'INVALID_ARGUMENT',
184+
message: 'missing api key',
185+
});
186+
}
187+
if (params.clientOptions.apiKey) {
188+
// Required for express, optional for others
189+
url += `${joiner}key=${params.clientOptions.apiKey}`;
190+
joiner = '&';
181191
}
182192
return url;
183193
}
184194

185195
async function getHeaders(clientOptions: ClientOptions): Promise<HeadersInit> {
186-
const token = await getToken(clientOptions.authClient);
187-
const headers: HeadersInit = {
188-
Authorization: `Bearer ${token}`,
189-
'x-goog-user-project': clientOptions.projectId,
190-
'Content-Type': 'application/json',
191-
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
192-
};
193-
return headers;
196+
if (clientOptions.kind == 'express') {
197+
const headers: HeadersInit = {
198+
// ApiKey is in the url query params
199+
'Content-Type': 'application/json',
200+
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
201+
'User-Agent': GENKIT_CLIENT_HEADER,
202+
};
203+
return headers;
204+
} else {
205+
const token = await getToken(clientOptions.authClient);
206+
const headers: HeadersInit = {
207+
Authorization: `Bearer ${token}`,
208+
'x-goog-user-project': clientOptions.projectId,
209+
'Content-Type': 'application/json',
210+
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
211+
'User-Agent': GENKIT_CLIENT_HEADER,
212+
};
213+
return headers;
214+
}
194215
}
195216

196217
async function getToken(authClient: GoogleAuth): Promise<string> {

js/plugins/google-genai/src/vertexai/converters.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,25 @@ export function toGeminiSafetySettings(
3333
};
3434
});
3535
}
36+
37+
export function toGeminiLabels(
38+
labels?: Record<string, string>
39+
): Record<string, string> | undefined {
40+
if (!labels) {
41+
return undefined;
42+
}
43+
const keys = Object.keys(labels);
44+
const newLabels: Record<string, string> = {};
45+
for (const key of keys) {
46+
const value = labels[key];
47+
if (!key) {
48+
continue;
49+
}
50+
newLabels[key] = value;
51+
}
52+
53+
if (Object.keys(newLabels).length == 0) {
54+
return undefined;
55+
}
56+
return newLabels;
57+
}

js/plugins/google-genai/src/vertexai/gemini.ts

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ import {
4646
generateContentStream,
4747
getVertexAIUrl,
4848
} from './client';
49-
import { toGeminiSafetySettings } from './converters';
49+
import { toGeminiLabels, toGeminiSafetySettings } from './converters';
5050
import {
5151
ClientOptions,
5252
Content,
@@ -59,6 +59,7 @@ import {
5959
ToolConfig,
6060
VertexPluginOptions,
6161
} from './types';
62+
import { calculateApiKey } from './utils';
6263

6364
export const SafetySettingsSchema = z.object({
6465
category: z.enum([
@@ -123,6 +124,14 @@ const GoogleSearchRetrievalSchema = z.object({
123124
* Please refer to: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#generationconfig, for further information.
124125
*/
125126
export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
127+
apiKey: z
128+
.string()
129+
.describe('Overrides the plugin-configured API key, if specified.')
130+
.optional(),
131+
labels: z
132+
.record(z.string())
133+
.optional()
134+
.describe('Key-value labels to attach to the request for cost tracking.'),
126135
temperature: z
127136
.number()
128137
.min(0.0)
@@ -419,6 +428,8 @@ export function defineModel(
419428
use: middlewares,
420429
},
421430
async (request, sendChunk) => {
431+
let clientOpt = { ...clientOptions };
432+
422433
// Make a copy of messages to avoid side-effects
423434
const messages = structuredClone(request.messages);
424435
if (messages.length === 0) throw new Error('No messages provided.');
@@ -431,19 +442,57 @@ export function defineModel(
431442
systemInstruction = toGeminiSystemInstruction(systemMessage);
432443
}
433444

434-
const requestConfig = request.config as GeminiConfig;
445+
const requestConfig = { ...request.config };
435446

436447
const {
448+
apiKey: apiKeyFromConfig,
437449
functionCallingConfig,
438450
version: versionFromConfig,
439451
googleSearchRetrieval,
440452
tools: toolsFromConfig,
441453
vertexRetrieval,
442-
location, // location can be overridden via config, take it out.
454+
location,
443455
safetySettings,
456+
labels: labelsFromConfig,
444457
...restOfConfig
445458
} = requestConfig;
446459

460+
if (
461+
location &&
462+
clientOptions.kind != 'express' &&
463+
clientOptions.location != location
464+
) {
465+
// Override the location if it's specified in the request
466+
if (location == 'global') {
467+
clientOpt = {
468+
kind: 'global',
469+
location: 'global',
470+
projectId: clientOptions.projectId,
471+
authClient: clientOptions.authClient,
472+
apiKey: clientOptions.apiKey,
473+
};
474+
} else {
475+
clientOpt = {
476+
kind: 'regional',
477+
location,
478+
projectId: clientOptions.projectId,
479+
authClient: clientOptions.authClient,
480+
apiKey: clientOptions.apiKey,
481+
};
482+
}
483+
}
484+
if (clientOptions.kind == 'express') {
485+
clientOpt.apiKey = calculateApiKey(
486+
clientOptions.apiKey,
487+
apiKeyFromConfig
488+
);
489+
} else if (apiKeyFromConfig) {
490+
// Regional or Global can still use APIKey for billing (not auth)
491+
clientOpt.apiKey = apiKeyFromConfig;
492+
}
493+
494+
const labels = toGeminiLabels(labelsFromConfig);
495+
447496
const tools: Tool[] = [];
448497
if (request.tools?.length) {
449498
tools.push({
@@ -492,10 +541,23 @@ export function defineModel(
492541

493542
if (vertexRetrieval) {
494543
const _projectId =
495-
vertexRetrieval.datastore.projectId || clientOptions.projectId;
544+
vertexRetrieval.datastore.projectId ||
545+
(clientOptions.kind != 'express'
546+
? clientOptions.projectId
547+
: undefined);
496548
const _location =
497-
vertexRetrieval.datastore.location || clientOptions.location;
549+
vertexRetrieval.datastore.location ||
550+
(clientOptions.kind == 'regional'
551+
? clientOptions.location
552+
: undefined);
498553
const _dataStoreId = vertexRetrieval.datastore.dataStoreId;
554+
if (!_projectId || !_location || !_dataStoreId) {
555+
throw new GenkitError({
556+
status: 'INVALID_ARGUMENT',
557+
message:
558+
'projectId, location and datastoreId are required for vertexRetrieval and could not be determined from configuration',
559+
});
560+
}
499561
const datastore = `projects/${_projectId}/locations/${_location}/collections/default_collection/dataStores/${_dataStoreId}`;
500562
tools.push({
501563
retrieval: {
@@ -518,6 +580,7 @@ export function defineModel(
518580
toolConfig,
519581
safetySettings: toGeminiSafetySettings(safetySettings),
520582
contents: messages.map((message) => toGeminiMessage(message, ref)),
583+
labels,
521584
};
522585

523586
const modelVersion = versionFromConfig || (ref.version as string);
@@ -536,7 +599,7 @@ export function defineModel(
536599
const result = await generateContentStream(
537600
modelVersion,
538601
generateContentRequest,
539-
clientOptions
602+
clientOpt
540603
);
541604

542605
for await (const item of result.stream) {
@@ -555,7 +618,7 @@ export function defineModel(
555618
response = await generateContent(
556619
modelVersion,
557620
generateContentRequest,
558-
clientOptions
621+
clientOpt
559622
);
560623
}
561624

0 commit comments

Comments
 (0)