15
15
* limitations under the License.
16
16
*/
17
17
18
+ import {
19
+ ActionMetadata ,
20
+ embedderActionMetadata ,
21
+ embedderRef ,
22
+ EmbedderReference ,
23
+ Genkit ,
24
+ modelActionMetadata ,
25
+ modelRef ,
26
+ ModelReference ,
27
+ z ,
28
+ } from 'genkit' ;
29
+ import { GenkitPlugin } from 'genkit/plugin' ;
30
+ import { ActionType } from 'genkit/registry' ;
31
+ import OpenAI from 'openai' ;
18
32
import {
19
33
defineCompatOpenAISpeechModel ,
20
34
defineCompatOpenAITranscriptionModel ,
@@ -23,14 +37,104 @@ import { defineCompatOpenAIEmbedder } from '../embedder.js';
23
37
import { defineCompatOpenAIImageModel } from '../image.js' ;
24
38
import openAICompatible , { PluginOptions } from '../index.js' ;
25
39
import { defineCompatOpenAIModel } from '../model.js' ;
26
- import { SUPPORTED_IMAGE_MODELS } from './dalle.js' ;
27
- import { SUPPORTED_EMBEDDING_MODELS } from './embedder.js' ;
28
- import { SUPPORTED_GPT_MODELS } from './gpt.js' ;
29
- import { SUPPORTED_TTS_MODELS } from './tts.js' ;
30
- import { SUPPORTED_STT_MODELS } from './whisper.js' ;
40
+ import {
41
+ IMAGE_GENERATION_MODEL_INFO ,
42
+ ImageGenerationConfigSchema ,
43
+ SUPPORTED_IMAGE_MODELS ,
44
+ } from './dalle.js' ;
45
+ import {
46
+ SUPPORTED_EMBEDDING_MODELS ,
47
+ TextEmbeddingConfigSchema ,
48
+ } from './embedder.js' ;
49
+ import { ChatCompletionConfigSchema , SUPPORTED_GPT_MODELS } from './gpt.js' ;
50
+ import {
51
+ SPEECH_MODEL_INFO ,
52
+ SpeechConfigSchema ,
53
+ SUPPORTED_TTS_MODELS ,
54
+ } from './tts.js' ;
55
+ import { SUPPORTED_STT_MODELS , TranscriptionConfigSchema } from './whisper.js' ;
31
56
32
57
export type OpenAIPluginOptions = Exclude < PluginOptions , 'name' > ;
33
58
59
+ const resolver = async (
60
+ ai : Genkit ,
61
+ client : OpenAI ,
62
+ actionType : ActionType ,
63
+ actionName : string
64
+ ) => {
65
+ if ( actionType === 'embedder' ) {
66
+ defineCompatOpenAIEmbedder ( { ai, name : `openai/${ actionName } ` , client } ) ;
67
+ } else if (
68
+ actionName . includes ( 'gpt-image-1' ) ||
69
+ actionName . includes ( 'dall-e' )
70
+ ) {
71
+ defineCompatOpenAIImageModel ( { ai, name : `openai/${ actionName } ` , client } ) ;
72
+ } else if ( actionName . includes ( 'tts' ) ) {
73
+ defineCompatOpenAISpeechModel ( { ai, name : `openai/${ actionName } ` , client } ) ;
74
+ } else if (
75
+ actionName . includes ( 'whisper' ) ||
76
+ actionName . includes ( 'transcribe' )
77
+ ) {
78
+ defineCompatOpenAITranscriptionModel ( {
79
+ ai,
80
+ name : `openai/${ actionName } ` ,
81
+ client,
82
+ } ) ;
83
+ } else {
84
+ defineCompatOpenAIModel ( {
85
+ ai,
86
+ name : `openai/${ actionName } ` ,
87
+ client,
88
+ } ) ;
89
+ }
90
+ } ;
91
+
92
+ const listActions = async ( client : OpenAI ) : Promise < ActionMetadata [ ] > => {
93
+ return await client . models . list ( ) . then ( ( response ) =>
94
+ response . data
95
+ . filter ( ( model ) => model . object === 'model' )
96
+ . map ( ( model : OpenAI . Model ) => {
97
+ if ( model . id . includes ( 'embedding' ) ) {
98
+ return embedderActionMetadata ( {
99
+ name : `openai/${ model . id } ` ,
100
+ configSchema : TextEmbeddingConfigSchema ,
101
+ info : SUPPORTED_EMBEDDING_MODELS [ model . id ] ?. info ,
102
+ } ) ;
103
+ } else if (
104
+ model . id . includes ( 'gpt-image-1' ) ||
105
+ model . id . includes ( 'dall-e' )
106
+ ) {
107
+ return modelActionMetadata ( {
108
+ name : `openai/${ model . id } ` ,
109
+ configSchema : ImageGenerationConfigSchema ,
110
+ info : IMAGE_GENERATION_MODEL_INFO ,
111
+ } ) ;
112
+ } else if ( model . id . includes ( 'tts' ) ) {
113
+ return modelActionMetadata ( {
114
+ name : `openai/${ model . id } ` ,
115
+ configSchema : SpeechConfigSchema ,
116
+ info : SPEECH_MODEL_INFO ,
117
+ } ) ;
118
+ } else if (
119
+ model . id . includes ( 'whisper' ) ||
120
+ model . id . includes ( 'transcribe' )
121
+ ) {
122
+ return modelActionMetadata ( {
123
+ name : `openai/${ model . id } ` ,
124
+ configSchema : TranscriptionConfigSchema ,
125
+ info : SPEECH_MODEL_INFO ,
126
+ } ) ;
127
+ } else {
128
+ return modelActionMetadata ( {
129
+ name : `openai/${ model . id } ` ,
130
+ configSchema : ChatCompletionConfigSchema ,
131
+ info : SUPPORTED_GPT_MODELS [ model . id ] ?. info ,
132
+ } ) ;
133
+ }
134
+ } )
135
+ ) ;
136
+ } ;
137
+
34
138
/**
35
139
* This module provides an interface to the OpenAI models through the Genkit
36
140
* plugin system. It allows users to interact with various models by providing
@@ -60,8 +164,8 @@ export type OpenAIPluginOptions = Exclude<PluginOptions, 'name'>;
60
164
* });
61
165
* ```
62
166
*/
63
- export const openAI = ( options ?: OpenAIPluginOptions ) =>
64
- openAICompatible ( {
167
+ export function openAIPlugin ( options ?: OpenAIPluginOptions ) : GenkitPlugin {
168
+ return openAICompatible ( {
65
169
name : 'openai' ,
66
170
...options ,
67
171
initializer : async ( ai , client ) => {
@@ -101,6 +205,59 @@ export const openAI = (options?: OpenAIPluginOptions) =>
101
205
} )
102
206
) ;
103
207
} ,
208
+ resolver,
209
+ listActions,
210
+ } ) ;
211
+ }
212
+
213
+ export type OpenAIPlugin = {
214
+ ( params ?: OpenAIPluginOptions ) : GenkitPlugin ;
215
+ model ( name : string , config ?: any ) : ModelReference < z . ZodTypeAny > ;
216
+ embedder ( name : string , config ?: any ) : EmbedderReference < z . ZodTypeAny > ;
217
+ } ;
218
+
219
+ export const openAI = openAIPlugin as OpenAIPlugin ;
220
+ // provide generic implementation for the model function overloads.
221
+ ( openAI as any ) . model = (
222
+ name : string ,
223
+ config ?: any
224
+ ) : ModelReference < z . ZodTypeAny > => {
225
+ if ( name . includes ( 'gpt-image-1' ) || name . includes ( 'dall-e' ) ) {
226
+ return modelRef ( {
227
+ name : `openai/${ name } ` ,
228
+ config,
229
+ configSchema : ImageGenerationConfigSchema ,
230
+ } ) ;
231
+ }
232
+ if ( name . includes ( 'tts' ) ) {
233
+ return modelRef ( {
234
+ name : `openai/${ name } ` ,
235
+ config,
236
+ configSchema : SpeechConfigSchema ,
237
+ } ) ;
238
+ }
239
+ if ( name . includes ( 'whisper' ) || name . includes ( 'transcribe' ) ) {
240
+ return modelRef ( {
241
+ name : `openai/${ name } ` ,
242
+ config,
243
+ configSchema : TranscriptionConfigSchema ,
244
+ } ) ;
245
+ }
246
+ return modelRef ( {
247
+ name : `openai/${ name } ` ,
248
+ config,
249
+ configSchema : ChatCompletionConfigSchema ,
250
+ } ) ;
251
+ } ;
252
+ openAI . embedder = (
253
+ name : string ,
254
+ config ?: any
255
+ ) : EmbedderReference < z . ZodTypeAny > => {
256
+ return embedderRef ( {
257
+ name : `openai/${ name } ` ,
258
+ config,
259
+ configSchema : TextEmbeddingConfigSchema ,
104
260
} ) ;
261
+ } ;
105
262
106
263
export default openAI ;
0 commit comments