@@ -41,7 +41,7 @@ class test_col_major_1 {
4141 static constexpr size_t dequant_s = 128 ;
4242 // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
4343 // static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44- static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_ZERO_NO_DEGRAD ;
44+ static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_FP_ZERO ;
4545
4646 static constexpr size_t local_kslicing = 1 ;
4747 static constexpr size_t global_kslicing = 1 ;
@@ -133,15 +133,15 @@ std::vector<fp16> convert_int4(
133133 std::vector<fp16> dequant_fp16 (sizeof (data_type_b) * 2 );
134134
135135 int8_t zero_pt_i8;
136- if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD )
136+ if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO )
137137 zero_pt_i8 = zero_pt & 0xf ;
138138 for (uint32_t i = 0 ; i < dequant_fp16.size (); i++) {
139139 int8_t dequant_8bit = data_b & 0xf ;
140140 if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
141141 dequant_fp16[i] = scale * (dequant_8bit - 8 );
142142 } else if constexpr (quant_mode == quant_mode::S4_ASYM) {
143143 dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144- } else if constexpr (quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD ) {
144+ } else if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO ) {
145145 dequant_fp16[i] = scale * (dequant_8bit - 8 ) + zero_pt;
146146 } else {
147147 assert (0 );
@@ -176,13 +176,13 @@ std::vector<data_type_acc_in> dequantize_weight(
176176 for (uint32_t j = 0 ; j < width; j += step) {
177177 int start_b_in = i * width + j;
178178 int start_scale_in = start_b_in / step;
179- int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
179+ int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_FP_ZERO
180180 ? (j / step) * matrix_n + i
181181 : (j / step) * (matrix_n / pack_radio) + i / pack_radio;
182182 int start_out =
183183 layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184184 data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185- if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD )
185+ if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO )
186186 zp_value = zp_value >> (4 * (i % pack_radio));
187187 for (uint32_t jj = 0 ; jj < step; jj++) {
188188 std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
@@ -225,7 +225,7 @@ void dequantize_gemv_run(int iter) {
225225 using data_type_b = typename Test::data_type_b;
226226 using data_type_c = typename Test::data_type_c;
227227 using data_type_zero_pt = std::conditional_t <
228- Test::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD ,
228+ Test::quant_mode == quant_mode::INT4_ASYM_FP_ZERO ,
229229 data_type_c,
230230 data_type_b>;
231231 using data_type_scale = fp16;
@@ -246,7 +246,7 @@ void dequantize_gemv_run(int iter) {
246246 constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
247247 constexpr size_t size_zero_pt_n = matrix_n;
248248 constexpr size_t size_zero_pt =
249- Test::quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
249+ Test::quant_mode != quant_mode::INT4_ASYM_FP_ZERO
250250 ? size_zero_pt_k * size_zero_pt_n / 2
251251 : size_zero_pt_k * size_zero_pt_n;
252252
@@ -509,7 +509,7 @@ void dequantize_gemv_run(int iter) {
509509 epilogue_args);
510510 } else if constexpr (
511511 compute_policy::quant_mode == quant_mode::S4_ASYM ||
512- compute_policy::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD ) {
512+ compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO ) {
513513 gemm_arg =
514514 typename gemm_op_t ::template arguments_t <compute_policy::quant_mode>(
515515 matrix_m,
0 commit comments