-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora #12074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
curious to see an example @linoytsaban would love to try this out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this. Left some comments.
has_alpha = f"blocks.{i}.cross_attn.{o}.alpha" in original_state_dict | ||
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" | ||
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight" | ||
|
||
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" | ||
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight" | ||
|
||
if has_alpha: | ||
down_weight = original_state_dict.pop(original_key_A) | ||
up_weight = original_state_dict.pop(original_key_B) | ||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}.alpha") | ||
converted_state_dict[converted_key_A] = down_weight * scale_down | ||
converted_state_dict[converted_key_B] = up_weight * scale_up | ||
else: | ||
if original_key_A in original_state_dict: | ||
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
hotswap=hotswap, | ||
) | ||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) | ||
if load_into_transformer_2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should raise in case geattr(self, "transformer_2", None) is None
.
@@ -5064,7 +5064,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): | |||
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. | |||
""" | |||
|
|||
_lora_loadable_modules = ["transformer"] | |||
_lora_loadable_modules = ["transformer", "transformer_2"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to note that this loader is shared amongst Wan 2.1 and 2.2 as the pipelines are also one and the same. For Wan 2.1, we won't have any transformer_2
.
else: | ||
self.load_lora_into_transformer( | ||
state_dict, | ||
transformer=getattr(self, self.transformer_name) if not hasattr(self, | ||
"transformer") else self.transformer, | ||
adapter_name=adapter_name, | ||
metadata=metadata, | ||
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
hotswap=hotswap, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why put it under else
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my thought process was that, as opposed to LoRAs with weights for the transformer and text encoder for example, that we load in one load_lora_weights
op, here we can have a situation where we have different weights for each transformer, but the state_dict
keys are identical. Also, this way we can load the lora into each transformer separately with different adapter names - making it easy to use different scales for each transformer lora (which was seen to be beneficial for quality). I'm happy to improve this logic, but these are the considerations to keep in mind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. So, in case users want to load both transformers, won't it just load one if load_into_transformer_2=True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep it would, they would need to load separately to each
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you show some pseudo-code expected from the users? This is another way of loading another adapter into transformer_2
:
#12040 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here: #12074 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel strongly about it staying that exact way, but i do think it should remain possible to load different lora weights into the transformers and in different scales
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Let's go with this but with a note in the docstrings saying it's experimental in nature.
I2V example: using Wan2.2 with Wan2.1 lightning LoRA
i2v_output-84.mp4 |
thanks a lot for the amazing work @linoytsaban just FYI issue #12047 also applies to this PR, I tried and I get the mismatch error with GGUF models, reporting as they are the most popular way to run Wan on consumer hardware. |
@linoytsaban are we sure if we don't put boundary_ratio args in our generation pipe would still choose transformer2 as low noise ? Bcs I can see first PR on wan2.2 #12004 by @yiyixuxu has these lines
|
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
yes @mayankagrawal10198 it should still use |
@bot /style |
Style fix is beginning .... View the workflow run here. |
@bot /style |
Style bot fixed some files and pushed the changes. |
Hi @linoytsaban , Thanks for the reply. Please correct my understanding for this. |
Hey Guys this is amazing work.. There is now a new concept to do this in 3 stages 3 stage approach==> The first stage uses the original WAN2.2 model, without Lightx2v lora. This allows for faster motions to be generated. The 2nd and 3rd stage uses the High and Low Lightx2v loras like normal. I will do some experiment on this :) |
Wan 2.2 has 2 transformers, the community has found it to be beneficial to load Wan LoRAs into both transformers and occasionally in different scales as well (this also applies for Wan 2.1 LoRAs, loaded into
transformer
andtransformer_2
).Recently, new lighting LoRA was released for Wan2.2 T2V- with separate weights for
transformer
(High noise stage) andtransformer_2
(Low noise stage)This PR adds support for LoRA loading into
transformer_2
+ adds support for lightning LoRA (hasalpha
keys)T2V example:
t2v_out-5.mp4