-
Notifications
You must be signed in to change notification settings - Fork 88
Muon with QK-Clip support #1198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @zhtmike, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces the Muon optimizer with QK-Clip support, primarily aimed at optimizing attention mechanisms in large language models. The implementation addresses the issue of QK value explosion, as observed in a modified Qwen2 model, by applying a clipping mechanism to the Query-Key products. This enhancement is designed to improve the stability and performance of models utilizing such attention architectures.
Highlights
- New Muon Optimizer: I've added a new
Muon
optimizer class (mindone/trainers/muon.py
) that incorporates the Muon optimization algorithm, including azeropower_via_newtonschulz5
function for orthogonalization. - QK-Clip Mechanism: I've implemented a QK-clip mechanism within the
Muon
optimizer. This mechanism dynamically scales down Query and Key projections based on their QK product to prevent value explosion. - MLA Attention Integration: I've introduced a
Qwen2MLAAttention
class inmindone/transformers/models/qwen2/modeling_qwen2.py
to support the new attention mechanism. This allows for conditional use of MLA attention when theUSE_MLA
environment variable is set. - QK Product Exposure: The
eager_attention_forward
function and theQwen2DecoderLayer
in the Qwen2 model now return the Query-Key product, making it accessible for the QK-clip mechanism in the optimizer. - Refined Weight Initialization: I've updated the
_init_weights
method inQwen2PreTrainedModel
to usenormal_
andzeros_
for more consistent weight initialization. - Toy Training Scripts: I've included new toy training scripts for both MindSpore and PyTorch (
tests/trainer_tests/muon/mindspore/toy_train_ms.py
andtests/trainer_tests/muon/torch/toy_train.py
) to demonstrate the usage and effectiveness of the Muon optimizer with QK-Clip.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the Muon optimizer with QK-clip support, along with a new MLA attention mechanism for Qwen2 models. The implementation looks solid, but there are a few critical areas for improvement. The use of environment variables to control model architecture is a major concern for reproducibility and should be replaced with explicit configuration options. The optimizer's dependency on hardcoded parameter names for QK-clipping is fragile and should be made more robust. Additionally, there are some minor issues with assertions and magic numbers that could be improved for better code quality and maintainability. The new test scripts are a great addition for verifying the functionality.
if os.environ.get("USE_MLA", None) == "1": | ||
logger.info("Use MLA attention.") | ||
config.head_dim = config.hidden_size // config.num_attention_heads // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using an environment variable os.environ.get("USE_MLA", None) == "1"
to control the model's architecture (e.g., config.head_dim
) is not a good practice. It makes the model's behavior dependent on an implicit external state, which can lead to reproducibility issues and makes debugging difficult. It would be much better to make this an explicit configuration option within Qwen2Config
, similar to how _attn_implementation
is handled. This same issue appears later in the file for selecting the attention class.
elif os.environ.get("USE_MLA", None) == "1": | ||
self.self_attn = Qwen2MLAAttention(config=config, layer_idx=layer_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in a previous comment, using an environment variable os.environ.get("USE_MLA", None) == "1"
to select the attention implementation is not robust. This should be controlled via the model's configuration, for example by adding a new _attn_implementation
type like "mla"
. This would make the model's configuration explicit and self-contained.
|
||
if len(shape) > 2: | ||
G = G.view(G.shape[0], -1) | ||
assert len(shape) == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion assert len(shape) == 2
is incorrect. The shape
variable holds the shape of the tensor G
before it is potentially reshaped inside the if len(shape) > 2:
block. The assertion will therefore fail for any tensor that has more than 2 dimensions, and also for 1D tensors. It seems the intent is to ensure that G
is a 2D tensor before proceeding. You should assert on the shape of G
after the reshape operation.
assert len(shape) == 2 | |
assert G.ndim == 2 |
for x in self.parameters: | ||
if x.name.endswith("q_b_proj.weight"): | ||
layer_idx = int(x.name.split(".")[2]) | ||
q_b_projs.append((layer_idx, x)) | ||
elif x.name.endswith("kv_b_proj.weight"): | ||
layer_idx = int(x.name.split(".")[2]) | ||
kv_b_projs.append((layer_idx, x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for identifying Q and KV projection parameters for QK-clipping relies on hardcoded parameter names (.endswith("q_b_proj.weight")
, .endswith("kv_b_proj.weight")
) and parsing the layer index from the name string. This is very fragile and tightly couples the optimizer to a specific model's naming convention. A change in the model architecture or parameter naming would break this. As noted in the TODO
on line 200, a more robust approach would be to pass these parameters explicitly to the optimizer, for instance, as an argument to the construct
method. This would decouple the optimizer from the model's internal structure.
self.q_lora_rank = config.intermediate_size // 14 | ||
self.qk_nope_head_dim = self.head_dim | ||
self.qk_rope_head_dim = self.head_dim // 2 | ||
self.v_head_dim = self.head_dim | ||
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim | ||
self.kv_lora_rank = config.hidden_size // 14 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if model_name == "qwen": | ||
tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", trust_remote_code=True) | ||
else: | ||
assert 0, f"model {model_name} not supported" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assert 0, "..."
for error handling is generally discouraged. It's better to raise a more specific and descriptive exception, such as ValueError
or NotImplementedError
. This provides clearer error messages and improves code quality. This pattern is repeated elsewhere in the file.
assert 0, f"model {model_name} not supported" | |
raise ValueError(f"model {model_name} not supported") |
if model_name == "qwen": | ||
tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", trust_remote_code=True) | ||
else: | ||
assert 0, f"model {model_name} not supported" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assert 0, "..."
for error handling is not ideal. It's better to raise a more specific exception like ValueError
or NotImplementedError
to provide more informative error messages. This pattern is repeated elsewhere in the file.
assert 0, f"model {model_name} not supported" | |
raise ValueError(f"model {model_name} not supported") |
This PR introduced Muon optimization with QK-clip, announced in Kimi-K2 paper.
This experiment is done based on a modified version of Qwen2, where the attention blocks are replaced by MLA blocks. We verified that the QK value is exploded using Muon w/o QK-clip, and the value is suppressed and becomes reasonable after imposing QK-clip.
NOTE: Qwen2 modification is only for experiment, should be fixed before code review.
What does this PR do?
Fixes # (issue)
Adds # (feature)
Before submitting
What's New
. Here are thedocumentation guidelines
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@xxx