@@ -764,39 +764,41 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
764
764
return None
765
765
return remapped_name
766
766
767
- possible_scale_names = [".k_scale" , ".v_scale" ]
768
- modelopt_scale_names = [
769
- ".self_attn.k_proj.k_scale" , ".self_attn.v_proj.v_scale"
767
+ # Define scale name mapping patterns in order of precedence
768
+ scale_mapping_patterns = [
769
+ # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
770
+ # .self_attn.attn.{k,v}_scale
771
+ (r"\.self_attn\.([kv])_proj\.([kv])_scale$" ,
772
+ r".self_attn.attn.\2_scale" ),
773
+ # QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
774
+ # .self_attn.attn.{k,v}_scale
775
+ (r"\.self_attn\.qkv_proj\.([kv])_scale$" , r".self_attn.attn.\1_scale" ),
776
+ # Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale ->
777
+ # .self_attn.attn.{k,v}_scale
778
+ (r"\.self_attn\.qkqkv_proj\.([kv])_scale$" , r".self_attn.attn.\1_scale"
779
+ ),
780
+ # Default format: .{k,v}_scale -> .attn.{k,v}_scale
781
+ (r"\.([kv])_scale$" , r".attn.\1_scale" ),
770
782
]
771
- # Also support qkv_proj scale parameters (from stacked parameter processing)
772
- qkv_proj_scale_names = [
773
- ".self_attn.qkv_proj.k_scale" , ".self_attn.qkv_proj.v_scale"
774
- ]
775
- for scale_name in possible_scale_names :
776
- if name .endswith (scale_name ):
777
- if any (mo_scale_name in name
778
- for mo_scale_name in modelopt_scale_names ):
779
- remapped_name = name .replace (
780
- f".self_attn.{ scale_name [1 ]} _proj{ scale_name } " ,
781
- f".self_attn.attn{ scale_name } " )
782
- elif any (qkv_scale_name in name
783
- for qkv_scale_name in qkv_proj_scale_names ):
784
- # Handle qkv_proj scale parameters
785
- remapped_name = name .replace (
786
- f".self_attn.qkv_proj{ scale_name } " ,
787
- f".self_attn.attn{ scale_name } " )
788
- else :
789
- remapped_name = name .replace (scale_name , f".attn{ scale_name } " )
790
- if remapped_name not in params_dict :
791
- logger .warning_once (
792
- "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded." , # noqa: E501
793
- scale_name ,
794
- name ,
795
- remapped_name ,
796
- scale_name ,
797
- )
798
- return None
799
- return remapped_name
783
+
784
+ # Check if name ends with k_scale or v_scale
785
+ if name .endswith ((".k_scale" , ".v_scale" )):
786
+ import regex as re
787
+
788
+ for pattern , replacement in scale_mapping_patterns :
789
+ if re .search (pattern , name ):
790
+ remapped_name = re .sub (pattern , replacement , name )
791
+ if remapped_name not in params_dict :
792
+ scale_type = name .split ("." )[- 1 ]
793
+ logger .warning_once (
794
+ "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded." , # noqa: E501
795
+ scale_type ,
796
+ name ,
797
+ remapped_name ,
798
+ scale_type ,
799
+ )
800
+ return None
801
+ return remapped_name
800
802
801
803
# If there were no matches, return the untouched param name
802
804
return name
0 commit comments