Skip to content

Commit 94ca46e

Browse files
yewentao256x22x22
authored andcommitted
[Perf] Cuda Kernel for Int8 Per Token Group Quant (vllm-project#21476)
Signed-off-by: yewentao256 <[email protected]> Signed-off-by: x22x22 <[email protected]>
1 parent a2b51da commit 94ca46e

File tree

6 files changed

+47
-3
lines changed

6 files changed

+47
-3
lines changed

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,11 @@ void per_token_group_quant_fp8(const torch::Tensor& input,
292292
torch::Tensor& output_q, torch::Tensor& output_s,
293293
int64_t group_size, double eps, double fp8_min,
294294
double fp8_max, bool scale_ue8m0);
295+
296+
void per_token_group_quant_int8(const torch::Tensor& input,
297+
torch::Tensor& output_q,
298+
torch::Tensor& output_s, int64_t group_size,
299+
double eps, double int8_min, double int8_max);
295300
#endif
296301

297302
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

csrc/quantization/compressed_tensors/int8_quant_kernels.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <ATen/cuda/CUDAContext.h>
22
#include <torch/all.h>
33

4+
#include "../per_token_group_quant_8bit.h"
5+
46
#include <cmath>
57

68
#include "../../dispatch_utils.h"
@@ -336,3 +338,11 @@ void dynamic_scaled_int8_quant(
336338
}
337339
});
338340
}
341+
342+
void per_token_group_quant_int8(const torch::Tensor& input,
343+
torch::Tensor& output_q,
344+
torch::Tensor& output_s, int64_t group_size,
345+
double eps, double int8_min, double int8_max) {
346+
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
347+
int8_min, int8_max);
348+
}

csrc/quantization/fp8/per_token_group_quant.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <ATen/cuda/CUDAContext.h>
22
#include <c10/util/Float8_e4m3fn.h>
33

4+
#include "../per_token_group_quant_8bit.h"
5+
46
#include <cmath>
57

68
#include <cuda_fp16.h>
@@ -120,7 +122,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
120122
torch::Tensor& output_q,
121123
torch::Tensor& output_s, int64_t group_size,
122124
double eps, double min_8bit, double max_8bit,
123-
bool scale_ue8m0 = false) {
125+
bool scale_ue8m0) {
124126
TORCH_CHECK(input.is_contiguous());
125127
TORCH_CHECK(output_q.is_contiguous());
126128

@@ -198,6 +200,8 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
198200
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
199201
if (dst_type == at::ScalarType::Float8_e4m3fn) {
200202
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
203+
} else if (dst_type == at::ScalarType::Char) {
204+
LAUNCH_KERNEL(scalar_t, int8_t);
201205
}
202206
}));
203207

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#pragma once
2+
#include <torch/all.h>
3+
4+
// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders
5+
// 8-bit per-token-group quantization helper used by both FP8 and INT8
6+
void per_token_group_quant_8bit(const torch::Tensor& input,
7+
torch::Tensor& output_q,
8+
torch::Tensor& output_s, int64_t group_size,
9+
double eps, double min_8bit, double max_8bit,
10+
bool scale_ue8m0 = false);

csrc/torch_bindings.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
624624
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
625625
&per_token_group_quant_fp8);
626626

627+
// Compute per-token-group INT8 quantized tensor and scaling factor.
628+
ops.def(
629+
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
630+
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
631+
"()");
632+
ops.impl("per_token_group_quant_int8", torch::kCUDA,
633+
&per_token_group_quant_int8);
634+
627635
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
628636
ops.def(
629637
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "

vllm/model_executor/layers/quantization/utils/int8_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,20 @@ def per_token_group_quant_int8(
238238
int8_min = iinfo.min
239239

240240
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
241-
M = x.numel() // group_size
242-
N = group_size
243241
x_s = torch.empty(
244242
x.shape[:-1] + (x.shape[-1] // group_size, ),
245243
device=x.device,
246244
dtype=torch.float32,
247245
)
246+
# prefer CUDA kernel if available
247+
if current_platform.is_cuda():
248+
torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps,
249+
float(int8_min),
250+
float(int8_max))
251+
return x_q, x_s
252+
253+
M = x.numel() // group_size
254+
N = group_size
248255

249256
BLOCK = triton.next_power_of_2(N)
250257
# heuristics for number of warps

0 commit comments

Comments
 (0)