Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,19 @@ struct memory : public handle<dnnl_memory_t> {
ABc4a4b = dnnl_ABc4a4b,
aBc16b = dnnl_aBc16b,
aBc32b = dnnl_aBc32b,
aBC8c8b2c = dnnl_aBC8c8b2c,
aBC8c16b2c = dnnl_aBC8c16b2c,
aBC8c24b2c = dnnl_aBC8c24b2c,
aBC8c32b2c = dnnl_aBC8c32b2c,
aBC8c64b2c = dnnl_aBC8c64b2c,
aBC16c16b2c = dnnl_aBC16c16b2c,
aBC16c32b2c = dnnl_aBC16c32b2c,
aBC16c48b2c = dnnl_aBC16c48b2c,
aBC16c64b2c = dnnl_aBC16c64b2c,
aBC16c16b4c = dnnl_aBC16c16b4c,
aBC16c32b4c = dnnl_aBC16c32b4c,
aBC16c48b4c = dnnl_aBC16c48b4c,
aBC16c64b4c = dnnl_aBC16c64b4c,
ABc16b16a = dnnl_ABc16b16a,
AcB16b16a = dnnl_AcB16b16a,
ABc16b32a = dnnl_ABc16b32a,
Expand Down
13 changes: 13 additions & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,19 @@ typedef enum {
dnnl_aCB16b32c,
dnnl_aCB16b48c,
dnnl_aCB16b64c,
dnnl_aBC8c8b2c,
dnnl_aBC8c16b2c,
dnnl_aBC8c24b2c,
dnnl_aBC8c32b2c,
dnnl_aBC8c64b2c,
dnnl_aBC16c16b2c,
dnnl_aBC16c32b2c,
dnnl_aBC16c48b2c,
dnnl_aBC16c64b2c,
dnnl_aBC16c16b4c,
dnnl_aBC16c32b4c,
dnnl_aBC16c48b4c,
dnnl_aBC16c64b4c,
dnnl_aCB16b16c2b,
dnnl_aCB16b32c2b,
dnnl_aCB16b48c2b,
Expand Down
13 changes: 13 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,19 @@ const format_tag_t AB16b16a2b = dnnl_AB16b16a2b;
const format_tag_t AB16b32a2b = dnnl_AB16b32a2b;
const format_tag_t AB16b48a2b = dnnl_AB16b48a2b;
const format_tag_t AB16b64a2b = dnnl_AB16b64a2b;
const format_tag_t aBC8c8b2c = dnnl_aBC8c8b2c;
const format_tag_t aBC8c16b2c = dnnl_aBC8c16b2c;
const format_tag_t aBC8c24b2c = dnnl_aBC8c24b2c;
const format_tag_t aBC8c32b2c = dnnl_aBC8c32b2c;
const format_tag_t aBC8c64b2c = dnnl_aBC8c64b2c;
const format_tag_t aBC16c16b2c = dnnl_aBC16c16b2c;
const format_tag_t aBC16c32b2c = dnnl_aBC16c32b2c;
const format_tag_t aBC16c48b2c = dnnl_aBC16c48b2c;
const format_tag_t aBC16c64b2c = dnnl_aBC16c64b2c;
const format_tag_t aBC16c16b4c = dnnl_aBC16c16b4c;
const format_tag_t aBC16c32b4c = dnnl_aBC16c32b4c;
const format_tag_t aBC16c48b4c = dnnl_aBC16c48b4c;
const format_tag_t aBC16c64b4c = dnnl_aBC16c64b4c;
const format_tag_t BA4b4a = dnnl_BA4b4a;
const format_tag_t BA8b4a = dnnl_BA8b4a;
const format_tag_t BA16a16b = dnnl_BA16a16b;
Expand Down
13 changes: 13 additions & 0 deletions src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,19 @@ status_t process_tag(F f, format_tag_t tag, Args&&... args) {
C(aBC32b32c, {0, 1, 2}, {32, 32}, {1, 2});
C(aBC48b16c, {0, 1, 2}, {48, 16}, {1, 2});
C(aBC48b32c, {0, 1, 2}, {48, 32}, {1, 2});
C(aBC8c8b2c, {0, 1, 2}, {8, 8, 2}, {2, 1, 2});
C(aBC8c16b2c, {0, 1, 2}, {8, 16, 2}, {2, 1, 2});
C(aBC8c24b2c, {0, 1, 2}, {8, 24, 2}, {2, 1, 2});
C(aBC8c32b2c, {0, 1, 2}, {8, 32, 2}, {2, 1, 2});
C(aBC8c64b2c, {0, 1, 2}, {8, 64, 2}, {2, 1, 2});
C(aBC16c16b2c, {0, 1, 2}, {16, 16, 2}, {2, 1, 2});
C(aBC16c32b2c, {0, 1, 2}, {16, 32, 2}, {2, 1, 2});
C(aBC16c48b2c, {0, 1, 2}, {16, 48, 2}, {2, 1, 2});
C(aBC16c64b2c, {0, 1, 2}, {16, 64, 2}, {2, 1, 2});
C(aBC16c16b4c, {0, 1, 2}, {16, 16, 4}, {2, 1, 2});
C(aBC16c32b4c, {0, 1, 2}, {16, 32, 4}, {2, 1, 2});
C(aBC16c48b4c, {0, 1, 2}, {16, 48, 4}, {2, 1, 2});
C(aBC16c64b4c, {0, 1, 2}, {16, 64, 4}, {2, 1, 2});
C(aCB4c8b8c2b, {0, 2, 1}, {4, 8, 8, 2}, {2, 1, 2, 1});
C(aCB4c8b8c4b, {0, 2, 1}, {4, 8, 8, 4}, {2, 1, 2, 1});
C(aCB4c8b16c2b, {0, 2, 1}, {4, 8, 16, 2}, {2, 1, 2, 1});
Expand Down
25 changes: 22 additions & 3 deletions src/common/tag_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ enum class inner_blk_t {
_8c2b,
_8c4b,
_8c8b,
_8c8b2c,
_8c24b2c,
_8c32b2c,
_8c64b2c,
_16a16b,
_16a32b,
_16a48b,
Expand Down Expand Up @@ -308,10 +312,11 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) {
: (f == ib::_16b16a2b || f == ib::_16c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2
: (f == ib::_16b16a4b || f == ib::_16c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4
: (f == ib::_8b8a2b) ? (x1 / 2) * 16 + x0 * 2 + x1 % 2
: (f == ib::_8c8b2c) ? (x1 / 2) * 16 + x0 * 2 + x1 % 2
: (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2
: (f == ib::_8b24a2b) ? (x1 / 2) * 48 + x0 * 2 + x1 % 2
: (f == ib::_8b32a2b) ? (x1 / 2) * 64 + x0 * 2 + x1 % 2
: (f == ib::_8b64a2b) ? (x1 / 2) * 128 + x0 * 2 + x1 % 2
: (f == ib::_8b24a2b || f == ib::_8c24b2c) ? (x1 / 2) * 48 + x0 * 2 + x1 % 2
: (f == ib::_8b32a2b || f == ib::_8c32b2c) ? (x1 / 2) * 64 + x0 * 2 + x1 % 2
: (f == ib::_8b64a2b || f == ib::_8c64b2c) ? (x1 / 2) * 128 + x0 * 2 + x1 % 2
: (f == ib::_2b4c2b || f == ib::_2c4b2c) ? (x0 / 2) * 8 + x1 * 2 + x0 % 2
: (f == ib::_4b8c2b || f == ib::_4c8b2c) ? (x0 / 2) * 16 + x1 * 2 + x0 % 2
: (f == ib::_2a8b8a2b || f == ib::_2b8c8b2c) ? (x0 / 8) * 128 + (x1 / 2) * 16 + (x0 % 8) * 2 + x1 % 2
Expand Down Expand Up @@ -472,6 +477,20 @@ DECL_TRAITS(BA16a16b4a, _AB, _16a16b4a, 2);
DECL_TRAITS(BA16a32b4a, _AB, _16a32b4a, 2);
DECL_TRAITS(BA16a48b4a, _AB, _16a48b4a, 2);
DECL_TRAITS(BA16a64b4a, _AB, _16a64b4a, 2);

DECL_TRAITS(aBC8c8b2c, _BC, _8c8b2c, 3);
DECL_TRAITS(aBC8c16b2c, _BC, _8c16b2c, 3);
DECL_TRAITS(aBC8c24b2c, _BC, _8c24b2c, 3);
DECL_TRAITS(aBC8c32b2c, _BC, _8c32b2c, 3);
DECL_TRAITS(aBC8c64b2c, _BC, _8c64b2c, 3);
DECL_TRAITS(aBC16c16b2c, _BC, _16c16b2c, 3);
DECL_TRAITS(aBC16c32b2c, _BC, _16c32b2c, 3);
DECL_TRAITS(aBC16c48b2c, _BC, _16c48b2c, 3);
DECL_TRAITS(aBC16c64b2c, _BC, _16c64b2c, 3);
DECL_TRAITS(aBC16c16b4c, _BC, _16c16b4c, 3);
DECL_TRAITS(aBC16c32b4c, _BC, _16c32b4c, 3);
DECL_TRAITS(aBC16c48b4c, _BC, _16c48b4c, 3);
DECL_TRAITS(aBC16c64b4c, _BC, _16c64b4c, 3);
DECL_TRAITS(aCB16b16c, _BC, _16b16c, 2);
DECL_TRAITS(aCB16b32c, _BC, _16b32c, 2);
DECL_TRAITS(aCB16b48c, _BC, _16b48c, 2);
Expand Down
13 changes: 13 additions & 0 deletions src/cpu/reorder/cpu_reorder_regular_s4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ const impl_list_map_t &regular_s4_impl_list_map() {
REG_SR(s4, any, s4, OI16i32o2i, fmt_order_keep)
REG_SR(s4, any, s4, OI16i48o2i, fmt_order_keep)
REG_SR(s4, any, s4, OI16i64o2i, fmt_order_keep)
REG_SR(s4, any, s4, aBC8c8b2c, fmt_order_keep)
REG_SR(s4, any, s4, aBC8c16b2c, fmt_order_keep)
REG_SR(s4, any, s4, aBC8c24b2c, fmt_order_keep)
REG_SR(s4, any, s4, aBC8c32b2c, fmt_order_keep)
REG_SR(s4, any, s4, aBC8c64b2c, fmt_order_keep)
REG_SR(s4, any, s4, aBC16c16b2c, fmt_order_keep)
REG_SR(s4, any, s4, aBC16c32b2c, fmt_order_keep)
REG_SR(s4, any, s4, aBC16c48b2c, fmt_order_keep)
REG_SR(s4, any, s4, aBC16c64b2c, fmt_order_keep)
REG_SR(s4, any, s4, aBC16c16b4c, fmt_order_keep)
REG_SR(s4, any, s4, aBC16c32b4c, fmt_order_keep)
REG_SR(s4, any, s4, aBC16c48b4c, fmt_order_keep)
REG_SR(s4, any, s4, aBC16c64b4c, fmt_order_keep)
REG_SR(s4, any, u8, any, fmt_order_keep, spec::reference)
REG_SR(s4, any, f32, any, fmt_order_keep, spec::reference)
REG_SR(s4, any, f32, any, fmt_order::any, spec::reference)
Expand Down
13 changes: 13 additions & 0 deletions src/cpu/reorder/cpu_reorder_regular_u4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ const impl_list_map_t &regular_u4_impl_list_map() {
REG_SR(u4, any, u4, OI16i32o4i, fmt_order_keep)
REG_SR(u4, any, u4, OI16i48o4i, fmt_order_keep)
REG_SR(u4, any, u4, OI16i64o4i, fmt_order_keep)
REG_SR(u4, any, u4, aBC8c8b2c, fmt_order_keep)
REG_SR(u4, any, u4, aBC8c16b2c, fmt_order_keep)
REG_SR(u4, any, u4, aBC8c24b2c, fmt_order_keep)
REG_SR(u4, any, u4, aBC8c32b2c, fmt_order_keep)
REG_SR(u4, any, u4, aBC8c64b2c, fmt_order_keep)
REG_SR(u4, any, u4, aBC16c16b2c, fmt_order_keep)
REG_SR(u4, any, u4, aBC16c32b2c, fmt_order_keep)
REG_SR(u4, any, u4, aBC16c48b2c, fmt_order_keep)
REG_SR(u4, any, u4, aBC16c64b2c, fmt_order_keep)
REG_SR(u4, any, u4, aBC16c16b4c, fmt_order_keep)
REG_SR(u4, any, u4, aBC16c32b4c, fmt_order_keep)
REG_SR(u4, any, u4, aBC16c48b4c, fmt_order_keep)
REG_SR(u4, any, u4, aBC16c64b4c, fmt_order_keep)
REG_SR(u4, any, u8, any, fmt_order_keep, spec::reference)
REG_SR(u4, any, f32, any, fmt_order_keep, spec::reference)
REG_SR(u4, any, f32, any, fmt_order::any, spec::reference)
Expand Down
131 changes: 131 additions & 0 deletions src/cpu/reorder/simple_reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,137 @@ typename utils::enable_if<tag_i == format_tag::any &&
}
};

template <SIMPLE_REORDER_TEMPL_DECL>
struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
typename utils::enable_if<tag_i == format_tag::any
&& tag_traits_t<tag_o>::block_dims == bd::_BC
&& utils::one_of(type_i, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)
&& type_i == type_o>::type> {
static status_t is_applicable(const memory_desc_wrapper &input_d,
const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
if (!(!input_d.has_runtime_dims_or_strides()
&& simple_attr_check(attr, false, true)
&& (order_keep ? output_d.matches_tag(tag_o) && input_d.is_plain()
: input_d.matches_tag(tag_o) && output_d.is_plain())))
return status::invalid_arguments;

if (output_d.blocking_desc().inner_nblks != 3
|| !utils::one_of(output_d.blocking_desc().inner_blks[2], 2, 4)
|| output_d.blocking_desc().inner_idxs[2] != 2)
return status::invalid_arguments;

return status::success;
}

GET_SCRATCHPAD_SIZE_ZERO();

static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
DECLARE_COMMON_PARAMS();

int blksize_b = 1;
int blksize_c = 1;

for (int i = 0; i < output_d.blocking_desc().inner_nblks; i++) {
if (output_d.blocking_desc().inner_idxs[i] == 1)
blksize_b *= output_d.blocking_desc().inner_blks[i];
else if (output_d.blocking_desc().inner_idxs[i] == 2)
blksize_c *= output_d.blocking_desc().inner_blks[i];
}

const auto &dims = input_d.dims();
const auto &pdims
= order_keep ? output_d.padded_dims() : input_d.padded_dims();

const int A = dims[0];
const int B = dims[1];
const int NB_B = pdims[1] / blksize_b;
const int C = dims[2];
const int NB_C = pdims[2] / blksize_c;

const int i_mult_b = blksize_b;
const int i_mult_c = blksize_c;

auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t {
uint8_t shift = high_half ? 4 : 0;
return (uint8_t)((val >> shift) & 0x000F);
};

auto insert_half_byte
= [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t {
uint8_t shift = high_half ? 0 : 4;
return dst | (uint8_t)(val << shift);
};

if (output_d.blocking_desc().inner_blks[2] == 4) {
parallel_nd(A, NB_B, NB_C, [&](int a, int nb_b, int nb_c) {
const int b_block = nstl::min(blksize_b, B - nb_b * blksize_b);
const int c_block = nstl::min(blksize_c, C - nb_c * blksize_c);

for (int cb = 0; cb < utils::div_up(c_block, 8); ++cb) {
for (int b = 0; b < b_block; ++b) {
const int c_int_block = nstl::min(8, c_block - cb * 8);
for (int c = 0; c < c_int_block; ++c) {
size_t iidx
= a * input_d.blocking_desc().strides[0]
+ (i_mult_b * nb_b + b)
* input_d.blocking_desc().strides[1]
+ (i_mult_c * nb_c + cb * 8 + c)
* input_d.blocking_desc().strides[2];
size_t oidx = output_d.blk_off<false>(a, nb_b, nb_c)
+ cb * blksize_b * 8 + b * 8 + 2 * (c % 4)
+ c / 4;
const uint8_t *packed_val
= reinterpret_cast<const uint8_t *>(input);
auto src_val = extract_half_byte(
packed_val[iidx / 2], (uint8_t)(iidx % 2));
uint8_t *output_val
= reinterpret_cast<uint8_t *>(output);
uint8_t dst_val = oidx % 2 == 0 ? 0
: output_val[oidx / 2];
dst_val = insert_half_byte(
dst_val, src_val, (uint8_t)(oidx % 2));
output_val[oidx / 2] = dst_val;
}
}
}
});
} else {
parallel_nd(A, NB_B, NB_C, [&](int a, int nb_b, int nb_c) {
const int b_block = nstl::min(blksize_b, B - nb_b * blksize_b);
const int c_block = nstl::min(blksize_c, C - nb_c * blksize_c);

for (int cb = 0; cb < utils::div_up(c_block, 2); ++cb) {
for (int b = 0; b < b_block; ++b) {
for (int c = 0; c < 2; ++c) {
size_t iidx
= a * input_d.blocking_desc().strides[0]
+ (i_mult_b * nb_b + b)
* input_d.blocking_desc().strides[1]
+ (i_mult_c * nb_c + cb * 2 + c)
* input_d.blocking_desc().strides[2];
size_t oidx = output_d.blk_off<false>(a, nb_b, nb_c)
+ cb * blksize_b * 2 + b * 2 + c;
const uint8_t *packed_val
= reinterpret_cast<const uint8_t *>(input);
auto src_val = extract_half_byte(
packed_val[iidx / 2], (uint8_t)(iidx % 2));
uint8_t *output_val
= reinterpret_cast<uint8_t *>(output);
uint8_t dst_val
= c == 1 ? output_val[oidx / 2] : 0;
dst_val = insert_half_byte(
dst_val, src_val, (uint8_t)(oidx % 2));
output_val[oidx / 2] = dst_val;
}
}
}
});
}

return status::success;
}
};

template <SIMPLE_REORDER_TEMPL_DECL>
struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
typename utils::enable_if<tag_i == format_tag::any
Expand Down
Loading