@@ -2519,6 +2519,13 @@ def normalize_out_key(k: str) -> str:
25192519 if has_default :
25202520 state_dict = {k .replace ("default." , "" ): v for k , v in state_dict .items ()}
25212521
2522+ # Normalize ZImage-specific dot-separated module names to underscore form so they
2523+ # match the diffusers model parameter names (context_refiner, noise_refiner).
2524+ state_dict = {
2525+ k .replace ("context.refiner." , "context_refiner." ).replace ("noise.refiner." , "noise_refiner." ): v
2526+ for k , v in state_dict .items ()
2527+ }
2528+
25222529 converted_state_dict = {}
25232530 all_keys = list (state_dict .keys ())
25242531 down_key = ".lora_down.weight"
@@ -2529,19 +2536,18 @@ def normalize_out_key(k: str) -> str:
25292536 has_non_diffusers_lora_id = any (down_key in k or up_key in k for k in all_keys )
25302537 has_diffusers_lora_id = any (a_key in k or b_key in k for k in all_keys )
25312538
2532- if has_non_diffusers_lora_id :
2533-
2534- def get_alpha_scales (down_weight , alpha_key ):
2535- rank = down_weight .shape [0 ]
2536- alpha = state_dict .pop (alpha_key ).item ()
2537- scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2538- scale_down = scale
2539- scale_up = 1.0
2540- while scale_down * 2 < scale_up :
2541- scale_down *= 2
2542- scale_up /= 2
2543- return scale_down , scale_up
2539+ def get_alpha_scales (down_weight , alpha_key ):
2540+ rank = down_weight .shape [0 ]
2541+ alpha = state_dict .pop (alpha_key ).item ()
2542+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2543+ scale_down = scale
2544+ scale_up = 1.0
2545+ while scale_down * 2 < scale_up :
2546+ scale_down *= 2
2547+ scale_up /= 2
2548+ return scale_down , scale_up
25442549
2550+ if has_non_diffusers_lora_id :
25452551 for k in all_keys :
25462552 if k .endswith (down_key ):
25472553 diffusers_down_key = k .replace (down_key , ".lora_A.weight" )
@@ -2554,13 +2560,69 @@ def get_alpha_scales(down_weight, alpha_key):
25542560 converted_state_dict [diffusers_down_key ] = down_weight * scale_down
25552561 converted_state_dict [diffusers_up_key ] = up_weight * scale_up
25562562
2557- # Already in diffusers format (lora_A/lora_B), just pop
2563+ # Already in diffusers format (lora_A/lora_B), apply alpha scaling and pop.
25582564 elif has_diffusers_lora_id :
25592565 for k in all_keys :
2560- if a_key in k or b_key in k :
2561- converted_state_dict [k ] = state_dict .pop (k )
2562- elif ".alpha" in k :
2566+ if k .endswith (a_key ):
2567+ diffusers_up_key = k .replace (a_key , b_key )
2568+ alpha_key = k .replace (a_key , ".alpha" )
2569+
2570+ down_weight = state_dict .pop (k )
2571+ up_weight = state_dict .pop (diffusers_up_key )
2572+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
2573+ converted_state_dict [k ] = down_weight * scale_down
2574+ converted_state_dict [diffusers_up_key ] = up_weight * scale_up
2575+
2576+ # Handle dot-format LoRA keys: ".lora.down.weight" / ".lora.up.weight".
2577+ # Some external ZImage trainers (e.g. Anime-Z) use dots instead of underscores in
2578+ # lora weight names and also include redundant keys:
2579+ # - "qkv.lora.*" duplicates individual "to.q/k/v.lora.*" keys → skip qkv
2580+ # - "out.lora.*" duplicates "to_out.0.lora.*" keys → skip bare out
2581+ # - "to.q/k/v.lora.*" → normalise to "to_q/k/v.lora_A/B.weight"
2582+ lora_dot_down_key = ".lora.down.weight"
2583+ lora_dot_up_key = ".lora.up.weight"
2584+ has_lora_dot_format = any (lora_dot_down_key in k for k in state_dict )
2585+
2586+ if has_lora_dot_format :
2587+ dot_keys = list (state_dict .keys ())
2588+ for k in dot_keys :
2589+ if lora_dot_down_key not in k :
2590+ continue
2591+ if k not in state_dict :
2592+ continue # already popped by a prior iteration
2593+
2594+ base = k [: - len (lora_dot_down_key )]
2595+
2596+ # Skip combined "qkv" projection — individual to.q/k/v keys are also present.
2597+ if base .endswith (".qkv" ):
2598+ state_dict .pop (k )
2599+ state_dict .pop (k .replace (lora_dot_down_key , lora_dot_up_key ), None )
2600+ state_dict .pop (base + ".alpha" , None )
2601+ continue
2602+
2603+ # Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
2604+ if re .search (r"\.out$" , base ) and ".to_out" not in base :
25632605 state_dict .pop (k )
2606+ state_dict .pop (k .replace (lora_dot_down_key , lora_dot_up_key ), None )
2607+ continue
2608+
2609+ # Normalise "to.q/k/v" → "to_q/k/v" for the diffusers output key.
2610+ norm_k = re .sub (
2611+ r"\.to\.([qkv])" + re .escape (lora_dot_down_key ) + r"$" ,
2612+ r".to_\1" + lora_dot_down_key ,
2613+ k ,
2614+ )
2615+ norm_base = norm_k [: - len (lora_dot_down_key )]
2616+ alpha_key = norm_base + ".alpha"
2617+
2618+ diffusers_down = norm_k .replace (lora_dot_down_key , ".lora_A.weight" )
2619+ diffusers_up = norm_k .replace (lora_dot_down_key , ".lora_B.weight" )
2620+
2621+ down_weight = state_dict .pop (k )
2622+ up_weight = state_dict .pop (k .replace (lora_dot_down_key , lora_dot_up_key ))
2623+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
2624+ converted_state_dict [diffusers_down ] = down_weight * scale_down
2625+ converted_state_dict [diffusers_up ] = up_weight * scale_up
25642626
25652627 if len (state_dict ) > 0 :
25662628 raise ValueError (f"`state_dict` should be empty at this point but has { state_dict .keys ()= } " )
0 commit comments