Skip to content

Commit 3902145

Browse files
authored
[lora] fix zimage lora conversion to support for more lora. (#13209)
fix zimage lora conversion to support for more lora.
1 parent 5570f81 commit 3902145

File tree

1 file changed

+78
-16
lines changed

1 file changed

+78
-16
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2519,6 +2519,13 @@ def normalize_out_key(k: str) -> str:
25192519
if has_default:
25202520
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
25212521

2522+
# Normalize ZImage-specific dot-separated module names to underscore form so they
2523+
# match the diffusers model parameter names (context_refiner, noise_refiner).
2524+
state_dict = {
2525+
k.replace("context.refiner.", "context_refiner.").replace("noise.refiner.", "noise_refiner."): v
2526+
for k, v in state_dict.items()
2527+
}
2528+
25222529
converted_state_dict = {}
25232530
all_keys = list(state_dict.keys())
25242531
down_key = ".lora_down.weight"
@@ -2529,19 +2536,18 @@ def normalize_out_key(k: str) -> str:
25292536
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
25302537
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
25312538

2532-
if has_non_diffusers_lora_id:
2533-
2534-
def get_alpha_scales(down_weight, alpha_key):
2535-
rank = down_weight.shape[0]
2536-
alpha = state_dict.pop(alpha_key).item()
2537-
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2538-
scale_down = scale
2539-
scale_up = 1.0
2540-
while scale_down * 2 < scale_up:
2541-
scale_down *= 2
2542-
scale_up /= 2
2543-
return scale_down, scale_up
2539+
def get_alpha_scales(down_weight, alpha_key):
2540+
rank = down_weight.shape[0]
2541+
alpha = state_dict.pop(alpha_key).item()
2542+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2543+
scale_down = scale
2544+
scale_up = 1.0
2545+
while scale_down * 2 < scale_up:
2546+
scale_down *= 2
2547+
scale_up /= 2
2548+
return scale_down, scale_up
25442549

2550+
if has_non_diffusers_lora_id:
25452551
for k in all_keys:
25462552
if k.endswith(down_key):
25472553
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
@@ -2554,13 +2560,69 @@ def get_alpha_scales(down_weight, alpha_key):
25542560
converted_state_dict[diffusers_down_key] = down_weight * scale_down
25552561
converted_state_dict[diffusers_up_key] = up_weight * scale_up
25562562

2557-
# Already in diffusers format (lora_A/lora_B), just pop
2563+
# Already in diffusers format (lora_A/lora_B), apply alpha scaling and pop.
25582564
elif has_diffusers_lora_id:
25592565
for k in all_keys:
2560-
if a_key in k or b_key in k:
2561-
converted_state_dict[k] = state_dict.pop(k)
2562-
elif ".alpha" in k:
2566+
if k.endswith(a_key):
2567+
diffusers_up_key = k.replace(a_key, b_key)
2568+
alpha_key = k.replace(a_key, ".alpha")
2569+
2570+
down_weight = state_dict.pop(k)
2571+
up_weight = state_dict.pop(diffusers_up_key)
2572+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2573+
converted_state_dict[k] = down_weight * scale_down
2574+
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2575+
2576+
# Handle dot-format LoRA keys: ".lora.down.weight" / ".lora.up.weight".
2577+
# Some external ZImage trainers (e.g. Anime-Z) use dots instead of underscores in
2578+
# lora weight names and also include redundant keys:
2579+
# - "qkv.lora.*" duplicates individual "to.q/k/v.lora.*" keys → skip qkv
2580+
# - "out.lora.*" duplicates "to_out.0.lora.*" keys → skip bare out
2581+
# - "to.q/k/v.lora.*" → normalise to "to_q/k/v.lora_A/B.weight"
2582+
lora_dot_down_key = ".lora.down.weight"
2583+
lora_dot_up_key = ".lora.up.weight"
2584+
has_lora_dot_format = any(lora_dot_down_key in k for k in state_dict)
2585+
2586+
if has_lora_dot_format:
2587+
dot_keys = list(state_dict.keys())
2588+
for k in dot_keys:
2589+
if lora_dot_down_key not in k:
2590+
continue
2591+
if k not in state_dict:
2592+
continue # already popped by a prior iteration
2593+
2594+
base = k[: -len(lora_dot_down_key)]
2595+
2596+
# Skip combined "qkv" projection — individual to.q/k/v keys are also present.
2597+
if base.endswith(".qkv"):
2598+
state_dict.pop(k)
2599+
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
2600+
state_dict.pop(base + ".alpha", None)
2601+
continue
2602+
2603+
# Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
2604+
if re.search(r"\.out$", base) and ".to_out" not in base:
25632605
state_dict.pop(k)
2606+
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
2607+
continue
2608+
2609+
# Normalise "to.q/k/v" → "to_q/k/v" for the diffusers output key.
2610+
norm_k = re.sub(
2611+
r"\.to\.([qkv])" + re.escape(lora_dot_down_key) + r"$",
2612+
r".to_\1" + lora_dot_down_key,
2613+
k,
2614+
)
2615+
norm_base = norm_k[: -len(lora_dot_down_key)]
2616+
alpha_key = norm_base + ".alpha"
2617+
2618+
diffusers_down = norm_k.replace(lora_dot_down_key, ".lora_A.weight")
2619+
diffusers_up = norm_k.replace(lora_dot_down_key, ".lora_B.weight")
2620+
2621+
down_weight = state_dict.pop(k)
2622+
up_weight = state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key))
2623+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2624+
converted_state_dict[diffusers_down] = down_weight * scale_down
2625+
converted_state_dict[diffusers_up] = up_weight * scale_up
25642626

25652627
if len(state_dict) > 0:
25662628
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

0 commit comments

Comments
 (0)