Skip to content

Commit 54ac27d

Browse files
committed
init
1 parent ae268b6 commit 54ac27d

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
1212
from vllm.model_executor.layers.fused_moe.utils import (
1313
extract_required_args, moe_kernel_quantize_input)
14-
from vllm.utils.flashinfer import block_scale_interleave
14+
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
1515

1616

1717
def get_local_sizes(local_tokens):
@@ -92,7 +92,7 @@ def prepare(
9292
dim=0,
9393
sizes=get_local_sizes(local_tokens))
9494
a1_m, a1_n = a1q.shape
95-
a1q_scale = block_scale_interleave(a1q_scale)
95+
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
9696

9797
return a1q, a1q_scale, None, topk_ids, topk_weights
9898

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,8 +1254,8 @@ def apply(
12541254
x, layer.w13_weight, layer.w2_weight), (
12551255
"Flashinfer CUTLASS Fused MoE not applicable!")
12561256

1257-
a1_gscale = torch.min(layer.w13_input_scale_quant)
1258-
a2_gscale = torch.min(layer.w2_input_scale_quant)
1257+
a1_gscale = layer.w13_input_scale_quant
1258+
a2_gscale = layer.w2_input_scale_quant
12591259
extra_expert_args = {
12601260
'g1_alphas': layer.g1_alphas,
12611261
'g2_alphas': layer.g2_alphas,

vllm/utils/flashinfer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def wrapper(*args, **kwargs):
6969
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
7070
"cutlass_fused_moe")
7171
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
72-
block_scale_interleave = _lazy_import_wrapper("flashinfer",
73-
"block_scale_interleave")
72+
nvfp4_block_scale_interleave = _lazy_import_wrapper("flashinfer",
73+
"nvfp4_block_scale_interleave")
7474

7575
# Special case for autotune since it returns a context manager
7676
autotune = _lazy_import_wrapper(
@@ -95,7 +95,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
9595
required_functions = [
9696
("flashinfer.fused_moe", "cutlass_fused_moe"),
9797
("flashinfer", "fp4_quantize"),
98-
("flashinfer", "block_scale_interleave"),
98+
("flashinfer", "nvfp4_block_scale_interleave"),
9999
]
100100

101101
for module_name, attr_name in required_functions:
@@ -110,7 +110,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
110110
"flashinfer_trtllm_fp8_block_scale_moe",
111111
"flashinfer_cutlass_fused_moe",
112112
"fp4_quantize",
113-
"block_scale_interleave",
113+
"nvfp4_block_scale_interleave",
114114
"autotune",
115115
"has_flashinfer_moe",
116116
"has_flashinfer_cutlass_fused_moe",

0 commit comments

Comments
 (0)