-
-
Notifications
You must be signed in to change notification settings - Fork 9.8k
fix the mxfp4 packed qk weight loading issue for llama4 #21722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix the mxfp4 packed qk weight loading issue for llama4 #21722
Conversation
Signed-off-by: xuebwang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR fixes a weight loading issue for llama4 with mxfp4 packed weights. The change correctly adjusts the output dimension for packed tensors. My review includes a suggestion to make the check for packed weights more robust by using the quantization configuration, which will prevent potential issues with other 8-bit quantization schemes.
if w.dtype in [torch.uint8, torch.int8]: | ||
attn_out = attn_out // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this fix is likely correct for the intended mxfp4
case, relying solely on w.dtype
is a bit fragile. It assumes that any uint8
or int8
weights encountered here are packed 4-bit tensors.
To make this more robust and prevent potential issues if other 8-bit quantization schemes are used in the future, I suggest explicitly checking for the mxfp4
quantization configuration. This makes the intention clearer and safer.
I'm assuming the quantization name is 'mxfp4'
based on the PR title. Please verify and adjust if needed.
if (self.quant_config and self.quant_config.get_name() == "mxfp4"
and w.dtype in [torch.uint8, torch.int8]):
attn_out = attn_out // 2
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
this should be fixed in my #21499 |
if w.dtype in [torch.uint8, torch.int8]: | ||
attn_out = attn_out // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with gemini, this is not very robust.
e.g. wouldn't
Lines 57 to 63 in 9ace2ea
weight = ModelWeightParameter(data=torch.empty( | |
sum(output_partition_sizes), | |
input_size_per_partition, | |
dtype=torch.int8), | |
input_dim=1, | |
output_dim=0, | |
weight_loader=weight_loader) |
vllm/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
Lines 54 to 57 in 9ace2ea
weight = ModelWeightParameter(data=torch.empty( | |
sum(output_partition_sizes), | |
input_size_per_partition, | |
dtype=torch.int8), |
shouldn't there be a pack factor defined for all quantization methods, being used here?
This pull request has merge conflicts that must be resolved before it can be |
this should already by fixed by #21499 |
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
Error:
Model configuration:
amd-quark
.language_model.model.layers.*.self_attn.q_proj
andlanguage_model.model.layers.*.self_attn.k_proj
.Root cause analysis:
For MXFP4 quantization, the exported weights are packed into uint8 in column wise compression, i.e., from
[M, N]
to[M, N/2]
. Therefore, the weightw
is in shape of[attn_in, attn_out//2]
.On the other hand, the permutation for qk weight is doing something like swap adjacent rows. The data order of the columns given a row would not be broken.
Therefore, one can safely and simply half the attn_out dimension as
attn_out = attn_out // 2
.Test Plan
Test Result
(Optional) Documentation Update