-
Notifications
You must be signed in to change notification settings - Fork 416
Add support for Qwen3-MoE Model #2092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
3585f64
to
1f8cb84
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, one comment to follow latest method of checkpoint conversion through param mapping.
97d289a
to
29e1e72
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few comments.
Quick double check, the fine-tuning end-to-end is verified right?
# ----------------------------------------- | ||
# The MoE Decoder Layer for Qwen3 | ||
# ----------------------------------------- | ||
class Qwen3MoeDecoderLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if we could directly use nnx.module? otherwise, we will migrate it again. One example here: main...gpt_layer_nnx. However, I met some memory overheads. cc @bvandermoon @cgarciae if you have some ideas.
But I am totally fine if we cannot refactor to nnx at this moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if we could directly use nnx.module?
That is a great call-out. TBH I am not sure. @bvandermoon any suggestions ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RissyRan where did you see memory overheads? We are currently debugging additional memory during inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bvandermoon, please check this comment for details. I probably missed something there in NNX version for training.
# NN:
Total memory size: 57.9 GB, Output size: 30.7 GB, Temp size: 27.2 GB, Argument size: 30.7 GB, Host temp size: 0.0 GB.
# NNX:
Total memory size: 79.8 GB, Output size: 30.7 GB, Temp size: 49.1 GB, Argument size: 30.7 GB, Host temp size: 0.0 GB.
@parambole If we still have issues about decoding, totally fine to push nn.module for Qwen3 at this moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @RissyRan. Followed up on the bug.
For your comment around implementing this in NNX - ideally we would do that. Unfortunately we are in the process of migrating the first decoder layer currently: #2178.
When do you need this by? If we can wait for that PR to merge, then NNX makes sense here and will avoid the fast-follow migration needed (it will probably need to be migrated next week)
704f597
to
ee98113
Compare
@RissyRan Can you please elaborate ? |
aa39b95
to
2b83085
Compare
Oh, I meant to say if you have loaded the ckpt, and run SFT for some steps? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just some minor comments, and please address other comments if any. Great work!
dd912b7
to
dd80c34
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Param, left one small change to make
29996a8
to
1d170ac
Compare
d895dcc
to
64c6087
Compare
64c6087
to
bf4df75
Compare
Feat: Add support for Qwen3-MoE Model
TL;DR
What: This PR integrates the Qwen3-MoE architecture into MaxText, with a specific implementation for the
Qwen3-235B-A22B
model. It includes a new MoE decoder layer and a dedicated checkpoint conversion script.Why: To enable the pre-training of the open-source Qwen3 family of Mixture-of-Experts models within the MaxText framework.
How: By introducing a
Qwen3MoeDecoderLayer
, adding aqwen3-235b-a22b.yml
configuration file, implementing anorm_topk_prob
flag for Qwen-specific router logic, and creating a newconvert_qwen3_moe.py
script to convert Hugging Face checkpoints.Detailed Description
Background and Motivation
This pull request adds support for the Qwen3-MoE family of models, as introduced in the technical report: Qwen3: A More Powerful and General Base Model Series. This initial integration focuses on the
Qwen3-235B-A22B
variant, a large-scale Mixture-of-Experts (MoE) model.Architectural Changes
To support the Qwen3-MoE architecture, the following key changes have been implemented:
Qwen3 Decoder Layers (
MaxText/layers/qwen3.py
):self_attention_with_norm
helper function was extracted to share the attention block logic between both decoder types.Qwen3MoeDecoderLayer
has been added. It utilizes the shared self-attention block and integrates with the existingmoe.get_routed_moe
function for the expert layers.MoE Router Normalization (
MaxText/layers/moe.py
):norm_topk_prob
, has been added.New Model Configuration (
MaxText/configs/models/qwen3-235b-a22b.yml
):Qwen3-235B-A22B
model, includingnum_experts
,num_experts_per_tok
,base_moe_mlp_dim
, anduse_qk_norm
.Checkpoint Conversion Script (
MaxText/convert_qwen3_moe.py
):Qwen3-235B-A22B
model into the MaxText format.Framework Integration (
MaxText/common_types.py
,MaxText/layers/decoders.py
):DecoderBlockType
enum has been updated withQWEN3_MOE
.decoders.py
has been updated to correctly instantiateQwen3MoeDecoderLayer
whendecoder_block: "qwen3_moe"
is set in the config.How to Use
1. Convert Hugging Face Checkpoint
First, convert the original Hugging Face checkpoint to the MaxText format using the new script.
2. Run Training
Once the checkpoint is converted, you can run training workloads by specifying the model name and loading the converted weights.
Bash
Testing Strategy & Validation
The correctness of the implementation and the conversion script was validated by comparing the logits produced by the MaxText model against the original Hugging Face model (
Qwen/Qwen3-235B-A22B-Thinking-2507
).The
forward_pass_logit_checker
test was used with several prompts.Key Validation Metrics:
Checklist
Before submitting this PR, please make sure (put X in square brackets):