Skip to content

Update flashinfer CUTLASS MoE Kernel #21408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 24, 2025

Conversation

wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Jul 22, 2025

Essential Elements of an Effective PR Description Checklist

  • Update the block_scale_interleave API
  • Use per expert global scaling factor for gemm1 and gemm2.

Before the change:
INFO:lm_eval.loggers.evaluation_tracker:Output path not provided, skipping saving results aggregated
vllm (pretrained=nvidia/DeepSeek-R1-FP4,quantization=modelopt_fp4,tensor_parallel_size=4,enforce_eager=True,max_model_len=2048,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8734 ± 0.0092
strict-match 5 exact_match 0.8726 ± 0.0092

After:
INFO:lm_eval.loggers.evaluation_tracker:Output path not provided, skipping saving results aggregated
vllm (pretrained=nvidia/DeepSeek-R1-FP4,quantization=modelopt_fp4,tensor_parallel_size=4,enforce_eager=True,max_model_len=2048,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9409 ± 0.0065
strict-match 5 exact_match 0.9378 ± 0.0067

Purpose

Test Plan

Test Result

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the FlashInfer CUTLASS MoE kernel usage. The changes mostly involve renaming block_scale_interleave to nvfp4_block_scale_interleave to align with an API update. However, there's a potential issue in how activation scales are handled for FP4 MoE layers. A per-expert scale tensor is being incorrectly used to quantize the entire input activation tensor, which requires a single scalar scale. I've suggested a fix to restore the correct behavior.

Comment on lines +1257 to +1258
a1_gscale = layer.w13_input_scale_quant
a2_gscale = layer.w2_input_scale_quant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change from torch.min(layer.w13_input_scale_quant) to layer.w13_input_scale_quant may be incorrect. The a1_gscale variable is used in extra_prepare_args and passed to FlashInferCutlassMoEPrepareAndFinalize.prepare. This method uses the scale to quantize the input activation tensor (hidden_states) before the tokens are routed to the experts. The input activation tensor has a shape of (num_tokens, hidden_dim). To quantize this tensor, a single scalar scale is required. The original code, torch.min(layer.w13_input_scale_quant), correctly produced this scalar by selecting the most conservative scale among all per-expert scales. The new code passes layer.w13_input_scale_quant, which is a tensor of per-expert scales with shape (num_experts,). Using per-expert scales to quantize the entire activation tensor before expert routing is logically incorrect and may cause a runtime error or incorrect results from flashinfer.fp4_quantize. Please revert this change to use torch.min to ensure a correct scalar scale is used for the initial input quantization.

Suggested change
a1_gscale = layer.w13_input_scale_quant
a2_gscale = layer.w2_input_scale_quant
a1_gscale = torch.min(layer.w13_input_scale_quant)
a2_gscale = torch.min(layer.w2_input_scale_quant)

@wenscarl wenscarl marked this pull request as ready for review July 23, 2025 17:05
Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified locally that works, LGTM!

@alexm-redhat alexm-redhat enabled auto-merge (squash) July 23, 2025 17:08
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 23, 2025
auto-merge was automatically disabled July 23, 2025 17:13

Head branch was pushed to by a user without write access

@wenscarl wenscarl force-pushed the flashinfer_moe_update branch from 54ac27d to f0b0755 Compare July 23, 2025 17:13
@alexm-redhat alexm-redhat enabled auto-merge (squash) July 23, 2025 17:22
Signed-off-by: Shu Wang. <[email protected]>
auto-merge was automatically disabled July 23, 2025 17:29

Head branch was pushed to by a user without write access

@alexm-redhat alexm-redhat enabled auto-merge (squash) July 24, 2025 14:36
@vllm-bot vllm-bot merged commit 1b25f1f into vllm-project:main Jul 24, 2025
72 of 74 checks passed
@mgoin mgoin mentioned this pull request Jul 24, 2025
4 tasks
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
BoyuanFeng pushed a commit to BoyuanFeng/vllm that referenced this pull request Aug 14, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants