@@ -112,10 +112,20 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
112
112
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
113
113
vae ([`AutoencoderKLWan`]):
114
114
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
115
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
116
+ Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
117
+ two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
118
+ stages. If not provided, only `transformer` is used.
119
+ boundary_ratio (`float`, *optional*, defaults to `None`):
120
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
121
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
122
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
123
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
115
124
"""
116
125
117
- model_cpu_offload_seq = "text_encoder->transformer->vae"
126
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2-> vae"
118
127
_callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
128
+ _optional_components = ["transformer_2" ]
119
129
120
130
def __init__ (
121
131
self ,
@@ -124,6 +134,9 @@ def __init__(
124
134
transformer : WanTransformer3DModel ,
125
135
vae : AutoencoderKLWan ,
126
136
scheduler : FlowMatchEulerDiscreteScheduler ,
137
+ transformer_2 : Optional [WanTransformer3DModel ] = None ,
138
+ boundary_ratio : Optional [float ] = None ,
139
+ expand_timesteps : bool = False , # Wan2.2 ti2v
127
140
):
128
141
super ().__init__ ()
129
142
@@ -133,10 +146,12 @@ def __init__(
133
146
tokenizer = tokenizer ,
134
147
transformer = transformer ,
135
148
scheduler = scheduler ,
149
+ transformer_2 = transformer_2 ,
136
150
)
137
-
138
- self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
139
- self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
151
+ self .register_to_config (boundary_ratio = boundary_ratio )
152
+ self .register_to_config (expand_timesteps = expand_timesteps )
153
+ self .vae_scale_factor_temporal = self .vae .config .scale_factor_temporal if getattr (self , "vae" , None ) else 4
154
+ self .vae_scale_factor_spatial = self .vae .config .scale_factor_spatial if getattr (self , "vae" , None ) else 8
140
155
self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
141
156
142
157
def _get_t5_prompt_embeds (
@@ -270,6 +285,7 @@ def check_inputs(
270
285
prompt_embeds = None ,
271
286
negative_prompt_embeds = None ,
272
287
callback_on_step_end_tensor_inputs = None ,
288
+ guidance_scale_2 = None ,
273
289
):
274
290
if height % 16 != 0 or width % 16 != 0 :
275
291
raise ValueError (f"`height` and `width` have to be divisible by 16 but are { height } and { width } ." )
@@ -302,6 +318,9 @@ def check_inputs(
302
318
):
303
319
raise ValueError (f"`negative_prompt` has to be of type `str` or `list` but is { type (negative_prompt )} " )
304
320
321
+ if self .config .boundary_ratio is None and guidance_scale_2 is not None :
322
+ raise ValueError ("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None." )
323
+
305
324
def prepare_latents (
306
325
self ,
307
326
batch_size : int ,
@@ -369,6 +388,7 @@ def __call__(
369
388
num_frames : int = 81 ,
370
389
num_inference_steps : int = 50 ,
371
390
guidance_scale : float = 5.0 ,
391
+ guidance_scale_2 : Optional [float ] = None ,
372
392
num_videos_per_prompt : Optional [int ] = 1 ,
373
393
generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
374
394
latents : Optional [torch .Tensor ] = None ,
@@ -407,6 +427,10 @@ def __call__(
407
427
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
408
428
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
409
429
the text `prompt`, usually at the expense of lower image quality.
430
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
431
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
432
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
433
+ and the pipeline's `boundary_ratio` are not None.
410
434
num_videos_per_prompt (`int`, *optional*, defaults to 1):
411
435
The number of images to generate per prompt.
412
436
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -461,6 +485,7 @@ def __call__(
461
485
prompt_embeds ,
462
486
negative_prompt_embeds ,
463
487
callback_on_step_end_tensor_inputs ,
488
+ guidance_scale_2 ,
464
489
)
465
490
466
491
if num_frames % self .vae_scale_factor_temporal != 1 :
@@ -470,7 +495,11 @@ def __call__(
470
495
num_frames = num_frames // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
471
496
num_frames = max (num_frames , 1 )
472
497
498
+ if self .config .boundary_ratio is not None and guidance_scale_2 is None :
499
+ guidance_scale_2 = guidance_scale
500
+
473
501
self ._guidance_scale = guidance_scale
502
+ self ._guidance_scale_2 = guidance_scale_2
474
503
self ._attention_kwargs = attention_kwargs
475
504
self ._current_timestep = None
476
505
self ._interrupt = False
@@ -520,21 +549,44 @@ def __call__(
520
549
latents ,
521
550
)
522
551
552
+ mask = torch .ones (latents .shape , dtype = torch .float32 , device = device )
553
+
523
554
# 6. Denoising loop
524
555
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
525
556
self ._num_timesteps = len (timesteps )
526
557
558
+ if self .config .boundary_ratio is not None :
559
+ boundary_timestep = self .config .boundary_ratio * self .scheduler .config .num_train_timesteps
560
+ else :
561
+ boundary_timestep = None
562
+
527
563
with self .progress_bar (total = num_inference_steps ) as progress_bar :
528
564
for i , t in enumerate (timesteps ):
529
565
if self .interrupt :
530
566
continue
531
567
532
568
self ._current_timestep = t
533
- latent_model_input = latents .to (transformer_dtype )
534
- timestep = t .expand (latents .shape [0 ])
535
569
536
- with self .transformer .cache_context ("cond" ):
537
- noise_pred = self .transformer (
570
+ if boundary_timestep is None or t >= boundary_timestep :
571
+ # wan2.1 or high-noise stage in wan2.2
572
+ current_model = self .transformer
573
+ current_guidance_scale = guidance_scale
574
+ else :
575
+ # low-noise stage in wan2.2
576
+ current_model = self .transformer_2
577
+ current_guidance_scale = guidance_scale_2
578
+
579
+ latent_model_input = latents .to (transformer_dtype )
580
+ if self .config .expand_timesteps :
581
+ # seq_len: num_latent_frames * latent_height//2 * latent_width//2
582
+ temp_ts = (mask [0 ][0 ][:, ::2 , ::2 ] * t ).flatten ()
583
+ # batch_size, seq_len
584
+ timestep = temp_ts .unsqueeze (0 ).expand (latents .shape [0 ], - 1 )
585
+ else :
586
+ timestep = t .expand (latents .shape [0 ])
587
+
588
+ with current_model .cache_context ("cond" ):
589
+ noise_pred = current_model (
538
590
hidden_states = latent_model_input ,
539
591
timestep = timestep ,
540
592
encoder_hidden_states = prompt_embeds ,
@@ -543,15 +595,15 @@ def __call__(
543
595
)[0 ]
544
596
545
597
if self .do_classifier_free_guidance :
546
- with self . transformer .cache_context ("uncond" ):
547
- noise_uncond = self . transformer (
598
+ with current_model .cache_context ("uncond" ):
599
+ noise_uncond = current_model (
548
600
hidden_states = latent_model_input ,
549
601
timestep = timestep ,
550
602
encoder_hidden_states = negative_prompt_embeds ,
551
603
attention_kwargs = attention_kwargs ,
552
604
return_dict = False ,
553
605
)[0 ]
554
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
606
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
555
607
556
608
# compute the previous noisy sample x_t -> x_t-1
557
609
latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments