diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 227aacf25c0b..b69575c7e96d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -30,6 +30,8 @@ TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + calculate_tile_tokens_dim) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4) from vllm.platforms import current_platform @@ -1065,22 +1067,6 @@ def inplace_fused_experts_fake( ) -def next_positive_power_of_2(x: int) -> int: - if x < 1: - return 1 - return 1 << (x - 1).bit_length() - - -def _get_tile_tokens_dim(num_tokens, top_k, num_experts): - # Guess tokens per expert assuming perfect expert distribution first. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # And pad the number to the next power of 2. - tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - return tile_tokens_dim - - def flashinfer_fused_moe_blockscale_fp8( routing_logits: torch.Tensor, routing_bias: torch.Tensor, @@ -1128,8 +1114,8 @@ def flashinfer_fused_moe_blockscale_fp8( local_expert_offset=expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling, - tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, - global_num_experts), + tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k, + global_num_experts), routing_method_type=2, # DeepSeek-styled routing method use_shuffled_weight=False, ) @@ -1164,6 +1150,97 @@ def flashinfer_fused_moe_blockscale_fp8_fake( ) +def flashinfer_fused_moe_per_tensor_scale_fp8( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + activation_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0) -> torch.Tensor: + num_expert_group = num_expert_group if num_expert_group is not None else 0 + topk_group = topk_group if topk_group is not None else 0 + + quant_hidden_states, input_scale = moe_kernel_quantize_input( + hidden_states, + input_scale, + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False) + + output1_scales_scalar = gemm1_weights_scale * input_scale * ( + 1.0 / activation_scale) + output1_scales_gate_scalar = gemm1_weights_scale * input_scale + output2_scales_scalar = activation_scale * gemm2_weights_scale + + from vllm.utils.flashinfer import ( + flashinfer_trtllm_fp8_per_tensor_scale_moe) + return flashinfer_trtllm_fp8_per_tensor_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=quant_hidden_states, + gemm1_weights=gemm1_weights, + output1_scales_scalar=output1_scales_scalar, + output1_scales_gate_scalar=output1_scales_gate_scalar, + gemm2_weights=gemm2_weights, + output2_scales_scalar=output2_scales_scalar, + num_experts=num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=use_routing_scales_on_input, + tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0], + top_k, num_experts), + routing_method_type=routing_method_type) + + +def flashinfer_fused_moe_per_tensor_scale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + gemm2_weights: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: float = 1.0, + use_routing_scales_on_input: bool = False, + tile_tokens_dim: int = 8, + routing_method_type: int = 0) -> torch.Tensor: + pass + + +direct_register_custom_op( + op_name="flashinfer_fused_moe_per_tensor_scale_fp8", + op_func=flashinfer_fused_moe_per_tensor_scale_fp8, + mutates_args=["hidden_states"], + fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 75f8adf34f7d..8b6ed154bdbe 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -53,11 +56,6 @@ logger = init_logger(__name__) -def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: - return x.reshape(-1, 2, x.shape[-2] // 2, - x.shape[-1]).flip(dims=[1]).reshape(x.shape) - - def _is_col_major(x: torch.Tensor) -> bool: assert x.dim() == 3 b, m, n = x.shape @@ -695,11 +693,13 @@ def process_weights_after_loading(self, layer: Module) -> None: elif self.flashinfer_moe_enabled: # NOTE: weights have to be swapped since the activation is # applied on different half for flashinfer vs vllm - w13_weight = _swap_w13_to_w31(layer.w13_weight.data) - w13_weight_scale_inv = _swap_w13_to_w31( + w13_weight = swap_w13_to_w31(layer.w13_weight.data) + w13_weight_scale_inv = swap_w13_to_w31( layer.w13_weight_scale_inv.data) w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data + if not self.block_quant: + rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) else: w13_weight = layer.w13_weight.data w13_weight_scale_inv = layer.w13_weight_scale_inv.data @@ -998,30 +998,43 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) elif self.flashinfer_moe_enabled: - # Currently only work with DS models - assert self.block_quant - assert (renormalize and use_grouped_topk - and scoring_func == 'sigmoid' - and custom_routing_function is None) - assert activation == "silu" - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), - routing_bias=e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale_inv, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale_inv, - global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.quant_config.weight_block_size, - routed_scaling=1.0, - ) + assert activation == 'silu' + assert scoring_func == 'sigmoid' + if self.block_quant: + assert (renormalize and use_grouped_topk + and custom_routing_function is None) + + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32), + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale_inv, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale_inv, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.quant_config.weight_block_size, + routed_scaling=1.0, + ) + else: + assert (not renormalize + and custom_routing_function is not None) + return apply_flashinfer_per_tensor_scale_fp8( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + apply_router_weight_on_input=apply_router_weight_on_input) else: return self.fused_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 8fbc3231d86c..b8ffcf90c022 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) @@ -34,6 +37,7 @@ PerTensorScaleParameter) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils.flashinfer import has_flashinfer_moe logger = init_logger(__name__) @@ -267,6 +271,11 @@ def __init__(self, quant_config: ModelOptFp8Config): from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported) self.cutlass_fp8_supported = cutlass_fp8_supported() + self.flashinfer_moe_enabled = False + if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + logger.info_once( + "Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.") + self.flashinfer_moe_enabled = True def create_weights( self, @@ -410,6 +419,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), requires_grad=False) + if self.flashinfer_moe_enabled: + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) + def apply( self, layer: torch.nn.Module, @@ -436,6 +450,20 @@ def apply( raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") + if self.flashinfer_moe_enabled: + assert activation == 'silu' + assert not renormalize + return apply_flashinfer_per_tensor_scale_fp8( + layer=layer, + hidden_states=x, + router_logits=router_logits, + routing_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + apply_router_weight_on_input=apply_router_weight_on_input) + # Expert selection topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py new file mode 100644 index 000000000000..c6f914febc0a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + + +def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): + from flashinfer import next_positive_power_of_2 + + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: + return x.reshape(-1, 2, x.shape[-2] // 2, + x.shape[-1]).flip(dims=[1]).reshape(x.shape) + + +def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor): + from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a + epilogue_tile_m = 128 + num_experts = gemm1_weights.shape[0] + hidden_size = gemm1_weights.shape[-1] + intermediate_size = gemm1_weights.shape[1] // 2 + + # Reorder rows of W1 for fused gated activation + gemm1_weights_fp8_interleaved = [] + for i in range(num_experts): + gemm1_weights_fp8_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_weights[i])) + + # Stack weights and scales for all experts + gemm1_weights_fp8_interleaved = torch.stack( + gemm1_weights_fp8_interleaved).reshape(num_experts, + 2 * intermediate_size, + hidden_size) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_fp8_shuffled = [] + gemm2_weights_fp8_shuffled = [] + for i in range(num_experts): + gemm1_weights_fp8_shuffled.append( + shuffle_matrix_a( + gemm1_weights_fp8_interleaved[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_fp8_shuffled.append( + shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), + epilogue_tile_m)) + + # Stack weights for all experts + gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view( + torch.float8_e4m3fn) + gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view( + torch.float8_e4m3fn) + + +def apply_flashinfer_per_tensor_scale_fp8( + layer: torch.nn.Module, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + global_num_experts: int, + apply_router_weight_on_input: bool, +) -> torch.Tensor: + from flashinfer.fused_moe import RoutingMethodType + + from vllm.model_executor.models.llama4 import Llama4MoE + assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ + "FusedMoE flashinfer kernels are only supported for Llama4" + return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( + routing_logits=router_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + input_scale=layer.w13_input_scale, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale, + activation_scale=layer.w2_input_scale, + num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + use_routing_scales_on_input=apply_router_weight_on_input, + routing_method_type=RoutingMethodType.Llama4, + ) \ No newline at end of file diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index ebc54fd029da..3bfb9808c0a0 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -66,6 +66,8 @@ def wrapper(*args, **kwargs): # Create lazy wrappers for each function flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe") +flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe") flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", "cutlass_fused_moe") fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")