@@ -680,34 +680,31 @@ def QWEN3_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):
680
680
if f"params-decoder-layers_{ i } -{ key } " in mapping :
681
681
del mapping [f"params-decoder-layers_{ i } -{ key } " ]
682
682
683
- # Add MoE-specific mappings.
683
+ # Add MoE-specific mappings WITHOUT the "-kernel" suffix for experts .
684
684
if scan_layers :
685
685
mapping ["params-decoder-layers-moe_block-gate-kernel" ] = [f"model.layers.{ i } .mlp.gate.weight" for i in range (n_layers )]
686
- mapping ["params-decoder-layers-moe_block-wi_0-kernel " ] = [[f"model.layers.{ i } .mlp.experts.{ j } .gate_proj.weight" for j in range (num_experts )] for i in range (n_layers )]
687
- mapping ["params-decoder-layers-moe_block-wi_1-kernel " ] = [[f"model.layers.{ i } .mlp.experts.{ j } .up_proj.weight" for j in range (num_experts )] for i in range (n_layers )]
688
- mapping ["params-decoder-layers-moe_block-wo-kernel " ] = [[f"model.layers.{ i } .mlp.experts.{ j } .down_proj.weight" for j in range (num_experts )] for i in range (n_layers )]
686
+ mapping ["params-decoder-layers-moe_block-wi_0" ] = [[f"model.layers.{ i } .mlp.experts.{ j } .gate_proj.weight" for j in range (num_experts )] for i in range (n_layers )]
687
+ mapping ["params-decoder-layers-moe_block-wi_1" ] = [[f"model.layers.{ i } .mlp.experts.{ j } .up_proj.weight" for j in range (num_experts )] for i in range (n_layers )]
688
+ mapping ["params-decoder-layers-moe_block-wo" ] = [[f"model.layers.{ i } .mlp.experts.{ j } .down_proj.weight" for j in range (num_experts )] for i in range (n_layers )]
689
689
else :
690
690
for i in range (n_layers ):
691
691
mapping [f"params-decoder-layers_{ i } -moe_block-gate-kernel" ] = f"model.layers.{ i } .mlp.gate.weight"
692
- mapping [f"params-decoder-layers_{ i } -moe_block-wi_0-kernel " ] = [f"model.layers.{ i } .mlp.experts.{ j } .gate_proj.weight" for j in range (num_experts )]
693
- mapping [f"params-decoder-layers_{ i } -moe_block-wi_1-kernel " ] = [f"model.layers.{ i } .mlp.experts.{ j } .up_proj.weight" for j in range (num_experts )]
694
- mapping [f"params-decoder-layers_{ i } -moe_block-wo-kernel " ] = [f"model.layers.{ i } .mlp.experts.{ j } .down_proj.weight" for j in range (num_experts )]
692
+ mapping [f"params-decoder-layers_{ i } -moe_block-wi_0" ] = [f"model.layers.{ i } .mlp.experts.{ j } .gate_proj.weight" for j in range (num_experts )]
693
+ mapping [f"params-decoder-layers_{ i } -moe_block-wi_1" ] = [f"model.layers.{ i } .mlp.experts.{ j } .up_proj.weight" for j in range (num_experts )]
694
+ mapping [f"params-decoder-layers_{ i } -moe_block-wo" ] = [f"model.layers.{ i } .mlp.experts.{ j } .down_proj.weight" for j in range (num_experts )]
695
695
696
696
return mapping
697
697
698
698
def QWEN3_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , scan_layers = False , saving_to_hf = False ):
699
699
"""Creates transformation hooks for Qwen3-MoE, replicating the original script's logic."""
700
- # Start with the standard hooks for attention, norms, etc.
701
700
mapping = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , scan_layers , saving_to_hf )
702
- reshape_kernel = mapping ["params-decoder-logits_dense-kernel" ] # This is the "Standard Rule"
701
+ reshape_kernel = mapping ["params-decoder-logits_dense-kernel" ]
702
+ simple_transpose = lambda x , _ : x .T if not saving_to_hf else x .T # Simplified for clarity
703
703
704
- # The "Special Rule" for expert weights that only does the individual transpose.
705
- def simple_transpose (input_tensor , target_shape ):
706
- return input_tensor .T .reshape (target_shape ) if not saving_to_hf else input_tensor .reshape (np .flip (np .array (target_shape ))).T
707
-
708
- # Define which kernels get which rule.
704
+ # The gate kernel DOES have a "-kernel" suffix and gets the standard rule.
709
705
gate_kernel = ["moe_block-gate-kernel" ]
710
- expert_kernels = ["moe_block-wi_0-kernel" , "moe_block-wi_1-kernel" , "moe_block-wo-kernel" ]
706
+ # The expert weights DO NOT have a "-kernel" suffix and get the special rule.
707
+ expert_kernels = ["moe_block-wi_0" , "moe_block-wi_1" , "moe_block-wo" ]
711
708
712
709
if scan_layers :
713
710
mapping [f"params-decoder-layers-{ gate_kernel [0 ]} " ] = reshape_kernel
0 commit comments