Skip to content

Commit 3114f6a

Browse files
authored
[Modular] Changes for using WAN I2V (#12959)
* initial * add kayers
1 parent 9d68742 commit 3114f6a

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

src/diffusers/modular_pipelines/mellon_node_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def latents(cls, display: str = "input") -> "MellonParam":
6868
def image_latents(cls, display: str = "input") -> "MellonParam":
6969
return cls(name="image_latents", label="Image Latents", type="latents", display=display)
7070

71+
@classmethod
72+
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
73+
return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display)
74+
7175
@classmethod
7276
def image_latents_with_strength(cls) -> "MellonParam":
7377
return cls(
@@ -89,6 +93,10 @@ def latents_preview(cls) -> "MellonParam":
8993
def embeddings(cls, display: str = "output") -> "MellonParam":
9094
return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display)
9195

96+
@classmethod
97+
def image_embeds(cls, display: str = "output") -> "MellonParam":
98+
return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display)
99+
92100
@classmethod
93101
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
94102
return cls(
@@ -172,6 +180,10 @@ def num_inference_steps(cls, default: int = 25) -> "MellonParam":
172180
def num_frames(cls, default: int = 81) -> "MellonParam":
173181
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
174182

183+
@classmethod
184+
def layers(cls, default: int = 4) -> "MellonParam":
185+
return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider")
186+
175187
@classmethod
176188
def videos(cls) -> "MellonParam":
177189
return cls(name="videos", label="Videos", type="video", display="output")
@@ -186,6 +198,16 @@ def vae(cls) -> "MellonParam":
186198
"""
187199
return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input")
188200

201+
@classmethod
202+
def image_encoder(cls) -> "MellonParam":
203+
"""
204+
Image Encoder model info dict.
205+
206+
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
207+
the actual model.
208+
"""
209+
return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input")
210+
189211
@classmethod
190212
def unet(cls) -> "MellonParam":
191213
"""

src/diffusers/modular_pipelines/wan/modular_blocks.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def description(self):
8484
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
8585
model_name = "wan"
8686
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
87-
block_names = ["image_resize", "vae_image_encoder"]
87+
block_names = ["image_resize", "vae_encoder"]
8888

8989
@property
9090
def description(self):
@@ -142,7 +142,7 @@ def description(self):
142142
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
143143
model_name = "wan"
144144
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
145-
block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
145+
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
146146

147147
@property
148148
def description(self):
@@ -203,7 +203,7 @@ def description(self):
203203
## vae encoder
204204
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
205205
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
206-
block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
206+
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
207207
block_trigger_inputs = ["last_image", "image"]
208208

209209
@property
@@ -251,7 +251,7 @@ class WanAutoBlocks(SequentialPipelineBlocks):
251251
block_names = [
252252
"text_encoder",
253253
"image_encoder",
254-
"vae_image_encoder",
254+
"vae_encoder",
255255
"denoise",
256256
"decode",
257257
]
@@ -353,7 +353,7 @@ class Wan22AutoBlocks(SequentialPipelineBlocks):
353353
]
354354
block_names = [
355355
"text_encoder",
356-
"vae_image_encoder",
356+
"vae_encoder",
357357
"denoise",
358358
"decode",
359359
]
@@ -384,7 +384,7 @@ def description(self):
384384
[
385385
("image_resize", WanImageResizeStep),
386386
("image_encoder", WanImage2VideoImageEncoderStep),
387-
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
387+
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
388388
("input", WanTextInputStep),
389389
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
390390
("set_timesteps", WanSetTimestepsStep),
@@ -401,7 +401,7 @@ def description(self):
401401
("image_resize", WanImageResizeStep),
402402
("last_image_resize", WanImageCropResizeStep),
403403
("image_encoder", WanFLF2VImageEncoderStep),
404-
("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
404+
("vae_encoder", WanFLF2VVaeImageEncoderStep),
405405
("input", WanTextInputStep),
406406
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
407407
("set_timesteps", WanSetTimestepsStep),
@@ -416,7 +416,7 @@ def description(self):
416416
[
417417
("text_encoder", WanTextEncoderStep),
418418
("image_encoder", WanAutoImageEncoderStep),
419-
("vae_image_encoder", WanAutoVaeImageEncoderStep),
419+
("vae_encoder", WanAutoVaeImageEncoderStep),
420420
("denoise", WanAutoDenoiseStep),
421421
("decode", WanImageVaeDecoderStep),
422422
]
@@ -438,7 +438,7 @@ def description(self):
438438
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
439439
[
440440
("image_resize", WanImageResizeStep),
441-
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
441+
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
442442
("input", WanTextInputStep),
443443
("set_timesteps", WanSetTimestepsStep),
444444
("prepare_latents", WanPrepareLatentsStep),
@@ -450,7 +450,7 @@ def description(self):
450450
AUTO_BLOCKS_WAN22 = InsertableDict(
451451
[
452452
("text_encoder", WanTextEncoderStep),
453-
("vae_image_encoder", WanAutoVaeImageEncoderStep),
453+
("vae_encoder", WanAutoVaeImageEncoderStep),
454454
("denoise", Wan22AutoDenoiseStep),
455455
("decode", WanImageVaeDecoderStep),
456456
]

0 commit comments

Comments
 (0)