Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 4929d80

Browse files
authored
align with IPEX master 0921c6332e3d3e357b9849acc0893a63d9b34b4d ca3e7d24329483babdda0ebff3bca0204c15f735 bac2d0d759c483378bbb41138bf1dc3fe6010026 (#241)
1 parent a9d5cc9 commit 4929d80

File tree

12 files changed

+170
-59
lines changed

12 files changed

+170
-59
lines changed

include/common/core/arch_config.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace gpu::xetla {
2828

2929
template <msg_type message_type, gpu_arch arch_tag>
3030
struct load_store_attr_t {};
31+
3132
template <>
3233
struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
3334
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
@@ -75,12 +76,19 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpg>
7576
: public client_load_store_attr_base_t<
7677
msg_type::block_2d,
7778
gpu_arch::XeHpg> {};
79+
7880
template <>
7981
struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeLpg>
8082
: public client_load_store_attr_base_t<
8183
msg_type::block_2d,
8284
gpu_arch::XeLpg> {};
8385

86+
template <gpu_arch arch_tag>
87+
struct load_store_attr_t<msg_type::block_1d, arch_tag> {
88+
static constexpr uint32_t max_load_vec_len = 64;
89+
static constexpr uint32_t max_store_vec_len = 64;
90+
};
91+
8492
template <gpu_arch arch_tag>
8593
struct mma_attr_t {};
8694

include/common/core/base_types.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ concept xetla_matrix_ref = __ESIMD_NS::detail::is_simd_view_type_v<Ta> &&
232232

233233
} // namespace gpu::xetla
234234

235-
#if (__LIBSYCL_MAJOR_VERSION >= 7) && (__LIBSYCL_MINOR_VERSION >= 1)
235+
#if (__LIBSYCL_MAJOR_VERSION > 7) || \
236+
((__LIBSYCL_MAJOR_VERSION == 7) && (__LIBSYCL_MINOR_VERSION >= 1))
236237

237238
namespace sycl::detail {
238239
template <typename T>

include/common/core/common.hpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#pragma once
2121

2222
#include <CL/sycl.hpp>
23+
#include <common/core/common_types.hpp>
2324
#include <ext/intel/esimd.hpp>
2425
#include <version.hpp>
2526

@@ -70,15 +71,6 @@ __XETLA_API int32_t xetla_get_subdevice_id() {
7071
}
7172

7273
namespace gpu::xetla {
73-
74-
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 };
75-
inline constexpr bool arch_has_xmx(gpu_arch arch) {
76-
return arch >= gpu_arch::XeHpg;
77-
}
78-
79-
enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
80-
81-
enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };
8274
enum class mem_space : uint8_t { global = 0, local = 1 };
8375
enum class msg_type : uint8_t {
8476
block_2d = 0,

include/common/core/common_types.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2022-2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
/// @file
18+
/// C++ API
19+
20+
#pragma once
21+
#include <cstdint>
22+
23+
namespace gpu::xetla {
24+
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 };
25+
inline constexpr bool arch_has_xmx(gpu_arch arch) {
26+
return arch >= gpu_arch::XeHpg;
27+
}
28+
inline constexpr bool arch_has_2d_load_store(gpu_arch arch) {
29+
return arch >= gpu_arch::XeHpc;
30+
}
31+
32+
enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
33+
34+
enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };
35+
} // namespace gpu::xetla

include/experimental/common/base_types.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
#pragma once
2121

22-
#include <common/common.hpp>
23-
2422
namespace gpu::xetla {
2523

2624
/// @brief xetla 4bits data packed as 8bits data type.

include/experimental/common/common.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919

2020
#pragma once
2121

22-
#include <experimental/common/base_types.hpp>
22+
#include <common/common.hpp>
23+
#include <experimental/common/base_types.hpp>

include/group/gemm/impl/default_xmx_xe.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ class gemm_t<
125125

126126
/******** set tile **********/
127127
static constexpr reg_layout reg_layout_a = reg_layout::tiled;
128+
129+
public:
128130
using matA_tile_desc_t = subgroup::tile_desc_t<
129131
tile_size_x_a,
130132
tile_size_y_a,
@@ -165,7 +167,6 @@ class gemm_t<
165167
wg_size_y,
166168
arch_tag>;
167169

168-
public:
169170
using matAcc_tile_desc_t = subgroup::tile_desc_t<
170171
tile_size_x_c,
171172
tile_size_y_c,

include/group/gemm/impl/unaligned_xmx_xe.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ class gemm_t<
129129

130130
/******** set tile **********/
131131
static constexpr reg_layout reg_layout_a = reg_layout::tiled;
132+
133+
public:
132134
using matA_tile_desc_t = subgroup::tile_desc_t<
133135
tile_size_x_a,
134136
tile_size_y_a,
@@ -214,7 +216,6 @@ class gemm_t<
214216
wg_size_y,
215217
arch_tag>;
216218

217-
public:
218219
using matAcc_tile_desc_t = subgroup::tile_desc_t<
219220
tile_size_x_c,
220221
tile_size_y_c,

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ template <
7979
typename payload_t>
8080
__XETLA_API typename std::enable_if_t<
8181
detail::check_load_type<tile_t, payload_t>::is_global_block_2d &&
82-
payload_t::arch_tag == gpu_arch::XeHpc>
82+
arch_has_2d_load_store(payload_t::arch_tag)>
8383
tile_load(tile_t& tile, payload_t& payload) {
8484
using dtype = typename tile_t::dtype;
8585
using load_dtype = typename payload_t::mem_dtype;
@@ -405,23 +405,37 @@ tile_load(tile_t& tile, payload_t& payload) {
405405

406406
static constexpr uint32_t tile_size_x = tile_t::tile_size_x;
407407
static constexpr uint32_t scale_factor = payload_t::scale_factor;
408-
constexpr uint32_t load_len = tile_size_x / scale_factor;
408+
static constexpr uint32_t load_len = tile_size_x / scale_factor;
409+
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
410+
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
411+
static constexpr uint32_t max_load_vec_len =
412+
load_store_attr::max_load_vec_len;
409413

410-
if constexpr (load_len >= 64) {
414+
static constexpr uint32_t load_iter_steps = load_len / max_load_vec_len;
415+
if constexpr (load_len >= max_load_vec_len) {
411416
#pragma unroll
412-
for (uint32_t i = 0; i < load_len / 64; i++) {
413-
uint32_t offset_x = i * 64 * scale_factor;
414-
auto reg_sub = tile.reg.xetla_select<64 * scale_factor, 1>(offset_x);
417+
for (uint32_t i = 0; i < load_iter_steps; i++) {
418+
uint32_t offset_x = i * max_load_vec_len * scale_factor;
419+
auto reg_sub =
420+
tile.reg.xetla_select<max_load_vec_len * scale_factor, 1>(offset_x);
415421
uint32_t address_offset = offset_x * sizeof(dtype);
416-
reg_sub.xetla_format<load_dtype>() =
417-
xetla_load_global<load_dtype, 64, data_size::default_size, L1, L2>(
418-
payload.base_ptr, payload.base_offset + address_offset);
422+
reg_sub.xetla_format<load_dtype>() = xetla_load_global<
423+
load_dtype,
424+
max_load_vec_len,
425+
data_size::default_size,
426+
L1,
427+
L2>(payload.base_ptr, payload.base_offset + address_offset);
419428
}
420429
}
421-
constexpr uint32_t tail_len = load_len % 64;
422-
uint32_t tail_offset = load_len / 64 * 64 * scale_factor;
423-
detail::process_1d_tail<tail_len, 32, detail::process_flag::load, L1, L2>(
424-
tile, payload, tail_offset);
430+
431+
constexpr uint32_t tail_len = load_len % max_load_vec_len;
432+
uint32_t tail_offset = load_iter_steps * max_load_vec_len * scale_factor;
433+
detail::process_1d_tail<
434+
tail_len,
435+
(max_load_vec_len >> 1),
436+
detail::process_flag::load,
437+
L1,
438+
L2>(tile, payload, tail_offset);
425439
}
426440

427441
/// @brief This function loads data from unaligned-2D memory surface.
@@ -850,21 +864,33 @@ tile_load(tile_t& tile, payload_t& payload) {
850864
using load_dtype = typename payload_t::mem_dtype;
851865

852866
constexpr uint32_t scale_factor = payload_t::scale_factor;
853-
constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor;
854-
if constexpr (load_len >= 64) {
867+
static constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor;
868+
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
869+
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
870+
static constexpr uint32_t max_load_vec_len =
871+
load_store_attr::max_load_vec_len;
872+
873+
static constexpr uint32_t load_iter_steps = load_len / max_load_vec_len;
874+
875+
if constexpr (load_len >= max_load_vec_len) {
855876
#pragma unroll
856-
for (uint32_t j = 0; j < load_len / 64; j++) {
857-
uint32_t offset_x = j * 64 * scale_factor;
858-
auto reg_sub = tile.reg.xetla_select<64 * scale_factor, 1>(offset_x);
877+
for (uint32_t j = 0; j < load_iter_steps; j++) {
878+
uint32_t offset_x = j * max_load_vec_len * scale_factor;
879+
auto reg_sub =
880+
tile.reg.xetla_select<max_load_vec_len * scale_factor, 1>(offset_x);
859881
uint32_t address_offset = offset_x * sizeof(dtype);
860-
reg_sub.xetla_format<load_dtype>() =
861-
xetla_load_local<load_dtype, 64, data_size::default_size>(
862-
payload.address + address_offset);
882+
reg_sub.xetla_format<load_dtype>() = xetla_load_local<
883+
load_dtype,
884+
max_load_vec_len,
885+
data_size::default_size>(payload.address + address_offset);
863886
}
864887
}
865-
detail::
866-
process_1d_tail<load_len % 64, 32, detail::process_flag::load, L1, L2>(
867-
tile, payload, load_len / 64 * 64 * scale_factor);
888+
detail::process_1d_tail<
889+
load_len % max_load_vec_len,
890+
(max_load_vec_len >> 1),
891+
detail::process_flag::load,
892+
L1,
893+
L2>(tile, payload, load_iter_steps * max_load_vec_len * scale_factor);
868894
}
869895

870896
} // namespace gpu::xetla::subgroup

include/subgroup/tile/impl/mma_xe.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct tile_mma_t<
3838
matA_t_,
3939
mma_engine::xmx,
4040
arch_tag_,
41-
std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> {
41+
std::enable_if_t<arch_has_xmx(arch_tag_)>> {
4242
using matA_t = matA_t_;
4343
using matB_t = matB_t_;
4444
using matSrc_t = matAcc_src_t_;

0 commit comments

Comments
 (0)