|
| 1 | +""" |
| 2 | +Copyright 2025 Google LLC |
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +Unless required by applicable law or agreed to in writing, software |
| 8 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | +See the License for the specific language governing permissions and |
| 11 | +limitations under the License. |
| 12 | +""" |
| 13 | + |
| 14 | +r"""Convert weights from a Qwen3-MoE style model to a MaxText one. |
| 15 | +
|
| 16 | +This script rigorously follows the two-stage conversion process (map-then-transform) |
| 17 | +required for generating a MaxText checkpoint compatible with scanned model layers. |
| 18 | +
|
| 19 | +Example cmd: |
| 20 | +
|
| 21 | +python3 -m MaxText.convert_qwen3_moe_ckpt --base_model_path <path/to/hf/ckpt> \ |
| 22 | + --maxtext_model_path gs://<gcs_bucket>/<path/to/save/ckpt> --model_size qwen3-235b-a22b |
| 23 | +""" |
| 24 | + |
| 25 | +import argparse |
| 26 | +import gc |
| 27 | +import os |
| 28 | +import pathlib |
| 29 | + |
| 30 | +import numpy as np |
| 31 | +import torch |
| 32 | +from safetensors import safe_open |
| 33 | +from tqdm import tqdm |
| 34 | + |
| 35 | +from MaxText import llama_or_mistral_ckpt, max_logging |
| 36 | +from MaxText.inference_utils import str2bool |
| 37 | + |
| 38 | +# Static model parameters dictionary |
| 39 | +MODEL_PARAMS_DICT = { |
| 40 | + "qwen3-235b-a22b": { |
| 41 | + "num_hidden_layers": 94, |
| 42 | + "num_attention_heads": 64, |
| 43 | + "num_key_value_heads": 4, |
| 44 | + "hidden_size": 4096, |
| 45 | + "head_dim": 128, |
| 46 | + "num_experts": 128, |
| 47 | + "moe_intermediate_size": 1536, |
| 48 | + } |
| 49 | +} |
| 50 | + |
| 51 | + |
| 52 | +def hf_to_maxtext_mapping(layer_idx: int, num_experts: int) -> dict: |
| 53 | + """Creates a mapping from HF weight names to MaxText weight names.""" |
| 54 | + mapping = { |
| 55 | + "model.embed_tokens.weight": "token_embedder.embedding", |
| 56 | + "model.norm.weight": "decoder.decoder_norm.scale", |
| 57 | + "lm_head.weight": "decoder.logits_dense.kernel", |
| 58 | + } |
| 59 | + # Layer-specific mappings for a pure MoE/scanned model |
| 60 | + mapping.update({ |
| 61 | + f"model.layers.{layer_idx}.input_layernorm.weight": ( |
| 62 | + f"decoder.layers.{layer_idx}.pre_self_attention_layer_norm.scale" |
| 63 | + ), |
| 64 | + f"model.layers.{layer_idx}.post_attention_layernorm.weight": ( |
| 65 | + f"decoder.layers.{layer_idx}.post_self_attention_layer_norm.scale" |
| 66 | + ), |
| 67 | + f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"decoder.layers.{layer_idx}.self_attention.query.kernel", |
| 68 | + f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"decoder.layers.{layer_idx}.self_attention.key.kernel", |
| 69 | + f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"decoder.layers.{layer_idx}.self_attention.value.kernel", |
| 70 | + f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"decoder.layers.{layer_idx}.self_attention.out.kernel", |
| 71 | + f"model.layers.{layer_idx}.self_attn.q_norm.weight": f"decoder.layers.{layer_idx}.self_attention.query_norm.scale", |
| 72 | + f"model.layers.{layer_idx}.self_attn.k_norm.weight": f"decoder.layers.{layer_idx}.self_attention.key_norm.scale", |
| 73 | + f"model.layers.{layer_idx}.mlp.gate.weight": f"decoder.layers.{layer_idx}.moe_block.gate.kernel", |
| 74 | + }) |
| 75 | + |
| 76 | + # MoE expert mappings |
| 77 | + for i in range(num_experts): |
| 78 | + mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] = ( |
| 79 | + f"decoder.layers.{layer_idx}.moe_block.{i}.wi_0" |
| 80 | + ) |
| 81 | + mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] = f"decoder.layers.{layer_idx}.moe_block.{i}.wi_1" |
| 82 | + mapping[f"model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] = f"decoder.layers.{layer_idx}.moe_block.{i}.wo" |
| 83 | + |
| 84 | + return mapping |
| 85 | + |
| 86 | + |
| 87 | +def convert_hf_to_maxtext(base_model_path: str, model_params: dict) -> dict: |
| 88 | + """Converts a Hugging Face Qwen3-MoE checkpoint to a MaxText compatible format.""" |
| 89 | + num_layers = model_params["num_hidden_layers"] |
| 90 | + num_experts = model_params["num_experts"] |
| 91 | + hidden_size = model_params["hidden_size"] |
| 92 | + num_heads = model_params["num_attention_heads"] |
| 93 | + num_kv_heads = model_params["num_key_value_heads"] |
| 94 | + head_dim = model_params["head_dim"] |
| 95 | + moe_intermediate_size = model_params["moe_intermediate_size"] |
| 96 | + |
| 97 | + # Part 1: Load all weights from safetensors into a flat dictionary with MaxText names |
| 98 | + ckpt_paths = sorted(pathlib.Path(base_model_path).glob("*.safetensors")) |
| 99 | + chkpt_vars = {} |
| 100 | + for i, ckpt_path in enumerate(ckpt_paths): |
| 101 | + max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)}...") |
| 102 | + with safe_open(ckpt_path, framework="pt", device="cpu") as f: |
| 103 | + for key in f.keys(): |
| 104 | + if "layers" not in key and "embed_tokens" not in key and "norm" not in key and "lm_head" not in key: |
| 105 | + continue |
| 106 | + |
| 107 | + layer_idx_str = key.split(".")[2] if "layers" in key else "0" |
| 108 | + layer_idx = int(layer_idx_str) if layer_idx_str.isdigit() else 0 |
| 109 | + |
| 110 | + maxtext_key = hf_to_maxtext_mapping(layer_idx, num_experts).get(key) |
| 111 | + if maxtext_key: |
| 112 | + chkpt_vars[maxtext_key] = f.get_tensor(key) |
| 113 | + |
| 114 | + # Part 2: Initialize, populate, and transform the weights for MaxText |
| 115 | + maxtext_weights = { |
| 116 | + "decoder": { |
| 117 | + "layers": { |
| 118 | + "pre_self_attention_layer_norm": {"scale": None}, |
| 119 | + "post_self_attention_layer_norm": {"scale": None}, |
| 120 | + "self_attention": { |
| 121 | + "query": {"kernel": None}, |
| 122 | + "key": {"kernel": None}, |
| 123 | + "value": {"kernel": None}, |
| 124 | + "out": {"kernel": None}, |
| 125 | + "query_norm": {"scale": None}, |
| 126 | + "key_norm": {"scale": None}, |
| 127 | + }, |
| 128 | + "moe_block": { |
| 129 | + "gate": {"kernel": None}, |
| 130 | + "wi_0": None, |
| 131 | + "wi_1": None, |
| 132 | + "wo": None, |
| 133 | + }, |
| 134 | + }, |
| 135 | + "decoder_norm": {"scale": None}, |
| 136 | + "logits_dense": {"kernel": None}, |
| 137 | + }, |
| 138 | + "token_embedder": {"embedding": None}, |
| 139 | + } |
| 140 | + |
| 141 | + max_logging.log("Populating non-layer weights...") |
| 142 | + maxtext_weights["token_embedder"]["embedding"] = chkpt_vars["token_embedder.embedding"].to(torch.float16).numpy() |
| 143 | + maxtext_weights["decoder"]["decoder_norm"]["scale"] = chkpt_vars["decoder.decoder_norm.scale"].to(torch.float16).numpy() |
| 144 | + maxtext_weights["decoder"]["logits_dense"]["kernel"] = ( |
| 145 | + chkpt_vars["decoder.logits_dense.kernel"].to(torch.float16).numpy().transpose() |
| 146 | + ) |
| 147 | + |
| 148 | + max_logging.log("Allocating and stacking layer weights...") |
| 149 | + ln = maxtext_weights["decoder"]["layers"] |
| 150 | + s_attn = ln["self_attention"] |
| 151 | + moe = ln["moe_block"] |
| 152 | + |
| 153 | + # Pre-allocate stacked arrays with the 'layer' dimension first |
| 154 | + ln["pre_self_attention_layer_norm"]["scale"] = np.zeros((num_layers, hidden_size), dtype=np.float16) |
| 155 | + ln["post_self_attention_layer_norm"]["scale"] = np.zeros((num_layers, hidden_size), dtype=np.float16) |
| 156 | + s_attn["query"]["kernel"] = np.zeros((num_layers, hidden_size, num_heads, head_dim), dtype=np.float16) |
| 157 | + s_attn["key"]["kernel"] = np.zeros((num_layers, hidden_size, num_kv_heads, head_dim), dtype=np.float16) |
| 158 | + s_attn["value"]["kernel"] = np.zeros((num_layers, hidden_size, num_kv_heads, head_dim), dtype=np.float16) |
| 159 | + s_attn["out"]["kernel"] = np.zeros((num_layers, num_heads, head_dim, hidden_size), dtype=np.float16) |
| 160 | + s_attn["query_norm"]["scale"] = np.zeros((num_layers, head_dim), dtype=np.float16) |
| 161 | + s_attn["key_norm"]["scale"] = np.zeros((num_layers, head_dim), dtype=np.float16) |
| 162 | + moe["gate"]["kernel"] = np.zeros((num_layers, hidden_size, num_experts), dtype=np.float16) |
| 163 | + moe["wi_0"] = np.zeros((num_experts, num_layers, hidden_size, moe_intermediate_size), dtype=np.float16) |
| 164 | + moe["wi_1"] = np.zeros((num_experts, num_layers, hidden_size, moe_intermediate_size), dtype=np.float16) |
| 165 | + moe["wo"] = np.zeros((num_experts, num_layers, moe_intermediate_size, hidden_size), dtype=np.float16) |
| 166 | + |
| 167 | + # Loop through layers and populate the stacked arrays |
| 168 | + # pylint: disable=unsupported-assignment-operation |
| 169 | + for l in tqdm(range(num_layers), desc="Stacking layer weights"): |
| 170 | + ln["pre_self_attention_layer_norm"]["scale"][l, :] = ( |
| 171 | + chkpt_vars[f"decoder.layers.{l}.pre_self_attention_layer_norm.scale"].to(torch.float16).numpy() |
| 172 | + ) |
| 173 | + ln["post_self_attention_layer_norm"]["scale"][l, :] = ( |
| 174 | + chkpt_vars[f"decoder.layers.{l}.post_self_attention_layer_norm.scale"].to(torch.float16).numpy() |
| 175 | + ) |
| 176 | + |
| 177 | + s_attn["query"]["kernel"][l, ...] = ( |
| 178 | + chkpt_vars[f"decoder.layers.{l}.self_attention.query.kernel"] |
| 179 | + .to(torch.float16) |
| 180 | + .numpy() |
| 181 | + .transpose() |
| 182 | + .reshape(hidden_size, num_heads, head_dim) |
| 183 | + ) |
| 184 | + s_attn["key"]["kernel"][l, ...] = ( |
| 185 | + chkpt_vars[f"decoder.layers.{l}.self_attention.key.kernel"] |
| 186 | + .to(torch.float16) |
| 187 | + .numpy() |
| 188 | + .transpose() |
| 189 | + .reshape(hidden_size, num_kv_heads, head_dim) |
| 190 | + ) |
| 191 | + s_attn["value"]["kernel"][l, ...] = ( |
| 192 | + chkpt_vars[f"decoder.layers.{l}.self_attention.value.kernel"] |
| 193 | + .to(torch.float16) |
| 194 | + .numpy() |
| 195 | + .transpose() |
| 196 | + .reshape(hidden_size, num_kv_heads, head_dim) |
| 197 | + ) |
| 198 | + s_attn["out"]["kernel"][l, ...] = ( |
| 199 | + chkpt_vars[f"decoder.layers.{l}.self_attention.out.kernel"] |
| 200 | + .to(torch.float16) |
| 201 | + .numpy() |
| 202 | + .transpose() |
| 203 | + .reshape(num_heads, head_dim, hidden_size) |
| 204 | + ) |
| 205 | + |
| 206 | + s_attn["query_norm"]["scale"][l, ...] = ( |
| 207 | + chkpt_vars[f"decoder.layers.{l}.self_attention.query_norm.scale"].to(torch.float16).numpy() |
| 208 | + ) |
| 209 | + s_attn["key_norm"]["scale"][l, ...] = ( |
| 210 | + chkpt_vars[f"decoder.layers.{l}.self_attention.key_norm.scale"].to(torch.float16).numpy() |
| 211 | + ) |
| 212 | + |
| 213 | + moe["gate"]["kernel"][l, ...] = ( |
| 214 | + chkpt_vars[f"decoder.layers.{l}.moe_block.gate.kernel"].to(torch.float16).numpy().transpose() |
| 215 | + ) |
| 216 | + for i in range(num_experts): |
| 217 | + moe["wi_0"][i, l, ...] = chkpt_vars[f"decoder.layers.{l}.moe_block.{i}.wi_0"].to(torch.float16).numpy().transpose() |
| 218 | + moe["wi_1"][i, l, ...] = chkpt_vars[f"decoder.layers.{l}.moe_block.{i}.wi_1"].to(torch.float16).numpy().transpose() |
| 219 | + moe["wo"][i, l, ...] = chkpt_vars[f"decoder.layers.{l}.moe_block.{i}.wo"].to(torch.float16).numpy().transpose() |
| 220 | + |
| 221 | + # Final transformations for scanned weights (swap layer and feature axes) |
| 222 | + max_logging.log("Transposing layer weights for MaxText scanned format...") |
| 223 | + |
| 224 | + ln["pre_self_attention_layer_norm"]["scale"] = np.transpose(ln["pre_self_attention_layer_norm"]["scale"], axes=(1, 0)) |
| 225 | + ln["post_self_attention_layer_norm"]["scale"] = np.transpose(ln["post_self_attention_layer_norm"]["scale"], axes=(1, 0)) |
| 226 | + s_attn["query_norm"]["scale"] = np.transpose(s_attn["query_norm"]["scale"], axes=(1, 0)) |
| 227 | + s_attn["key_norm"]["scale"] = np.transpose(s_attn["key_norm"]["scale"], axes=(1, 0)) |
| 228 | + |
| 229 | + s_attn["query"]["kernel"] = np.transpose(s_attn["query"]["kernel"], axes=(1, 0, 2, 3)) |
| 230 | + s_attn["key"]["kernel"] = np.transpose(s_attn["key"]["kernel"], axes=(1, 0, 2, 3)) |
| 231 | + s_attn["value"]["kernel"] = np.transpose(s_attn["value"]["kernel"], axes=(1, 0, 2, 3)) |
| 232 | + s_attn["out"]["kernel"] = np.transpose(s_attn["out"]["kernel"], axes=(1, 0, 2, 3)) |
| 233 | + |
| 234 | + moe["gate"]["kernel"] = np.transpose(moe["gate"]["kernel"], axes=(1, 0, 2)) |
| 235 | + |
| 236 | + gc.collect() |
| 237 | + return maxtext_weights |
| 238 | + |
| 239 | + |
| 240 | +def main(args): |
| 241 | + """Main function to run the conversion.""" |
| 242 | + # Set up JAX simulated environment |
| 243 | + os.environ["JAX_PLATFORMS"] = "cpu" |
| 244 | + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}" |
| 245 | + |
| 246 | + if args.model_size not in MODEL_PARAMS_DICT: |
| 247 | + raise ValueError(f"Model size '{args.model_size}' not found in MODEL_PARAMS_DICT.") |
| 248 | + |
| 249 | + model_params = MODEL_PARAMS_DICT[args.model_size] |
| 250 | + max_logging.log(f"Starting conversion for Qwen3-MoE model size: {args.model_size}") |
| 251 | + jax_weights = convert_hf_to_maxtext(args.base_model_path, model_params) |
| 252 | + max_logging.log(f"Conversion complete. Saving MaxText checkpoint to {args.maxtext_model_path}") |
| 253 | + llama_or_mistral_ckpt.save_weights_to_checkpoint( |
| 254 | + args.maxtext_model_path, jax_weights, args.simulated_cpu_devices_count, args.use_ocdbt, args.use_zarr3 |
| 255 | + ) |
| 256 | + max_logging.log("Checkpoint saved successfully.") |
| 257 | + |
| 258 | + |
| 259 | +if __name__ == "__main__": |
| 260 | + parser = argparse.ArgumentParser(description="Convert Qwen3-MoE HF weights to MaxText.") |
| 261 | + parser.add_argument("--base_model_path", type=str, required=True, help="Path to the HF Qwen3-MoE checkpoint files.") |
| 262 | + parser.add_argument( |
| 263 | + "--maxtext_model_path", type=str, required=True, help="Path to save the MaxText checkpoint (local or GCS)." |
| 264 | + ) |
| 265 | + parser.add_argument( |
| 266 | + "--model_size", type=str, required=True, choices=MODEL_PARAMS_DICT.keys(), help="The model size to convert." |
| 267 | + ) |
| 268 | + parser.add_argument( |
| 269 | + "--simulated_cpu_devices_count", type=int, default=16, help="Number of simulated CPU devices for saving." |
| 270 | + ) |
| 271 | + parser.add_argument("--use-ocdbt", type=str2bool, default=True, help="Use OCDBT format for saving.") |
| 272 | + parser.add_argument("--use-zarr3", type=str2bool, default=True, help="Use Zarr3 format for saving.") |
| 273 | + |
| 274 | + parsed_args = parser.parse_args() |
| 275 | + main(parsed_args) |
0 commit comments