Skip to content

Commit f442955

Browse files
authored
[lora] support loading loras from lightx2v/Qwen-Image-Lightning (#12119)
* feat: support qwen lightning lora. * add docs. * fix
1 parent ff9a387 commit f442955

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,63 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
2424

2525
</Tip>
2626

27+
## LoRA for faster inference
28+
29+
Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
30+
number of steps. Refer to the code snippet below:
31+
32+
<details>
33+
<summary>Code</summary>
34+
35+
```py
36+
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
37+
import torch
38+
import math
39+
40+
ckpt_id = "Qwen/Qwen-Image"
41+
42+
# From
43+
# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
44+
scheduler_config = {
45+
"base_image_seq_len": 256,
46+
"base_shift": math.log(3), # We use shift=3 in distillation
47+
"invert_sigmas": False,
48+
"max_image_seq_len": 8192,
49+
"max_shift": math.log(3), # We use shift=3 in distillation
50+
"num_train_timesteps": 1000,
51+
"shift": 1.0,
52+
"shift_terminal": None, # set shift_terminal to None
53+
"stochastic_sampling": False,
54+
"time_shift_type": "exponential",
55+
"use_beta_sigmas": False,
56+
"use_dynamic_shifting": True,
57+
"use_exponential_sigmas": False,
58+
"use_karras_sigmas": False,
59+
}
60+
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
61+
pipe = DiffusionPipeline.from_pretrained(
62+
ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
63+
).to("cuda")
64+
pipe.load_lora_weights(
65+
"lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
66+
)
67+
68+
prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
69+
negative_prompt = " "
70+
image = pipe(
71+
prompt=prompt,
72+
negative_prompt=negative_prompt,
73+
width=1024,
74+
height=1024,
75+
num_inference_steps=8,
76+
true_cfg_scale=1.0,
77+
generator=torch.manual_seed(0),
78+
).images[0]
79+
image.save("qwen_fewsteps.png")
80+
```
81+
82+
</details>
83+
2784
## QwenImagePipeline
2885

2986
[[autodoc]] QwenImagePipeline

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20772077
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
20782078
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
20792079
return converted_state_dict
2080+
2081+
2082+
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2083+
converted_state_dict = {}
2084+
all_keys = list(state_dict.keys())
2085+
down_key = ".lora_down.weight"
2086+
up_key = ".lora_up.weight"
2087+
2088+
def get_alpha_scales(down_weight, alpha_key):
2089+
rank = down_weight.shape[0]
2090+
alpha = state_dict.pop(alpha_key).item()
2091+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2092+
scale_down = scale
2093+
scale_up = 1.0
2094+
while scale_down * 2 < scale_up:
2095+
scale_down *= 2
2096+
scale_up /= 2
2097+
return scale_down, scale_up
2098+
2099+
for k in all_keys:
2100+
if k.endswith(down_key):
2101+
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2102+
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2103+
alpha_key = k.replace(down_key, ".alpha")
2104+
2105+
down_weight = state_dict.pop(k)
2106+
up_weight = state_dict.pop(k.replace(down_key, up_key))
2107+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2108+
converted_state_dict[diffusers_down_key] = down_weight * scale_down
2109+
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2110+
2111+
if len(state_dict) > 0:
2112+
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
2113+
2114+
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
2115+
return converted_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_convert_non_diffusers_lora_to_diffusers,
5050
_convert_non_diffusers_ltxv_lora_to_diffusers,
5151
_convert_non_diffusers_lumina2_lora_to_diffusers,
52+
_convert_non_diffusers_qwen_lora_to_diffusers,
5253
_convert_non_diffusers_wan_lora_to_diffusers,
5354
_convert_xlabs_flux_lora_to_diffusers,
5455
_maybe_map_sgm_blocks_to_diffusers,
@@ -6548,7 +6549,6 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
65486549

65496550
@classmethod
65506551
@validate_hf_hub_args
6551-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
65526552
def lora_state_dict(
65536553
cls,
65546554
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -6642,6 +6642,10 @@ def lora_state_dict(
66426642
logger.warning(warn_msg)
66436643
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
66446644

6645+
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
6646+
if has_alphas_in_sd:
6647+
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
6648+
66456649
out = (state_dict, metadata) if return_lora_metadata else state_dict
66466650
return out
66476651

0 commit comments

Comments
 (0)