Skip to content

Commit 8697207

Browse files
committed
fixed formatting
1 parent a738a0c commit 8697207

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

MaxText/convert_qwen3_moe.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@
2626
import gc
2727
import os
2828
import pathlib
29-
import logging
3029

3130
import numpy as np
3231
import torch
33-
import jax
3432
from safetensors import safe_open
3533
from tqdm import tqdm
3634

@@ -59,23 +57,27 @@ def hf_to_maxtext_mapping(layer_idx: int, num_experts: int) -> dict:
5957
"lm_head.weight": "decoder.logits_dense.kernel",
6058
}
6159
# Layer-specific mappings for a pure MoE/scanned model
62-
mapping.update(
63-
{
64-
f"model.layers.{layer_idx}.input_layernorm.weight": f"decoder.layers.{layer_idx}.pre_self_attention_layer_norm.scale",
65-
f"model.layers.{layer_idx}.post_attention_layernorm.weight": f"decoder.layers.{layer_idx}.post_self_attention_layer_norm.scale",
66-
f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"decoder.layers.{layer_idx}.self_attention.query.kernel",
67-
f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"decoder.layers.{layer_idx}.self_attention.key.kernel",
68-
f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"decoder.layers.{layer_idx}.self_attention.value.kernel",
69-
f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"decoder.layers.{layer_idx}.self_attention.out.kernel",
70-
f"model.layers.{layer_idx}.self_attn.q_norm.weight": f"decoder.layers.{layer_idx}.self_attention.query_norm.scale",
71-
f"model.layers.{layer_idx}.self_attn.k_norm.weight": f"decoder.layers.{layer_idx}.self_attention.key_norm.scale",
72-
f"model.layers.{layer_idx}.mlp.gate.weight": f"decoder.layers.{layer_idx}.moe_block.gate.kernel",
73-
}
74-
)
60+
mapping.update({
61+
f"model.layers.{layer_idx}.input_layernorm.weight": (
62+
f"decoder.layers.{layer_idx}.pre_self_attention_layer_norm.scale"
63+
),
64+
f"model.layers.{layer_idx}.post_attention_layernorm.weight": (
65+
f"decoder.layers.{layer_idx}.post_self_attention_layer_norm.scale"
66+
),
67+
f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"decoder.layers.{layer_idx}.self_attention.query.kernel",
68+
f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"decoder.layers.{layer_idx}.self_attention.key.kernel",
69+
f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"decoder.layers.{layer_idx}.self_attention.value.kernel",
70+
f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"decoder.layers.{layer_idx}.self_attention.out.kernel",
71+
f"model.layers.{layer_idx}.self_attn.q_norm.weight": f"decoder.layers.{layer_idx}.self_attention.query_norm.scale",
72+
f"model.layers.{layer_idx}.self_attn.k_norm.weight": f"decoder.layers.{layer_idx}.self_attention.key_norm.scale",
73+
f"model.layers.{layer_idx}.mlp.gate.weight": f"decoder.layers.{layer_idx}.moe_block.gate.kernel",
74+
})
7575

7676
# MoE expert mappings
7777
for i in range(num_experts):
78-
mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] = f"decoder.layers.{layer_idx}.moe_block.{i}.wi_0"
78+
mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] = (
79+
f"decoder.layers.{layer_idx}.moe_block.{i}.wi_0"
80+
)
7981
mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] = f"decoder.layers.{layer_idx}.moe_block.{i}.wi_1"
8082
mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] = f"decoder.layers.{layer_idx}.moe_block.{i}.wo"
8183

@@ -163,6 +165,7 @@ def convert_hf_to_maxtext(base_model_path: str, model_params: dict) -> dict:
163165
moe["wo"] = np.zeros((num_experts, num_layers, moe_intermediate_size, hidden_size), dtype=np.float16)
164166

165167
# Loop through layers and populate the stacked arrays
168+
# pylint: disable=unsupported-assignment-operation
166169
for l in tqdm(range(num_layers), desc="Stacking layer weights"):
167170
ln["pre_self_attention_layer_norm"]["scale"][l, :] = (
168171
chkpt_vars[f"decoder.layers.{l}.pre_self_attention_layer_norm.scale"].to(torch.float16).numpy()
@@ -268,5 +271,5 @@ def main(args):
268271
parser.add_argument("--use-ocdbt", type=str2bool, default=True, help="Use OCDBT format for saving.")
269272
parser.add_argument("--use-zarr3", type=str2bool, default=True, help="Use Zarr3 format for saving.")
270273

271-
args = parser.parse_args()
272-
main(args)
274+
parsed_args = parser.parse_args()
275+
main(parsed_args)

MaxText/layers/decoders.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from MaxText.common_types import DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
3131
from MaxText import max_logging
32-
from MaxText import max_utils
3332
from MaxText.inference import page_manager
3433
from MaxText.layers import linears
3534
from MaxText.layers import quantizations
@@ -444,14 +443,7 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
444443
length=length,
445444
metadata_params={nn.PARTITION_NAME: metadata_axis_name},
446445
)
447-
return scan_fn(
448-
config=cfg,
449-
mesh=mesh,
450-
name=metadata_axis_name,
451-
quant=self.quant,
452-
model_mode=model_mode,
453-
**kwargs
454-
)
446+
return scan_fn(config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, model_mode=model_mode, **kwargs)
455447

456448
def get_pipeline_stage_module(self, decoder_blocks):
457449
"""get pipeline stage module"""

0 commit comments

Comments
 (0)