Skip to content

Commit 872160e

Browse files
committed
Adding flashinfer fp8 FusedMoE
Signed-off-by: Amir Klein <[email protected]>
1 parent 1cbf951 commit 872160e

File tree

5 files changed

+250
-49
lines changed

5 files changed

+250
-49
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 99 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
TopKWeightAndReduceNoOP)
3030
from vllm.model_executor.layers.fused_moe.utils import (
3131
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
32+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
33+
calculate_tile_tokens_dim)
3234
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
3335
dequant_mxfp4)
3436
from vllm.platforms import current_platform
@@ -1061,22 +1063,6 @@ def inplace_fused_experts_fake(
10611063
)
10621064

10631065

1064-
def next_positive_power_of_2(x: int) -> int:
1065-
if x < 1:
1066-
return 1
1067-
return 1 << (x - 1).bit_length()
1068-
1069-
1070-
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
1071-
# Guess tokens per expert assuming perfect expert distribution first.
1072-
num_tokens_per_expert = (num_tokens * top_k) // num_experts
1073-
# And pad the number to the next power of 2.
1074-
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
1075-
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
1076-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
1077-
return tile_tokens_dim
1078-
1079-
10801066
def flashinfer_fused_moe_blockscale_fp8(
10811067
routing_logits: torch.Tensor,
10821068
routing_bias: torch.Tensor,
@@ -1124,8 +1110,8 @@ def flashinfer_fused_moe_blockscale_fp8(
11241110
local_expert_offset=expert_offset,
11251111
local_num_experts=local_num_experts,
11261112
routed_scaling_factor=routed_scaling,
1127-
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
1128-
global_num_experts),
1113+
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
1114+
global_num_experts),
11291115
routing_method_type=2, # DeepSeek-styled routing method
11301116
use_shuffled_weight=False,
11311117
)
@@ -1160,6 +1146,101 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
11601146
)
11611147

11621148

1149+
def flashinfer_fused_moe_per_tensor_scale_fp8(
1150+
routing_logits: torch.Tensor,
1151+
routing_bias: Optional[torch.Tensor],
1152+
hidden_states: torch.Tensor,
1153+
input_scale: torch.Tensor,
1154+
gemm1_weights: torch.Tensor,
1155+
gemm1_weights_scale: torch.Tensor,
1156+
activation_scale: torch.Tensor,
1157+
gemm2_weights: torch.Tensor,
1158+
gemm2_weights_scale: torch.Tensor,
1159+
num_experts: int,
1160+
top_k: int,
1161+
num_expert_group: Optional[int],
1162+
topk_group: Optional[int],
1163+
intermediate_size: int,
1164+
local_expert_offset: int,
1165+
local_num_experts: int,
1166+
use_routing_scales_on_input: bool,
1167+
routed_scaling_factor: float = 1.0,
1168+
routing_method_type: int = 3 # Llama4-styled routing method
1169+
) -> torch.Tensor:
1170+
if routing_bias is None:
1171+
routing_bias = torch.zeros(num_experts,
1172+
dtype=torch.bfloat16,
1173+
device=hidden_states.device)
1174+
num_expert_group = num_expert_group if num_expert_group is not None else 1
1175+
topk_group = topk_group if topk_group is not None else 0
1176+
1177+
quant_hidden_states, input_scale = moe_kernel_quantize_input(
1178+
hidden_states,
1179+
input_scale,
1180+
quant_dtype=torch.float8_e4m3fn,
1181+
per_act_token_quant=False)
1182+
1183+
output1_scales_scalar = gemm1_weights_scale * input_scale * (
1184+
1.0 / activation_scale)
1185+
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
1186+
output2_scales_scalar = activation_scale * gemm2_weights_scale
1187+
1188+
from vllm.utils.flashinfer import (
1189+
flashinfer_trtllm_fp8_per_tensor_scale_moe)
1190+
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
1191+
routing_logits=routing_logits,
1192+
routing_bias=routing_bias,
1193+
hidden_states=quant_hidden_states,
1194+
gemm1_weights=gemm1_weights,
1195+
output1_scales_scalar=output1_scales_scalar,
1196+
output1_scales_gate_scalar=output1_scales_gate_scalar,
1197+
gemm2_weights=gemm2_weights,
1198+
output2_scales_scalar=output2_scales_scalar,
1199+
num_experts=num_experts,
1200+
top_k=top_k,
1201+
n_group=num_expert_group,
1202+
topk_group=topk_group,
1203+
intermediate_size=intermediate_size,
1204+
local_expert_offset=local_expert_offset,
1205+
local_num_experts=local_num_experts,
1206+
routed_scaling_factor=routed_scaling_factor,
1207+
use_routing_scales_on_input=use_routing_scales_on_input,
1208+
tile_tokens_dim=8,
1209+
routing_method_type=routing_method_type)
1210+
1211+
1212+
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
1213+
routing_logits: torch.Tensor,
1214+
routing_bias: torch.Tensor,
1215+
hidden_states: torch.Tensor,
1216+
gemm1_weights: torch.Tensor,
1217+
output1_scales_scalar: torch.Tensor,
1218+
output1_scales_gate_scalar: torch.Tensor,
1219+
gemm2_weights: torch.Tensor,
1220+
output2_scales_scalar: torch.Tensor,
1221+
num_experts: int,
1222+
top_k: int,
1223+
num_expert_group: int,
1224+
topk_group: int,
1225+
intermediate_size: int,
1226+
local_expert_offset: int,
1227+
local_num_experts: int,
1228+
routed_scaling_factor: float = 1.0,
1229+
use_routing_scales_on_input: bool = False,
1230+
tile_tokens_dim: int = 8,
1231+
routing_method_type: int = 0) -> torch.Tensor:
1232+
pass
1233+
1234+
1235+
direct_register_custom_op(
1236+
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
1237+
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
1238+
mutates_args=["hidden_states"],
1239+
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
1240+
tags=(torch.Tag.needs_fixed_stride_order, ),
1241+
)
1242+
1243+
11631244
def outplace_fused_experts(
11641245
hidden_states: torch.Tensor,
11651246
w1: torch.Tensor,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from vllm.model_executor.layers.quantization.base_config import (
2424
QuantizationConfig, QuantizeMethodBase)
2525
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
27+
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
2628
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2729
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
2830
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@@ -53,11 +55,6 @@
5355
logger = init_logger(__name__)
5456

5557

56-
def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
57-
return x.reshape(-1, 2, x.shape[-2] // 2,
58-
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
59-
60-
6158
def _is_col_major(x: torch.Tensor) -> bool:
6259
assert x.dim() == 3
6360
b, m, n = x.shape
@@ -695,11 +692,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
695692
elif self.flashinfer_moe_enabled:
696693
# NOTE: weights have to be swapped since the activation is
697694
# applied on different half for flashinfer vs vllm
698-
w13_weight = _swap_w13_to_w31(layer.w13_weight.data)
699-
w13_weight_scale_inv = _swap_w13_to_w31(
695+
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
696+
w13_weight_scale_inv = swap_w13_to_w31(
700697
layer.w13_weight_scale_inv.data)
701698
w2_weight = layer.w2_weight.data
702699
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
700+
if not self.block_quant:
701+
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
703702
else:
704703
w13_weight = layer.w13_weight.data
705704
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
@@ -998,30 +997,53 @@ def apply(
998997
global_num_experts=global_num_experts,
999998
expert_map=expert_map)
1000999
elif self.flashinfer_moe_enabled:
1001-
# Currently only work with DS models
1002-
assert self.block_quant
1003-
assert (renormalize and use_grouped_topk
1004-
and scoring_func == 'sigmoid'
1005-
and custom_routing_function is None)
1006-
assert activation == "silu"
1007-
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1008-
routing_logits=router_logits.to(torch.float32),
1009-
routing_bias=e_score_correction_bias,
1010-
x=x,
1011-
w13_weight=layer.w13_weight,
1012-
w13_weight_scale_inv=layer.w13_weight_scale_inv,
1013-
w2_weight=layer.w2_weight,
1014-
w2_weight_scale_inv=layer.w2_weight_scale_inv,
1015-
global_num_experts=global_num_experts,
1016-
top_k=top_k,
1017-
num_expert_group=num_expert_group,
1018-
topk_group=topk_group,
1019-
intermediate_size=layer.intermediate_size_per_partition,
1020-
expert_offset=layer.ep_rank * layer.local_num_experts,
1021-
local_num_experts=layer.local_num_experts,
1022-
block_shape=self.quant_config.weight_block_size,
1023-
routed_scaling=1.0,
1024-
)
1000+
assert activation == 'silu'
1001+
assert scoring_func == 'sigmoid'
1002+
if self.block_quant:
1003+
assert (renormalize and use_grouped_topk
1004+
and custom_routing_function is None)
1005+
1006+
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1007+
routing_logits=router_logits.to(torch.float32),
1008+
routing_bias=e_score_correction_bias,
1009+
x=x,
1010+
w13_weight=layer.w13_weight,
1011+
w13_weight_scale_inv=layer.w13_weight_scale_inv,
1012+
w2_weight=layer.w2_weight,
1013+
w2_weight_scale_inv=layer.w2_weight_scale_inv,
1014+
global_num_experts=global_num_experts,
1015+
top_k=top_k,
1016+
num_expert_group=num_expert_group,
1017+
topk_group=topk_group,
1018+
intermediate_size=layer.intermediate_size_per_partition,
1019+
expert_offset=layer.ep_rank * layer.local_num_experts,
1020+
local_num_experts=layer.local_num_experts,
1021+
block_shape=self.quant_config.weight_block_size,
1022+
routed_scaling=1.0,
1023+
)
1024+
else:
1025+
assert (not renormalize
1026+
and custom_routing_function is not None)
1027+
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
1028+
routing_logits=router_logits,
1029+
routing_bias=e_score_correction_bias,
1030+
hidden_states=x,
1031+
input_scale=layer.w13_input_scale,
1032+
gemm1_weights=layer.w13_weight,
1033+
gemm1_weights_scale=layer.w13_weight_scale,
1034+
gemm2_weights=layer.w2_weight,
1035+
gemm2_weights_scale=layer.w2_weight_scale,
1036+
activation_scale=layer.w2_input_scale,
1037+
num_experts=global_num_experts,
1038+
top_k=top_k,
1039+
num_expert_group=num_expert_group,
1040+
topk_group=topk_group,
1041+
intermediate_size=layer.intermediate_size_per_partition,
1042+
local_expert_offset=layer.ep_rank *
1043+
layer.local_num_experts,
1044+
local_num_experts=layer.local_num_experts,
1045+
use_routing_scales_on_input=apply_router_weight_on_input,
1046+
)
10251047
else:
10261048
return self.fused_experts(
10271049
hidden_states=x,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from vllm.model_executor.layers.quantization.base_config import (
2525
QuantizationConfig, QuantizeMethodBase)
2626
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27+
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
28+
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
2729
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
2830
apply_fp4_marlin_linear, is_fp4_marlin_supported,
2931
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
@@ -35,6 +37,7 @@
3537
PerTensorScaleParameter)
3638
from vllm.platforms import current_platform
3739
from vllm.scalar_type import scalar_types
40+
from vllm.utils.flashinfer import has_flashinfer_moe
3841

3942
logger = init_logger(__name__)
4043

@@ -268,6 +271,11 @@ def __init__(self, quant_config: ModelOptFp8Config):
268271
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
269272
cutlass_fp8_supported)
270273
self.cutlass_fp8_supported = cutlass_fp8_supported()
274+
self.flashinfer_moe_enabled = False
275+
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
276+
logger.info_once(
277+
"Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
278+
self.flashinfer_moe_enabled = True
271279

272280
def create_weights(
273281
self,
@@ -411,6 +419,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
411419
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
412420
requires_grad=False)
413421

422+
if self.flashinfer_moe_enabled:
423+
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
424+
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
425+
layer.w2_weight)
426+
414427
def apply(
415428
self,
416429
layer: torch.nn.Module,
@@ -437,6 +450,29 @@ def apply(
437450
raise NotImplementedError(
438451
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
439452

453+
if self.flashinfer_moe_enabled:
454+
assert activation == 'silu'
455+
assert not renormalize
456+
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
457+
routing_logits=router_logits,
458+
routing_bias=e_score_correction_bias,
459+
hidden_states=x,
460+
input_scale=layer.w13_input_scale,
461+
gemm1_weights=layer.w13_weight,
462+
gemm1_weights_scale=layer.w13_weight_scale,
463+
gemm2_weights=layer.w2_weight,
464+
gemm2_weights_scale=layer.w2_weight_scale,
465+
activation_scale=layer.w2_input_scale,
466+
num_experts=global_num_experts,
467+
top_k=top_k,
468+
num_expert_group=num_expert_group,
469+
topk_group=topk_group,
470+
intermediate_size=layer.intermediate_size_per_partition,
471+
local_expert_offset=layer.ep_rank * layer.local_num_experts,
472+
local_num_experts=layer.local_num_experts,
473+
use_routing_scales_on_input=apply_router_weight_on_input,
474+
)
475+
440476
# Expert selection
441477
topk_weights, topk_ids = FusedMoE.select_experts(
442478
hidden_states=x,
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
5+
6+
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
7+
from flashinfer import next_positive_power_of_2
8+
9+
# Guess tokens per expert assuming perfect expert distribution first.
10+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
11+
# And pad the number to the next power of 2.
12+
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
13+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
14+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
15+
return tile_tokens_dim
16+
17+
18+
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
19+
return x.reshape(-1, 2, x.shape[-2] // 2,
20+
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
21+
22+
23+
def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
24+
gemm2_weights: torch.Tensor):
25+
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
26+
epilogue_tile_m = 128
27+
num_experts = gemm1_weights.shape[0]
28+
hidden_size = gemm1_weights.shape[-1]
29+
intermediate_size = gemm1_weights.shape[1] // 2
30+
31+
# Reorder rows of W1 for fused gated activation
32+
gemm1_weights_fp8_interleaved = []
33+
for i in range(num_experts):
34+
gemm1_weights_fp8_interleaved.append(
35+
reorder_rows_for_gated_act_gemm(gemm1_weights[i]))
36+
37+
# Stack weights and scales for all experts
38+
gemm1_weights_fp8_interleaved = torch.stack(
39+
gemm1_weights_fp8_interleaved).reshape(num_experts,
40+
2 * intermediate_size,
41+
hidden_size)
42+
43+
# Shuffle weights and scaling factors for transposed mma output
44+
gemm1_weights_fp8_shuffled = []
45+
gemm2_weights_fp8_shuffled = []
46+
for i in range(num_experts):
47+
gemm1_weights_fp8_shuffled.append(
48+
shuffle_matrix_a(
49+
gemm1_weights_fp8_interleaved[i].view(torch.uint8),
50+
epilogue_tile_m))
51+
52+
gemm2_weights_fp8_shuffled.append(
53+
shuffle_matrix_a(gemm2_weights[i].view(torch.uint8),
54+
epilogue_tile_m))
55+
56+
# Stack weights for all experts
57+
gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view(
58+
torch.float8_e4m3fn)
59+
gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view(
60+
torch.float8_e4m3fn)

0 commit comments

Comments
 (0)