From fdf635bb732521858584a5a98a1cdc0e527c646b Mon Sep 17 00:00:00 2001 From: Amir Klein <203507526+amirkl94@users.noreply.github.com> Date: Mon, 21 Jul 2025 10:27:39 +0300 Subject: [PATCH 1/3] Adding flashinfer fp8 FusedMoE Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- .../layers/fused_moe/fused_moe.py | 117 +++++++++++++++--- .../model_executor/layers/quantization/fp8.py | 84 ++++++++----- .../layers/quantization/modelopt.py | 36 ++++++ .../quantization/utils/flashinfer_utils.py | 60 +++++++++ vllm/utils/flashinfer.py | 2 + 5 files changed, 250 insertions(+), 49 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/utils/flashinfer_utils.py diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1985e8612da3..88bc9ef564d9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -29,6 +29,8 @@ TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + calculate_tile_tokens_dim) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4) from vllm.platforms import current_platform @@ -1061,22 +1063,6 @@ def inplace_fused_experts_fake( ) -def next_positive_power_of_2(x: int) -> int: - if x < 1: - return 1 - return 1 << (x - 1).bit_length() - - -def _get_tile_tokens_dim(num_tokens, top_k, num_experts): - # Guess tokens per expert assuming perfect expert distribution first. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # And pad the number to the next power of 2. - tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - return tile_tokens_dim - - def flashinfer_fused_moe_blockscale_fp8( routing_logits: torch.Tensor, routing_bias: torch.Tensor, @@ -1124,8 +1110,8 @@ def flashinfer_fused_moe_blockscale_fp8( local_expert_offset=expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling, - tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, - global_num_experts), + tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k, + global_num_experts), routing_method_type=2, # DeepSeek-styled routing method use_shuffled_weight=False, ) @@ -1160,6 +1146,101 @@ def flashinfer_fused_moe_blockscale_fp8_fake( ) +def flashinfer_fused_moe_per_tensor_scale_fp8( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + activation_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routed_scaling_factor: float = 1.0, + routing_method_type: int = 3 # Llama4-styled routing method +) -> torch.Tensor: + if routing_bias is None: + routing_bias = torch.zeros(num_experts, + dtype=torch.bfloat16, + device=hidden_states.device) + num_expert_group = num_expert_group if num_expert_group is not None else 0 + topk_group = topk_group if topk_group is not None else 0 + + quant_hidden_states, input_scale = moe_kernel_quantize_input( + hidden_states, + input_scale, + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False) + + output1_scales_scalar = gemm1_weights_scale * input_scale * ( + 1.0 / activation_scale) + output1_scales_gate_scalar = gemm1_weights_scale * input_scale + output2_scales_scalar = activation_scale * gemm2_weights_scale + + from vllm.utils.flashinfer import ( + flashinfer_trtllm_fp8_per_tensor_scale_moe) + return flashinfer_trtllm_fp8_per_tensor_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=quant_hidden_states, + gemm1_weights=gemm1_weights, + output1_scales_scalar=output1_scales_scalar, + output1_scales_gate_scalar=output1_scales_gate_scalar, + gemm2_weights=gemm2_weights, + output2_scales_scalar=output2_scales_scalar, + num_experts=num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=use_routing_scales_on_input, + tile_tokens_dim=8, + routing_method_type=routing_method_type) + + +def flashinfer_fused_moe_per_tensor_scale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + gemm2_weights: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: float = 1.0, + use_routing_scales_on_input: bool = False, + tile_tokens_dim: int = 8, + routing_method_type: int = 0) -> torch.Tensor: + pass + + +direct_register_custom_op( + op_name="flashinfer_fused_moe_per_tensor_scale_fp8", + op_func=flashinfer_fused_moe_per_tensor_scale_fp8, + mutates_args=["hidden_states"], + fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 75f8adf34f7d..c80145cca2ff 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -53,11 +55,6 @@ logger = init_logger(__name__) -def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: - return x.reshape(-1, 2, x.shape[-2] // 2, - x.shape[-1]).flip(dims=[1]).reshape(x.shape) - - def _is_col_major(x: torch.Tensor) -> bool: assert x.dim() == 3 b, m, n = x.shape @@ -695,11 +692,13 @@ def process_weights_after_loading(self, layer: Module) -> None: elif self.flashinfer_moe_enabled: # NOTE: weights have to be swapped since the activation is # applied on different half for flashinfer vs vllm - w13_weight = _swap_w13_to_w31(layer.w13_weight.data) - w13_weight_scale_inv = _swap_w13_to_w31( + w13_weight = swap_w13_to_w31(layer.w13_weight.data) + w13_weight_scale_inv = swap_w13_to_w31( layer.w13_weight_scale_inv.data) w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data + if not self.block_quant: + rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) else: w13_weight = layer.w13_weight.data w13_weight_scale_inv = layer.w13_weight_scale_inv.data @@ -998,30 +997,53 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) elif self.flashinfer_moe_enabled: - # Currently only work with DS models - assert self.block_quant - assert (renormalize and use_grouped_topk - and scoring_func == 'sigmoid' - and custom_routing_function is None) - assert activation == "silu" - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), - routing_bias=e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale_inv, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale_inv, - global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.quant_config.weight_block_size, - routed_scaling=1.0, - ) + assert activation == 'silu' + assert scoring_func == 'sigmoid' + if self.block_quant: + assert (renormalize and use_grouped_topk + and custom_routing_function is None) + + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32), + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale_inv, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale_inv, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.quant_config.weight_block_size, + routed_scaling=1.0, + ) + else: + assert (not renormalize + and custom_routing_function is not None) + return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( + routing_logits=router_logits, + routing_bias=e_score_correction_bias, + hidden_states=x, + input_scale=layer.w13_input_scale, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale, + activation_scale=layer.w2_input_scale, + num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * + layer.local_num_experts, + local_num_experts=layer.local_num_experts, + use_routing_scales_on_input=apply_router_weight_on_input, + ) else: return self.fused_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 81611ed07aaa..1017c2c9478c 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -24,6 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) @@ -35,6 +37,7 @@ PerTensorScaleParameter) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils.flashinfer import has_flashinfer_moe logger = init_logger(__name__) @@ -268,6 +271,11 @@ def __init__(self, quant_config: ModelOptFp8Config): from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported) self.cutlass_fp8_supported = cutlass_fp8_supported() + self.flashinfer_moe_enabled = False + if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + logger.info_once( + "Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.") + self.flashinfer_moe_enabled = True def create_weights( self, @@ -411,6 +419,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), requires_grad=False) + if self.flashinfer_moe_enabled: + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, + layer.w2_weight) + def apply( self, layer: torch.nn.Module, @@ -437,6 +450,29 @@ def apply( raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") + if self.flashinfer_moe_enabled: + assert activation == 'silu' + assert not renormalize + return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( + routing_logits=router_logits, + routing_bias=e_score_correction_bias, + hidden_states=x, + input_scale=layer.w13_input_scale, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale, + activation_scale=layer.w2_input_scale, + num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + use_routing_scales_on_input=apply_router_weight_on_input, + ) + # Expert selection topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py new file mode 100644 index 000000000000..fe434561982f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + + +def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): + from flashinfer import next_positive_power_of_2 + + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: + return x.reshape(-1, 2, x.shape[-2] // 2, + x.shape[-1]).flip(dims=[1]).reshape(x.shape) + + +def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor): + from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a + epilogue_tile_m = 128 + num_experts = gemm1_weights.shape[0] + hidden_size = gemm1_weights.shape[-1] + intermediate_size = gemm1_weights.shape[1] // 2 + + # Reorder rows of W1 for fused gated activation + gemm1_weights_fp8_interleaved = [] + for i in range(num_experts): + gemm1_weights_fp8_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_weights[i])) + + # Stack weights and scales for all experts + gemm1_weights_fp8_interleaved = torch.stack( + gemm1_weights_fp8_interleaved).reshape(num_experts, + 2 * intermediate_size, + hidden_size) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_fp8_shuffled = [] + gemm2_weights_fp8_shuffled = [] + for i in range(num_experts): + gemm1_weights_fp8_shuffled.append( + shuffle_matrix_a( + gemm1_weights_fp8_interleaved[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_fp8_shuffled.append( + shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), + epilogue_tile_m)) + + # Stack weights for all experts + gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view( + torch.float8_e4m3fn) + gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view( + torch.float8_e4m3fn) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index ebc54fd029da..3bfb9808c0a0 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -66,6 +66,8 @@ def wrapper(*args, **kwargs): # Create lazy wrappers for each function flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe") +flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe") flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", "cutlass_fused_moe") fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") From 6582abc6da94248aaf594fe3b237e5a1f79ea567 Mon Sep 17 00:00:00 2001 From: Amir Klein <203507526+amirkl94@users.noreply.github.com> Date: Tue, 29 Jul 2025 12:47:51 +0300 Subject: [PATCH 2/3] CR Fixes Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- .../layers/fused_moe/fused_moe.py | 46 +++++++++---------- .../model_executor/layers/quantization/fp8.py | 24 ++++------ .../layers/quantization/modelopt.py | 23 ++++------ .../quantization/utils/flashinfer_utils.py | 40 ++++++++++++++++ 4 files changed, 77 insertions(+), 56 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 88bc9ef564d9..dcc4c97b9c68 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1147,30 +1147,25 @@ def flashinfer_fused_moe_blockscale_fp8_fake( def flashinfer_fused_moe_per_tensor_scale_fp8( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - activation_scale: torch.Tensor, - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routed_scaling_factor: float = 1.0, - routing_method_type: int = 3 # Llama4-styled routing method -) -> torch.Tensor: - if routing_bias is None: - routing_bias = torch.zeros(num_experts, - dtype=torch.bfloat16, - device=hidden_states.device) + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + activation_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0) -> torch.Tensor: num_expert_group = num_expert_group if num_expert_group is not None else 0 topk_group = topk_group if topk_group is not None else 0 @@ -1205,7 +1200,8 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, use_routing_scales_on_input=use_routing_scales_on_input, - tile_tokens_dim=8, + tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0], + top_k, num_experts), routing_method_type=routing_method_type) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c80145cca2ff..3527f28f3625 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,6 +6,8 @@ import torch import torch.nn.functional as F +from quantization.utils.flashinfer_utils import ( + apply_flashinfer_per_tensor_scale_fp8) from torch.nn import Module from torch.nn.parameter import Parameter @@ -1024,26 +1026,16 @@ def apply( else: assert (not renormalize and custom_routing_function is not None) - return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( - routing_logits=router_logits, - routing_bias=e_score_correction_bias, + return apply_flashinfer_per_tensor_scale_fp8( + layer=layer, hidden_states=x, - input_scale=layer.w13_input_scale, - gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale, - gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale, - activation_scale=layer.w2_input_scale, - num_experts=global_num_experts, + router_logits=router_logits, + routing_bias=e_score_correction_bias, + global_num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * - layer.local_num_experts, - local_num_experts=layer.local_num_experts, - use_routing_scales_on_input=apply_router_weight_on_input, - ) + apply_router_weight_on_input=apply_router_weight_on_input) else: return self.fused_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1017c2c9478c..b5488a2a9fb4 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -4,6 +4,8 @@ from typing import Any, Callable, Optional, Union import torch +from quantization.utils.flashinfer_utils import ( + apply_flashinfer_per_tensor_scale_fp8) from torch.nn import Module from torch.nn.parameter import Parameter @@ -453,25 +455,16 @@ def apply( if self.flashinfer_moe_enabled: assert activation == 'silu' assert not renormalize - return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( - routing_logits=router_logits, - routing_bias=e_score_correction_bias, + return apply_flashinfer_per_tensor_scale_fp8( + layer=layer, hidden_states=x, - input_scale=layer.w13_input_scale, - gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale, - gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale, - activation_scale=layer.w2_input_scale, - num_experts=global_num_experts, + router_logits=router_logits, + routing_bias=e_score_correction_bias, + global_num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - use_routing_scales_on_input=apply_router_weight_on_input, - ) + apply_router_weight_on_input=apply_router_weight_on_input) # Expert selection topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index fe434561982f..98240d858233 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import torch @@ -58,3 +60,41 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor, torch.float8_e4m3fn) gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view( torch.float8_e4m3fn) + + +def apply_flashinfer_per_tensor_scale_fp8( + layer: torch.nn.Module, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + global_num_experts: int, + apply_router_weight_on_input: bool, +) -> torch.Tensor: + from flashinfer.fushed_moe import RoutingMethodType + + from vllm.model_executor.models.llama4 import Llama4MoE + assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ + "FusedMoE flashinfer kernels are only supported for Llama4" + return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( + routing_logits=router_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + input_scale=layer.w13_input_scale, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale, + activation_scale=layer.w2_input_scale, + num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + use_routing_scales_on_input=apply_router_weight_on_input, + routing_method=RoutingMethodType.Llama4, + ) \ No newline at end of file From 740627f7119711bee543e80713e84b5a2df17039 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 30 Jul 2025 15:08:03 -0400 Subject: [PATCH 3/3] Fixes to imports and names Signed-off-by: mgoin --- vllm/model_executor/layers/quantization/fp8.py | 5 ++--- vllm/model_executor/layers/quantization/modelopt.py | 5 ++--- .../layers/quantization/utils/flashinfer_utils.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3527f28f3625..8b6ed154bdbe 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,8 +6,6 @@ import torch import torch.nn.functional as F -from quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8) from torch.nn import Module from torch.nn.parameter import Parameter @@ -26,7 +24,8 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) + apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index b5488a2a9fb4..e3f4c59a84cd 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -4,8 +4,6 @@ from typing import Any, Callable, Optional, Union import torch -from quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8) from torch.nn import Module from torch.nn.parameter import Parameter @@ -27,7 +25,8 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) + apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 98240d858233..c6f914febc0a 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -73,7 +73,7 @@ def apply_flashinfer_per_tensor_scale_fp8( global_num_experts: int, apply_router_weight_on_input: bool, ) -> torch.Tensor: - from flashinfer.fushed_moe import RoutingMethodType + from flashinfer.fused_moe import RoutingMethodType from vllm.model_executor.models.llama4 import Llama4MoE assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ @@ -96,5 +96,5 @@ def apply_flashinfer_per_tensor_scale_fp8( local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, use_routing_scales_on_input=apply_router_weight_on_input, - routing_method=RoutingMethodType.Llama4, + routing_method_type=RoutingMethodType.Llama4, ) \ No newline at end of file