Skip to content

Commit 1b0a155

Browse files
authored
[Perf] Using __nv_fp8_e4m3 instead of c10::e4m3 for per_token_group_quant (#21867)
Signed-off-by: yewentao256 <[email protected]>
1 parent 44bc46d commit 1b0a155

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

csrc/quantization/fp8/per_token_group_quant.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
#include <ATen/cuda/CUDAContext.h>
2-
#include <c10/util/Float8_e4m3fn.h>
32

43
#include "../per_token_group_quant_8bit.h"
54

65
#include <cmath>
76

8-
#include <cuda_fp16.h>
9-
#include <cuda_bf16.h>
7+
#include <cuda_fp8.h>
108

119
#include <torch/all.h>
1210

@@ -199,7 +197,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
199197
VLLM_DISPATCH_FLOATING_TYPES(
200198
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
201199
if (dst_type == at::ScalarType::Float8_e4m3fn) {
202-
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
200+
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
203201
} else if (dst_type == at::ScalarType::Char) {
204202
LAUNCH_KERNEL(scalar_t, int8_t);
205203
}

0 commit comments

Comments
 (0)