Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 4c65a40

Browse files
committed
int4 with bf16 support
1 parent 7848595 commit 4c65a40

File tree

5 files changed

+43
-12
lines changed

5 files changed

+43
-12
lines changed

include/common/core/explicit_conv.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@ xetla_cvt(xetla_vector<T_src, N> src) {
6262
return dst;
6363
}
6464

65+
/// @brief xetla explicit data conversion, bf16->fp16.
66+
/// @tparam T_dst is the float16 data type.
67+
/// @tparam T_src is the bfloat16 data type.
68+
/// @tparam N is the element number in xetla_vector.
69+
template <typename T_dst, typename T_src, int N>
70+
__XETLA_API typename std::enable_if_t<
71+
std::is_same<T_dst, fp16>::value && std::is_same<T_src, bf16>::value,
72+
xetla_vector<T_dst, N>>
73+
xetla_cvt(xetla_vector<T_src, N> src) {
74+
xetla_vector<T_dst, N> dst = src;
75+
return dst;
76+
}
77+
6578
/// @brief xetla explicit data conversion, bf16->fp32.
6679
/// @tparam T_dst is the bfloat16 data type.
6780
/// @tparam T_src is the float32 data type.

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,10 @@ class gemm_universal_t<
526526
template <quant_mode quant_mode>
527527
static bool can_implement(arguments_t<quant_mode>& args) {
528528
bool implementable = true;
529+
if (arch_tag == gpu_arch::XeLpg) {
530+
implementable &= !std::is_same_v<dtype_a, bf16>; // XeLpg arch dosen't
531+
// have bf16 related isa.
532+
}
529533
if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
530534
if (gemm_t::msg_type_a == msg_type::block_2d) {
531535
implementable &= kernel::block_2d<arch_tag, dtype_a>::check_tensor(

tests/integration/gemm/int4_dequantization/main.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,9 @@ void dequantize_gemm_run(uint32_t iter) {
229229
compute_attr_t<data_type_acc_in, data_type_acc_in, data_type_acc>;
230230
using perf_tuning_knob = xetla::group::
231231
perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
232-
233-
static constexpr quant_info quant_info{quant_mode::S4_ASYM, Test::dequant_s, layout_b};
232+
233+
static constexpr quant_info quant_info{
234+
quant_mode::S4_ASYM, Test::dequant_s, layout_b};
234235

235236
using compute_policy = xetla::group::compute_policy_int4_dequantize<
236237
compute_attr,

tests/integration/gemm/int4_dequantization_bias/main_client.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1043,4 +1043,4 @@ REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd);
10431043
INSTANTIATE_TYPED_TEST_SUITE_P(
10441044
dequantize_gemm_act_shuf_test_suite,
10451045
dequantize_gemm_act_shuf_test,
1046-
tests);
1046+
tests);

tests/integration/gemv/int4/main.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ constexpr int ITER = 200;
2727
#endif
2828
constexpr size_t UNDEFINED_DATA_SIZE = 1024;
2929

30+
template <typename scalar_t>
3031
class test_col_major_1 {
3132
public:
3233
// Extract the parameters required by different test cases
@@ -48,9 +49,9 @@ class test_col_major_1 {
4849
static constexpr mem_layout layout_b = mem_layout::col_major;
4950
static constexpr mma_engine mma_eng = mma_engine::fpu;
5051
static constexpr gpu_arch arch = gpu_arch::XeLpg;
51-
using data_type_a = fp16;
52+
using data_type_a = scalar_t;
5253
using data_type_b = int4x8;
53-
using data_type_c = fp16;
54+
using data_type_c = scalar_t;
5455
};
5556
class test_col_major_2 {
5657
public:
@@ -173,11 +174,11 @@ std::vector<data_type_acc_in> dequantize_weight(
173174
(j / step) * (matrix_n / pack_radio) + i / pack_radio;
174175
int start_out =
175176
layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
177+
data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
178+
zp_value = zp_value >> (4 * (i % pack_radio));
176179
for (uint32_t jj = 0; jj < step; jj++) {
177180
std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
178-
b[start_b_in + jj],
179-
scale[start_scale_in],
180-
zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio)));
181+
b[start_b_in + jj], scale[start_scale_in], zp_value);
181182
for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) {
182183
b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj];
183184
}
@@ -551,9 +552,11 @@ void dequantize_gemv_run(int iter) {
551552
// performance
552553
prof.print_profiling_result(profiling_selector::GPU);
553554
// check result
554-
std::vector<typename Test::data_type_a> dequantize_b =
555-
dequantize_weight<dequant_s, layout_b, compute_policy::quant_mode>(
556-
matrix_k, matrix_n, B_h, scale_h, zero_pt_h);
555+
std::vector<typename Test::data_type_a> dequantize_b = dequantize_weight<
556+
dequant_s,
557+
layout_b,
558+
compute_policy::quant_mode,
559+
data_type_c>(matrix_k, matrix_n, B_h, scale_h, zero_pt_h);
557560

558561
queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait();
559562
ASSERT_EQ(
@@ -585,6 +588,12 @@ void dequantize_gemv_run(int iter) {
585588
free(Cnt_d, context);
586589
}
587590

591+
// Placeholder for void test param
592+
template <>
593+
void dequantize_gemv_run<void>(int) {
594+
GTEST_SKIP();
595+
}
596+
588597
template <typename T>
589598
class dequantize_gemv_test : public ::testing::Test {};
590599
TYPED_TEST_SUITE_P(dequantize_gemv_test);
@@ -594,7 +603,11 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) {
594603
}
595604

596605
REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd);
597-
using tests = ::testing::Types<test_col_major_1>;
606+
using tests = ::testing::Types< //
607+
test_col_major_1<fp16>,
608+
test_col_major_1<bf16>,
609+
// test_col_major_2,
610+
void>;
598611

599612
INSTANTIATE_TYPED_TEST_SUITE_P(
600613
dequantize_gemv_test_suite,

0 commit comments

Comments
 (0)