diff --git a/.changeset/fireworks-image-model-options.md b/.changeset/fireworks-image-model-options.md new file mode 100644 index 000000000000..97e79e8feff9 --- /dev/null +++ b/.changeset/fireworks-image-model-options.md @@ -0,0 +1,5 @@ +--- +'@ai-sdk/fireworks': patch +--- + +Add `FireworksImageModelOptions` type and validate image generation provider options using Zod schema with `parseProviderOptions`. diff --git a/examples/ai-functions/src/generate-image/fireworks.ts b/examples/ai-functions/src/generate-image/fireworks.ts index bc609ed1c7e3..e16c034e2fe2 100644 --- a/examples/ai-functions/src/generate-image/fireworks.ts +++ b/examples/ai-functions/src/generate-image/fireworks.ts @@ -1,4 +1,5 @@ import { fireworks } from '@ai-sdk/fireworks'; +import type { FireworksImageModelOptions } from '@ai-sdk/fireworks'; import { generateImage } from 'ai'; import { presentImages } from '../lib/present-image'; import { run } from '../lib/run'; @@ -17,7 +18,7 @@ run(async () => { // https://fireworks.ai/models/fireworks/stable-diffusion-xl-1024-v1-0/playground cfg_scale: 10, steps: 30, - }, + } satisfies FireworksImageModelOptions, }, }); diff --git a/packages/fireworks/src/fireworks-image-model.test.ts b/packages/fireworks/src/fireworks-image-model.test.ts index a37909fd9599..3536a0ce3abe 100644 --- a/packages/fireworks/src/fireworks-image-model.test.ts +++ b/packages/fireworks/src/fireworks-image-model.test.ts @@ -64,7 +64,7 @@ describe('FireworksImageModel', () => { size: undefined, aspectRatio: '16:9', seed: 42, - providerOptions: { fireworks: { additional_param: 'value' } }, + providerOptions: { fireworks: { cfg_scale: 10 } }, }); expect(await server.calls[0].requestBodyJson).toStrictEqual({ @@ -72,7 +72,7 @@ describe('FireworksImageModel', () => { aspect_ratio: '16:9', seed: 42, samples: 1, - additional_param: 'value', + cfg_scale: 10, }); }); diff --git a/packages/fireworks/src/fireworks-image-model.ts b/packages/fireworks/src/fireworks-image-model.ts index a798fc1b858e..d115ddda0b5e 100644 --- a/packages/fireworks/src/fireworks-image-model.ts +++ b/packages/fireworks/src/fireworks-image-model.ts @@ -5,9 +5,13 @@ import { createBinaryResponseHandler, createStatusCodeErrorResponseHandler, FetchFunction, + parseProviderOptions, postJsonToApi, } from '@ai-sdk/provider-utils'; -import { FireworksImageModelId } from './fireworks-image-options'; +import { + FireworksImageModelId, + fireworksImageModelOptions, +} from './fireworks-image-options'; interface FireworksImageModelBackendConfig { urlFormat: 'workflows' | 'workflows_edit' | 'image_generation'; @@ -163,20 +167,52 @@ export class FireworksImageModel implements ImageModelV3 { }); } + const fireworksOptions = await parseProviderOptions({ + provider: 'fireworks', + providerOptions, + schema: fireworksImageModelOptions, + }); + const splitSize = size?.split('x'); const currentDate = this.config._internal?.currentDate?.() ?? new Date(); + + const body: Record = { + prompt, + aspect_ratio: aspectRatio, + seed, + samples: n, + ...(inputImage && { input_image: inputImage }), + ...(splitSize && { width: splitSize[0], height: splitSize[1] }), + }; + if (fireworksOptions?.cfg_scale != null) { + body.cfg_scale = fireworksOptions.cfg_scale; + } + if (fireworksOptions?.steps != null) { + body.steps = fireworksOptions.steps; + } + if (fireworksOptions?.negative_prompt != null) { + body.negative_prompt = fireworksOptions.negative_prompt; + } + if (fireworksOptions?.strength != null) { + body.strength = fireworksOptions.strength; + } + if (fireworksOptions?.scheduler != null) { + body.scheduler = fireworksOptions.scheduler; + } + if (fireworksOptions?.safety_checker != null) { + body.safety_checker = fireworksOptions.safety_checker; + } + if (fireworksOptions?.output_format != null) { + body.output_format = fireworksOptions.output_format; + } + if (fireworksOptions?.safety_tolerance != null) { + body.safety_tolerance = fireworksOptions.safety_tolerance; + } + const { value: response, responseHeaders } = await postJsonToApi({ url: getUrlForModel(this.config.baseURL, this.modelId, hasInputImage), headers: combineHeaders(this.config.headers(), headers), - body: { - prompt, - aspect_ratio: aspectRatio, - seed, - samples: n, - ...(inputImage && { input_image: inputImage }), - ...(splitSize && { width: splitSize[0], height: splitSize[1] }), - ...(providerOptions.fireworks ?? {}), - }, + body, failedResponseHandler: createStatusCodeErrorResponseHandler(), successfulResponseHandler: createBinaryResponseHandler(), abortSignal, diff --git a/packages/fireworks/src/fireworks-image-options.ts b/packages/fireworks/src/fireworks-image-options.ts index 47355d91e010..b076910d6396 100644 --- a/packages/fireworks/src/fireworks-image-options.ts +++ b/packages/fireworks/src/fireworks-image-options.ts @@ -1,3 +1,5 @@ +import { z } from 'zod/v4'; + // https://fireworks.ai/models?type=image export type FireworksImageModelId = | 'accounts/fireworks/models/flux-1-dev-fp8' @@ -10,3 +12,19 @@ export type FireworksImageModelId = | 'accounts/fireworks/models/SSD-1B' | 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0' | (string & {}); + +// https://docs.fireworks.ai/api-reference/post-imagegeneration +export const fireworksImageModelOptions = z.object({ + cfg_scale: z.number().optional(), + steps: z.number().optional(), + negative_prompt: z.string().optional(), + strength: z.number().optional(), + scheduler: z.string().optional(), + safety_checker: z.boolean().optional(), + output_format: z.string().optional(), + safety_tolerance: z.number().optional(), +}); + +export type FireworksImageModelOptions = z.infer< + typeof fireworksImageModelOptions +>; diff --git a/packages/fireworks/src/index.ts b/packages/fireworks/src/index.ts index 1468d8113dc3..e14c26044c0d 100644 --- a/packages/fireworks/src/index.ts +++ b/packages/fireworks/src/index.ts @@ -10,7 +10,10 @@ export type { FireworksEmbeddingModelOptions as FireworksEmbeddingProviderOptions, } from './fireworks-embedding-options'; export { FireworksImageModel } from './fireworks-image-model'; -export type { FireworksImageModelId } from './fireworks-image-options'; +export type { + FireworksImageModelId, + FireworksImageModelOptions, +} from './fireworks-image-options'; export { fireworks, createFireworks } from './fireworks-provider'; export type { FireworksProvider,