Skip to content

Commit fde623f

Browse files
amirkl94mgoin
authored andcommitted
[NVIDIA] Add SM100 Flashinfer MoE per tensor scale fp8 backend (vllm-project#21458)
Signed-off-by: Amir Klein <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Noam Gat <[email protected]>
1 parent fa63275 commit fde623f

File tree

5 files changed

+269
-49
lines changed

5 files changed

+269
-49
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 95 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
TopKWeightAndReduceNoOP)
3131
from vllm.model_executor.layers.fused_moe.utils import (
3232
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
33+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
34+
calculate_tile_tokens_dim)
3335
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
3436
dequant_mxfp4)
3537
from vllm.platforms import current_platform
@@ -1065,22 +1067,6 @@ def inplace_fused_experts_fake(
10651067
)
10661068

10671069

1068-
def next_positive_power_of_2(x: int) -> int:
1069-
if x < 1:
1070-
return 1
1071-
return 1 << (x - 1).bit_length()
1072-
1073-
1074-
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
1075-
# Guess tokens per expert assuming perfect expert distribution first.
1076-
num_tokens_per_expert = (num_tokens * top_k) // num_experts
1077-
# And pad the number to the next power of 2.
1078-
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
1079-
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
1080-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
1081-
return tile_tokens_dim
1082-
1083-
10841070
def flashinfer_fused_moe_blockscale_fp8(
10851071
routing_logits: torch.Tensor,
10861072
routing_bias: torch.Tensor,
@@ -1128,8 +1114,8 @@ def flashinfer_fused_moe_blockscale_fp8(
11281114
local_expert_offset=expert_offset,
11291115
local_num_experts=local_num_experts,
11301116
routed_scaling_factor=routed_scaling,
1131-
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
1132-
global_num_experts),
1117+
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
1118+
global_num_experts),
11331119
routing_method_type=2, # DeepSeek-styled routing method
11341120
use_shuffled_weight=False,
11351121
)
@@ -1164,6 +1150,97 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
11641150
)
11651151

11661152

1153+
def flashinfer_fused_moe_per_tensor_scale_fp8(
1154+
routing_logits: torch.Tensor,
1155+
routing_bias: Optional[torch.Tensor],
1156+
hidden_states: torch.Tensor,
1157+
input_scale: torch.Tensor,
1158+
gemm1_weights: torch.Tensor,
1159+
gemm1_weights_scale: torch.Tensor,
1160+
activation_scale: torch.Tensor,
1161+
gemm2_weights: torch.Tensor,
1162+
gemm2_weights_scale: torch.Tensor,
1163+
num_experts: int,
1164+
top_k: int,
1165+
num_expert_group: Optional[int],
1166+
topk_group: Optional[int],
1167+
intermediate_size: int,
1168+
local_expert_offset: int,
1169+
local_num_experts: int,
1170+
use_routing_scales_on_input: bool,
1171+
routing_method_type: int,
1172+
routed_scaling_factor: float = 1.0) -> torch.Tensor:
1173+
num_expert_group = num_expert_group if num_expert_group is not None else 0
1174+
topk_group = topk_group if topk_group is not None else 0
1175+
1176+
quant_hidden_states, input_scale = moe_kernel_quantize_input(
1177+
hidden_states,
1178+
input_scale,
1179+
quant_dtype=torch.float8_e4m3fn,
1180+
per_act_token_quant=False)
1181+
1182+
output1_scales_scalar = gemm1_weights_scale * input_scale * (
1183+
1.0 / activation_scale)
1184+
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
1185+
output2_scales_scalar = activation_scale * gemm2_weights_scale
1186+
1187+
from vllm.utils.flashinfer import (
1188+
flashinfer_trtllm_fp8_per_tensor_scale_moe)
1189+
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
1190+
routing_logits=routing_logits,
1191+
routing_bias=routing_bias,
1192+
hidden_states=quant_hidden_states,
1193+
gemm1_weights=gemm1_weights,
1194+
output1_scales_scalar=output1_scales_scalar,
1195+
output1_scales_gate_scalar=output1_scales_gate_scalar,
1196+
gemm2_weights=gemm2_weights,
1197+
output2_scales_scalar=output2_scales_scalar,
1198+
num_experts=num_experts,
1199+
top_k=top_k,
1200+
n_group=num_expert_group,
1201+
topk_group=topk_group,
1202+
intermediate_size=intermediate_size,
1203+
local_expert_offset=local_expert_offset,
1204+
local_num_experts=local_num_experts,
1205+
routed_scaling_factor=routed_scaling_factor,
1206+
use_routing_scales_on_input=use_routing_scales_on_input,
1207+
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
1208+
top_k, num_experts),
1209+
routing_method_type=routing_method_type)
1210+
1211+
1212+
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
1213+
routing_logits: torch.Tensor,
1214+
routing_bias: torch.Tensor,
1215+
hidden_states: torch.Tensor,
1216+
gemm1_weights: torch.Tensor,
1217+
output1_scales_scalar: torch.Tensor,
1218+
output1_scales_gate_scalar: torch.Tensor,
1219+
gemm2_weights: torch.Tensor,
1220+
output2_scales_scalar: torch.Tensor,
1221+
num_experts: int,
1222+
top_k: int,
1223+
num_expert_group: int,
1224+
topk_group: int,
1225+
intermediate_size: int,
1226+
local_expert_offset: int,
1227+
local_num_experts: int,
1228+
routed_scaling_factor: float = 1.0,
1229+
use_routing_scales_on_input: bool = False,
1230+
tile_tokens_dim: int = 8,
1231+
routing_method_type: int = 0) -> torch.Tensor:
1232+
pass
1233+
1234+
1235+
direct_register_custom_op(
1236+
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
1237+
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
1238+
mutates_args=["hidden_states"],
1239+
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
1240+
tags=(torch.Tag.needs_fixed_stride_order, ),
1241+
)
1242+
1243+
11671244
def outplace_fused_experts(
11681245
hidden_states: torch.Tensor,
11691246
w1: torch.Tensor,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from vllm.model_executor.layers.quantization.base_config import (
2424
QuantizationConfig, QuantizeMethodBase)
2525
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
27+
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
28+
swap_w13_to_w31)
2629
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2730
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
2831
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@@ -53,11 +56,6 @@
5356
logger = init_logger(__name__)
5457

5558

56-
def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
57-
return x.reshape(-1, 2, x.shape[-2] // 2,
58-
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
59-
60-
6159
def _is_col_major(x: torch.Tensor) -> bool:
6260
assert x.dim() == 3
6361
b, m, n = x.shape
@@ -695,11 +693,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
695693
elif self.flashinfer_moe_enabled:
696694
# NOTE: weights have to be swapped since the activation is
697695
# applied on different half for flashinfer vs vllm
698-
w13_weight = _swap_w13_to_w31(layer.w13_weight.data)
699-
w13_weight_scale_inv = _swap_w13_to_w31(
696+
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
697+
w13_weight_scale_inv = swap_w13_to_w31(
700698
layer.w13_weight_scale_inv.data)
701699
w2_weight = layer.w2_weight.data
702700
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
701+
if not self.block_quant:
702+
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
703703
else:
704704
w13_weight = layer.w13_weight.data
705705
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
@@ -998,30 +998,43 @@ def apply(
998998
global_num_experts=global_num_experts,
999999
expert_map=expert_map)
10001000
elif self.flashinfer_moe_enabled:
1001-
# Currently only work with DS models
1002-
assert self.block_quant
1003-
assert (renormalize and use_grouped_topk
1004-
and scoring_func == 'sigmoid'
1005-
and custom_routing_function is None)
1006-
assert activation == "silu"
1007-
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1008-
routing_logits=router_logits.to(torch.float32),
1009-
routing_bias=e_score_correction_bias,
1010-
x=x,
1011-
w13_weight=layer.w13_weight,
1012-
w13_weight_scale_inv=layer.w13_weight_scale_inv,
1013-
w2_weight=layer.w2_weight,
1014-
w2_weight_scale_inv=layer.w2_weight_scale_inv,
1015-
global_num_experts=global_num_experts,
1016-
top_k=top_k,
1017-
num_expert_group=num_expert_group,
1018-
topk_group=topk_group,
1019-
intermediate_size=layer.intermediate_size_per_partition,
1020-
expert_offset=layer.ep_rank * layer.local_num_experts,
1021-
local_num_experts=layer.local_num_experts,
1022-
block_shape=self.quant_config.weight_block_size,
1023-
routed_scaling=1.0,
1024-
)
1001+
assert activation == 'silu'
1002+
assert scoring_func == 'sigmoid'
1003+
if self.block_quant:
1004+
assert (renormalize and use_grouped_topk
1005+
and custom_routing_function is None)
1006+
1007+
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1008+
routing_logits=router_logits.to(torch.float32),
1009+
routing_bias=e_score_correction_bias,
1010+
x=x,
1011+
w13_weight=layer.w13_weight,
1012+
w13_weight_scale_inv=layer.w13_weight_scale_inv,
1013+
w2_weight=layer.w2_weight,
1014+
w2_weight_scale_inv=layer.w2_weight_scale_inv,
1015+
global_num_experts=global_num_experts,
1016+
top_k=top_k,
1017+
num_expert_group=num_expert_group,
1018+
topk_group=topk_group,
1019+
intermediate_size=layer.intermediate_size_per_partition,
1020+
expert_offset=layer.ep_rank * layer.local_num_experts,
1021+
local_num_experts=layer.local_num_experts,
1022+
block_shape=self.quant_config.weight_block_size,
1023+
routed_scaling=1.0,
1024+
)
1025+
else:
1026+
assert (not renormalize
1027+
and custom_routing_function is not None)
1028+
return apply_flashinfer_per_tensor_scale_fp8(
1029+
layer=layer,
1030+
hidden_states=x,
1031+
router_logits=router_logits,
1032+
routing_bias=e_score_correction_bias,
1033+
global_num_experts=global_num_experts,
1034+
top_k=top_k,
1035+
num_expert_group=num_expert_group,
1036+
topk_group=topk_group,
1037+
apply_router_weight_on_input=apply_router_weight_on_input)
10251038
else:
10261039
return self.fused_experts(
10271040
hidden_states=x,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from vllm.model_executor.layers.quantization.base_config import (
2424
QuantizationConfig, QuantizeMethodBase)
2525
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
27+
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
28+
swap_w13_to_w31)
2629
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
2730
apply_fp4_marlin_linear, is_fp4_marlin_supported,
2831
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
@@ -34,6 +37,7 @@
3437
PerTensorScaleParameter)
3538
from vllm.platforms import current_platform
3639
from vllm.scalar_type import scalar_types
40+
from vllm.utils.flashinfer import has_flashinfer_moe
3741

3842
logger = init_logger(__name__)
3943

@@ -267,6 +271,11 @@ def __init__(self, quant_config: ModelOptFp8Config):
267271
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
268272
cutlass_fp8_supported)
269273
self.cutlass_fp8_supported = cutlass_fp8_supported()
274+
self.flashinfer_moe_enabled = False
275+
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
276+
logger.info_once(
277+
"Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
278+
self.flashinfer_moe_enabled = True
270279

271280
def create_weights(
272281
self,
@@ -410,6 +419,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
410419
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
411420
requires_grad=False)
412421

422+
if self.flashinfer_moe_enabled:
423+
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
424+
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
425+
layer.w2_weight)
426+
413427
def apply(
414428
self,
415429
layer: torch.nn.Module,
@@ -436,6 +450,20 @@ def apply(
436450
raise NotImplementedError(
437451
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
438452

453+
if self.flashinfer_moe_enabled:
454+
assert activation == 'silu'
455+
assert not renormalize
456+
return apply_flashinfer_per_tensor_scale_fp8(
457+
layer=layer,
458+
hidden_states=x,
459+
router_logits=router_logits,
460+
routing_bias=e_score_correction_bias,
461+
global_num_experts=global_num_experts,
462+
top_k=top_k,
463+
num_expert_group=num_expert_group,
464+
topk_group=topk_group,
465+
apply_router_weight_on_input=apply_router_weight_on_input)
466+
439467
# Expert selection
440468
topk_weights, topk_ids = FusedMoE.select_experts(
441469
hidden_states=x,

0 commit comments

Comments
 (0)