Skip to content

Commit f73fe03

Browse files
committed
IDK again
1 parent 08f57be commit f73fe03

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

MaxText/utils/ckpt_conversion/utils/param_mapping.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -680,34 +680,31 @@ def QWEN3_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):
680680
if f"params-decoder-layers_{i}-{key}" in mapping:
681681
del mapping[f"params-decoder-layers_{i}-{key}"]
682682

683-
# Add MoE-specific mappings.
683+
# Add MoE-specific mappings WITHOUT the "-kernel" suffix for experts.
684684
if scan_layers:
685685
mapping["params-decoder-layers-moe_block-gate-kernel"] = [f"model.layers.{i}.mlp.gate.weight" for i in range(n_layers)]
686-
mapping["params-decoder-layers-moe_block-wi_0-kernel"] = [[f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight" for j in range(num_experts)] for i in range(n_layers)]
687-
mapping["params-decoder-layers-moe_block-wi_1-kernel"] = [[f"model.layers.{i}.mlp.experts.{j}.up_proj.weight" for j in range(num_experts)] for i in range(n_layers)]
688-
mapping["params-decoder-layers-moe_block-wo-kernel"] = [[f"model.layers.{i}.mlp.experts.{j}.down_proj.weight" for j in range(num_experts)] for i in range(n_layers)]
686+
mapping["params-decoder-layers-moe_block-wi_0"] = [[f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight" for j in range(num_experts)] for i in range(n_layers)]
687+
mapping["params-decoder-layers-moe_block-wi_1"] = [[f"model.layers.{i}.mlp.experts.{j}.up_proj.weight" for j in range(num_experts)] for i in range(n_layers)]
688+
mapping["params-decoder-layers-moe_block-wo"] = [[f"model.layers.{i}.mlp.experts.{j}.down_proj.weight" for j in range(num_experts)] for i in range(n_layers)]
689689
else:
690690
for i in range(n_layers):
691691
mapping[f"params-decoder-layers_{i}-moe_block-gate-kernel"] = f"model.layers.{i}.mlp.gate.weight"
692-
mapping[f"params-decoder-layers_{i}-moe_block-wi_0-kernel"] = [f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight" for j in range(num_experts)]
693-
mapping[f"params-decoder-layers_{i}-moe_block-wi_1-kernel"] = [f"model.layers.{i}.mlp.experts.{j}.up_proj.weight" for j in range(num_experts)]
694-
mapping[f"params-decoder-layers_{i}-moe_block-wo-kernel"] = [f"model.layers.{i}.mlp.experts.{j}.down_proj.weight" for j in range(num_experts)]
692+
mapping[f"params-decoder-layers_{i}-moe_block-wi_0"] = [f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight" for j in range(num_experts)]
693+
mapping[f"params-decoder-layers_{i}-moe_block-wi_1"] = [f"model.layers.{i}.mlp.experts.{j}.up_proj.weight" for j in range(num_experts)]
694+
mapping[f"params-decoder-layers_{i}-moe_block-wo"] = [f"model.layers.{i}.mlp.experts.{j}.down_proj.weight" for j in range(num_experts)]
695695

696696
return mapping
697697

698698
def QWEN3_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=False):
699699
"""Creates transformation hooks for Qwen3-MoE, replicating the original script's logic."""
700-
# Start with the standard hooks for attention, norms, etc.
701700
mapping = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers, saving_to_hf)
702-
reshape_kernel = mapping["params-decoder-logits_dense-kernel"] # This is the "Standard Rule"
701+
reshape_kernel = mapping["params-decoder-logits_dense-kernel"]
702+
simple_transpose = lambda x, _: x.T if not saving_to_hf else x.T # Simplified for clarity
703703

704-
# The "Special Rule" for expert weights that only does the individual transpose.
705-
def simple_transpose(input_tensor, target_shape):
706-
return input_tensor.T.reshape(target_shape) if not saving_to_hf else input_tensor.reshape(np.flip(np.array(target_shape))).T
707-
708-
# Define which kernels get which rule.
704+
# The gate kernel DOES have a "-kernel" suffix and gets the standard rule.
709705
gate_kernel = ["moe_block-gate-kernel"]
710-
expert_kernels = ["moe_block-wi_0-kernel", "moe_block-wi_1-kernel", "moe_block-wo-kernel"]
706+
# The expert weights DO NOT have a "-kernel" suffix and get the special rule.
707+
expert_kernels = ["moe_block-wi_0", "moe_block-wi_1", "moe_block-wo"]
711708

712709
if scan_layers:
713710
mapping[f"params-decoder-layers-{gate_kernel[0]}"] = reshape_kernel

0 commit comments

Comments
 (0)