Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 95 additions & 18 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -1065,22 +1067,6 @@ def inplace_fused_experts_fake(
)


def next_positive_power_of_2(x: int) -> int:
if x < 1:
return 1
return 1 << (x - 1).bit_length()


def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim


def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
Expand Down Expand Up @@ -1128,8 +1114,8 @@ def flashinfer_fused_moe_blockscale_fp8(
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
global_num_experts),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
Expand Down Expand Up @@ -1164,6 +1150,97 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
)


def flashinfer_fused_moe_per_tensor_scale_fp8(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
activation_scale: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
num_expert_group = num_expert_group if num_expert_group is not None else 0
topk_group = topk_group if topk_group is not None else 0

quant_hidden_states, input_scale = moe_kernel_quantize_input(
hidden_states,
input_scale,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False)

output1_scales_scalar = gemm1_weights_scale * input_scale * (
1.0 / activation_scale)
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
output2_scales_scalar = activation_scale * gemm2_weights_scale

from vllm.utils.flashinfer import (
flashinfer_trtllm_fp8_per_tensor_scale_moe)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=quant_hidden_states,
gemm1_weights=gemm1_weights,
output1_scales_scalar=output1_scales_scalar,
output1_scales_gate_scalar=output1_scales_gate_scalar,
gemm2_weights=gemm2_weights,
output2_scales_scalar=output2_scales_scalar,
num_experts=num_experts,
top_k=top_k,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
top_k, num_experts),
routing_method_type=routing_method_type)


def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
output1_scales_scalar: torch.Tensor,
output1_scales_gate_scalar: torch.Tensor,
gemm2_weights: torch.Tensor,
output2_scales_scalar: torch.Tensor,
num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float = 1.0,
use_routing_scales_on_input: bool = False,
tile_tokens_dim: int = 8,
routing_method_type: int = 0) -> torch.Tensor:
pass


direct_register_custom_op(
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
mutates_args=["hidden_states"],
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)


def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down
75 changes: 44 additions & 31 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
Expand Down Expand Up @@ -53,11 +56,6 @@
logger = init_logger(__name__)


def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
return x.reshape(-1, 2, x.shape[-2] // 2,
x.shape[-1]).flip(dims=[1]).reshape(x.shape)


def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
Expand Down Expand Up @@ -695,11 +693,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
elif self.flashinfer_moe_enabled:
# NOTE: weights have to be swapped since the activation is
# applied on different half for flashinfer vs vllm
w13_weight = _swap_w13_to_w31(layer.w13_weight.data)
w13_weight_scale_inv = _swap_w13_to_w31(
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
w13_weight_scale_inv = swap_w13_to_w31(
layer.w13_weight_scale_inv.data)
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
if not self.block_quant:
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
else:
w13_weight = layer.w13_weight.data
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
Expand Down Expand Up @@ -998,30 +998,43 @@ def apply(
global_num_experts=global_num_experts,
expert_map=expert_map)
elif self.flashinfer_moe_enabled:
# Currently only work with DS models
assert self.block_quant
assert (renormalize and use_grouped_topk
and scoring_func == 'sigmoid'
and custom_routing_function is None)
assert activation == "silu"
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.quant_config.weight_block_size,
routed_scaling=1.0,
)
assert activation == 'silu'
assert scoring_func == 'sigmoid'
if self.block_quant:
assert (renormalize and use_grouped_topk
and custom_routing_function is None)

return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.quant_config.weight_block_size,
routed_scaling=1.0,
)
else:
assert (not renormalize
and custom_routing_function is not None)
return apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
return self.fused_experts(
hidden_states=x,
Expand Down
28 changes: 28 additions & 0 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
Expand All @@ -34,6 +37,7 @@
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import has_flashinfer_moe

logger = init_logger(__name__)

Expand Down Expand Up @@ -267,6 +271,11 @@ def __init__(self, quant_config: ModelOptFp8Config):
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.flashinfer_moe_enabled = False
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
logger.info_once(
"Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
self.flashinfer_moe_enabled = True

def create_weights(
self,
Expand Down Expand Up @@ -410,6 +419,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
requires_grad=False)

if self.flashinfer_moe_enabled:
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
layer.w2_weight)

def apply(
self,
layer: torch.nn.Module,
Expand All @@ -436,6 +450,20 @@ def apply(
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")

if self.flashinfer_moe_enabled:
assert activation == 'silu'
assert not renormalize
return apply_flashinfer_per_tensor_scale_fp8(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)

# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand Down
Loading