-
-
Notifications
You must be signed in to change notification settings - Fork 9.7k
[NVIDIA] Add SM100 Flashinfer MoE per tensor scale fp8 backend #21458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new FlashInfer backend for per-tensor scaled FP8 Mixture of Experts (MoE), which shows promising performance improvements on SM100 architectures. The changes include adding a new custom operator, refactoring some utility functions into a shared module, and updating the quantization layers to use this new backend.
The code is generally well-structured, and the refactoring of utility functions into flashinfer_utils.py
is a good step towards better code organization.
However, there are a couple of areas that could be improved for better maintainability and potentially better performance:
- There is significant code duplication in the logic that invokes the new MoE kernel from both the
Fp8MoEMethod
andModelOptFp8MoEMethod
. This should be refactored into a shared helper function. - The
tile_tokens_dim
parameter for the new kernel is hardcoded, which might not be optimal for all workloads and differs from the dynamic approach used in the existing block-scale kernel.
Addressing these points will enhance the quality and robustness of the new backend.
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( | ||
routing_logits=router_logits, | ||
routing_bias=e_score_correction_bias, | ||
hidden_states=x, | ||
input_scale=layer.w13_input_scale, | ||
gemm1_weights=layer.w13_weight, | ||
gemm1_weights_scale=layer.w13_weight_scale, | ||
gemm2_weights=layer.w2_weight, | ||
gemm2_weights_scale=layer.w2_weight_scale, | ||
activation_scale=layer.w2_input_scale, | ||
num_experts=global_num_experts, | ||
top_k=top_k, | ||
num_expert_group=num_expert_group, | ||
topk_group=topk_group, | ||
intermediate_size=layer.intermediate_size_per_partition, | ||
local_expert_offset=layer.ep_rank * layer.local_num_experts, | ||
local_num_experts=layer.local_num_experts, | ||
use_routing_scales_on_input=apply_router_weight_on_input, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be significant code duplication here. The logic inside this if self.flashinfer_moe_enabled:
block is nearly identical to the logic in vllm/model_executor/layers/quantization/fp8.py
(lines 993-1016).
Duplicating this code block makes future maintenance harder, as changes would need to be applied in two places.
To improve maintainability, I suggest refactoring this shared logic into a common helper function. This function could be placed in a utility module, perhaps vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
, and called from both Fp8MoEMethod.apply
and ModelOptFp8MoEMethod.apply
.
For example, you could create a helper like this:
# In a shared utility file
def apply_flashinfer_per_tensor_scale_fp8(
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
e_score_correction_bias: Optional[torch.Tensor],
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
global_num_experts: int,
apply_router_weight_on_input: bool,
) -> torch.Tensor:
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=x,
input_scale=layer.w13_input_scale,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale,
activation_scale=layer.w2_input_scale,
num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input,
)
This would centralize the logic and make the code cleaner and easier to maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like this utility to be implemented to help with drift
routing_method_type: int = 3 # Llama4-styled routing method | ||
) -> torch.Tensor: | ||
if routing_bias is None: | ||
routing_bias = torch.zeros(num_experts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a little worried about this line breaking the cuda graph capture because we are creating new tensor on-the-fly. Should we create this zero-bias in the caller instead? Or maybe ask FlashInfer to support routing_bias=None so that we don't need to pass in fake bias.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a blocking issue for now. We will fix this later if we really see it becoming an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point, I think asking flashinfer to support routing_bias=None
is better probably
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlashInfer has fixed this in 0.2.9rc2. Do you think this is a blocker? If not, I prefer that we merge this PR first and then file another PR after we have upgraded to FlashInfer v0.2.9rc2.
However, if you think this is a blocker, we can wait until FlashINfer v0.2.9rc2 upgrade, which should happen very soon
local_num_experts: int, | ||
use_routing_scales_on_input: bool, | ||
routed_scaling_factor: float = 1.0, | ||
routing_method_type: int = 3 # Llama4-styled routing method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use RoutingMethodType.Llama4
instead of a hard-coded "3"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a blocking issue, just code style
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better but the issue with it is that if a different version of flashinfer is installed (or flashinfer isn't installed at all) we'll get an import error. I thought about doing this conversion inside the function after we know that the correct version of flashinfer is installed, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or we can define our class to mimic FlashInfer's class?
has_flashinfer = False
try:
import flashinfer
import flashinfer.fused_moe.RoutingMethodType
has_flashinfer = True
except ImportError:
pass
class FlashInferRoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = RoutingMethodType.Default if has_flashinfer else 0
# Renormalize: TopK -> Softmax
Renormalize = RoutingMethodType.Renormalize if has_flashinfer else 1
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the Top4 groups
DeepSeekV3 = RoutingMethodType.DeepSeekV3 if has_flashinfer else 2
# Llama4: Top1 -> Sigmoid
Llama4 = RoutingMethodType.Llama4 if has_flashinfer else 3
# Qwen3: Softmax -> TopK -> Renormalize
RenormalizeNaive = RoutingMethodType.RenormalizeNaive if has_flashinfer else 4
# Unspecified
Unspecified = RoutingMethodType.Unspecified if has_flashinfer else 5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not critical and can be handled in later PRs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can make this class lazy imported and remove the default arg, so we only need to import it once in the function
Depends on #21485 |
c3e365c
to
872160e
Compare
Signed-off-by: Amir Klein <[email protected]>
routing_bias = torch.zeros(num_experts, | ||
dtype=torch.bfloat16, | ||
device=hidden_states.device) | ||
num_expert_group = num_expert_group if num_expert_group is not None else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_expert_group = num_expert_group if num_expert_group is not None else 1
should set to 0 if num_expert_group
is None
872160e
to
fdf635b
Compare
local_num_experts: int, | ||
use_routing_scales_on_input: bool, | ||
routed_scaling_factor: float = 1.0, | ||
routing_method_type: int = 3 # Llama4-styled routing method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of defaulting this parameter if it is going to dictate model support. For instance in the current usage of this function this parameter isn't set, but there is no check that this is model needs llama 4 routing i.e. it would be silently incorrect for a Mixtral with the same quant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amirkl94 Maybe let's remove the default value for routing_method_type
and make this arg a required argument?
And from llama4.py
we should pass this into fused_moe.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama4 already does this by defining its own custom routing function and passing that into FusedMoE
vllm/vllm/model_executor/models/llama4.py
Line 79 in e18f085
custom_routing_function=Llama4MoE.custom_routing_function, |
I suppose you could just check if
custom_routing_function == Llama4MoE.custom_routing_function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I can't check custom_routing_function == Llama4MoE.custom_routing_function
, unless you meant in llama4.py
?
Should I just make this parameter optional and pass it only from llama4
and if it's not passed I'll default to the non-flashinfer implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what @mgoin meant is that in modelopt.py:
return apply_flashinfer_per_tensor_scale_fp8( |
the layer
object is just an instance of FusedMoE
, so you can dispatch routing_method
using:
if layer.routing_method == Llama4MoE.custom_routing_function:
routing_method = 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin is this what you meant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this was what I meant. Obviously not optimal, but should be okay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin Currently, FlashInfer's per-tensor FP8 MoE only supports Llama4 routing mode, so I told @amirkl94 to assert if layer.routing_method == Llama4MoE.custom_routing_function
is True. If it is not, an exception will be raised.
This is done such that in the future if anyone wants to use FlashInfer per-tensor FP8 MoE for another model, it will fail loudly telling the user why that is not supported. My philosophy is: a loud failure is better than a silent corruption.
Could you check if the current implementation is acceptable to you? Thanks!
local_num_experts: int, | ||
use_routing_scales_on_input: bool, | ||
routed_scaling_factor: float = 1.0, | ||
routing_method_type: int = 3 # Llama4-styled routing method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can make this class lazy imported and remove the default arg, so we only need to import it once in the function
Signed-off-by: Amir Klein <[email protected]>
185bdd6
to
6582abc
Compare
pipeline failure doesn't seem to be caused by this PR:
|
Signed-off-by: mgoin <[email protected]>
I fixed some issues with the PR and validated acc+performance. I see about 10% throughput improvement on gsm8k on 1xB200
Will do a final review now. |
The failure is:
Doesn't seem to be related to this PR |
@mgoin The CI errors seem to be unrelated to my PR as I saw they're happening on other branches as well - https://github.com/vllm-project/vllm/pull/21747/commits . |
Yes, this is what I've found too. I've requested force merge, thank you. |
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: shuw <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: x22x22 <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: x22x22 <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: jingyu <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Noam Gat <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Paul Pak <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Boyuan Feng <[email protected]>
…project#21458) Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
Purpose
This PR introduces a new backend for per-tensor scaled MoE from flashinfer. This backend gives a perf improvement as described below.
Accuracy tests
Ran manual
lm_eval gsm8k
, using the following command:Results:
Perf tests
Tested on a 1xB200 gpu, using latency benchmark:
Results: