11import type { TextToImageArgs } from "../tasks/cv/textToImage.js" ;
22import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js" ;
33import type { TextToVideoArgs } from "../tasks/cv/textToVideo.js" ;
4+ import type { ImageToVideoArgs } from "../tasks/cv/imageToVideo.js" ;
45import type { BodyParams , RequestArgs , UrlParams } from "../types.js" ;
56import { delay } from "../utils/delay.js" ;
67import { omit } from "../utils/omit.js" ;
78import { base64FromBytes } from "../utils/base64FromBytes.js" ;
8- import type { TextToImageTaskHelper , TextToVideoTaskHelper , ImageToImageTaskHelper } from "./providerHelper.js" ;
9+ import type {
10+ TextToImageTaskHelper ,
11+ TextToVideoTaskHelper ,
12+ ImageToImageTaskHelper ,
13+ ImageToVideoTaskHelper ,
14+ } from "./providerHelper.js" ;
915import { TaskProviderHelper } from "./providerHelper.js" ;
1016import {
1117 InferenceClientInputError ,
@@ -72,7 +78,9 @@ abstract class WavespeedAITask extends TaskProviderHelper {
7278 return `/api/v3/${ params . model } ` ;
7379 }
7480
75- preparePayload ( params : BodyParams < ImageToImageArgs | TextToImageArgs | TextToVideoArgs > ) : Record < string , unknown > {
81+ preparePayload (
82+ params : BodyParams < ImageToImageArgs | TextToImageArgs | TextToVideoArgs | ImageToVideoArgs >
83+ ) : Record < string , unknown > {
7684 const payload : Record < string , unknown > = {
7785 ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
7886 ...params . args . parameters ,
@@ -95,11 +103,17 @@ abstract class WavespeedAITask extends TaskProviderHelper {
95103 url ?: string ,
96104 headers ?: Record < string , string >
97105 ) : Promise < Blob > {
98- if ( ! headers ) {
106+ if ( ! url || ! headers ) {
99107 throw new InferenceClientInputError ( "Headers are required for WaveSpeed AI API calls" ) ;
100108 }
101109
102- const resultUrl = response . data . urls . get ;
110+ const parsedUrl = new URL ( url ) ;
111+ const resultPath = new URL ( response . data . urls . get ) . pathname ;
112+ /// override the base url to use the router.huggingface.co if going through huggingface router
113+ const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${
114+ parsedUrl . host === "router.huggingface.co" ? "/wavespeed" : ""
115+ } `;
116+ const resultUrl = `${ baseUrl } ${ resultPath } ` ;
103117
104118 // Poll for results until completion
105119 while ( true ) {
@@ -183,3 +197,19 @@ export class WavespeedAIImageToImageTask extends WavespeedAITask implements Imag
183197 } ;
184198 }
185199}
200+
201+ export class WavespeedAIImageToVideoTask extends WavespeedAITask implements ImageToVideoTaskHelper {
202+ constructor ( ) {
203+ super ( WAVESPEEDAI_API_BASE_URL ) ;
204+ }
205+
206+ async preparePayloadAsync ( args : ImageToVideoArgs ) : Promise < RequestArgs > {
207+ return {
208+ ...args ,
209+ inputs : args . parameters ?. prompt ,
210+ image : base64FromBytes (
211+ new Uint8Array ( args . inputs instanceof ArrayBuffer ? args . inputs : await ( args . inputs as Blob ) . arrayBuffer ( ) )
212+ ) ,
213+ } ;
214+ }
215+ }
0 commit comments