Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions js/plugins/vertexai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
* limitations under the License.
*/

import { z, type Document, type Genkit } from 'genkit';
import { z, type Document } from 'genkit';
import {
embedderRef,
embedderRef as createEmbedderRef,
type EmbedderAction,
type EmbedderReference,
} from 'genkit/embedder';
import { embedder } from 'genkit/plugin';
import type { GoogleAuth } from 'google-auth-library';
import type { PluginOptions } from './common/types.js';
import { predictModel, type PredictClient } from './predict.js';
Expand Down Expand Up @@ -61,7 +62,7 @@ function commonRef(
name: string,
input?: InputType[]
): EmbedderReference<typeof VertexEmbeddingConfigSchema> {
return embedderRef({
return createEmbedderRef({
name: `vertexai/${name}`,
configSchema: VertexEmbeddingConfigSchema,
info: {
Expand All @@ -88,7 +89,7 @@ export const multimodalEmbedding001 = commonRef('multimodalembedding@001', [
'image',
'video',
]);
export const geminiEmbedding001 = embedderRef({
export const geminiEmbedding001 = createEmbedderRef({
name: 'vertexai/gemini-embedding-001',
configSchema: VertexEmbeddingConfigSchema,
info: {
Expand Down Expand Up @@ -254,14 +255,13 @@ type EmbeddingResult = {
};

export function defineVertexAIEmbedder(
ai: Genkit,
name: string,
client: GoogleAuth,
options: PluginOptions
): EmbedderAction<any> {
const embedder =
const embedderRef =
SUPPORTED_EMBEDDER_MODELS[name] ??
embedderRef({
createEmbedderRef({
name: `vertexai/${name}`,
configSchema: VertexEmbeddingConfigSchema,
info: {
Expand Down Expand Up @@ -298,18 +298,21 @@ export function defineVertexAIEmbedder(
return predictClients[requestLocation];
};

return ai.defineEmbedder(
return embedder(
{
name: embedder.name,
configSchema: embedder.configSchema,
info: embedder.info!,
name: embedderRef.name,
configSchema: embedderRef.configSchema,
info: embedderRef.info!,
},
async (input, options) => {
const predictClient = predictClientFactory(options);
async (request, options) => {
const predictClient = predictClientFactory(embedderRef.config);
const response = await predictClient(
input.map((doc: Document) => {
request.input.map((doc: Document) => {
let instance: EmbeddingInstance;
if (isMultiModal(embedder) && checkValidDocument(embedder, doc)) {
if (
isMultiModal(embedderRef) &&
checkValidDocument(embedderRef, doc)
) {
instance = {};
if (doc.text) {
instance.text = doc.text;
Expand Down Expand Up @@ -370,13 +373,13 @@ export function defineVertexAIEmbedder(
// Text only embedder
instance = {
content: doc.text,
task_type: options?.taskType,
title: options?.title,
task_type: embedderRef.config?.taskType,
title: embedderRef.config?.title,
};
}
return instance;
}),
{ outputDimensionality: options?.outputDimensionality }
{ outputDimensionality: embedderRef.config?.outputDimensionality }
);
return {
embeddings: response.predictions
Expand Down
35 changes: 9 additions & 26 deletions js/plugins/vertexai/src/evaluation/evaluation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { z, type Action, type Genkit } from 'genkit';
import { z, type Action } from 'genkit';
import type { GoogleAuth } from 'google-auth-library';
import { EvaluatorFactory } from './evaluator_factory.js';

Expand Down Expand Up @@ -54,7 +54,6 @@ function stringify(input: unknown) {
}

export function vertexEvaluators(
ai: Genkit,
auth: GoogleAuth,
metrics: VertexAIEvaluationMetric[],
projectId: string,
Expand All @@ -67,28 +66,28 @@ export function vertexEvaluators(

switch (metricType) {
case VertexAIEvaluationMetricType.BLEU: {
return createBleuEvaluator(ai, factory, metricSpec);
return createBleuEvaluator(factory, metricSpec);
}
case VertexAIEvaluationMetricType.ROUGE: {
return createRougeEvaluator(ai, factory, metricSpec);
return createRougeEvaluator(factory, metricSpec);
}
case VertexAIEvaluationMetricType.FLUENCY: {
return createFluencyEvaluator(ai, factory, metricSpec);
return createFluencyEvaluator(factory, metricSpec);
}
case VertexAIEvaluationMetricType.SAFETY: {
return createSafetyEvaluator(ai, factory, metricSpec);
return createSafetyEvaluator(factory, metricSpec);
}
case VertexAIEvaluationMetricType.GROUNDEDNESS: {
return createGroundednessEvaluator(ai, factory, metricSpec);
return createGroundednessEvaluator(factory, metricSpec);
}
case VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY: {
return createSummarizationQualityEvaluator(ai, factory, metricSpec);
return createSummarizationQualityEvaluator(factory, metricSpec);
}
case VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS: {
return createSummarizationHelpfulnessEvaluator(ai, factory, metricSpec);
return createSummarizationHelpfulnessEvaluator(factory, metricSpec);
}
case VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY: {
return createSummarizationVerbosityEvaluator(ai, factory, metricSpec);
return createSummarizationVerbosityEvaluator(factory, metricSpec);
}
}
});
Expand All @@ -108,12 +107,10 @@ const BleuResponseSchema = z.object({

// TODO: Add support for batch inputs
function createBleuEvaluator(
ai: Genkit,
factory: EvaluatorFactory,
metricSpec: any
): Action {
return factory.create(
ai,
{
metric: VertexAIEvaluationMetricType.BLEU,
displayName: 'BLEU',
Expand Down Expand Up @@ -150,12 +147,10 @@ const RougeResponseSchema = z.object({

// TODO: Add support for batch inputs
function createRougeEvaluator(
ai: Genkit,
factory: EvaluatorFactory,
metricSpec: any
): Action {
return factory.create(
ai,
{
metric: VertexAIEvaluationMetricType.ROUGE,
displayName: 'ROUGE',
Expand Down Expand Up @@ -191,12 +186,10 @@ const FluencyResponseSchema = z.object({
});

function createFluencyEvaluator(
ai: Genkit,
factory: EvaluatorFactory,
metricSpec: any
): Action {
return factory.create(
ai,
{
metric: VertexAIEvaluationMetricType.FLUENCY,
displayName: 'Fluency',
Expand Down Expand Up @@ -233,12 +226,10 @@ const SafetyResponseSchema = z.object({
});

function createSafetyEvaluator(
ai: Genkit,
factory: EvaluatorFactory,
metricSpec: any
): Action {
return factory.create(
ai,
{
metric: VertexAIEvaluationMetricType.SAFETY,
displayName: 'Safety',
Expand Down Expand Up @@ -275,12 +266,10 @@ const GroundednessResponseSchema = z.object({
});

function createGroundednessEvaluator(
ai: Genkit,
factory: EvaluatorFactory,
metricSpec: any
): Action {
return factory.create(
ai,
{
metric: VertexAIEvaluationMetricType.GROUNDEDNESS,
displayName: 'Groundedness',
Expand Down Expand Up @@ -319,12 +308,10 @@ const SummarizationQualityResponseSchema = z.object({
});

function createSummarizationQualityEvaluator(
ai: Genkit,
factory: EvaluatorFactory,
metricSpec: any
): Action {
return factory.create(
ai,
{
metric: VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY,
displayName: 'Summarization quality',
Expand Down Expand Up @@ -363,12 +350,10 @@ const SummarizationHelpfulnessResponseSchema = z.object({
});

function createSummarizationHelpfulnessEvaluator(
ai: Genkit,
factory: EvaluatorFactory,
metricSpec: any
): Action {
return factory.create(
ai,
{
metric: VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS,
displayName: 'Summarization helpfulness',
Expand Down Expand Up @@ -408,12 +393,10 @@ const SummarizationVerbositySchema = z.object({
});

function createSummarizationVerbosityEvaluator(
ai: Genkit,
factory: EvaluatorFactory,
metricSpec: any
): Action {
return factory.create(
ai,
{
metric: VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY,
displayName: 'Summarization verbosity',
Expand Down
9 changes: 3 additions & 6 deletions js/plugins/vertexai/src/evaluation/evaluator_factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
* limitations under the License.
*/

import { type Action, type Genkit, type z } from 'genkit';
import { type Action, type z } from 'genkit';
import type { BaseEvalDataPoint, Score } from 'genkit/evaluator';
import { evaluator } from 'genkit/plugin';
import { runInNewSpan } from 'genkit/tracing';
import type { GoogleAuth } from 'google-auth-library';
import { getGenkitClientHeader } from '../common/index.js';
Expand All @@ -29,7 +30,6 @@ export class EvaluatorFactory {
) {}

create<ResponseType extends z.ZodTypeAny>(
ai: Genkit,
config: {
metric: VertexAIEvaluationMetricType;
displayName: string;
Expand All @@ -39,7 +39,7 @@ export class EvaluatorFactory {
toRequest: (datapoint: BaseEvalDataPoint) => any,
responseHandler: (response: z.infer<ResponseType>) => Score
): Action {
return ai.defineEvaluator(
return evaluator(
{
name: `vertexai/${config.metric.toLocaleLowerCase()}`,
displayName: config.displayName,
Expand All @@ -48,7 +48,6 @@ export class EvaluatorFactory {
async (datapoint: BaseEvalDataPoint) => {
const responseSchema = config.responseSchema;
const response = await this.evaluateInstances(
ai,
toRequest(datapoint),
responseSchema
);
Expand All @@ -62,13 +61,11 @@ export class EvaluatorFactory {
}

async evaluateInstances<ResponseType extends z.ZodTypeAny>(
ai: Genkit,
partialRequest: any,
responseSchema: ResponseType
): Promise<z.infer<ResponseType>> {
const locationName = `projects/${this.projectId}/locations/${this.location}`;
return await runInNewSpan(
ai,
{
metadata: {
name: 'EvaluationService#evaluateInstances',
Expand Down
15 changes: 9 additions & 6 deletions js/plugins/vertexai/src/evaluation/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
* limitations under the License.
*/

import type { Genkit } from 'genkit';
import { genkitPlugin, type GenkitPlugin } from 'genkit/plugin';
import { genkitPluginV2, type GenkitPluginV2 } from 'genkit/plugin';
import { getDerivedParams } from '../common/index.js';
import { vertexEvaluators } from './evaluation.js';
import type { PluginOptions } from './types.js';
Expand All @@ -25,10 +24,14 @@ export type { PluginOptions };
/**
* Add Google Cloud Vertex AI Rerankers API to Genkit.
*/
export function vertexAIEvaluation(options: PluginOptions): GenkitPlugin {
return genkitPlugin('vertexAIEvaluation', async (ai: Genkit) => {
const { projectId, location, authClient } = await getDerivedParams(options);
export function vertexAIEvaluation(options: PluginOptions): GenkitPluginV2 {
return genkitPluginV2({
name: 'vertexAIEvaluation',
init: async () => {
const { projectId, location, authClient } =
await getDerivedParams(options);

vertexEvaluators(ai, authClient, options.metrics, projectId, location);
return vertexEvaluators(authClient, options.metrics, projectId, location);
},
});
}
Loading