diff --git a/include/common/core/common_types.hpp b/include/common/core/common_types.hpp index cbd174462..28aab784a 100644 --- a/include/common/core/common_types.hpp +++ b/include/common/core/common_types.hpp @@ -27,7 +27,11 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 }; enum class mem_layout : uint8_t { row_major = 0, col_major = 1 }; -enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 }; +enum class quant_mode : uint8_t { + I4_ASYM = 0, + I4_SYM = 1, + I4_ASYM_FP_ZERO = 2 +}; struct quant_info { quant_mode quant_mode; diff --git a/include/common/core/explicit_conv.hpp b/include/common/core/explicit_conv.hpp index 0c61f12bc..ba553ad2d 100644 --- a/include/common/core/explicit_conv.hpp +++ b/include/common/core/explicit_conv.hpp @@ -62,6 +62,19 @@ xetla_cvt(xetla_vector src) { return dst; } +/// @brief xetla explicit data conversion, bf16->fp16. +/// @tparam T_dst is the float16 data type. +/// @tparam T_src is the bfloat16 data type. +/// @tparam N is the element number in xetla_vector. +template +__XETLA_API typename std::enable_if_t< + std::is_same::value && std::is_same::value, + xetla_vector> +xetla_cvt(xetla_vector src) { + xetla_vector dst = src; + return dst; +} + /// @brief xetla explicit data conversion, bf16->fp32. /// @tparam T_dst is the bfloat16 data type. /// @tparam T_src is the float32 data type. diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index cfb09fe22..c9b5ed925 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -102,10 +102,16 @@ class gemm_t< std::is_same, remove_const_t>::value, "this is for 4bit matB "); static_assert( - std::is_same, remove_const_t>:: - value || - std::is_same, remove_const_t>:: - value, + quant_info_.quant_mode == quant_mode::I4_ASYM_FP_ZERO + ? std::is_same_v< + remove_const_t, + remove_const_t> + : (std::is_same_v< + remove_const_t, + remove_const_t> || + std::is_same_v< + remove_const_t, + remove_const_t>), "this is for 4bit zero_pt "); /******** set memory attribute **********/ @@ -284,12 +290,20 @@ class gemm_t< arch_tag>; // compress int4 along N dimensions - using zero_pt_tile_desc_t = subgroup::tile_desc_t< - (tile_size_x_b + pack_ratio - 1) / pack_ratio, - tile_size_y_zero_pt, - (block_size_x_b + pack_ratio - 1) / pack_ratio, - block_size_y_zero_pt, - reg_layout::tiled>; + using zero_pt_tile_desc_t = std::conditional_t< + quant_info_.quant_mode != quant_mode::I4_ASYM_FP_ZERO, + subgroup::tile_desc_t< + (tile_size_x_b + pack_ratio - 1) / pack_ratio, + tile_size_y_zero_pt, + (block_size_x_b + pack_ratio - 1) / pack_ratio, + block_size_y_zero_pt, + reg_layout::tiled>, + subgroup::tile_desc_t< + tile_size_x_b, + tile_size_y_zero_pt, + block_size_x_b, + block_size_y_zero_pt, + reg_layout::tiled>>; using zero_pt_t = subgroup::tile_t; using zero_pt_payload_t = subgroup::mem_payload_t< @@ -520,8 +534,7 @@ class gemm_t< // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( scale_prefetch_payload); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( zero_pt_prefetch_payload); @@ -534,8 +547,7 @@ class gemm_t< if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { zero_pt_prefetch_payload .template update_tdesc( zero_pt_t::tile_size_y); @@ -564,8 +576,7 @@ class gemm_t< // matB, matB_payload); subgroup::tile_load( scale, scale_payload); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { subgroup::tile_load( zero_pt, zero_pt_payload); } @@ -579,8 +590,7 @@ class gemm_t< // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( scale_prefetch_payload); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( zero_pt_prefetch_payload); @@ -593,8 +603,7 @@ class gemm_t< if (tile_k_idx % scale_addr_update_freq == 0) { scale_payload.template update_tdesc(scale_t::tile_size_y); } - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { if (tile_k_idx % zero_pt_addr_update_freq == 0) { zero_pt_payload.template update_tdesc( zero_pt_t::tile_size_y); @@ -608,8 +617,7 @@ class gemm_t< if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); - if constexpr ( - compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) { zero_pt_prefetch_payload .template update_tdesc( zero_pt_t::tile_size_y); diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 79d37d517..bff5914f3 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -159,7 +159,7 @@ class gemm_universal_t< /// @brief GEMM arguments. /// This is the interface for users to pass the application-related runtime /// variables. - template + template struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). @@ -295,7 +295,7 @@ class gemm_universal_t< } }; template <> - struct arguments_t { + struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). uint32_t matrix_m; @@ -526,6 +526,10 @@ class gemm_universal_t< template static bool can_implement(arguments_t& args) { bool implementable = true; + if (arch_tag == gpu_arch::XeLpg) { + implementable &= !std::is_same_v; // XeLpg arch dosen't + // have bf16 related isa. + } if (gemm_t::msg_type_a != msg_type::unaligned_2d) { if (gemm_t::msg_type_a == msg_type::block_2d) { implementable &= kernel::block_2d::check_tensor( @@ -566,8 +570,7 @@ class gemm_universal_t< // check for int4x2 implementable &= ((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0)); - if constexpr ( - gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) { implementable &= (args.zero_pt_ld % pack_ratio == 0); } @@ -618,7 +621,10 @@ class gemm_universal_t< int start_x_scale = start_n; int start_y_scale = start_k / dequant_s; - int start_x_zero_pt = start_n / pack_ratio; + int start_x_zero_pt = + gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO + ? start_n + : start_n / pack_ratio; int start_y_zero_pt = start_k / dequant_s; // set up arguments @@ -664,15 +670,15 @@ class gemm_universal_t< uint32_t inner_loop_start = (start_k + k_stride - 1) / k_stride; uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride; gemm_args_t gemm_args; - if constexpr ( - gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_SYM) { gemm_args = gemm_args_t( mem_desc_a, mem_desc_b, inner_loop_start, inner_loop_count, mem_desc_scale); - } else { + } else if constexpr ( + gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM) { mem_desc_zero_pt_t mem_desc_zero_pt( args.zero_pt_base, {(args.matrix_n + pack_ratio - 1) / pack_ratio, @@ -686,6 +692,23 @@ class gemm_universal_t< inner_loop_count, mem_desc_scale, mem_desc_zero_pt); + } else if constexpr ( + gemm_t::compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) { + mem_desc_zero_pt_t mem_desc_zero_pt( + args.zero_pt_base, + {args.matrix_n, + ((args.matrix_k + dequant_s - 1) / dequant_s), + args.zero_pt_ld}, + {start_x_zero_pt, start_y_zero_pt}); + gemm_args = gemm_args_t( + mem_desc_a, + mem_desc_b, + inner_loop_start, + inner_loop_count, + mem_desc_scale, + mem_desc_zero_pt); + } else { + assert(0); } matAcc_t matAcc; matAcc.init(0); diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 060302448..ad022d90b 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -130,7 +130,7 @@ struct dequant_int4_weight_t { (offset_y_in_tile) / dequant_s * scale_t::block_size_x + offset_x_in_tile; - if constexpr (quant_mode == quant_mode::S4_ASYM) { + if constexpr (quant_mode == quant_mode::I4_ASYM) { uint32_t zero_pt_idx = offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + offset_x_in_tile / pack_ratio; @@ -149,7 +149,9 @@ struct dequant_int4_weight_t { cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - zero_pt_i8; - } else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + } else if constexpr ( + quant_mode == quant_mode::I4_SYM || + quant_mode == quant_mode::I4_ASYM_FP_ZERO) { cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - int8_t(8); @@ -157,7 +159,15 @@ struct dequant_int4_weight_t { dst_blk.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * scale.reg[scale_idx]; - + if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) { + uint32_t zero_pt_idx = + offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + + offset_x_in_tile; + xetla_vector zero_pt_pack = zero_pt.reg[zero_pt_idx]; + dst_blk.xetla_select(jj * block_size_y_b + ii) = + dst_blk.xetla_select(jj * block_size_y_b + ii) + + zero_pt_pack[0]; + } // sycl::ext::oneapi::experimental::printf( // "scale[%d] %f \n", // scale_idx, diff --git a/tests/integration/gemm/int4_dequantization/main.cpp b/tests/integration/gemm/int4_dequantization/main.cpp index 18e40ded5..88c21250e 100644 --- a/tests/integration/gemm/int4_dequantization/main.cpp +++ b/tests/integration/gemm/int4_dequantization/main.cpp @@ -229,8 +229,9 @@ void dequantize_gemm_run(uint32_t iter) { compute_attr_t; using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; - - static constexpr quant_info quant_info{quant_mode::S4_ASYM, Test::dequant_s, layout_b}; + + static constexpr quant_info quant_info{ + quant_mode::I4_ASYM, Test::dequant_s, layout_b}; using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index 69fdfc1fe..0597af758 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -622,7 +622,7 @@ void dequantize_gemm_run(int iter) { perf_tuning_knob_t; static constexpr quant_info quant_info{ - quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b}; + quant_mode::I4_SYM, Test::dequant_s, layout_b}; using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, @@ -1043,4 +1043,4 @@ REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd); INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemm_act_shuf_test_suite, dequantize_gemm_act_shuf_test, - tests); \ No newline at end of file + tests); diff --git a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp index 1c42454df..0cc9d8d6f 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp @@ -388,7 +388,7 @@ void dequantize_gemm_run(int iter) { using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; static constexpr quant_info quant_info{ - quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b}; + quant_mode::I4_SYM, Test::dequant_s, layout_b}; using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 04f79c862..c8d819722 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -27,6 +27,7 @@ constexpr int ITER = 200; #endif constexpr size_t UNDEFINED_DATA_SIZE = 1024; +template class test_col_major_1 { public: // Extract the parameters required by different test cases @@ -39,8 +40,8 @@ class test_col_major_1 { static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 512 / sg_m; static constexpr size_t dequant_s = 128; - // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM; - static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP; + // static constexpr quant_mode quant_mode = quant_mode::I4_ASYM; + static constexpr quant_mode quant_mode = quant_mode_; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -48,9 +49,9 @@ class test_col_major_1 { static constexpr mem_layout layout_b = mem_layout::col_major; static constexpr mma_engine mma_eng = mma_engine::fpu; static constexpr gpu_arch arch = gpu_arch::XeLpg; - using data_type_a = fp16; + using data_type_a = scalar_t; using data_type_b = int4x8; - using data_type_c = fp16; + using data_type_c = scalar_t; }; class test_col_major_2 { public: @@ -120,7 +121,7 @@ int gemm_result_validate( } template < - quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, + quant_mode quant_mode = quant_mode::I4_SYM, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, @@ -131,13 +132,19 @@ std::vector convert_int4( data_type_zero_pt zero_pt) { std::vector dequant_fp16(sizeof(data_type_b) * 2); - int8_t zero_pt_i8 = zero_pt & 0xf; + int8_t zero_pt_i8; + if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO) + zero_pt_i8 = zero_pt & 0xf; for (uint32_t i = 0; i < dequant_fp16.size(); i++) { int8_t dequant_8bit = data_b & 0xf; - if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (quant_mode == quant_mode::I4_SYM) { dequant_fp16[i] = scale * (dequant_8bit - 8); - } else { + } else if constexpr (quant_mode == quant_mode::I4_ASYM) { dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); + } else if constexpr (quant_mode == quant_mode::I4_ASYM_FP_ZERO) { + dequant_fp16[i] = scale * (dequant_8bit - 8) + zero_pt; + } else { + assert(0); } data_b = data_b >> 4; } @@ -147,7 +154,7 @@ std::vector convert_int4( template < size_t dequant_s, mem_layout layout_b = mem_layout::col_major, - quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, + quant_mode quant_mode = quant_mode::I4_SYM, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, @@ -169,15 +176,17 @@ std::vector dequantize_weight( for (uint32_t j = 0; j < width; j += step) { int start_b_in = i * width + j; int start_scale_in = start_b_in / step; - int start_zero_pt_in = - (j / step) * (matrix_n / pack_radio) + i / pack_radio; + int start_zero_pt_in = quant_mode == quant_mode::I4_ASYM_FP_ZERO + ? (j / step) * matrix_n + i + : (j / step) * (matrix_n / pack_radio) + i / pack_radio; int start_out = layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; + data_type_zero_pt zp_value = zero_pt[start_zero_pt_in]; + if constexpr (quant_mode != quant_mode::I4_ASYM_FP_ZERO) + zp_value = zp_value >> (4 * (i % pack_radio)); for (uint32_t jj = 0; jj < step; jj++) { std::vector dequant_fp16 = convert_int4( - b[start_b_in + jj], - scale[start_scale_in], - zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio))); + b[start_b_in + jj], scale[start_scale_in], zp_value); for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) { b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj]; } @@ -215,7 +224,10 @@ void dequantize_gemv_run(int iter) { using data_type_a = typename Test::data_type_a; using data_type_b = typename Test::data_type_b; using data_type_c = typename Test::data_type_c; - using data_type_zero_pt = data_type_b; + using data_type_zero_pt = std::conditional_t< + Test::quant_mode == quant_mode::I4_ASYM_FP_ZERO, + data_type_c, + data_type_b>; using data_type_scale = fp16; using data_type_acc_in = fp16; using data_type_acc = float; @@ -225,7 +237,7 @@ void dequantize_gemv_run(int iter) { constexpr mem_layout layout_b = Test::layout_b; constexpr size_t size_a = matrix_m * matrix_k; - constexpr size_t size_b = matrix_k * matrix_n / (2 * sizeof(data_type_b)); + constexpr size_t size_b = matrix_k * matrix_n / 2; constexpr size_t size_scale_k = matrix_k / dequant_s; constexpr size_t size_scale_n = matrix_n; @@ -234,7 +246,9 @@ void dequantize_gemv_run(int iter) { constexpr size_t size_zero_pt_k = matrix_k / dequant_s; constexpr size_t size_zero_pt_n = matrix_n; constexpr size_t size_zero_pt = - size_zero_pt_k * size_zero_pt_n / (2 * sizeof(data_type_b)); + Test::quant_mode != quant_mode::I4_ASYM_FP_ZERO + ? size_zero_pt_k * size_zero_pt_n / 2 + : size_zero_pt_k * size_zero_pt_n; constexpr size_t size_c = matrix_m * matrix_n; constexpr size_t size_bias = matrix_n; @@ -405,16 +419,18 @@ void dequantize_gemv_run(int iter) { scale_h[i] = INFINITY; } for (unsigned i = 0; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { zero_pt_h[i] = random_uint8(); #ifdef UT_DEBUG zero_pt_h[i] = 0x12 << i; #endif - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { zero_pt_h[i] = random_uint32(); #ifdef UT_DEBUG zero_pt_h[i] = 0x33333333; #endif + } else if constexpr (std::is_same_v) { + zero_pt_h[i] = random_float(); } } @@ -474,7 +490,7 @@ void dequantize_gemv_run(int iter) { // It accepts the base pointer to matrix D, and its dimensions {bias_d, bias_add_shape}}); typename gemm_op_t::template arguments_t gemm_arg; - if constexpr (compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (compute_policy::quant_mode == quant_mode::I4_SYM) { gemm_arg = typename gemm_op_t::template arguments_t( matrix_m, @@ -491,7 +507,9 @@ void dequantize_gemv_run(int iter) { Acc_d, Cnt_d, epilogue_args); - } else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) { + } else if constexpr ( + compute_policy::quant_mode == quant_mode::I4_ASYM || + compute_policy::quant_mode == quant_mode::I4_ASYM_FP_ZERO) { gemm_arg = typename gemm_op_t::template arguments_t( matrix_m, @@ -551,9 +569,11 @@ void dequantize_gemv_run(int iter) { // performance prof.print_profiling_result(profiling_selector::GPU); // check result - std::vector dequantize_b = - dequantize_weight( - matrix_k, matrix_n, B_h, scale_h, zero_pt_h); + std::vector dequantize_b = dequantize_weight< + dequant_s, + layout_b, + compute_policy::quant_mode, + data_type_c>(matrix_k, matrix_n, B_h, scale_h, zero_pt_h); queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); ASSERT_EQ( @@ -585,6 +605,12 @@ void dequantize_gemv_run(int iter) { free(Cnt_d, context); } +// Placeholder for void test param +template <> +void dequantize_gemv_run(int) { + GTEST_SKIP(); +} + template class dequantize_gemv_test : public ::testing::Test {}; TYPED_TEST_SUITE_P(dequantize_gemv_test); @@ -594,7 +620,13 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd); -using tests = ::testing::Types; +using tests = ::testing::Types< // + test_col_major_1, + test_col_major_1, + test_col_major_1, + test_col_major_1, + // test_col_major_2, + void>; INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemv_test_suite,