Skip to content

[NVIDIA] Explicitly disable shuffled weights for flashinfer blockscale moe fp8 kernels #21411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2025

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Jul 22, 2025

The latest Flashinfer (PR) introduces a new flag to the trtllm_fp8_block_scale_moe API, which defaults to True. This PR explicitly disables it to restore the previous behavior.

I have verified the perf and accuracy with the tot and we recommend to use flashinfer v0.2.9.

cc. @kushanam @mgoin

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to restore previous behavior for FlashInfer MoE kernels by explicitly disabling a new use_shuffled_weight flag. While the change is correct in its intent, it introduces a critical backward compatibility issue for users with older versions of FlashInfer. I've provided a comment with a suggested fix to address this.

@@ -1127,6 +1127,7 @@ def flashinfer_fused_moe_blockscale_fp8(
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change explicitly adds the use_shuffled_weight argument, which was introduced in FlashInfer v0.2.9. This will cause a TypeError for users with older versions of FlashInfer, breaking backward compatibility.

To fix this, we should only pass the argument if the installed FlashInfer version supports it. This can be done with a version check, which requires refactoring the function call to use a kwargs dictionary.

Here is a suggested implementation to replace lines 1097-1131:

    from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe

    a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
    # NOTE: scales of hidden states have to be transposed!
    a_sf_t = a_sf.t().contiguous()

    kwargs = dict(
        routing_logits=routing_logits,
        routing_bias=routing_bias,
        hidden_states=a_q,
        hidden_states_scale=a_sf_t,
        gemm1_weights=w13_weight,
        gemm1_weights_scale=w13_weight_scale_inv,
        gemm2_weights=w2_weight,
        gemm2_weights_scale=w2_weight_scale_inv,
        num_experts=global_num_experts,
        top_k=top_k,
        n_group=num_expert_group,
        topk_group=topk_group,
        intermediate_size=intermediate_size,
        local_expert_offset=expert_offset,
        local_num_experts=local_num_experts,
        routed_scaling_factor=routed_scaling,
        tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
                                             global_num_experts),
        routing_method_type=2,  # DeepSeek-styled routing method
    )

    try:
        import flashinfer
        from packaging.version import Version
        # The use_shuffled_weight argument was added in flashinfer v0.2.9
        if Version(flashinfer.__version__) >= Version("0.2.9"):
            kwargs["use_shuffled_weight"] = False
    except (ImportError, AttributeError):
        # Older flashinfer version or flashinfer not installed.
        # The lazy loader will handle the ImportError later if it's missing.
        pass

    return flashinfer_trtllm_fp8_block_scale_moe(**kwargs)

@mgoin
Copy link
Member

mgoin commented Jul 22, 2025

Thanks Kaixi, should we wait for the 0.2.9 release to land this?

@kaixih
Copy link
Contributor Author

kaixih commented Jul 23, 2025

@mgoin Sure.

@kaixih
Copy link
Contributor Author

kaixih commented Jul 24, 2025

@mgoin the flashinfer has released the 0.2.9rc1. I took a quick look at the vLLM codebase and noticed that only the Dockerfile explicitly references the Flashinfer version. I'm not sure how vLLM determines or enforces the Flashinfer version elsewhere. Can you advise?

@mgoin
Copy link
Member

mgoin commented Jul 24, 2025

@kaixih Let's wait on landing this PR until after #21485 then, as it already updates the dockerfile. We will enforce the version in the future by adding it to requirements/cuda.txt as well

@kaixih
Copy link
Contributor Author

kaixih commented Jul 25, 2025

@mgoin can we merge this PR since the flashinfer 0.2.9rc1 is in.

@mgoin mgoin enabled auto-merge (squash) July 25, 2025 22:46
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 25, 2025
@vllm-bot vllm-bot merged commit de509ae into vllm-project:main Jul 26, 2025
70 of 73 checks passed
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
HsChen-sys pushed a commit to HsChen-sys/vllm that referenced this pull request Aug 1, 2025
wenscarl pushed a commit to wenscarl/vllm that referenced this pull request Aug 4, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
BoyuanFeng pushed a commit to BoyuanFeng/vllm that referenced this pull request Aug 14, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants