Skip to content

Commit ae31f1d

Browse files
wenscarlxinli-sw
authored andcommitted
Support Tensorrt-LLM MoE fp4 for low-latency (vllm-project#21331)
Signed-off-by: Shu Wang <[email protected]> Signed-off-by: Po-Han Huang <[email protected]> Signed-off-by: Shu Wang. <[email protected]> Signed-off-by: XIn Li <[email protected]> Co-authored-by: XIn Li <[email protected]>
1 parent 0404d3c commit ae31f1d

File tree

7 files changed

+288
-43
lines changed

7 files changed

+288
-43
lines changed

vllm/envs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
130130
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
131131
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
132+
VLLM_FLASHINFER_MOE_BACKEND: str = "throughput"
132133
VLLM_XGRAMMAR_CACHE_MB: int = 0
133134
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
134135
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@@ -982,6 +983,20 @@ def get_vllm_port() -> Optional[int]:
982983
"VLLM_ALL2ALL_BACKEND":
983984
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
984985

986+
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both
987+
# require compute capability 10.0 or above.
988+
# Available options:
989+
# - "throughput": [default]
990+
# Uses CUTLASS kernels optimized for high-throughput batch inference.
991+
# - "latency":
992+
# Uses TensorRT-LLM kernels optimized for low-latency inference.
993+
# To set this backend, define the environment variable:
994+
# export VLLM_FLASHINFER_MOE_BACKEND=latency.
995+
# If not set, defaults to "throughput".
996+
"VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv(
997+
"VLLM_FLASHINFER_MOE_BACKEND", "throughput"
998+
),
999+
9851000
# Control the maximum number of tokens per expert supported by the
9861001
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
9871002
# the blockscale tensor of activations NVFP4 Quantization.

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def use_deepep_ll_kernels(self):
192192
@property
193193
def use_flashinfer_cutlass_kernels(self):
194194
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
195-
and has_flashinfer_cutlass_fused_moe())
195+
and has_flashinfer_cutlass_fused_moe()
196+
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
196197

197198
@staticmethod
198199
def make(tp_size_: int, dp_size_: int,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(self):
105105
detect_nvfp4_moe_support)
106106
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
107107
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
108-
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
108+
self.allow_flashinfer = _nvfp4.allow_flashinfer
109109
self.use_marlin = _nvfp4.use_marlin
110110
self.group_size = 16
111111
self.fused_experts = None # type: ignore[assignment]
@@ -212,7 +212,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
212212
requires_grad=False)
213213

214214
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
215-
if self.allow_flashinfer_cutlass:
215+
if self.allow_flashinfer:
216216
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
217217
layer.w13_weight_scale.data,
218218
dim=-2)
@@ -266,7 +266,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
266266
(layer.w2_input_global_scale), requires_grad=False)
267267

268268
def maybe_swap_experts_impl(self, moe_parallel_config):
269-
if not self.allow_flashinfer_cutlass:
269+
if not self.allow_flashinfer:
270270
return
271271
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
272272
moe_parallel_config)
@@ -277,8 +277,7 @@ def select_gemm_impl(self, prepare_finalize, moe):
277277
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
278278
select_nvfp4_gemm_impl)
279279

280-
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
281-
logger)
280+
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
282281

283282
def apply(
284283
self,

0 commit comments

Comments
 (0)