Skip to content

Commit 29e1e72

Browse files
committed
Add support for Qwen3-MoE Model
1 parent 3078cd3 commit 29e1e72

File tree

10 files changed

+662
-113
lines changed

10 files changed

+662
-113
lines changed

MaxText/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class DecoderBlockType(enum.Enum):
8686
GEMMA2 = "gemma2"
8787
GEMMA3 = "gemma3"
8888
QWEN3 = "qwen3"
89+
QWEN3_MOE = "qwen3_moe"
8990
GPT3 = "gpt3"
9091
SIMPLE = "simple"
9192
SIMPLE_MLP = "simple_mlp"

MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ use_random_routing: False # whether to use random routing for debug/test purpose
168168
tile_batch_seq: 512
169169
tile_activation_dim: 1024
170170
tile_weight_dim: 1024
171+
norm_topk_prob: False # Boolean to enable the top-k probability normalization.
171172

172173
# How the expert axis is used to shard attention weights and activations
173174
# "fsdp" (ep acts as fsdp parallelism)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Model config for Qwen3-235B-A22B
16+
17+
# Core Architectural Parameters
18+
base_emb_dim: 4096
19+
base_num_query_heads: 64
20+
base_num_kv_heads: 4
21+
base_num_decoder_layers: 94
22+
head_dim: 128
23+
mlp_activations: ["silu", "linear"]
24+
vocab_size: 151936
25+
normalization_layer_epsilon: 1.0e-6
26+
use_qk_norm: True
27+
28+
# MoE Specific Parameters
29+
decoder_block: "qwen3_moe"
30+
num_experts: 128
31+
num_experts_per_tok: 8
32+
base_moe_mlp_dim: 1536
33+
load_balance_loss_weight: 0.001
34+
norm_topk_prob: true
35+
36+
# RoPE Settings
37+
rope_max_timescale: 5000000
38+
39+
# General Model Settings
40+
enable_dropout: False

MaxText/convert_qwen3_moe.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
https://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
r"""Convert weights from a Qwen3-MoE style model to a MaxText one.
15+
16+
This script rigorously follows the two-stage conversion process (map-then-transform)
17+
required for generating a MaxText checkpoint compatible with scanned model layers.
18+
19+
Example cmd:
20+
21+
python3 -m MaxText.convert_qwen3_moe_ckpt --base_model_path <path/to/hf/ckpt> \
22+
--maxtext_model_path gs://<gcs_bucket>/<path/to/save/ckpt> --model_size qwen3-235b-a22b
23+
"""
24+
25+
import argparse
26+
import gc
27+
import os
28+
import pathlib
29+
30+
import numpy as np
31+
import torch
32+
from safetensors import safe_open
33+
from tqdm import tqdm
34+
35+
from MaxText import llama_or_mistral_ckpt, max_logging
36+
from MaxText.inference_utils import str2bool
37+
38+
# Static model parameters dictionary
39+
MODEL_PARAMS_DICT = {
40+
"qwen3-235b-a22b": {
41+
"num_hidden_layers": 94,
42+
"num_attention_heads": 64,
43+
"num_key_value_heads": 4,
44+
"hidden_size": 4096,
45+
"head_dim": 128,
46+
"num_experts": 128,
47+
"moe_intermediate_size": 1536,
48+
}
49+
}
50+
51+
52+
def hf_to_maxtext_mapping(layer_idx: int, num_experts: int) -> dict:
53+
"""Creates a mapping from HF weight names to MaxText weight names."""
54+
mapping = {
55+
"model.embed_tokens.weight": "token_embedder.embedding",
56+
"model.norm.weight": "decoder.decoder_norm.scale",
57+
"lm_head.weight": "decoder.logits_dense.kernel",
58+
}
59+
# Layer-specific mappings for a pure MoE/scanned model
60+
mapping.update({
61+
f"model.layers.{layer_idx}.input_layernorm.weight": (
62+
f"decoder.layers.{layer_idx}.pre_self_attention_layer_norm.scale"
63+
),
64+
f"model.layers.{layer_idx}.post_attention_layernorm.weight": (
65+
f"decoder.layers.{layer_idx}.post_self_attention_layer_norm.scale"
66+
),
67+
f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"decoder.layers.{layer_idx}.self_attention.query.kernel",
68+
f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"decoder.layers.{layer_idx}.self_attention.key.kernel",
69+
f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"decoder.layers.{layer_idx}.self_attention.value.kernel",
70+
f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"decoder.layers.{layer_idx}.self_attention.out.kernel",
71+
f"model.layers.{layer_idx}.self_attn.q_norm.weight": f"decoder.layers.{layer_idx}.self_attention.query_norm.scale",
72+
f"model.layers.{layer_idx}.self_attn.k_norm.weight": f"decoder.layers.{layer_idx}.self_attention.key_norm.scale",
73+
f"model.layers.{layer_idx}.mlp.gate.weight": f"decoder.layers.{layer_idx}.moe_block.gate.kernel",
74+
})
75+
76+
# MoE expert mappings
77+
for i in range(num_experts):
78+
mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] = (
79+
f"decoder.layers.{layer_idx}.moe_block.{i}.wi_0"
80+
)
81+
mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] = f"decoder.layers.{layer_idx}.moe_block.{i}.wi_1"
82+
mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] = f"decoder.layers.{layer_idx}.moe_block.{i}.wo"
83+
84+
return mapping
85+
86+
87+
def convert_hf_to_maxtext(base_model_path: str, model_params: dict) -> dict:
88+
"""Converts a Hugging Face Qwen3-MoE checkpoint to a MaxText compatible format."""
89+
num_layers = model_params["num_hidden_layers"]
90+
num_experts = model_params["num_experts"]
91+
hidden_size = model_params["hidden_size"]
92+
num_heads = model_params["num_attention_heads"]
93+
num_kv_heads = model_params["num_key_value_heads"]
94+
head_dim = model_params["head_dim"]
95+
moe_intermediate_size = model_params["moe_intermediate_size"]
96+
97+
# Part 1: Load all weights from safetensors into a flat dictionary with MaxText names
98+
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("*.safetensors"))
99+
chkpt_vars = {}
100+
for i, ckpt_path in enumerate(ckpt_paths):
101+
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)}...")
102+
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
103+
for key in f.keys():
104+
if "layers" not in key and "embed_tokens" not in key and "norm" not in key and "lm_head" not in key:
105+
continue
106+
107+
layer_idx_str = key.split(".")[2] if "layers" in key else "0"
108+
layer_idx = int(layer_idx_str) if layer_idx_str.isdigit() else 0
109+
110+
maxtext_key = hf_to_maxtext_mapping(layer_idx, num_experts).get(key)
111+
if maxtext_key:
112+
chkpt_vars[maxtext_key] = f.get_tensor(key)
113+
114+
# Part 2: Initialize, populate, and transform the weights for MaxText
115+
maxtext_weights = {
116+
"decoder": {
117+
"layers": {
118+
"pre_self_attention_layer_norm": {"scale": None},
119+
"post_self_attention_layer_norm": {"scale": None},
120+
"self_attention": {
121+
"query": {"kernel": None},
122+
"key": {"kernel": None},
123+
"value": {"kernel": None},
124+
"out": {"kernel": None},
125+
"query_norm": {"scale": None},
126+
"key_norm": {"scale": None},
127+
},
128+
"moe_block": {
129+
"gate": {"kernel": None},
130+
"wi_0": None,
131+
"wi_1": None,
132+
"wo": None,
133+
},
134+
},
135+
"decoder_norm": {"scale": None},
136+
"logits_dense": {"kernel": None},
137+
},
138+
"token_embedder": {"embedding": None},
139+
}
140+
141+
max_logging.log("Populating non-layer weights...")
142+
maxtext_weights["token_embedder"]["embedding"] = chkpt_vars["token_embedder.embedding"].to(torch.float16).numpy()
143+
maxtext_weights["decoder"]["decoder_norm"]["scale"] = chkpt_vars["decoder.decoder_norm.scale"].to(torch.float16).numpy()
144+
maxtext_weights["decoder"]["logits_dense"]["kernel"] = (
145+
chkpt_vars["decoder.logits_dense.kernel"].to(torch.float16).numpy().transpose()
146+
)
147+
148+
max_logging.log("Allocating and stacking layer weights...")
149+
ln = maxtext_weights["decoder"]["layers"]
150+
s_attn = ln["self_attention"]
151+
moe = ln["moe_block"]
152+
153+
# Pre-allocate stacked arrays with the 'layer' dimension first
154+
ln["pre_self_attention_layer_norm"]["scale"] = np.zeros((num_layers, hidden_size), dtype=np.float16)
155+
ln["post_self_attention_layer_norm"]["scale"] = np.zeros((num_layers, hidden_size), dtype=np.float16)
156+
s_attn["query"]["kernel"] = np.zeros((num_layers, hidden_size, num_heads, head_dim), dtype=np.float16)
157+
s_attn["key"]["kernel"] = np.zeros((num_layers, hidden_size, num_kv_heads, head_dim), dtype=np.float16)
158+
s_attn["value"]["kernel"] = np.zeros((num_layers, hidden_size, num_kv_heads, head_dim), dtype=np.float16)
159+
s_attn["out"]["kernel"] = np.zeros((num_layers, num_heads, head_dim, hidden_size), dtype=np.float16)
160+
s_attn["query_norm"]["scale"] = np.zeros((num_layers, head_dim), dtype=np.float16)
161+
s_attn["key_norm"]["scale"] = np.zeros((num_layers, head_dim), dtype=np.float16)
162+
moe["gate"]["kernel"] = np.zeros((num_layers, hidden_size, num_experts), dtype=np.float16)
163+
moe["wi_0"] = np.zeros((num_experts, num_layers, hidden_size, moe_intermediate_size), dtype=np.float16)
164+
moe["wi_1"] = np.zeros((num_experts, num_layers, hidden_size, moe_intermediate_size), dtype=np.float16)
165+
moe["wo"] = np.zeros((num_experts, num_layers, moe_intermediate_size, hidden_size), dtype=np.float16)
166+
167+
# Loop through layers and populate the stacked arrays
168+
# pylint: disable=unsupported-assignment-operation
169+
for l in tqdm(range(num_layers), desc="Stacking layer weights"):
170+
ln["pre_self_attention_layer_norm"]["scale"][l, :] = (
171+
chkpt_vars[f"decoder.layers.{l}.pre_self_attention_layer_norm.scale"].to(torch.float16).numpy()
172+
)
173+
ln["post_self_attention_layer_norm"]["scale"][l, :] = (
174+
chkpt_vars[f"decoder.layers.{l}.post_self_attention_layer_norm.scale"].to(torch.float16).numpy()
175+
)
176+
177+
s_attn["query"]["kernel"][l, ...] = (
178+
chkpt_vars[f"decoder.layers.{l}.self_attention.query.kernel"]
179+
.to(torch.float16)
180+
.numpy()
181+
.transpose()
182+
.reshape(hidden_size, num_heads, head_dim)
183+
)
184+
s_attn["key"]["kernel"][l, ...] = (
185+
chkpt_vars[f"decoder.layers.{l}.self_attention.key.kernel"]
186+
.to(torch.float16)
187+
.numpy()
188+
.transpose()
189+
.reshape(hidden_size, num_kv_heads, head_dim)
190+
)
191+
s_attn["value"]["kernel"][l, ...] = (
192+
chkpt_vars[f"decoder.layers.{l}.self_attention.value.kernel"]
193+
.to(torch.float16)
194+
.numpy()
195+
.transpose()
196+
.reshape(hidden_size, num_kv_heads, head_dim)
197+
)
198+
s_attn["out"]["kernel"][l, ...] = (
199+
chkpt_vars[f"decoder.layers.{l}.self_attention.out.kernel"]
200+
.to(torch.float16)
201+
.numpy()
202+
.transpose()
203+
.reshape(num_heads, head_dim, hidden_size)
204+
)
205+
206+
s_attn["query_norm"]["scale"][l, ...] = (
207+
chkpt_vars[f"decoder.layers.{l}.self_attention.query_norm.scale"].to(torch.float16).numpy()
208+
)
209+
s_attn["key_norm"]["scale"][l, ...] = (
210+
chkpt_vars[f"decoder.layers.{l}.self_attention.key_norm.scale"].to(torch.float16).numpy()
211+
)
212+
213+
moe["gate"]["kernel"][l, ...] = (
214+
chkpt_vars[f"decoder.layers.{l}.moe_block.gate.kernel"].to(torch.float16).numpy().transpose()
215+
)
216+
for i in range(num_experts):
217+
moe["wi_0"][i, l, ...] = chkpt_vars[f"decoder.layers.{l}.moe_block.{i}.wi_0"].to(torch.float16).numpy().transpose()
218+
moe["wi_1"][i, l, ...] = chkpt_vars[f"decoder.layers.{l}.moe_block.{i}.wi_1"].to(torch.float16).numpy().transpose()
219+
moe["wo"][i, l, ...] = chkpt_vars[f"decoder.layers.{l}.moe_block.{i}.wo"].to(torch.float16).numpy().transpose()
220+
221+
# Final transformations for scanned weights (swap layer and feature axes)
222+
max_logging.log("Transposing layer weights for MaxText scanned format...")
223+
224+
ln["pre_self_attention_layer_norm"]["scale"] = np.transpose(ln["pre_self_attention_layer_norm"]["scale"], axes=(1, 0))
225+
ln["post_self_attention_layer_norm"]["scale"] = np.transpose(ln["post_self_attention_layer_norm"]["scale"], axes=(1, 0))
226+
s_attn["query_norm"]["scale"] = np.transpose(s_attn["query_norm"]["scale"], axes=(1, 0))
227+
s_attn["key_norm"]["scale"] = np.transpose(s_attn["key_norm"]["scale"], axes=(1, 0))
228+
229+
s_attn["query"]["kernel"] = np.transpose(s_attn["query"]["kernel"], axes=(1, 0, 2, 3))
230+
s_attn["key"]["kernel"] = np.transpose(s_attn["key"]["kernel"], axes=(1, 0, 2, 3))
231+
s_attn["value"]["kernel"] = np.transpose(s_attn["value"]["kernel"], axes=(1, 0, 2, 3))
232+
s_attn["out"]["kernel"] = np.transpose(s_attn["out"]["kernel"], axes=(1, 0, 2, 3))
233+
234+
moe["gate"]["kernel"] = np.transpose(moe["gate"]["kernel"], axes=(1, 0, 2))
235+
236+
gc.collect()
237+
return maxtext_weights
238+
239+
240+
def main(args):
241+
"""Main function to run the conversion."""
242+
# Set up JAX simulated environment
243+
os.environ["JAX_PLATFORMS"] = "cpu"
244+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}"
245+
246+
if args.model_size not in MODEL_PARAMS_DICT:
247+
raise ValueError(f"Model size '{args.model_size}' not found in MODEL_PARAMS_DICT.")
248+
249+
model_params = MODEL_PARAMS_DICT[args.model_size]
250+
max_logging.log(f"Starting conversion for Qwen3-MoE model size: {args.model_size}")
251+
jax_weights = convert_hf_to_maxtext(args.base_model_path, model_params)
252+
max_logging.log(f"Conversion complete. Saving MaxText checkpoint to {args.maxtext_model_path}")
253+
llama_or_mistral_ckpt.save_weights_to_checkpoint(
254+
args.maxtext_model_path, jax_weights, args.simulated_cpu_devices_count, args.use_ocdbt, args.use_zarr3
255+
)
256+
max_logging.log("Checkpoint saved successfully.")
257+
258+
259+
if __name__ == "__main__":
260+
parser = argparse.ArgumentParser(description="Convert Qwen3-MoE HF weights to MaxText.")
261+
parser.add_argument("--base_model_path", type=str, required=True, help="Path to the HF Qwen3-MoE checkpoint files.")
262+
parser.add_argument(
263+
"--maxtext_model_path", type=str, required=True, help="Path to save the MaxText checkpoint (local or GCS)."
264+
)
265+
parser.add_argument(
266+
"--model_size", type=str, required=True, choices=MODEL_PARAMS_DICT.keys(), help="The model size to convert."
267+
)
268+
parser.add_argument(
269+
"--simulated_cpu_devices_count", type=int, default=16, help="Number of simulated CPU devices for saving."
270+
)
271+
parser.add_argument("--use-ocdbt", type=str2bool, default=True, help="Use OCDBT format for saving.")
272+
parser.add_argument("--use-zarr3", type=str2bool, default=True, help="Use Zarr3 format for saving.")
273+
274+
parsed_args = parser.parse_args()
275+
main(parsed_args)

MaxText/layers/decoders.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from MaxText.common_types import DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
3131
from MaxText import max_logging
32-
from MaxText import max_utils
3332
from MaxText.inference import page_manager
3433
from MaxText.layers import linears
3534
from MaxText.layers import quantizations
@@ -359,6 +358,8 @@ def get_decoder_layers(self):
359358
return [gpt3.Gpt3DecoderLayer]
360359
case DecoderBlockType.QWEN3:
361360
return [qwen3.Qwen3DecoderLayer]
361+
case DecoderBlockType.QWEN3_MOE:
362+
return [qwen3.Qwen3MoeDecoderLayer]
362363
case DecoderBlockType.SIMPLE:
363364
return [simple_layer.SimpleDecoderLayer]
364365
case DecoderBlockType.SIMPLE_MLP:
@@ -380,9 +381,7 @@ def move_to_device(variables):
380381

381382
def map_fn(path, value):
382383
max_logging.log(f"models.py: Moving parameter {path} to device")
383-
return jax.device_put(
384-
value, max_utils.device_space()
385-
)
384+
return jax.device_put(value, jax.memory.Space.Device)
386385

387386
return jax.tree_util.tree_map_with_path(map_fn, variables)
388387

@@ -411,6 +410,7 @@ def get_norm_layer(self, num_features: int):
411410
DecoderBlockType.GEMMA2,
412411
DecoderBlockType.GEMMA3,
413412
DecoderBlockType.QWEN3,
413+
DecoderBlockType.QWEN3_MOE,
414414
DecoderBlockType.SIMPLE,
415415
DecoderBlockType.SIMPLE_MLP,
416416
DecoderBlockType.LLAMA4,
@@ -443,14 +443,7 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
443443
length=length,
444444
metadata_params={nn.PARTITION_NAME: metadata_axis_name},
445445
)
446-
return scan_fn(
447-
config=cfg,
448-
mesh=mesh,
449-
name=metadata_axis_name,
450-
quant=self.quant,
451-
model_mode=model_mode,
452-
**kwargs
453-
)
446+
return scan_fn(config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, model_mode=model_mode, **kwargs)
454447

455448
def get_pipeline_stage_module(self, decoder_blocks):
456449
"""get pipeline stage module"""

MaxText/layers/moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,11 @@ def get_topk(self, gate_logits, pre_bias_logits):
341341
top_k_weights = self.deepseek_scale_weights(top_k_weights)
342342
elif self.config.decoder_block != ctypes.DecoderBlockType.LLAMA4:
343343
top_k_weights = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype)
344+
345+
# This is the Qwen3-specific normalization of router weights.
346+
if self.config.norm_topk_prob:
347+
top_k_weights /= top_k_weights.sum(axis=-1, keepdims=True)
348+
344349
return top_k_weights, top_k_indices
345350

346351
def deepseek_scale_weights(self, weights):

0 commit comments

Comments
 (0)