26
26
import gc
27
27
import os
28
28
import pathlib
29
- import logging
30
29
31
30
import numpy as np
32
31
import torch
33
- import jax
34
32
from safetensors import safe_open
35
33
from tqdm import tqdm
36
34
@@ -59,23 +57,27 @@ def hf_to_maxtext_mapping(layer_idx: int, num_experts: int) -> dict:
59
57
"lm_head.weight" : "decoder.logits_dense.kernel" ,
60
58
}
61
59
# Layer-specific mappings for a pure MoE/scanned model
62
- mapping .update (
63
- {
64
- f"model.layers.{ layer_idx } .input_layernorm.weight" : f"decoder.layers.{ layer_idx } .pre_self_attention_layer_norm.scale" ,
65
- f"model.layers.{ layer_idx } .post_attention_layernorm.weight" : f"decoder.layers.{ layer_idx } .post_self_attention_layer_norm.scale" ,
66
- f"model.layers.{ layer_idx } .self_attn.q_proj.weight" : f"decoder.layers.{ layer_idx } .self_attention.query.kernel" ,
67
- f"model.layers.{ layer_idx } .self_attn.k_proj.weight" : f"decoder.layers.{ layer_idx } .self_attention.key.kernel" ,
68
- f"model.layers.{ layer_idx } .self_attn.v_proj.weight" : f"decoder.layers.{ layer_idx } .self_attention.value.kernel" ,
69
- f"model.layers.{ layer_idx } .self_attn.o_proj.weight" : f"decoder.layers.{ layer_idx } .self_attention.out.kernel" ,
70
- f"model.layers.{ layer_idx } .self_attn.q_norm.weight" : f"decoder.layers.{ layer_idx } .self_attention.query_norm.scale" ,
71
- f"model.layers.{ layer_idx } .self_attn.k_norm.weight" : f"decoder.layers.{ layer_idx } .self_attention.key_norm.scale" ,
72
- f"model.layers.{ layer_idx } .mlp.gate.weight" : f"decoder.layers.{ layer_idx } .moe_block.gate.kernel" ,
73
- }
74
- )
60
+ mapping .update ({
61
+ f"model.layers.{ layer_idx } .input_layernorm.weight" : (
62
+ f"decoder.layers.{ layer_idx } .pre_self_attention_layer_norm.scale"
63
+ ),
64
+ f"model.layers.{ layer_idx } .post_attention_layernorm.weight" : (
65
+ f"decoder.layers.{ layer_idx } .post_self_attention_layer_norm.scale"
66
+ ),
67
+ f"model.layers.{ layer_idx } .self_attn.q_proj.weight" : f"decoder.layers.{ layer_idx } .self_attention.query.kernel" ,
68
+ f"model.layers.{ layer_idx } .self_attn.k_proj.weight" : f"decoder.layers.{ layer_idx } .self_attention.key.kernel" ,
69
+ f"model.layers.{ layer_idx } .self_attn.v_proj.weight" : f"decoder.layers.{ layer_idx } .self_attention.value.kernel" ,
70
+ f"model.layers.{ layer_idx } .self_attn.o_proj.weight" : f"decoder.layers.{ layer_idx } .self_attention.out.kernel" ,
71
+ f"model.layers.{ layer_idx } .self_attn.q_norm.weight" : f"decoder.layers.{ layer_idx } .self_attention.query_norm.scale" ,
72
+ f"model.layers.{ layer_idx } .self_attn.k_norm.weight" : f"decoder.layers.{ layer_idx } .self_attention.key_norm.scale" ,
73
+ f"model.layers.{ layer_idx } .mlp.gate.weight" : f"decoder.layers.{ layer_idx } .moe_block.gate.kernel" ,
74
+ })
75
75
76
76
# MoE expert mappings
77
77
for i in range (num_experts ):
78
- mapping [f"model.layers.{ layer_idx } .mlp.experts.{ i } .gate_proj.weight" ] = f"decoder.layers.{ layer_idx } .moe_block.{ i } .wi_0"
78
+ mapping [f"model.layers.{ layer_idx } .mlp.experts.{ i } .gate_proj.weight" ] = (
79
+ f"decoder.layers.{ layer_idx } .moe_block.{ i } .wi_0"
80
+ )
79
81
mapping [f"model.layers.{ layer_idx } .mlp.experts.{ i } .up_proj.weight" ] = f"decoder.layers.{ layer_idx } .moe_block.{ i } .wi_1"
80
82
mapping [f"model.layers.{ layer_idx } .mlp.experts.{ i } .down_proj.weight" ] = f"decoder.layers.{ layer_idx } .moe_block.{ i } .wo"
81
83
@@ -163,6 +165,7 @@ def convert_hf_to_maxtext(base_model_path: str, model_params: dict) -> dict:
163
165
moe ["wo" ] = np .zeros ((num_experts , num_layers , moe_intermediate_size , hidden_size ), dtype = np .float16 )
164
166
165
167
# Loop through layers and populate the stacked arrays
168
+ # pylint: disable=unsupported-assignment-operation
166
169
for l in tqdm (range (num_layers ), desc = "Stacking layer weights" ):
167
170
ln ["pre_self_attention_layer_norm" ]["scale" ][l , :] = (
168
171
chkpt_vars [f"decoder.layers.{ l } .pre_self_attention_layer_norm.scale" ].to (torch .float16 ).numpy ()
@@ -268,5 +271,5 @@ def main(args):
268
271
parser .add_argument ("--use-ocdbt" , type = str2bool , default = True , help = "Use OCDBT format for saving." )
269
272
parser .add_argument ("--use-zarr3" , type = str2bool , default = True , help = "Use Zarr3 format for saving." )
270
273
271
- args = parser .parse_args ()
272
- main (args )
274
+ parsed_args = parser .parse_args ()
275
+ main (parsed_args )
0 commit comments