-
-
Notifications
You must be signed in to change notification settings - Fork 9.6k
[NVIDIA] Explicitly disable shuffled weights for flashinfer blockscale moe fp8 kernels #21411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: kaixih <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to restore previous behavior for FlashInfer MoE kernels by explicitly disabling a new use_shuffled_weight
flag. While the change is correct in its intent, it introduces a critical backward compatibility issue for users with older versions of FlashInfer. I've provided a comment with a suggested fix to address this.
@@ -1127,6 +1127,7 @@ def flashinfer_fused_moe_blockscale_fp8( | |||
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, | |||
global_num_experts), | |||
routing_method_type=2, # DeepSeek-styled routing method | |||
use_shuffled_weight=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change explicitly adds the use_shuffled_weight
argument, which was introduced in FlashInfer v0.2.9. This will cause a TypeError
for users with older versions of FlashInfer, breaking backward compatibility.
To fix this, we should only pass the argument if the installed FlashInfer version supports it. This can be done with a version check, which requires refactoring the function call to use a kwargs
dictionary.
Here is a suggested implementation to replace lines 1097-1131:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
kwargs = dict(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=w13_weight,
gemm1_weights_scale=w13_weight_scale_inv,
gemm2_weights=w2_weight,
gemm2_weights_scale=w2_weight_scale_inv,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
)
try:
import flashinfer
from packaging.version import Version
# The use_shuffled_weight argument was added in flashinfer v0.2.9
if Version(flashinfer.__version__) >= Version("0.2.9"):
kwargs["use_shuffled_weight"] = False
except (ImportError, AttributeError):
# Older flashinfer version or flashinfer not installed.
# The lazy loader will handle the ImportError later if it's missing.
pass
return flashinfer_trtllm_fp8_block_scale_moe(**kwargs)
Thanks Kaixi, should we wait for the 0.2.9 release to land this? |
@mgoin Sure. |
@mgoin can we merge this PR since the flashinfer 0.2.9rc1 is in. |
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]> Signed-off-by: shuw <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]> Signed-off-by: x22x22 <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]> Signed-off-by: Paul Pak <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]> Signed-off-by: Boyuan Feng <[email protected]>
…e moe fp8 kernels (vllm-project#21411) Signed-off-by: kaixih <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
The latest Flashinfer (PR) introduces a new flag to the
trtllm_fp8_block_scale_moe
API, which defaults toTrue
. This PR explicitly disables it to restore the previous behavior.I have verified the perf and accuracy with the tot and we recommend to use flashinfer v0.2.9.
cc. @kushanam @mgoin