@@ -27,7 +27,7 @@ constexpr int ITER = 200;
2727#endif
2828constexpr size_t UNDEFINED_DATA_SIZE = 1024 ;
2929
30- template <typename scalar_t >
30+ template <typename scalar_t , quant_mode quant_mode_ >
3131class test_col_major_1 {
3232 public:
3333 // Extract the parameters required by different test cases
@@ -41,7 +41,7 @@ class test_col_major_1 {
4141 static constexpr size_t sg_k = 512 / sg_m;
4242 static constexpr size_t dequant_s = 128 ;
4343 // static constexpr quant_mode quant_mode = quant_mode::I4_ASYM;
44- static constexpr quant_mode quant_mode = quant_mode::I4_SYM ;
44+ static constexpr quant_mode quant_mode = quant_mode_ ;
4545
4646 static constexpr size_t local_kslicing = 1 ;
4747 static constexpr size_t global_kslicing = 1 ;
@@ -132,13 +132,19 @@ std::vector<fp16> convert_int4(
132132 data_type_zero_pt zero_pt) {
133133 std::vector<fp16> dequant_fp16 (sizeof (data_type_b) * 2 );
134134
135- int8_t zero_pt_i8 = zero_pt & 0xf ;
135+ int8_t zero_pt_i8;
136+ if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
137+ zero_pt_i8 = zero_pt & 0xf ;
136138 for (uint32_t i = 0 ; i < dequant_fp16.size (); i++) {
137139 int8_t dequant_8bit = data_b & 0xf ;
138140 if constexpr (quant_mode == quant_mode::I4_SYM) {
139141 dequant_fp16[i] = scale * (dequant_8bit - 8 );
140- } else {
142+ } else if constexpr (quant_mode == quant_mode::I4_ASYM) {
141143 dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144+ } else if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
145+ dequant_fp16[i] = scale * (dequant_8bit - 8 ) + zero_pt;
146+ } else {
147+ assert (0 );
142148 }
143149 data_b = data_b >> 4 ;
144150 }
@@ -170,12 +176,14 @@ std::vector<data_type_acc_in> dequantize_weight(
170176 for (uint32_t j = 0 ; j < width; j += step) {
171177 int start_b_in = i * width + j;
172178 int start_scale_in = start_b_in / step;
173- int start_zero_pt_in =
174- (j / step) * (matrix_n / pack_radio) + i / pack_radio;
179+ int start_zero_pt_in = quant_mode == quant_mode::I4_ASYM_FP_ZERO
180+ ? (j / step) * matrix_n + i
181+ : (j / step) * (matrix_n / pack_radio) + i / pack_radio;
175182 int start_out =
176183 layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
177184 data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
178- zp_value = zp_value >> (4 * (i % pack_radio));
185+ if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO)
186+ zp_value = zp_value >> (4 * (i % pack_radio));
179187 for (uint32_t jj = 0 ; jj < step; jj++) {
180188 std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
181189 b[start_b_in + jj], scale[start_scale_in], zp_value);
@@ -216,7 +224,10 @@ void dequantize_gemv_run(int iter) {
216224 using data_type_a = typename Test::data_type_a;
217225 using data_type_b = typename Test::data_type_b;
218226 using data_type_c = typename Test::data_type_c;
219- using data_type_zero_pt = data_type_b;
227+ using data_type_zero_pt = std::conditional_t <
228+ Test::quant_mode == quant_mode::I4_ASYM_FP_ZERO,
229+ data_type_c,
230+ data_type_b>;
220231 using data_type_scale = fp16;
221232 using data_type_acc_in = fp16;
222233 using data_type_acc = float ;
@@ -226,7 +237,7 @@ void dequantize_gemv_run(int iter) {
226237 constexpr mem_layout layout_b = Test::layout_b;
227238
228239 constexpr size_t size_a = matrix_m * matrix_k;
229- constexpr size_t size_b = matrix_k * matrix_n / ( 2 * sizeof (data_type_b)) ;
240+ constexpr size_t size_b = matrix_k * matrix_n / 2 ;
230241
231242 constexpr size_t size_scale_k = matrix_k / dequant_s;
232243 constexpr size_t size_scale_n = matrix_n;
@@ -235,7 +246,9 @@ void dequantize_gemv_run(int iter) {
235246 constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
236247 constexpr size_t size_zero_pt_n = matrix_n;
237248 constexpr size_t size_zero_pt =
238- size_zero_pt_k * size_zero_pt_n / (2 * sizeof (data_type_b));
249+ Test::quant_mode != quant_mode::I4_ASYM_FP_ZERO
250+ ? size_zero_pt_k * size_zero_pt_n / 2
251+ : size_zero_pt_k * size_zero_pt_n;
239252
240253 constexpr size_t size_c = matrix_m * matrix_n;
241254 constexpr size_t size_bias = matrix_n;
@@ -406,16 +419,18 @@ void dequantize_gemv_run(int iter) {
406419 scale_h[i] = INFINITY;
407420 }
408421 for (unsigned i = 0 ; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) {
409- if constexpr (std::is_same_v<int4x2, data_type_b >) {
422+ if constexpr (std::is_same_v<int4x2, data_type_zero_pt >) {
410423 zero_pt_h[i] = random_uint8 ();
411424#ifdef UT_DEBUG
412425 zero_pt_h[i] = 0x12 << i;
413426#endif
414- } else if constexpr (std::is_same_v<int4x8, data_type_b >) {
427+ } else if constexpr (std::is_same_v<int4x8, data_type_zero_pt >) {
415428 zero_pt_h[i] = random_uint32 ();
416429#ifdef UT_DEBUG
417430 zero_pt_h[i] = 0x33333333 ;
418431#endif
432+ } else if constexpr (std::is_same_v<fp16, data_type_zero_pt>) {
433+ zero_pt_h[i] = random_float ();
419434 }
420435 }
421436
@@ -492,7 +507,9 @@ void dequantize_gemv_run(int iter) {
492507 Acc_d,
493508 Cnt_d,
494509 epilogue_args);
495- } else if constexpr (compute_policy::quant_mode == quant_mode::I4_ASYM) {
510+ } else if constexpr (
511+ compute_policy::quant_mode == quant_mode::I4_ASYM ||
512+ compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) {
496513 gemm_arg =
497514 typename gemm_op_t ::template arguments_t <compute_policy::quant_mode>(
498515 matrix_m,
@@ -604,8 +621,10 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) {
604621
605622REGISTER_TYPED_TEST_SUITE_P (dequantize_gemv_test, esimd);
606623using tests = ::testing::Types< //
607- test_col_major_1<fp16>,
608- test_col_major_1<bf16 >,
624+ test_col_major_1<fp16, quant_mode::I4_SYM>,
625+ test_col_major_1<bf16 , quant_mode::I4_SYM>,
626+ test_col_major_1<fp16, quant_mode::I4_ASYM_FP_ZERO>,
627+ test_col_major_1<bf16 , quant_mode::I4_ASYM_FP_ZERO>,
609628 // test_col_major_2,
610629 void >;
611630
0 commit comments