Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 7848595

Browse files
authored
sync ipex 595b2e1 (#310)
1 parent 3043b5a commit 7848595

File tree

4 files changed

+21
-11
lines changed

4 files changed

+21
-11
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ if (${LOG} STREQUAL "on")
4646
endif ()
4747

4848
# For large registers mode, enable 256 registers for kernels
49-
# set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
49+
set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
5050
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt")
5151
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-codegen")
5252
# Enable bank conflict reduction.

include/experimental/group/gemm/compute_policy.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ struct compute_policy_int4_dequantize<
137137
quant_info_.weight_mem_layout == mem_layout::col_major;
138138

139139
static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16;
140-
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 128 : 32;
140+
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32;
141141
static constexpr uint32_t block_size_x_a =
142142
block_bytes_x_a / sizeof(dtype_mma_a);
143143
static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32;
144-
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 128 : 32;
144+
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32;
145145
static constexpr uint32_t block_size_y_b =
146146
block_bytes_y_b / sizeof(dtype_mma_b);
147147

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,12 +524,16 @@ tile_load(tile_t& tile, payload_t& payload) {
524524
}
525525
reg_sub
526526
.xetla_select<load_elems * pack_factor, 1>(
527-
sub_block_offset * tile_desc::block_size_x)
527+
sub_block_offset *
528+
(payload_t::mem_transpose ? tile_desc::block_size_y
529+
: tile_desc::block_size_x))
528530
.xetla_format<load_dtype>() = reg_tmp_trans;
529531
} else {
530532
reg_sub
531533
.xetla_select<load_elems * pack_factor, 1>(
532-
sub_block_offset * tile_desc::block_size_x)
534+
sub_block_offset *
535+
(payload_t::mem_transpose ? tile_desc::block_size_y
536+
: tile_desc::block_size_x))
533537
.xetla_format<load_dtype>() = reg_tmp;
534538
}
535539
}

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,8 +1659,10 @@ struct prefetch_payload_t<
16591659
arch_tag_,
16601660
std::enable_if_t<
16611661
arch_tag_ <= gpu_arch::XeHpg &&
1662-
((block_size_y_ != 1 && mem_layout_ == mem_layout::row_major) ||
1663-
(block_size_x_ != 1 && mem_layout_ == mem_layout::col_major))>> {
1662+
(((block_size_y_ != 1 || tile_size_y_ != 1) &&
1663+
mem_layout_ == mem_layout::row_major) ||
1664+
((block_size_x_ != 1 || tile_size_x_ != 1) &&
1665+
mem_layout_ == mem_layout::col_major))>> {
16641666
using dtype = native_type_t<dtype_>;
16651667
using mem_desc_t =
16661668
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
@@ -1885,8 +1887,10 @@ struct prefetch_payload_t<
18851887
arch_tag_,
18861888
std::enable_if_t<
18871889
(arch_tag_ == gpu_arch::XeHpc) &&
1888-
(((block_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
1889-
((block_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
1890+
(((block_size_y_ != 1 || tile_size_y_ != 1) &&
1891+
mem_layout_ == mem_layout::row_major) ||
1892+
((block_size_x_ != 1 || tile_size_x_ != 1) &&
1893+
mem_layout_ == mem_layout::col_major))>> {
18901894
using dtype = dtype_;
18911895
using mem_desc_t =
18921896
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
@@ -2176,8 +2180,10 @@ struct prefetch_payload_t<
21762180
num_coop_sg_,
21772181
arch_tag_,
21782182
std::enable_if_t<
2179-
((block_size_y_ == 1) && mem_layout_ == mem_layout::row_major) ||
2180-
((block_size_x_ == 1) && mem_layout_ == mem_layout::col_major)>> {
2183+
((block_size_y_ == 1 && tile_size_y_ == 1) &&
2184+
mem_layout_ == mem_layout::row_major) ||
2185+
((block_size_x_ == 1 && tile_size_x_ == 1) &&
2186+
mem_layout_ == mem_layout::col_major)>> {
21812187
using dtype = dtype_;
21822188
using mem_desc_t =
21832189
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;

0 commit comments

Comments
 (0)