Skip to content

Commit b614b48

Browse files
authored
feat(js/plugins/compat-oai): Support dynamic model resolution (#3117)
1 parent e2a30b1 commit b614b48

File tree

6 files changed

+237
-61
lines changed

6 files changed

+237
-61
lines changed

js/plugins/compat-oai/src/index.ts

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,21 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
import { type Genkit } from 'genkit';
17+
import { ActionMetadata, type Genkit } from 'genkit';
1818
import { genkitPlugin } from 'genkit/plugin';
19+
import { ActionType } from 'genkit/registry';
1920
import { OpenAI, type ClientOptions } from 'openai';
2021

2122
export interface PluginOptions extends Partial<ClientOptions> {
2223
name: string;
2324
initializer?: (ai: Genkit, client: OpenAI) => Promise<void>;
25+
resolver?: (
26+
ai: Genkit,
27+
client: OpenAI,
28+
actionType: ActionType,
29+
actionName: string
30+
) => Promise<void>;
31+
listActions?: (client: OpenAI) => Promise<ActionMetadata[]>;
2432
}
2533

2634
/**
@@ -75,12 +83,29 @@ export interface PluginOptions extends Partial<ClientOptions> {
7583
* });
7684
* ```
7785
*/
78-
export const openAICompatible = (options: PluginOptions) =>
79-
genkitPlugin(options.name, async (ai: Genkit) => {
80-
const client = new OpenAI(options);
81-
if (options.initializer) {
82-
await options.initializer(ai, client);
86+
export const openAICompatible = (options: PluginOptions) => {
87+
const client = new OpenAI(options);
88+
let listActionsCache;
89+
return genkitPlugin(
90+
options.name,
91+
async (ai: Genkit) => {
92+
if (options.initializer) {
93+
await options.initializer(ai, client);
94+
}
95+
},
96+
async (ai: Genkit, actionType: ActionType, actionName: string) => {
97+
if (options.resolver) {
98+
await options.resolver(ai, client, actionType, actionName);
99+
}
100+
},
101+
async () => {
102+
if (options.listActions) {
103+
if (listActionsCache) return listActionsCache;
104+
listActionsCache = await options.listActions(client);
105+
return listActionsCache;
106+
}
83107
}
84-
});
108+
);
109+
};
85110

86111
export default openAICompatible;

js/plugins/compat-oai/src/openai/dalle.ts

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
* limitations under the License.
1616
*/
1717
import { z } from 'genkit';
18-
import { modelRef } from 'genkit/model';
18+
import { ModelInfo, modelRef } from 'genkit/model';
1919

20-
export const DallE3ConfigSchema = z.object({
20+
export const ImageGenerationConfigSchema = z.object({
2121
size: z.enum(['1024x1024', '1792x1024', '1024x1792']).optional(),
2222
style: z.enum(['vivid', 'natural']).optional(),
2323
user: z.string().optional(),
@@ -26,19 +26,23 @@ export const DallE3ConfigSchema = z.object({
2626
response_format: z.enum(['b64_json', 'url']).optional(),
2727
});
2828

29+
export const IMAGE_GENERATION_MODEL_INFO: ModelInfo = {
30+
supports: {
31+
media: false,
32+
output: ['media'],
33+
multiturn: false,
34+
systemRole: false,
35+
tools: false,
36+
},
37+
};
38+
2939
export const dallE3 = modelRef({
3040
name: 'openai/dall-e-3',
3141
info: {
3242
label: 'OpenAI - DALL-E 3',
33-
supports: {
34-
media: false,
35-
output: ['media'],
36-
multiturn: false,
37-
systemRole: false,
38-
tools: false,
39-
},
43+
...IMAGE_GENERATION_MODEL_INFO,
4044
},
41-
configSchema: DallE3ConfigSchema,
45+
configSchema: ImageGenerationConfigSchema,
4246
});
4347

4448
export const SUPPORTED_IMAGE_MODELS = {

js/plugins/compat-oai/src/openai/index.ts

Lines changed: 164 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,20 @@
1515
* limitations under the License.
1616
*/
1717

18+
import {
19+
ActionMetadata,
20+
embedderActionMetadata,
21+
embedderRef,
22+
EmbedderReference,
23+
Genkit,
24+
modelActionMetadata,
25+
modelRef,
26+
ModelReference,
27+
z,
28+
} from 'genkit';
29+
import { GenkitPlugin } from 'genkit/plugin';
30+
import { ActionType } from 'genkit/registry';
31+
import OpenAI from 'openai';
1832
import {
1933
defineCompatOpenAISpeechModel,
2034
defineCompatOpenAITranscriptionModel,
@@ -23,14 +37,104 @@ import { defineCompatOpenAIEmbedder } from '../embedder.js';
2337
import { defineCompatOpenAIImageModel } from '../image.js';
2438
import openAICompatible, { PluginOptions } from '../index.js';
2539
import { defineCompatOpenAIModel } from '../model.js';
26-
import { SUPPORTED_IMAGE_MODELS } from './dalle.js';
27-
import { SUPPORTED_EMBEDDING_MODELS } from './embedder.js';
28-
import { SUPPORTED_GPT_MODELS } from './gpt.js';
29-
import { SUPPORTED_TTS_MODELS } from './tts.js';
30-
import { SUPPORTED_STT_MODELS } from './whisper.js';
40+
import {
41+
IMAGE_GENERATION_MODEL_INFO,
42+
ImageGenerationConfigSchema,
43+
SUPPORTED_IMAGE_MODELS,
44+
} from './dalle.js';
45+
import {
46+
SUPPORTED_EMBEDDING_MODELS,
47+
TextEmbeddingConfigSchema,
48+
} from './embedder.js';
49+
import { ChatCompletionConfigSchema, SUPPORTED_GPT_MODELS } from './gpt.js';
50+
import {
51+
SPEECH_MODEL_INFO,
52+
SpeechConfigSchema,
53+
SUPPORTED_TTS_MODELS,
54+
} from './tts.js';
55+
import { SUPPORTED_STT_MODELS, TranscriptionConfigSchema } from './whisper.js';
3156

3257
export type OpenAIPluginOptions = Exclude<PluginOptions, 'name'>;
3358

59+
const resolver = async (
60+
ai: Genkit,
61+
client: OpenAI,
62+
actionType: ActionType,
63+
actionName: string
64+
) => {
65+
if (actionType === 'embedder') {
66+
defineCompatOpenAIEmbedder({ ai, name: `openai/${actionName}`, client });
67+
} else if (
68+
actionName.includes('gpt-image-1') ||
69+
actionName.includes('dall-e')
70+
) {
71+
defineCompatOpenAIImageModel({ ai, name: `openai/${actionName}`, client });
72+
} else if (actionName.includes('tts')) {
73+
defineCompatOpenAISpeechModel({ ai, name: `openai/${actionName}`, client });
74+
} else if (
75+
actionName.includes('whisper') ||
76+
actionName.includes('transcribe')
77+
) {
78+
defineCompatOpenAITranscriptionModel({
79+
ai,
80+
name: `openai/${actionName}`,
81+
client,
82+
});
83+
} else {
84+
defineCompatOpenAIModel({
85+
ai,
86+
name: `openai/${actionName}`,
87+
client,
88+
});
89+
}
90+
};
91+
92+
const listActions = async (client: OpenAI): Promise<ActionMetadata[]> => {
93+
return await client.models.list().then((response) =>
94+
response.data
95+
.filter((model) => model.object === 'model')
96+
.map((model: OpenAI.Model) => {
97+
if (model.id.includes('embedding')) {
98+
return embedderActionMetadata({
99+
name: `openai/${model.id}`,
100+
configSchema: TextEmbeddingConfigSchema,
101+
info: SUPPORTED_EMBEDDING_MODELS[model.id]?.info,
102+
});
103+
} else if (
104+
model.id.includes('gpt-image-1') ||
105+
model.id.includes('dall-e')
106+
) {
107+
return modelActionMetadata({
108+
name: `openai/${model.id}`,
109+
configSchema: ImageGenerationConfigSchema,
110+
info: IMAGE_GENERATION_MODEL_INFO,
111+
});
112+
} else if (model.id.includes('tts')) {
113+
return modelActionMetadata({
114+
name: `openai/${model.id}`,
115+
configSchema: SpeechConfigSchema,
116+
info: SPEECH_MODEL_INFO,
117+
});
118+
} else if (
119+
model.id.includes('whisper') ||
120+
model.id.includes('transcribe')
121+
) {
122+
return modelActionMetadata({
123+
name: `openai/${model.id}`,
124+
configSchema: TranscriptionConfigSchema,
125+
info: SPEECH_MODEL_INFO,
126+
});
127+
} else {
128+
return modelActionMetadata({
129+
name: `openai/${model.id}`,
130+
configSchema: ChatCompletionConfigSchema,
131+
info: SUPPORTED_GPT_MODELS[model.id]?.info,
132+
});
133+
}
134+
})
135+
);
136+
};
137+
34138
/**
35139
* This module provides an interface to the OpenAI models through the Genkit
36140
* plugin system. It allows users to interact with various models by providing
@@ -60,8 +164,8 @@ export type OpenAIPluginOptions = Exclude<PluginOptions, 'name'>;
60164
* });
61165
* ```
62166
*/
63-
export const openAI = (options?: OpenAIPluginOptions) =>
64-
openAICompatible({
167+
export function openAIPlugin(options?: OpenAIPluginOptions): GenkitPlugin {
168+
return openAICompatible({
65169
name: 'openai',
66170
...options,
67171
initializer: async (ai, client) => {
@@ -101,6 +205,59 @@ export const openAI = (options?: OpenAIPluginOptions) =>
101205
})
102206
);
103207
},
208+
resolver,
209+
listActions,
210+
});
211+
}
212+
213+
export type OpenAIPlugin = {
214+
(params?: OpenAIPluginOptions): GenkitPlugin;
215+
model(name: string, config?: any): ModelReference<z.ZodTypeAny>;
216+
embedder(name: string, config?: any): EmbedderReference<z.ZodTypeAny>;
217+
};
218+
219+
export const openAI = openAIPlugin as OpenAIPlugin;
220+
// provide generic implementation for the model function overloads.
221+
(openAI as any).model = (
222+
name: string,
223+
config?: any
224+
): ModelReference<z.ZodTypeAny> => {
225+
if (name.includes('gpt-image-1') || name.includes('dall-e')) {
226+
return modelRef({
227+
name: `openai/${name}`,
228+
config,
229+
configSchema: ImageGenerationConfigSchema,
230+
});
231+
}
232+
if (name.includes('tts')) {
233+
return modelRef({
234+
name: `openai/${name}`,
235+
config,
236+
configSchema: SpeechConfigSchema,
237+
});
238+
}
239+
if (name.includes('whisper') || name.includes('transcribe')) {
240+
return modelRef({
241+
name: `openai/${name}`,
242+
config,
243+
configSchema: TranscriptionConfigSchema,
244+
});
245+
}
246+
return modelRef({
247+
name: `openai/${name}`,
248+
config,
249+
configSchema: ChatCompletionConfigSchema,
250+
});
251+
};
252+
openAI.embedder = (
253+
name: string,
254+
config?: any
255+
): EmbedderReference<z.ZodTypeAny> => {
256+
return embedderRef({
257+
name: `openai/${name}`,
258+
config,
259+
configSchema: TextEmbeddingConfigSchema,
104260
});
261+
};
105262

106263
export default openAI;

js/plugins/compat-oai/src/openai/tts.ts

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717
import { z } from 'genkit';
18-
import { modelRef } from 'genkit/model';
18+
import { ModelInfo, modelRef } from 'genkit/model';
1919

2020
export const SpeechConfigSchema = z.object({
2121
voice: z
@@ -27,17 +27,21 @@ export const SpeechConfigSchema = z.object({
2727
.optional(),
2828
});
2929

30+
export const SPEECH_MODEL_INFO: ModelInfo = {
31+
supports: {
32+
media: false,
33+
output: ['media'],
34+
multiturn: false,
35+
systemRole: false,
36+
tools: false,
37+
},
38+
};
39+
3040
export const tts1 = modelRef({
3141
name: 'openai/tts-1',
3242
info: {
3343
label: 'OpenAI - Text-to-speech 1',
34-
supports: {
35-
media: false,
36-
output: ['media'],
37-
multiturn: false,
38-
systemRole: false,
39-
tools: false,
40-
},
44+
...SPEECH_MODEL_INFO,
4145
},
4246
configSchema: SpeechConfigSchema,
4347
});
@@ -46,13 +50,7 @@ export const tts1Hd = modelRef({
4650
name: 'openai/tts-1-hd',
4751
info: {
4852
label: 'OpenAI - Text-to-speech 1 HD',
49-
supports: {
50-
media: false,
51-
output: ['media'],
52-
multiturn: false,
53-
systemRole: false,
54-
tools: false,
55-
},
53+
...SPEECH_MODEL_INFO,
5654
},
5755
configSchema: SpeechConfigSchema,
5856
});
@@ -61,13 +59,7 @@ export const gpt4oMiniTts = modelRef({
6159
name: 'openai/gpt-4o-mini-tts',
6260
info: {
6361
label: 'OpenAI - GPT-4o Mini Text-to-speech',
64-
supports: {
65-
media: false,
66-
output: ['media'],
67-
multiturn: false,
68-
systemRole: false,
69-
tools: false,
70-
},
62+
...SPEECH_MODEL_INFO,
7163
},
7264
configSchema: SpeechConfigSchema.omit({ speed: true }),
7365
});

0 commit comments

Comments
 (0)