Skip to content

Conversation

parambole
Copy link
Collaborator

@parambole parambole commented Aug 6, 2025

Feat: Add support for Qwen3-MoE Model

TL;DR

  • What: This PR integrates the Qwen3-MoE architecture into MaxText, with a specific implementation for the Qwen3-235B-A22B model. It includes a new MoE decoder layer and a dedicated checkpoint conversion script.

  • Why: To enable the pre-training of the open-source Qwen3 family of Mixture-of-Experts models within the MaxText framework.

  • How: By introducing a Qwen3MoeDecoderLayer, adding a qwen3-235b-a22b.yml configuration file, implementing a norm_topk_prob flag for Qwen-specific router logic, and creating a new convert_qwen3_moe.py script to convert Hugging Face checkpoints.


Detailed Description

Background and Motivation

This pull request adds support for the Qwen3-MoE family of models, as introduced in the technical report: Qwen3: A More Powerful and General Base Model Series. This initial integration focuses on the Qwen3-235B-A22B variant, a large-scale Mixture-of-Experts (MoE) model.

Architectural Changes

To support the Qwen3-MoE architecture, the following key changes have been implemented:

  1. Qwen3 Decoder Layers (MaxText/layers/qwen3.py):

    • The existing Qwen3 module has been refactored to support both dense and MoE variants.
    • A self_attention_with_norm helper function was extracted to share the attention block logic between both decoder types.
    • A new Qwen3MoeDecoderLayer has been added. It utilizes the shared self-attention block and integrates with the existing moe.get_routed_moe function for the expert layers.
  2. MoE Router Normalization (MaxText/layers/moe.py):

    • A new boolean configuration flag, norm_topk_prob, has been added.
    • When enabled, this flag applies a Qwen-specific normalization to the router probabilities by dividing the softmax output by its sum. This is essential for matching the original model's behavior.
  3. New Model Configuration (MaxText/configs/models/qwen3-235b-a22b.yml):

    • A new YAML configuration file is provided to define the specific architectural parameters for the Qwen3-235B-A22B model, including num_experts, num_experts_per_tok, base_moe_mlp_dim, and use_qk_norm.
  4. Checkpoint Conversion Script (MaxText/convert_qwen3_moe.py):

    • A new, standalone conversion script has been added to convert official Hugging Face checkpoints for the Qwen3-235B-A22B model into the MaxText format.
    • This script handles the mapping of weights, including the complex stacking of MoE layers, and performs the necessary transpositions to make them compatible with MaxText's scanned layer format.
  5. Framework Integration (MaxText/common_types.py, MaxText/layers/decoders.py):

    • The DecoderBlockType enum has been updated with QWEN3_MOE.
    • The decoder factory in decoders.py has been updated to correctly instantiate Qwen3MoeDecoderLayer when decoder_block: "qwen3_moe" is set in the config.

How to Use

1. Convert Hugging Face Checkpoint

First, convert the original Hugging Face checkpoint to the MaxText format using the new script.

python3 -m MaxText.convert_qwen3_moe\
  --base_model_path /path/to/hf/Qwen3-235B-A22B-Thinking-2507\
  --maxtext_model_path gs://your-gcs-bucket/qwen3/scanned/step_0/\
  --model_size qwen3-235b-a22b

2. Run Training

Once the checkpoint is converted, you can run training workloads by specifying the model name and loading the converted weights.

Bash

# Example for running validation
JAX_PLATFORMS=cpu python3 -m MaxText.tests.forward_pass_logit_checker MaxText/configs/base.yml\
  tokenizer_path=assets/qwen3-tokenizer\
  load_parameters_path=gs://your-gcs-bucket/qwen3/scanned/step_0/items/\
  model_name=qwen3-235b-a22b\
  --hf_model_path=Qwen/Qwen3-235B-A22B-Thinking-2507


Testing Strategy & Validation

The correctness of the implementation and the conversion script was validated by comparing the logits produced by the MaxText model against the original Hugging Face model (Qwen/Qwen3-235B-A22B-Thinking-2507).

The forward_pass_logit_checker test was used with several prompts.

Key Validation Metrics:

Prompt Overlap Count Jaccard Similarity Rank Agreement % Avg KL Divergence
"I love to" 10/10 1.0 40.0% 0.000254
"Today is a" 10/10 1.0 80.0% 0.001808
"What is the" 10/10 1.0 100.0% -0.000182

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests and provided the validation results above.
  • I have made or will make corresponding changes to the documentation if needed.

@parambole parambole force-pushed the parambole/maxtext_qwen_moe branch 2 times, most recently from 3585f64 to 1f8cb84 Compare August 11, 2025 17:19
@parambole parambole marked this pull request as ready for review August 11, 2025 17:36
@parambole parambole changed the title Qwen3 MoE Add support for Qwen3-MoE Model Aug 11, 2025
Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, one comment to follow latest method of checkpoint conversion through param mapping.

@parambole parambole force-pushed the parambole/maxtext_qwen_moe branch from 97d289a to 29e1e72 Compare August 11, 2025 18:54
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few comments.

Quick double check, the fine-tuning end-to-end is verified right?

# -----------------------------------------
# The MoE Decoder Layer for Qwen3
# -----------------------------------------
class Qwen3MoeDecoderLayer(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if we could directly use nnx.module? otherwise, we will migrate it again. One example here: main...gpt_layer_nnx. However, I met some memory overheads. cc @bvandermoon @cgarciae if you have some ideas.

But I am totally fine if we cannot refactor to nnx at this moment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if we could directly use nnx.module?

That is a great call-out. TBH I am not sure. @bvandermoon any suggestions ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RissyRan where did you see memory overheads? We are currently debugging additional memory during inference

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bvandermoon, please check this comment for details. I probably missed something there in NNX version for training.

# NN:
Total memory size: 57.9 GB, Output size: 30.7 GB, Temp size: 27.2 GB, Argument size: 30.7 GB, Host temp size: 0.0 GB.

# NNX:
Total memory size: 79.8 GB, Output size: 30.7 GB, Temp size: 49.1 GB, Argument size: 30.7 GB, Host temp size: 0.0 GB.

@parambole If we still have issues about decoding, totally fine to push nn.module for Qwen3 at this moment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @RissyRan. Followed up on the bug.

For your comment around implementing this in NNX - ideally we would do that. Unfortunately we are in the process of migrating the first decoder layer currently: #2178.

When do you need this by? If we can wait for that PR to merge, then NNX makes sense here and will avoid the fast-follow migration needed (it will probably need to be migrated next week)

@parambole parambole force-pushed the parambole/maxtext_qwen_moe branch from 704f597 to ee98113 Compare August 12, 2025 02:07
@parambole
Copy link
Collaborator Author

Quick double check, the fine-tuning end-to-end is verified right?

@RissyRan Can you please elaborate ?

@parambole parambole force-pushed the parambole/maxtext_qwen_moe branch from aa39b95 to 2b83085 Compare August 13, 2025 01:28
@RissyRan
Copy link
Collaborator

Quick double check, the fine-tuning end-to-end is verified right?

@RissyRan Can you please elaborate ?

Oh, I meant to say if you have loaded the ckpt, and run SFT for some steps?

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just some minor comments, and please address other comments if any. Great work!

@parambole parambole force-pushed the parambole/maxtext_qwen_moe branch from dd912b7 to dd80c34 Compare August 13, 2025 21:06
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Param, left one small change to make

@parambole parambole force-pushed the parambole/maxtext_qwen_moe branch 2 times, most recently from 29996a8 to 1d170ac Compare August 13, 2025 21:31
@parambole parambole force-pushed the parambole/maxtext_qwen_moe branch from d895dcc to 64c6087 Compare August 13, 2025 22:33
@parambole parambole force-pushed the parambole/maxtext_qwen_moe branch from 64c6087 to bf4df75 Compare August 13, 2025 23:28
@copybara-service copybara-service bot merged commit 4dce50f into main Aug 14, 2025
18 checks passed
@copybara-service copybara-service bot deleted the parambole/maxtext_qwen_moe branch August 14, 2025 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants