Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

XeTLA INT4 With BF16 Support #311

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/common/core/common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };

enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };

enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 };
enum class quant_mode : uint8_t { I4_ASYM = 0, I4_SYM = 1 };

struct quant_info {
quant_mode quant_mode;
Expand Down
13 changes: 13 additions & 0 deletions include/common/core/explicit_conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ xetla_cvt(xetla_vector<T_src, N> src) {
return dst;
}

/// @brief xetla explicit data conversion, bf16->fp16.
/// @tparam T_dst is the float16 data type.
/// @tparam T_src is the bfloat16 data type.
/// @tparam N is the element number in xetla_vector.
template <typename T_dst, typename T_src, int N>
__XETLA_API typename std::enable_if_t<
std::is_same<T_dst, fp16>::value && std::is_same<T_src, bf16>::value,
xetla_vector<T_dst, N>>
xetla_cvt(xetla_vector<T_src, N> src) {
xetla_vector<T_dst, N> dst = src;
return dst;
}

/// @brief xetla explicit data conversion, bf16->fp32.
/// @tparam T_dst is the bfloat16 data type.
/// @tparam T_src is the float32 data type.
Expand Down
18 changes: 6 additions & 12 deletions include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,7 @@ class gemm_t<
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
scale_prefetch_payload);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
zero_pt_prefetch_payload);
Expand All @@ -534,8 +533,7 @@ class gemm_t<
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
scale_prefetch_payload.template update_tdesc<update_dir_b>(
scale_t::tile_size_y);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
zero_pt_prefetch_payload
.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand Down Expand Up @@ -564,8 +562,7 @@ class gemm_t<
// matB, matB_payload);
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
scale, scale_payload);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
zero_pt, zero_pt_payload);
}
Expand All @@ -579,8 +576,7 @@ class gemm_t<
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
scale_prefetch_payload);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
// TODO 1D prefetch need pack to U32/U64
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
zero_pt_prefetch_payload);
Expand All @@ -593,8 +589,7 @@ class gemm_t<
if (tile_k_idx % scale_addr_update_freq == 0) {
scale_payload.template update_tdesc<update_dir_b>(scale_t::tile_size_y);
}
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
if (tile_k_idx % zero_pt_addr_update_freq == 0) {
zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand All @@ -608,8 +603,7 @@ class gemm_t<
if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) {
scale_prefetch_payload.template update_tdesc<tdesc_update_dir::y_dir>(
scale_t::tile_size_y);
if constexpr (
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
zero_pt_prefetch_payload
.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class gemm_universal_t<
/// @brief GEMM arguments.
/// This is the interface for users to pass the application-related runtime
/// variables.
template <quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP>
template <quant_mode quant_mode = quant_mode::I4_SYM>
struct arguments_t {
/// @brief Is the size of the m dimension of the matrix multiplication (m x
/// k x n).
Expand Down Expand Up @@ -295,7 +295,7 @@ class gemm_universal_t<
}
};
template <>
struct arguments_t<quant_mode::S4_FULLRANGE_NO_ZP> {
struct arguments_t<quant_mode::I4_SYM> {
/// @brief Is the size of the m dimension of the matrix multiplication (m x
/// k x n).
uint32_t matrix_m;
Expand Down Expand Up @@ -526,6 +526,10 @@ class gemm_universal_t<
template <quant_mode quant_mode>
static bool can_implement(arguments_t<quant_mode>& args) {
bool implementable = true;
if (arch_tag == gpu_arch::XeLpg) {
implementable &= !std::is_same_v<dtype_a, bf16>; // XeLpg arch dosen't
// have bf16 related isa.
}
if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
if (gemm_t::msg_type_a == msg_type::block_2d) {
implementable &= kernel::block_2d<arch_tag, dtype_a>::check_tensor(
Expand Down Expand Up @@ -566,8 +570,7 @@ class gemm_universal_t<
// check for int4x2
implementable &=
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
if constexpr (
gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (gemm_t::compute_policy::quant_mode != quant_mode::I4_SYM) {
implementable &= (args.zero_pt_ld % pack_ratio == 0);
}

Expand Down Expand Up @@ -664,8 +667,7 @@ class gemm_universal_t<
uint32_t inner_loop_start = (start_k + k_stride - 1) / k_stride;
uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
gemm_args_t gemm_args;
if constexpr (
gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (gemm_t::compute_policy::quant_mode == quant_mode::I4_SYM) {
gemm_args = gemm_args_t(
mem_desc_a,
mem_desc_b,
Expand Down
4 changes: 2 additions & 2 deletions include/subgroup/tile/impl/tile_op_functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct dequant_int4_weight_t {
(offset_y_in_tile) / dequant_s * scale_t::block_size_x +
offset_x_in_tile;

if constexpr (quant_mode == quant_mode::S4_ASYM) {
if constexpr (quant_mode == quant_mode::I4_ASYM) {
uint32_t zero_pt_idx =
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
offset_x_in_tile / pack_ratio;
Expand All @@ -149,7 +149,7 @@ struct dequant_int4_weight_t {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
zero_pt_i8;
} else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
} else if constexpr (quant_mode == quant_mode::I4_SYM) {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
int8_t(8);
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/gemm/int4_dequantization/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ void dequantize_gemm_run(uint32_t iter) {
compute_attr_t<data_type_acc_in, data_type_acc_in, data_type_acc>;
using perf_tuning_knob = xetla::group::
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;

static constexpr quant_info quant_info{quant_mode::S4_ASYM, Test::dequant_s, layout_b};

static constexpr quant_info quant_info{
quant_mode::I4_ASYM, Test::dequant_s, layout_b};

using compute_policy = xetla::group::compute_policy_int4_dequantize<
compute_attr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ void dequantize_gemm_run(int iter) {
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;

static constexpr quant_info quant_info{
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
quant_mode::I4_SYM, Test::dequant_s, layout_b};

using compute_policy = xetla::group::compute_policy_int4_dequantize<
compute_attr,
Expand Down Expand Up @@ -1043,4 +1043,4 @@ REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd);
INSTANTIATE_TYPED_TEST_SUITE_P(
dequantize_gemm_act_shuf_test_suite,
dequantize_gemm_act_shuf_test,
tests);
tests);
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ void dequantize_gemm_run(int iter) {
using perf_tuning_knob = xetla::group::
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
static constexpr quant_info quant_info{
quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b};
quant_mode::I4_SYM, Test::dequant_s, layout_b};

using compute_policy = xetla::group::compute_policy_int4_dequantize<
compute_attr,
Expand Down
45 changes: 29 additions & 16 deletions tests/integration/gemv/int4/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ constexpr int ITER = 200;
#endif
constexpr size_t UNDEFINED_DATA_SIZE = 1024;

template <typename scalar_t>
class test_col_major_1 {
public:
// Extract the parameters required by different test cases
Expand All @@ -39,18 +40,18 @@ class test_col_major_1 {
static constexpr size_t sg_n = 1;
static constexpr size_t sg_k = 512 / sg_m;
static constexpr size_t dequant_s = 128;
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
// static constexpr quant_mode quant_mode = quant_mode::I4_ASYM;
static constexpr quant_mode quant_mode = quant_mode::I4_SYM;

static constexpr size_t local_kslicing = 1;
static constexpr size_t global_kslicing = 1;
static constexpr mem_layout layout_a = mem_layout::row_major;
static constexpr mem_layout layout_b = mem_layout::col_major;
static constexpr mma_engine mma_eng = mma_engine::fpu;
static constexpr gpu_arch arch = gpu_arch::XeLpg;
using data_type_a = fp16;
using data_type_a = scalar_t;
using data_type_b = int4x8;
using data_type_c = fp16;
using data_type_c = scalar_t;
};
class test_col_major_2 {
public:
Expand Down Expand Up @@ -120,7 +121,7 @@ int gemm_result_validate(
}

template <
quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP,
quant_mode quant_mode = quant_mode::I4_SYM,
typename data_type_acc_in = fp16,
typename data_type_b,
typename data_type_scale,
Expand All @@ -134,7 +135,7 @@ std::vector<fp16> convert_int4(
int8_t zero_pt_i8 = zero_pt & 0xf;
for (uint32_t i = 0; i < dequant_fp16.size(); i++) {
int8_t dequant_8bit = data_b & 0xf;
if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (quant_mode == quant_mode::I4_SYM) {
dequant_fp16[i] = scale * (dequant_8bit - 8);
} else {
dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
Expand All @@ -147,7 +148,7 @@ std::vector<fp16> convert_int4(
template <
size_t dequant_s,
mem_layout layout_b = mem_layout::col_major,
quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP,
quant_mode quant_mode = quant_mode::I4_SYM,
typename data_type_acc_in = fp16,
typename data_type_b,
typename data_type_scale,
Expand All @@ -173,11 +174,11 @@ std::vector<data_type_acc_in> dequantize_weight(
(j / step) * (matrix_n / pack_radio) + i / pack_radio;
int start_out =
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
zp_value = zp_value >> (4 * (i % pack_radio));
for (uint32_t jj = 0; jj < step; jj++) {
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
b[start_b_in + jj],
scale[start_scale_in],
zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio)));
b[start_b_in + jj], scale[start_scale_in], zp_value);
for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) {
b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj];
}
Expand Down Expand Up @@ -474,7 +475,7 @@ void dequantize_gemv_run(int iter) {
// It accepts the base pointer to matrix D, and its dimensions
{bias_d, bias_add_shape}});
typename gemm_op_t::template arguments_t<compute_policy::quant_mode> gemm_arg;
if constexpr (compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
if constexpr (compute_policy::quant_mode == quant_mode::I4_SYM) {
gemm_arg =
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
matrix_m,
Expand All @@ -491,7 +492,7 @@ void dequantize_gemv_run(int iter) {
Acc_d,
Cnt_d,
epilogue_args);
} else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
} else if constexpr (compute_policy::quant_mode == quant_mode::I4_ASYM) {
gemm_arg =
typename gemm_op_t::template arguments_t<compute_policy::quant_mode>(
matrix_m,
Expand Down Expand Up @@ -551,9 +552,11 @@ void dequantize_gemv_run(int iter) {
// performance
prof.print_profiling_result(profiling_selector::GPU);
// check result
std::vector<typename Test::data_type_a> dequantize_b =
dequantize_weight<dequant_s, layout_b, compute_policy::quant_mode>(
matrix_k, matrix_n, B_h, scale_h, zero_pt_h);
std::vector<typename Test::data_type_a> dequantize_b = dequantize_weight<
dequant_s,
layout_b,
compute_policy::quant_mode,
data_type_c>(matrix_k, matrix_n, B_h, scale_h, zero_pt_h);

queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait();
ASSERT_EQ(
Expand Down Expand Up @@ -585,6 +588,12 @@ void dequantize_gemv_run(int iter) {
free(Cnt_d, context);
}

// Placeholder for void test param
template <>
void dequantize_gemv_run<void>(int) {
GTEST_SKIP();
}

template <typename T>
class dequantize_gemv_test : public ::testing::Test {};
TYPED_TEST_SUITE_P(dequantize_gemv_test);
Expand All @@ -594,7 +603,11 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) {
}

REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd);
using tests = ::testing::Types<test_col_major_1>;
using tests = ::testing::Types< //
test_col_major_1<fp16>,
test_col_major_1<bf16>,
// test_col_major_2,
void>;

INSTANTIATE_TYPED_TEST_SUITE_P(
dequantize_gemv_test_suite,
Expand Down
Loading