Skip to content

Commit 6582abc

Browse files
committed
CR Fixes
Signed-off-by: Amir Klein <[email protected]>
1 parent fdf635b commit 6582abc

File tree

4 files changed

+77
-56
lines changed

4 files changed

+77
-56
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,30 +1147,25 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
11471147

11481148

11491149
def flashinfer_fused_moe_per_tensor_scale_fp8(
1150-
routing_logits: torch.Tensor,
1151-
routing_bias: Optional[torch.Tensor],
1152-
hidden_states: torch.Tensor,
1153-
input_scale: torch.Tensor,
1154-
gemm1_weights: torch.Tensor,
1155-
gemm1_weights_scale: torch.Tensor,
1156-
activation_scale: torch.Tensor,
1157-
gemm2_weights: torch.Tensor,
1158-
gemm2_weights_scale: torch.Tensor,
1159-
num_experts: int,
1160-
top_k: int,
1161-
num_expert_group: Optional[int],
1162-
topk_group: Optional[int],
1163-
intermediate_size: int,
1164-
local_expert_offset: int,
1165-
local_num_experts: int,
1166-
use_routing_scales_on_input: bool,
1167-
routed_scaling_factor: float = 1.0,
1168-
routing_method_type: int = 3 # Llama4-styled routing method
1169-
) -> torch.Tensor:
1170-
if routing_bias is None:
1171-
routing_bias = torch.zeros(num_experts,
1172-
dtype=torch.bfloat16,
1173-
device=hidden_states.device)
1150+
routing_logits: torch.Tensor,
1151+
routing_bias: Optional[torch.Tensor],
1152+
hidden_states: torch.Tensor,
1153+
input_scale: torch.Tensor,
1154+
gemm1_weights: torch.Tensor,
1155+
gemm1_weights_scale: torch.Tensor,
1156+
activation_scale: torch.Tensor,
1157+
gemm2_weights: torch.Tensor,
1158+
gemm2_weights_scale: torch.Tensor,
1159+
num_experts: int,
1160+
top_k: int,
1161+
num_expert_group: Optional[int],
1162+
topk_group: Optional[int],
1163+
intermediate_size: int,
1164+
local_expert_offset: int,
1165+
local_num_experts: int,
1166+
use_routing_scales_on_input: bool,
1167+
routing_method_type: int,
1168+
routed_scaling_factor: float = 1.0) -> torch.Tensor:
11741169
num_expert_group = num_expert_group if num_expert_group is not None else 0
11751170
topk_group = topk_group if topk_group is not None else 0
11761171

@@ -1205,7 +1200,8 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
12051200
local_num_experts=local_num_experts,
12061201
routed_scaling_factor=routed_scaling_factor,
12071202
use_routing_scales_on_input=use_routing_scales_on_input,
1208-
tile_tokens_dim=8,
1203+
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
1204+
top_k, num_experts),
12091205
routing_method_type=routing_method_type)
12101206

12111207

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import torch
88
import torch.nn.functional as F
9+
from quantization.utils.flashinfer_utils import (
10+
apply_flashinfer_per_tensor_scale_fp8)
911
from torch.nn import Module
1012
from torch.nn.parameter import Parameter
1113

@@ -1024,26 +1026,16 @@ def apply(
10241026
else:
10251027
assert (not renormalize
10261028
and custom_routing_function is not None)
1027-
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
1028-
routing_logits=router_logits,
1029-
routing_bias=e_score_correction_bias,
1029+
return apply_flashinfer_per_tensor_scale_fp8(
1030+
layer=layer,
10301031
hidden_states=x,
1031-
input_scale=layer.w13_input_scale,
1032-
gemm1_weights=layer.w13_weight,
1033-
gemm1_weights_scale=layer.w13_weight_scale,
1034-
gemm2_weights=layer.w2_weight,
1035-
gemm2_weights_scale=layer.w2_weight_scale,
1036-
activation_scale=layer.w2_input_scale,
1037-
num_experts=global_num_experts,
1032+
router_logits=router_logits,
1033+
routing_bias=e_score_correction_bias,
1034+
global_num_experts=global_num_experts,
10381035
top_k=top_k,
10391036
num_expert_group=num_expert_group,
10401037
topk_group=topk_group,
1041-
intermediate_size=layer.intermediate_size_per_partition,
1042-
local_expert_offset=layer.ep_rank *
1043-
layer.local_num_experts,
1044-
local_num_experts=layer.local_num_experts,
1045-
use_routing_scales_on_input=apply_router_weight_on_input,
1046-
)
1038+
apply_router_weight_on_input=apply_router_weight_on_input)
10471039
else:
10481040
return self.fused_experts(
10491041
hidden_states=x,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import Any, Callable, Optional, Union
55

66
import torch
7+
from quantization.utils.flashinfer_utils import (
8+
apply_flashinfer_per_tensor_scale_fp8)
79
from torch.nn import Module
810
from torch.nn.parameter import Parameter
911

@@ -453,25 +455,16 @@ def apply(
453455
if self.flashinfer_moe_enabled:
454456
assert activation == 'silu'
455457
assert not renormalize
456-
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
457-
routing_logits=router_logits,
458-
routing_bias=e_score_correction_bias,
458+
return apply_flashinfer_per_tensor_scale_fp8(
459+
layer=layer,
459460
hidden_states=x,
460-
input_scale=layer.w13_input_scale,
461-
gemm1_weights=layer.w13_weight,
462-
gemm1_weights_scale=layer.w13_weight_scale,
463-
gemm2_weights=layer.w2_weight,
464-
gemm2_weights_scale=layer.w2_weight_scale,
465-
activation_scale=layer.w2_input_scale,
466-
num_experts=global_num_experts,
461+
router_logits=router_logits,
462+
routing_bias=e_score_correction_bias,
463+
global_num_experts=global_num_experts,
467464
top_k=top_k,
468465
num_expert_group=num_expert_group,
469466
topk_group=topk_group,
470-
intermediate_size=layer.intermediate_size_per_partition,
471-
local_expert_offset=layer.ep_rank * layer.local_num_experts,
472-
local_num_experts=layer.local_num_experts,
473-
use_routing_scales_on_input=apply_router_weight_on_input,
474-
)
467+
apply_router_weight_on_input=apply_router_weight_on_input)
475468

476469
# Expert selection
477470
topk_weights, topk_ids = FusedMoE.select_experts(

vllm/model_executor/layers/quantization/utils/flashinfer_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Optional
4+
35
import torch
46

57

@@ -58,3 +60,41 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
5860
torch.float8_e4m3fn)
5961
gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view(
6062
torch.float8_e4m3fn)
63+
64+
65+
def apply_flashinfer_per_tensor_scale_fp8(
66+
layer: torch.nn.Module,
67+
hidden_states: torch.Tensor,
68+
router_logits: torch.Tensor,
69+
routing_bias: Optional[torch.Tensor],
70+
top_k: int,
71+
num_expert_group: Optional[int],
72+
topk_group: Optional[int],
73+
global_num_experts: int,
74+
apply_router_weight_on_input: bool,
75+
) -> torch.Tensor:
76+
from flashinfer.fushed_moe import RoutingMethodType
77+
78+
from vllm.model_executor.models.llama4 import Llama4MoE
79+
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
80+
"FusedMoE flashinfer kernels are only supported for Llama4"
81+
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
82+
routing_logits=router_logits,
83+
routing_bias=routing_bias,
84+
hidden_states=hidden_states,
85+
input_scale=layer.w13_input_scale,
86+
gemm1_weights=layer.w13_weight,
87+
gemm1_weights_scale=layer.w13_weight_scale,
88+
gemm2_weights=layer.w2_weight,
89+
gemm2_weights_scale=layer.w2_weight_scale,
90+
activation_scale=layer.w2_input_scale,
91+
num_experts=global_num_experts,
92+
top_k=top_k,
93+
num_expert_group=num_expert_group,
94+
topk_group=topk_group,
95+
intermediate_size=layer.intermediate_size_per_partition,
96+
local_expert_offset=layer.ep_rank * layer.local_num_experts,
97+
local_num_experts=layer.local_num_experts,
98+
use_routing_scales_on_input=apply_router_weight_on_input,
99+
routing_method=RoutingMethodType.Llama4,
100+
)

0 commit comments

Comments
 (0)