Skip to content

[Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora #12074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

linoytsaban
Copy link
Collaborator

@linoytsaban linoytsaban commented Aug 5, 2025

Wan 2.2 has 2 transformers, the community has found it to be beneficial to load Wan LoRAs into both transformers and occasionally in different scales as well (this also applies for Wan 2.1 LoRAs, loaded into transformer and transformer_2).
Recently, new lighting LoRA was released for Wan2.2 T2V- with separate weights for transformer (High noise stage) and transformer_2 (Low noise stage)

This PR adds support for LoRA loading into transformer_2 + adds support for lightning LoRA (has alpha keys)

T2V example:

import torch
import numpy as np
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video, load_image

dtype = torch.bfloat16
device = "cuda"
vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", vae=vae, torch_dtype=dtype)
pipe.to(device)

pipe.load_lora_weights(
   "Kijai/WanVideo_comfy", 
   weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_HIGH_fp16.safetensors", 
    adapter_name="lightning"
)
kwargs = {}
kwargs["load_into_transformer_2"] = True
pipe.load_lora_weights(
   "Kijai/WanVideo_comfy", 
   weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors", 
    adapter_name="lightning_2", **kwargs
)
pipe.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])

height = 480
width = 832

prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=81,
    guidance_scale=1.0,
    guidance_scale_2=1.0,
    num_inference_steps=4,
    generator=torch.manual_seed(0),
).frames[0]
export_to_video(output, "t2v_out.mp4", fps=16)
t2v_out-5.mp4

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@luke14free
Copy link

curious to see an example @linoytsaban would love to try this out

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. Left some comments.

Comment on lines 1890 to 1905
has_alpha = f"blocks.{i}.cross_attn.{o}.alpha" in original_state_dict
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"

original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"

if has_alpha:
down_weight = original_state_dict.pop(original_key_A)
up_weight = original_state_dict.pop(original_key_B)
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}.alpha")
converted_state_dict[converted_key_A] = down_weight * scale_down
converted_state_dict[converted_key_B] = up_weight * scale_up
else:
if original_key_A in original_state_dict:
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

hotswap=hotswap,
)
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should raise in case geattr(self, "transformer_2", None) is None.

@@ -5064,7 +5064,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
"""

_lora_loadable_modules = ["transformer"]
_lora_loadable_modules = ["transformer", "transformer_2"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to note that this loader is shared amongst Wan 2.1 and 2.2 as the pipelines are also one and the same. For Wan 2.1, we won't have any transformer_2.

Comment on lines 5283 to 5293
else:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self,
"transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why put it under else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my thought process was that, as opposed to LoRAs with weights for the transformer and text encoder for example, that we load in one load_lora_weights op, here we can have a situation where we have different weights for each transformer, but the state_dict keys are identical. Also, this way we can load the lora into each transformer separately with different adapter names - making it easy to use different scales for each transformer lora (which was seen to be beneficial for quality). I'm happy to improve this logic, but these are the considerations to keep in mind

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. So, in case users want to load both transformers, won't it just load one if load_into_transformer_2=True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep it would, they would need to load separately to each

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you show some pseudo-code expected from the users? This is another way of loading another adapter into transformer_2:
#12040 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly about it staying that exact way, but i do think it should remain possible to load different lora weights into the transformers and in different scales

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Let's go with this but with a note in the docstrings saying it's experimental in nature.

@linoytsaban
Copy link
Collaborator Author

I2V example: using Wan2.2 with Wan2.1 lightning LoRA

import torch
import numpy as np
from diffusers import WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image

model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
dtype = torch.bfloat16
device = "cuda"

pipe = WanImageToVideoPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.to(device)


pipe.load_lora_weights(
   "Kijai/WanVideo_comfy", 
    weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
    adapter_name="lightning"
)
kwargs = {}
kwargs["load_into_transformer_2"] = True
pipe.load_lora_weights(
  "Kijai/WanVideo_comfy", 
            weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
    adapter_name="lightning_2", **kwargs
)
pipe.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
pipe.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
pipe.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
pipe.unload_lora_weights()

image = load_image(
    "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = "POV selfie video, white cat with sunglasses standing on surfboard, relaxed smile, tropical beach behind (clear water, green hills, blue sky with clouds). Surfboard tips, cat falls into ocean, camera plunges underwater with bubbles and sunlight beams. Brief underwater view of cat’s face, then cat resurfaces, still filming selfie, playful summer vacation mood."

negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
generator = torch.Generator(device=device).manual_seed(42)
output = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=81,
    guidance_scale=1,
    num_inference_steps=4,
    generator=generator,
).frames[0]
export_to_video(output, "i2v_output.mp4", fps=16)
i2v_output-84.mp4

@luke14free
Copy link

thanks a lot for the amazing work @linoytsaban just FYI issue #12047 also applies to this PR, I tried and I get the mismatch error with GGUF models, reporting as they are the most popular way to run Wan on consumer hardware.

@mayankagrawal10198
Copy link

mayankagrawal10198 commented Aug 6, 2025

@linoytsaban are we sure if we don't put boundary_ratio args in our generation pipe would still choose transformer2 as low noise ? Bcs I can see first PR on wan2.2 #12004 by @yiyixuxu has these lines

 if self.config.boundary_ratio is not None:
            boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
        else:
            boundary_timestep = None

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                self._current_timestep = t

                if boundary_timestep is None or t >= boundary_timestep:
                    # wan2.1 or high-noise stage in wan2.2
                    current_model = self.transformer
                    current_guidance_scale = guidance_scale
                else:
                    # low-noise stage in wan2.2
                    current_model = self.transformer_2
                    current_guidance_scale = guidance_scale_2

@linoytsaban
Copy link
Collaborator Author

@linoytsaban are we sure if we don't put boundary_ratio args in our generation pipe would still choose transformer2 as low noise ? Bcs I can see first PR on wan2.2 #12004 by @yiyixuxu has these lines

 if self.config.boundary_ratio is not None:
            boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
        else:
            boundary_timestep = None

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                self._current_timestep = t

                if boundary_timestep is None or t >= boundary_timestep:
                    # wan2.1 or high-noise stage in wan2.2
                    current_model = self.transformer
                    current_guidance_scale = guidance_scale
                else:
                    # low-noise stage in wan2.2
                    current_model = self.transformer_2
                    current_guidance_scale = guidance_scale_2

yes @mayankagrawal10198 it should still use transformer_2 for the low noise stage since the default config sets the boundary ratio of I2V to 0.9 and T2v to 0.875 (https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers/blob/main/model_index.json, https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers/blob/main/model_index.json), you can pass them explicitly to the pipeline if you wish to experiment with different values

@linoytsaban
Copy link
Collaborator Author

@bot /style

Copy link
Contributor

github-actions bot commented Aug 7, 2025

Style fix is beginning .... View the workflow run here.

@linoytsaban
Copy link
Collaborator Author

@bot /style

Copy link
Contributor

github-actions bot commented Aug 11, 2025

Style bot fixed some files and pushed the changes.

@mayankagrawal10198
Copy link

@linoytsaban are we sure if we don't put boundary_ratio args in our generation pipe would still choose transformer2 as low noise ? Bcs I can see first PR on wan2.2 #12004 by @yiyixuxu has these lines

 if self.config.boundary_ratio is not None:
            boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
        else:
            boundary_timestep = None

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                self._current_timestep = t

                if boundary_timestep is None or t >= boundary_timestep:
                    # wan2.1 or high-noise stage in wan2.2
                    current_model = self.transformer
                    current_guidance_scale = guidance_scale
                else:
                    # low-noise stage in wan2.2
                    current_model = self.transformer_2
                    current_guidance_scale = guidance_scale_2

yes @mayankagrawal10198 it should still use transformer_2 for the low noise stage since the default config sets the boundary ratio of I2V to 0.9 and T2v to 0.875 (https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers/blob/main/model_index.json, https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers/blob/main/model_index.json), you can pass them explicitly to the pipeline if you wish to experiment with different values

Hi @linoytsaban , Thanks for the reply. Please correct my understanding for this.
If boundary ratio is 0.9 by default and we put num_inference steps as 4 for Lightx Lora, Then High_Noise is getting only 0.4 Steps and Low_Noise is getting 3.6 steps. I mean can we put float number on the steps ? I thought we need whole numbers for High_Noise and Low_Noise like if we put Num_Inference as 8 and Put Boundary_Ratio as 0.75. Then High_Noise will get atleat 2 Setps and Low_Noise will get 6 steps.
I really need to understand this.

@linoytsaban linoytsaban requested a review from sayakpaul August 13, 2025 13:47
@innokria
Copy link

Hey Guys this is amazing work.. There is now a new concept to do this in 3 stages

3 stage approach==> The first stage uses the original WAN2.2 model, without Lightx2v lora. This allows for faster motions to be generated. The 2nd and 3rd stage uses the High and Low Lightx2v loras like normal.

I will do some experiment on this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants