Skip to content

[NVIDIA] Explicitly disable shuffled weights for flashinfer blockscale moe fp8 kernels #21411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,7 @@ def flashinfer_fused_moe_blockscale_fp8(
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change explicitly adds the use_shuffled_weight argument, which was introduced in FlashInfer v0.2.9. This will cause a TypeError for users with older versions of FlashInfer, breaking backward compatibility.

To fix this, we should only pass the argument if the installed FlashInfer version supports it. This can be done with a version check, which requires refactoring the function call to use a kwargs dictionary.

Here is a suggested implementation to replace lines 1097-1131:

    from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe

    a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
    # NOTE: scales of hidden states have to be transposed!
    a_sf_t = a_sf.t().contiguous()

    kwargs = dict(
        routing_logits=routing_logits,
        routing_bias=routing_bias,
        hidden_states=a_q,
        hidden_states_scale=a_sf_t,
        gemm1_weights=w13_weight,
        gemm1_weights_scale=w13_weight_scale_inv,
        gemm2_weights=w2_weight,
        gemm2_weights_scale=w2_weight_scale_inv,
        num_experts=global_num_experts,
        top_k=top_k,
        n_group=num_expert_group,
        topk_group=topk_group,
        intermediate_size=intermediate_size,
        local_expert_offset=expert_offset,
        local_num_experts=local_num_experts,
        routed_scaling_factor=routed_scaling,
        tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
                                             global_num_experts),
        routing_method_type=2,  # DeepSeek-styled routing method
    )

    try:
        import flashinfer
        from packaging.version import Version
        # The use_shuffled_weight argument was added in flashinfer v0.2.9
        if Version(flashinfer.__version__) >= Version("0.2.9"):
            kwargs["use_shuffled_weight"] = False
    except (ImportError, AttributeError):
        # Older flashinfer version or flashinfer not installed.
        # The lazy loader will handle the ImportError later if it's missing.
        pass

    return flashinfer_trtllm_fp8_block_scale_moe(**kwargs)

)


Expand Down