Skip to content

Commit a6d9f6a

Browse files
yiyixuxubghira
andauthored
[WIP] Wan2.2 (#12004)
* support wan 2.2 i2v * add t2v + vae2.2 * add conversion script for vae 2.2 * add * add 5b t2v * conversion script * refactor out reearrange * remove a copied from in skyreels * Apply suggestions from code review Co-authored-by: bagheera <[email protected]> * Update src/diffusers/models/transformers/transformer_wan.py * fix fast tests * style --------- Co-authored-by: bagheera <[email protected]>
1 parent 2841504 commit a6d9f6a

File tree

9 files changed

+1049
-78
lines changed

9 files changed

+1049
-78
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 419 additions & 5 deletions
Large diffs are not rendered by default.

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 399 additions & 41 deletions
Large diffs are not rendered by default.

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,11 @@ def forward(
170170
timestep: torch.Tensor,
171171
encoder_hidden_states: torch.Tensor,
172172
encoder_hidden_states_image: Optional[torch.Tensor] = None,
173+
timestep_seq_len: Optional[int] = None,
173174
):
174175
timestep = self.timesteps_proj(timestep)
176+
if timestep_seq_len is not None:
177+
timestep = timestep.unflatten(0, (1, timestep_seq_len))
175178

176179
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
177180
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
@@ -309,9 +312,23 @@ def forward(
309312
temb: torch.Tensor,
310313
rotary_emb: torch.Tensor,
311314
) -> torch.Tensor:
312-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
313-
self.scale_shift_table + temb.float()
314-
).chunk(6, dim=1)
315+
if temb.ndim == 4:
316+
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
317+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
318+
self.scale_shift_table.unsqueeze(0) + temb.float()
319+
).chunk(6, dim=2)
320+
# batch_size, seq_len, 1, inner_dim
321+
shift_msa = shift_msa.squeeze(2)
322+
scale_msa = scale_msa.squeeze(2)
323+
gate_msa = gate_msa.squeeze(2)
324+
c_shift_msa = c_shift_msa.squeeze(2)
325+
c_scale_msa = c_scale_msa.squeeze(2)
326+
c_gate_msa = c_gate_msa.squeeze(2)
327+
else:
328+
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
329+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
330+
self.scale_shift_table + temb.float()
331+
).chunk(6, dim=1)
315332

316333
# 1. Self-attention
317334
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
@@ -469,10 +486,22 @@ def forward(
469486
hidden_states = self.patch_embedding(hidden_states)
470487
hidden_states = hidden_states.flatten(2).transpose(1, 2)
471488

489+
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
490+
if timestep.ndim == 2:
491+
ts_seq_len = timestep.shape[1]
492+
timestep = timestep.flatten() # batch_size * seq_len
493+
else:
494+
ts_seq_len = None
495+
472496
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
473-
timestep, encoder_hidden_states, encoder_hidden_states_image
497+
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
474498
)
475-
timestep_proj = timestep_proj.unflatten(1, (6, -1))
499+
if ts_seq_len is not None:
500+
# batch_size, seq_len, 6, inner_dim
501+
timestep_proj = timestep_proj.unflatten(2, (6, -1))
502+
else:
503+
# batch_size, 6, inner_dim
504+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
476505

477506
if encoder_hidden_states_image is not None:
478507
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
@@ -488,7 +517,14 @@ def forward(
488517
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
489518

490519
# 5. Output norm, projection & unpatchify
491-
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
520+
if temb.ndim == 3:
521+
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
522+
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
523+
shift = shift.squeeze(2)
524+
scale = scale.squeeze(2)
525+
else:
526+
# batch_size, inner_dim
527+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
492528

493529
# Move the shift and scale tensors to the same device as hidden_states.
494530
# When using multi-GPU inference via accelerate these will be on the

src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ def encode_prompt(
275275

276276
return prompt_embeds, negative_prompt_embeds
277277

278-
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs
279278
def check_inputs(
280279
self,
281280
prompt,

src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def encode_prompt(
316316

317317
return prompt_embeds, negative_prompt_embeds
318318

319-
# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs
320319
def check_inputs(
321320
self,
322321
prompt,

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,20 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
112112
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
113113
vae ([`AutoencoderKLWan`]):
114114
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
115+
transformer_2 ([`WanTransformer3DModel`], *optional*):
116+
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
117+
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
118+
stages. If not provided, only `transformer` is used.
119+
boundary_ratio (`float`, *optional*, defaults to `None`):
120+
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
121+
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
122+
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
123+
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
115124
"""
116125

117-
model_cpu_offload_seq = "text_encoder->transformer->vae"
126+
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
118127
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
128+
_optional_components = ["transformer_2"]
119129

120130
def __init__(
121131
self,
@@ -124,6 +134,9 @@ def __init__(
124134
transformer: WanTransformer3DModel,
125135
vae: AutoencoderKLWan,
126136
scheduler: FlowMatchEulerDiscreteScheduler,
137+
transformer_2: Optional[WanTransformer3DModel] = None,
138+
boundary_ratio: Optional[float] = None,
139+
expand_timesteps: bool = False, # Wan2.2 ti2v
127140
):
128141
super().__init__()
129142

@@ -133,10 +146,12 @@ def __init__(
133146
tokenizer=tokenizer,
134147
transformer=transformer,
135148
scheduler=scheduler,
149+
transformer_2=transformer_2,
136150
)
137-
138-
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
139-
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
151+
self.register_to_config(boundary_ratio=boundary_ratio)
152+
self.register_to_config(expand_timesteps=expand_timesteps)
153+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
154+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
140155
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
141156

142157
def _get_t5_prompt_embeds(
@@ -270,6 +285,7 @@ def check_inputs(
270285
prompt_embeds=None,
271286
negative_prompt_embeds=None,
272287
callback_on_step_end_tensor_inputs=None,
288+
guidance_scale_2=None,
273289
):
274290
if height % 16 != 0 or width % 16 != 0:
275291
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -302,6 +318,9 @@ def check_inputs(
302318
):
303319
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
304320

321+
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
322+
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
323+
305324
def prepare_latents(
306325
self,
307326
batch_size: int,
@@ -369,6 +388,7 @@ def __call__(
369388
num_frames: int = 81,
370389
num_inference_steps: int = 50,
371390
guidance_scale: float = 5.0,
391+
guidance_scale_2: Optional[float] = None,
372392
num_videos_per_prompt: Optional[int] = 1,
373393
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
374394
latents: Optional[torch.Tensor] = None,
@@ -407,6 +427,10 @@ def __call__(
407427
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
408428
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
409429
the text `prompt`, usually at the expense of lower image quality.
430+
guidance_scale_2 (`float`, *optional*, defaults to `None`):
431+
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
432+
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
433+
and the pipeline's `boundary_ratio` are not None.
410434
num_videos_per_prompt (`int`, *optional*, defaults to 1):
411435
The number of images to generate per prompt.
412436
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -461,6 +485,7 @@ def __call__(
461485
prompt_embeds,
462486
negative_prompt_embeds,
463487
callback_on_step_end_tensor_inputs,
488+
guidance_scale_2,
464489
)
465490

466491
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -470,7 +495,11 @@ def __call__(
470495
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
471496
num_frames = max(num_frames, 1)
472497

498+
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
499+
guidance_scale_2 = guidance_scale
500+
473501
self._guidance_scale = guidance_scale
502+
self._guidance_scale_2 = guidance_scale_2
474503
self._attention_kwargs = attention_kwargs
475504
self._current_timestep = None
476505
self._interrupt = False
@@ -520,21 +549,44 @@ def __call__(
520549
latents,
521550
)
522551

552+
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
553+
523554
# 6. Denoising loop
524555
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
525556
self._num_timesteps = len(timesteps)
526557

558+
if self.config.boundary_ratio is not None:
559+
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
560+
else:
561+
boundary_timestep = None
562+
527563
with self.progress_bar(total=num_inference_steps) as progress_bar:
528564
for i, t in enumerate(timesteps):
529565
if self.interrupt:
530566
continue
531567

532568
self._current_timestep = t
533-
latent_model_input = latents.to(transformer_dtype)
534-
timestep = t.expand(latents.shape[0])
535569

536-
with self.transformer.cache_context("cond"):
537-
noise_pred = self.transformer(
570+
if boundary_timestep is None or t >= boundary_timestep:
571+
# wan2.1 or high-noise stage in wan2.2
572+
current_model = self.transformer
573+
current_guidance_scale = guidance_scale
574+
else:
575+
# low-noise stage in wan2.2
576+
current_model = self.transformer_2
577+
current_guidance_scale = guidance_scale_2
578+
579+
latent_model_input = latents.to(transformer_dtype)
580+
if self.config.expand_timesteps:
581+
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
582+
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
583+
# batch_size, seq_len
584+
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
585+
else:
586+
timestep = t.expand(latents.shape[0])
587+
588+
with current_model.cache_context("cond"):
589+
noise_pred = current_model(
538590
hidden_states=latent_model_input,
539591
timestep=timestep,
540592
encoder_hidden_states=prompt_embeds,
@@ -543,15 +595,15 @@ def __call__(
543595
)[0]
544596

545597
if self.do_classifier_free_guidance:
546-
with self.transformer.cache_context("uncond"):
547-
noise_uncond = self.transformer(
598+
with current_model.cache_context("uncond"):
599+
noise_uncond = current_model(
548600
hidden_states=latent_model_input,
549601
timestep=timestep,
550602
encoder_hidden_states=negative_prompt_embeds,
551603
attention_kwargs=attention_kwargs,
552604
return_dict=False,
553605
)[0]
554-
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
606+
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
555607

556608
# compute the previous noisy sample x_t -> x_t-1
557609
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

0 commit comments

Comments
 (0)