Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 1ba81d7

Browse files
committed
zp no degrad
1 parent 8f0abc4 commit 1ba81d7

File tree

5 files changed

+79
-23
lines changed

5 files changed

+79
-23
lines changed

include/common/core/common_types.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
2727

2828
enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };
2929

30-
enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 };
30+
enum class quant_mode : uint8_t {
31+
S4_ASYM = 0,
32+
S4_FULLRANGE_NO_ZP = 1,
33+
INT4_ASYM_ZERO_NO_DEGRAD = 2
34+
};
3135

3236
struct quant_info {
3337
quant_mode quant_mode;

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,11 @@ class gemm_t<
102102
std::is_same<remove_const_t<dtype_b>, remove_const_t<int4x8>>::value,
103103
"this is for 4bit matB ");
104104
static_assert(
105-
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x2>>::
106-
value ||
107-
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x8>>::
108-
value,
105+
quant_info_.quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD &&
106+
(std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x2>>::
107+
value ||
108+
std::is_same<remove_const_t<dtype_zero_pt>, remove_const_t<int4x8>>::
109+
value),
109110
"this is for 4bit zero_pt ");
110111

111112
/******** set memory attribute **********/
@@ -290,12 +291,20 @@ class gemm_t<
290291
arch_tag>;
291292

292293
// compress int4 along N dimensions
293-
using zero_pt_tile_desc_t = subgroup::tile_desc_t<
294-
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
295-
tile_size_y_zero_pt,
296-
(block_size_x_b + pack_ratio - 1) / pack_ratio,
297-
block_size_y_zero_pt,
298-
reg_layout::tiled>;
294+
using zero_pt_tile_desc_t = std::conditional_t<
295+
quant_info_.quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD,
296+
subgroup::tile_desc_t<
297+
(tile_size_x_b + pack_ratio - 1) / pack_ratio,
298+
tile_size_y_zero_pt,
299+
(block_size_x_b + pack_ratio - 1) / pack_ratio,
300+
block_size_y_zero_pt,
301+
reg_layout::tiled>,
302+
subgroup::tile_desc_t<
303+
tile_size_x_b,
304+
tile_size_y_zero_pt,
305+
block_size_x_b,
306+
block_size_y_zero_pt,
307+
reg_layout::tiled>>;
299308

300309
using zero_pt_t = subgroup::tile_t<dtype_zero_pt, zero_pt_tile_desc_t>;
301310
using zero_pt_payload_t = subgroup::mem_payload_t<

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,7 @@ class gemm_universal_t<
566566
// check for int4x2
567567
implementable &=
568568
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
569-
if constexpr (
570-
gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
569+
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::S4_ASYM) {
571570
implementable &= (args.zero_pt_ld % pack_ratio == 0);
572571
}
573572

@@ -618,7 +617,10 @@ class gemm_universal_t<
618617
int start_x_scale = start_n;
619618
int start_y_scale = start_k / dequant_s;
620619

621-
int start_x_zero_pt = start_n / pack_ratio;
620+
int start_x_zero_pt = gemm_t::compute_policy::quant_mode ==
621+
quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
622+
? start_n
623+
: start_n / pack_ratio;
622624
int start_y_zero_pt = start_k / dequant_s;
623625

624626
// set up arguments
@@ -672,7 +674,8 @@ class gemm_universal_t<
672674
inner_loop_start,
673675
inner_loop_count,
674676
mem_desc_scale);
675-
} else {
677+
} else if constexpr (
678+
gemm_t::compute_policy::quant_mode == quant_mode::S4_ASYM) {
676679
mem_desc_zero_pt_t mem_desc_zero_pt(
677680
args.zero_pt_base,
678681
{(args.matrix_n + pack_ratio - 1) / pack_ratio,
@@ -686,6 +689,24 @@ class gemm_universal_t<
686689
inner_loop_count,
687690
mem_desc_scale,
688691
mem_desc_zero_pt);
692+
} else if constexpr (
693+
gemm_t::compute_policy::quant_mode ==
694+
quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
695+
mem_desc_zero_pt_t mem_desc_zero_pt(
696+
args.zero_pt_base,
697+
{args.matrix_n,
698+
((args.matrix_k + dequant_s - 1) / dequant_s),
699+
args.zero_pt_ld},
700+
{start_x_zero_pt, start_y_zero_pt});
701+
gemm_args = gemm_args_t(
702+
mem_desc_a,
703+
mem_desc_b,
704+
inner_loop_start,
705+
inner_loop_count,
706+
mem_desc_scale,
707+
mem_desc_zero_pt);
708+
} else {
709+
assert(0);
689710
}
690711
matAcc_t matAcc;
691712
matAcc.init(0);

include/subgroup/tile/impl/tile_op_functor.hpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,26 @@ struct dequant_int4_weight_t {
149149
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
150150
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
151151
zero_pt_i8;
152-
} else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
152+
} else if constexpr (
153+
quant_mode == quant_mode::S4_FULLRANGE_NO_ZP ||
154+
quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
153155
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
154156
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
155157
int8_t(8);
156158
}
157159
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
158160
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
159161
scale.reg[scale_idx];
160-
162+
if constexpr (quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
163+
uint32_t zero_pt_idx =
164+
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
165+
offset_x_in_tile;
166+
native_type_t<typename zero_pt_t::dtype> zero_pt_pack =
167+
zero_pt.reg[zero_pt_idx];
168+
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
169+
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) +
170+
zero_pt_pack;
171+
}
161172
// sycl::ext::oneapi::experimental::printf(
162173
// "scale[%d] %f \n",
163174
// scale_idx,

tests/integration/gemv/int4/main.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,12 @@ std::vector<fp16> convert_int4(
136136
int8_t dequant_8bit = data_b & 0xf;
137137
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
138138
dequant_fp16[i] = scale * (dequant_8bit - 8);
139-
} else {
139+
} else if constexpr (quant_mode == quant_mode::S4_ASYM) {
140140
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
141+
} else if constexpr (quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
142+
dequant_fp16[i] = scale * (dequant_8bit - 8) + zero_pt;
143+
} else {
144+
assert(0);
141145
}
142146
data_b = data_b >> 4;
143147
}
@@ -215,7 +219,10 @@ void dequantize_gemv_run(int iter) {
215219
using data_type_a = typename Test::data_type_a;
216220
using data_type_b = typename Test::data_type_b;
217221
using data_type_c = typename Test::data_type_c;
218-
using data_type_zero_pt = data_type_b;
222+
using data_type_zero_pt = std::conditional_t<
223+
Test::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD,
224+
data_type_c,
225+
data_type_b>;
219226
using data_type_scale = fp16;
220227
using data_type_acc_in = fp16;
221228
using data_type_acc = float;
@@ -225,7 +232,7 @@ void dequantize_gemv_run(int iter) {
225232
constexpr mem_layout layout_b = Test::layout_b;
226233

227234
constexpr size_t size_a = matrix_m * matrix_k;
228-
constexpr size_t size_b = matrix_k * matrix_n / (2 * sizeof(data_type_b));
235+
constexpr size_t size_b = matrix_k * matrix_n / 2;
229236

230237
constexpr size_t size_scale_k = matrix_k / dequant_s;
231238
constexpr size_t size_scale_n = matrix_n;
@@ -234,7 +241,9 @@ void dequantize_gemv_run(int iter) {
234241
constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
235242
constexpr size_t size_zero_pt_n = matrix_n;
236243
constexpr size_t size_zero_pt =
237-
size_zero_pt_k * size_zero_pt_n / (2 * sizeof(data_type_b));
244+
Test::quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
245+
? size_zero_pt_k * size_zero_pt_n / 2
246+
: size_zero_pt_k * size_zero_pt_n;
238247

239248
constexpr size_t size_c = matrix_m * matrix_n;
240249
constexpr size_t size_bias = matrix_n;
@@ -405,16 +414,18 @@ void dequantize_gemv_run(int iter) {
405414
scale_h[i] = INFINITY;
406415
}
407416
for (unsigned i = 0; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) {
408-
if constexpr (std::is_same_v<int4x2, data_type_b>) {
417+
if constexpr (std::is_same_v<int4x2, data_type_zero_pt>) {
409418
zero_pt_h[i] = random_uint8();
410419
#ifdef UT_DEBUG
411420
zero_pt_h[i] = 0x12 << i;
412421
#endif
413-
} else if constexpr (std::is_same_v<int4x8, data_type_b>) {
422+
} else if constexpr (std::is_same_v<int4x8, data_type_zero_pt>) {
414423
zero_pt_h[i] = random_uint32();
415424
#ifdef UT_DEBUG
416425
zero_pt_h[i] = 0x33333333;
417426
#endif
427+
} else if constexpr (std::is_same_v<fp16, data_type_zero_pt>) {
428+
zero_pt_h[i] = random_float();
418429
}
419430
}
420431

0 commit comments

Comments
 (0)