Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 22c2123

Browse files
committed
rename quantmode
1 parent 4c65a40 commit 22c2123

File tree

8 files changed

+23
-31
lines changed

8 files changed

+23
-31
lines changed

include/common/core/common_types.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
2727

2828
enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };
2929

30-
enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 };
30+
enum class quant_mode : uint8_t { I4_ASYM = 0, I4_SYM = 1 };
3131

3232
struct quant_info {
3333
quant_mode quant_mode;

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,7 @@ class gemm_t<
520520
// TODO 1D prefetch need pack to U32/U64
521521
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
522522
scale_prefetch_payload);
523-
if constexpr (
524-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
523+
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
525524
// TODO 1D prefetch need pack to U32/U64
526525
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
527526
zero_pt_prefetch_payload);
@@ -534,8 +533,7 @@ class gemm_t<
534533
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
535534
scale_prefetch_payload.template update_tdesc<update_dir_b>(
536535
scale_t::tile_size_y);
537-
if constexpr (
538-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
536+
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
539537
zero_pt_prefetch_payload
540538
.template update_tdesc<tdesc_update_dir::y_dir>(
541539
zero_pt_t::tile_size_y);
@@ -564,8 +562,7 @@ class gemm_t<
564562
// matB, matB_payload);
565563
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
566564
scale, scale_payload);
567-
if constexpr (
568-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
565+
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
569566
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
570567
zero_pt, zero_pt_payload);
571568
}
@@ -579,8 +576,7 @@ class gemm_t<
579576
// TODO 1D prefetch need pack to U32/U64
580577
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
581578
scale_prefetch_payload);
582-
if constexpr (
583-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
579+
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
584580
// TODO 1D prefetch need pack to U32/U64
585581
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
586582
zero_pt_prefetch_payload);
@@ -593,8 +589,7 @@ class gemm_t<
593589
if (tile_k_idx % scale_addr_update_freq == 0) {
594590
scale_payload.template update_tdesc<update_dir_b>(scale_t::tile_size_y);
595591
}
596-
if constexpr (
597-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
592+
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
598593
if (tile_k_idx % zero_pt_addr_update_freq == 0) {
599594
zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
600595
zero_pt_t::tile_size_y);
@@ -608,8 +603,7 @@ class gemm_t<
608603
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
609604
scale_prefetch_payload.template update_tdesc<tdesc_update_dir::y_dir>(
610605
scale_t::tile_size_y);
611-
if constexpr (
612-
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
606+
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
613607
zero_pt_prefetch_payload
614608
.template update_tdesc<tdesc_update_dir::y_dir>(
615609
zero_pt_t::tile_size_y);

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class gemm_universal_t<
159159
/// @brief GEMM arguments.
160160
/// This is the interface for users to pass the application-related runtime
161161
/// variables.
162-
template <quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP>
162+
template <quant_mode quant_mode = quant_mode::I4_SYM>
163163
struct arguments_t {
164164
/// @brief Is the size of the m dimension of the matrix multiplication (m x
165165
/// k x n).
@@ -295,7 +295,7 @@ class gemm_universal_t<
295295
}
296296
};
297297
template <>
298-
struct arguments_t<quant_mode::S4_FULLRANGE_NO_ZP> {
298+
struct arguments_t<quant_mode::I4_SYM> {
299299
/// @brief Is the size of the m dimension of the matrix multiplication (m x
300300
/// k x n).
301301
uint32_t matrix_m;
@@ -570,8 +570,7 @@ class gemm_universal_t<
570570
// check for int4x2
571571
implementable &=
572572
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
573-
if constexpr (
574-
gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
573+
if constexpr (gemm_t::compute_policy::quant_mode != quant_mode::I4_SYM) {
575574
implementable &= (args.zero_pt_ld % pack_ratio == 0);
576575
}
577576

@@ -668,8 +667,7 @@ class gemm_universal_t<
668667
uint32_t inner_loop_start = (start_k + k_stride - 1) / k_stride;
669668
uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
670669
gemm_args_t gemm_args;
671-
if constexpr (
672-
gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
670+
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_SYM) {
673671
gemm_args = gemm_args_t(
674672
mem_desc_a,
675673
mem_desc_b,

include/subgroup/tile/impl/tile_op_functor.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ struct dequant_int4_weight_t {
130130
(offset_y_in_tile) / dequant_s * scale_t::block_size_x +
131131
offset_x_in_tile;
132132

133-
if constexpr (quant_mode == quant_mode::S4_ASYM) {
133+
if constexpr (quant_mode == quant_mode::I4_ASYM) {
134134
uint32_t zero_pt_idx =
135135
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
136136
offset_x_in_tile / pack_ratio;
@@ -149,7 +149,7 @@ struct dequant_int4_weight_t {
149149
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
150150
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
151151
zero_pt_i8;
152-
} else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
152+
} else if constexpr (quant_mode == quant_mode::I4_SYM) {
153153
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
154154
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
155155
int8_t(8);

tests/integration/gemm/int4_dequantization/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ void dequantize_gemm_run(uint32_t iter) {
231231
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
232232

233233
static constexpr quant_info quant_info{
234-
quant_mode::S4_ASYM, Test::dequant_s, layout_b};
234+
quant_mode::I4_ASYM, Test::dequant_s, layout_b};
235235

236236
using compute_policy = xetla::group::compute_policy_int4_dequantize<
237237
compute_attr,

tests/integration/gemm/int4_dequantization_bias/main_client.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ void dequantize_gemm_run(int iter) {
622622
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
623623

624624
static constexpr quant_info quant_info{
625-
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
625+
quant_mode::I4_SYM, Test::dequant_s, layout_b};
626626

627627
using compute_policy = xetla::group::compute_policy_int4_dequantize<
628628
compute_attr,

tests/integration/gemm/int4_dequantization_bias/main_xe.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ void dequantize_gemm_run(int iter) {
388388
using perf_tuning_knob = xetla::group::
389389
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
390390
static constexpr quant_info quant_info{
391-
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
391+
quant_mode::I4_SYM, Test::dequant_s, layout_b};
392392

393393
using compute_policy = xetla::group::compute_policy_int4_dequantize<
394394
compute_attr,

tests/integration/gemv/int4/main.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class test_col_major_1 {
4040
static constexpr size_t sg_n = 1;
4141
static constexpr size_t sg_k = 512 / sg_m;
4242
static constexpr size_t dequant_s = 128;
43-
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
44-
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
43+
// static constexpr quant_mode quant_mode = quant_mode::I4_ASYM;
44+
static constexpr quant_mode quant_mode = quant_mode::I4_SYM;
4545

4646
static constexpr size_t local_kslicing = 1;
4747
static constexpr size_t global_kslicing = 1;
@@ -121,7 +121,7 @@ int gemm_result_validate(
121121
}
122122

123123
template <
124-
quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP,
124+
quant_mode quant_mode = quant_mode::I4_SYM,
125125
typename data_type_acc_in = fp16,
126126
typename data_type_b,
127127
typename data_type_scale,
@@ -135,7 +135,7 @@ std::vector<fp16> convert_int4(
135135
int8_t zero_pt_i8 = zero_pt & 0xf;
136136
for (uint32_t i = 0; i < dequant_fp16.size(); i++) {
137137
int8_t dequant_8bit = data_b & 0xf;
138-
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
138+
if constexpr (quant_mode == quant_mode::I4_SYM) {
139139
dequant_fp16[i] = scale * (dequant_8bit - 8);
140140
} else {
141141
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
@@ -148,7 +148,7 @@ std::vector<fp16> convert_int4(
148148
template <
149149
size_t dequant_s,
150150
mem_layout layout_b = mem_layout::col_major,
151-
quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP,
151+
quant_mode quant_mode = quant_mode::I4_SYM,
152152
typename data_type_acc_in = fp16,
153153
typename data_type_b,
154154
typename data_type_scale,
@@ -475,7 +475,7 @@ void dequantize_gemv_run(int iter) {
475475
// It accepts the base pointer to matrix D, and its dimensions
476476
{bias_d, bias_add_shape}});
477477
typename gemm_op_t::template arguments_t<compute_policy::quant_mode> gemm_arg;
478-
if constexpr (compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
478+
if constexpr (compute_policy::quant_mode == quant_mode::I4_SYM) {
479479
gemm_arg =
480480
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
481481
matrix_m,
@@ -492,7 +492,7 @@ void dequantize_gemv_run(int iter) {
492492
Acc_d,
493493
Cnt_d,
494494
epilogue_args);
495-
} else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
495+
} else if constexpr (compute_policy::quant_mode == quant_mode::I4_ASYM) {
496496
gemm_arg =
497497
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
498498
matrix_m,

0 commit comments

Comments
 (0)