@@ -2080,6 +2080,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20802080
20812081
20822082def _convert_non_diffusers_qwen_lora_to_diffusers (state_dict ):
2083+ has_lora_unet = any (k .startswith ("lora_unet_" ) for k in state_dict )
2084+ if has_lora_unet :
2085+ state_dict = {k .removeprefix ("lora_unet_" ): v for k , v in state_dict .items ()}
2086+
2087+ def convert_key (key : str ) -> str :
2088+ prefix = "transformer_blocks"
2089+ if "." in key :
2090+ base , suffix = key .rsplit ("." , 1 )
2091+ else :
2092+ base , suffix = key , ""
2093+
2094+ start = f"{ prefix } _"
2095+ rest = base [len (start ) :]
2096+
2097+ if "." in rest :
2098+ head , tail = rest .split ("." , 1 )
2099+ tail = "." + tail
2100+ else :
2101+ head , tail = rest , ""
2102+
2103+ # Protected n-grams that must keep their internal underscores
2104+ protected = {
2105+ # pairs
2106+ ("to" , "q" ),
2107+ ("to" , "k" ),
2108+ ("to" , "v" ),
2109+ ("to" , "out" ),
2110+ ("add" , "q" ),
2111+ ("add" , "k" ),
2112+ ("add" , "v" ),
2113+ ("txt" , "mlp" ),
2114+ ("img" , "mlp" ),
2115+ ("txt" , "mod" ),
2116+ ("img" , "mod" ),
2117+ # triplets
2118+ ("add" , "q" , "proj" ),
2119+ ("add" , "k" , "proj" ),
2120+ ("add" , "v" , "proj" ),
2121+ ("to" , "add" , "out" ),
2122+ }
2123+
2124+ prot_by_len = {}
2125+ for ng in protected :
2126+ prot_by_len .setdefault (len (ng ), set ()).add (ng )
2127+
2128+ parts = head .split ("_" )
2129+ merged = []
2130+ i = 0
2131+ lengths_desc = sorted (prot_by_len .keys (), reverse = True )
2132+
2133+ while i < len (parts ):
2134+ matched = False
2135+ for L in lengths_desc :
2136+ if i + L <= len (parts ) and tuple (parts [i : i + L ]) in prot_by_len [L ]:
2137+ merged .append ("_" .join (parts [i : i + L ]))
2138+ i += L
2139+ matched = True
2140+ break
2141+ if not matched :
2142+ merged .append (parts [i ])
2143+ i += 1
2144+
2145+ head_converted = "." .join (merged )
2146+ converted_base = f"{ prefix } .{ head_converted } { tail } "
2147+ return converted_base + (("." + suffix ) if suffix else "" )
2148+
2149+ state_dict = {convert_key (k ): v for k , v in state_dict .items ()}
2150+
20832151 converted_state_dict = {}
20842152 all_keys = list (state_dict .keys ())
20852153 down_key = ".lora_down.weight"
0 commit comments