Skip to content

Commit f9a4087

Browse files
authored
Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (vllm-project#28431)
Signed-off-by: mgoin <[email protected]>
1 parent 287bbbe commit f9a4087

File tree

5 files changed

+36
-36
lines changed

5 files changed

+36
-36
lines changed

benchmarks/kernels/bench_block_fp8_gemm.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import os
5+
6+
# Disable DeepGEMM for this benchmark to use CUTLASS
7+
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
8+
49
import torch
510

611
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
7-
apply_w8a8_block_fp8_linear,
12+
W8A8BlockFp8LinearOp,
13+
)
14+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15+
GroupShape,
816
)
917
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1018
CUTLASS_BLOCK_FP8_SUPPORTED,
@@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
3947
fp8_info = torch.finfo(torch.float8_e4m3fn)
4048
fp8_max, fp8_min = fp8_info.max, fp8_info.min
4149

42-
# Create random FP8 tensors
50+
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
4351
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
4452

53+
# Create quantized weight tensor
4554
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
4655
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
4756

48-
# Create scales
57+
# Create weight scales
4958
block_n, block_k = block_size[0], block_size[1]
5059
n_tiles = (N + block_n - 1) // block_n
5160
k_tiles = (K + block_k - 1) // block_k
@@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
5564
* factor_for_scale
5665
)
5766

58-
# SM90 CUTLASS requires row-major format for scales
59-
if use_cutlass and current_platform.is_device_capability(90):
60-
Bs = Bs.T.contiguous()
67+
# Create W8A8BlockFp8LinearOp instance
68+
weight_group_shape = GroupShape(block_n, block_k)
69+
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
70+
71+
linear_op = W8A8BlockFp8LinearOp(
72+
weight_group_shape=weight_group_shape,
73+
act_quant_group_shape=act_quant_group_shape,
74+
cutlass_block_fp8_supported=use_cutlass,
75+
use_aiter_and_is_supported=False,
76+
)
6177

6278
def run():
63-
if use_cutlass:
64-
return apply_w8a8_block_fp8_linear(
65-
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
66-
)
67-
else:
68-
return apply_w8a8_block_fp8_linear(
69-
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
70-
)
79+
return linear_op.apply(
80+
input=A_ref,
81+
weight=B,
82+
weight_scale=Bs,
83+
input_scale=None,
84+
bias=None,
85+
)
7186

7287
return run
7388

csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
4848
using ElementBlockScale = float;
4949

5050
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
51-
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
51+
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
52+
cute::GMMA::Major::MN, cute::GMMA::Major::K>;
5253

5354
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
5455
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def process_weights_after_loading(self, layer) -> None:
173173
layer.input_scale = None
174174

175175
if self.strategy == QuantizationStrategy.BLOCK:
176-
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
176+
maybe_post_process_fp8_weight_block(layer)
177177

178178
def apply_weights(
179179
self,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
540540
return
541541

542542
if self.block_quant:
543-
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
543+
maybe_post_process_fp8_weight_block(layer)
544544

545545
def apply(
546546
self,

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,13 @@ def cutlass_scaled_mm(
5555
Bs: torch.Tensor,
5656
block_size: list[int],
5757
output_dtype: torch.dtype = torch.float16,
58-
is_hopper: bool | None = None,
5958
) -> torch.Tensor:
60-
if is_hopper is None:
61-
is_hopper = current_platform.is_device_capability(90)
6259
return ops.cutlass_scaled_mm(
6360
A,
6461
B.T,
6562
out_dtype=output_dtype,
6663
scale_a=As,
67-
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
68-
scale_b=Bs if block_size is not None and is_hopper else Bs.T,
64+
scale_b=Bs.T,
6965
)
7066

7167

@@ -130,7 +126,7 @@ def _padded_cutlass(
130126
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
131127

132128
output = cutlass_scaled_mm(
133-
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True
129+
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
134130
)
135131
return output[0 : qx.shape[0], ...]
136132

@@ -303,7 +299,6 @@ def _run_cutlass(
303299
weight_scale,
304300
list(self.weight_group_shape),
305301
input_2d.dtype,
306-
False,
307302
)
308303

309304
def _run_aiter(
@@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy(
11251120
return weight, weight_scale
11261121

11271122

1128-
def maybe_post_process_fp8_weight_block(
1129-
layer: torch.nn.Module, cutlass_block_fp8_supported: bool
1130-
):
1123+
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
11311124
assert layer.weight_block_size is not None
11321125

11331126
from vllm.utils.deep_gemm import (
@@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block(
11461139
requant_weight_ue8m0_inplace(
11471140
layer.weight.data, layer.weight_scale.data, block_sz
11481141
)
1149-
# SM90 Block FP8 CUTLASS requires row-major weight scales
1150-
elif (
1151-
current_platform.is_device_capability(90)
1152-
and cutlass_block_fp8_supported
1153-
and not should_use_deepgemm
1154-
):
1155-
layer.weight_scale = torch.nn.Parameter(
1156-
layer.weight_scale.data.T.contiguous(), requires_grad=False
1157-
)
11581142

11591143

11601144
def expert_weight_is_col_major(x: torch.Tensor) -> bool:

0 commit comments

Comments
 (0)