Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit c4e8540

Browse files
sunjiweiswiftDDEleairMeng
authored
[GPU]Xetla support MTL (#176)
* bugfix * add AOT * Update examples/CMakeLists.txt Co-authored-by: Yi DING <[email protected]> * XETLA_PRINTF replace cout * SLMSIZE 128KB 64KB * Update include/subgroup/tile/impl/load_xe.hpp Co-authored-by: Meng, Hengyu <[email protected]> * use arch_attr_t * add more shape for int4 * bugfix dump_mat * save --------- Co-authored-by: Yi DING <[email protected]> Co-authored-by: Meng, Hengyu <[email protected]>
1 parent b3d8f65 commit c4e8540

File tree

25 files changed

+889
-121
lines changed

25 files changed

+889
-121
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,16 @@ if (${LOG} STREQUAL "on")
4343
add_definitions(-DLOG_PRINT)
4444
endif ()
4545

46+
# AOT device
47+
set(AOT_DEVICE "" CACHE STRING "Set device list for AOT build")
48+
4649
add_compile_options(-fsycl)
4750
add_link_options(-fsycl)
4851
if(UNIX)
52+
if (AOT_DEVICE)
53+
add_compile_options(-fsycl-targets=spir64_gen)
54+
add_link_options(-fsycl-targets=spir64_gen -Xs "-device ${AOT_DEVICE}") # MTL
55+
endif()
4956
add_compile_options(-fp-model=precise -Wall -Wextra -Werror)
5057
add_link_options(-lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lpthread -lm)
5158
link_libraries(-lgtest -lgtest_main)

examples/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ include_directories(${CMAKE_SOURCE_DIR}/include)
22
include_directories(${CMAKE_SOURCE_DIR})
33

44
# Creates a separate device code module for each SYCL* kernel
5-
# so that kernel for Dg2 and Xe will be JIT separately
5+
# so that kernel for XeHpc, XeHpg, and XeLpg will be JIT separately
66
add_compile_options(-fsycl-device-code-split=per_kernel)
77
add_link_options(-fsycl-device-code-split=per_kernel)
88

include/common/core/arch_config.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ struct register_attr_t {};
108108
template <grf_mode grf_num_mode, gpu_arch arch_tag>
109109
struct client_register_attr_base_t {
110110
static constexpr uint32_t acc_reg_in_bytes =
111-
(grf_num_mode == grf_mode::normal) ? 4 * 32 : 8 * 32;
111+
(grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64;
112112
static constexpr uint32_t grf_in_bytes =
113-
(grf_num_mode == grf_mode::normal) ? 128 * 32 : 256 * 32;
114-
static constexpr uint32_t reg_in_bytes = 32;
113+
(grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64;
114+
static constexpr uint32_t reg_in_bytes = 64;
115115
};
116116

117117
template <grf_mode grf_num_mode>
@@ -139,7 +139,7 @@ struct client_arch_attr_base_t {
139139
template <msg_type message_type = msg_type::block_2d>
140140
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpg>;
141141

142-
template <grf_mode grf_num_mode = grf_mode::normal>
142+
template <grf_mode grf_num_mode = grf_mode::double_grf>
143143
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpg>;
144144

145145
using mma_attr = mma_attr_t<gpu_arch::XeHpg>;

include/experimental/group/gemm/compute_policy.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ struct compute_policy_int4_dequantize<
8282

8383
static constexpr bool is_int4_matB_policy = true;
8484

85-
static constexpr uint32_t block_size_x_b =
86-
arch_attr_t<arch_tag>::mma_attr::mma_n_in_elem;
85+
static constexpr uint32_t block_size_x_b = (mma_engine == mma_engine::xmx)
86+
? arch_attr_t<arch_tag>::mma_attr::mma_n_in_elem
87+
: 32;
8788
static constexpr uint32_t block_bytes_y_b = 32;
8889
static_assert(
8990
block_bytes_x_a == block_bytes_y_b,

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -560,20 +560,6 @@ class gemm_t<
560560
}
561561

562562
private:
563-
template <typename T>
564-
void dump_mat(T mat, size_t tile_x, size_t tile_y) {
565-
#pragma unroll
566-
for (size_t row = 0; row < tile_x; row++) {
567-
#pragma unroll
568-
for (size_t col = 0; col < tile_y; col++) {
569-
sycl::ext::oneapi::experimental::printf(
570-
"%0.1f ", (float)(sycl::half)mat.reg[row * tile_y + col]);
571-
}
572-
sycl::ext::oneapi::experimental::printf("\n ");
573-
}
574-
sycl::ext::oneapi::experimental::printf("\n ");
575-
}
576-
577563
inline void dequantize(
578564
matB_acc_t& matB_acc,
579565
matB_t& matB,

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,11 @@ class gemm_universal_t<
451451
static cl::sycl::range<3> get_local_range() {
452452
uint32_t local_range_m = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
453453
uint32_t local_range_n = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
454-
// std::cout << "Local range: {" << num_local_kslicing << ", " <<
455-
// local_range_m
456-
// << ", " << local_range_n << "} \n";
454+
XETLA_PRINTF(
455+
"Local range: {%d, %d, %d}",
456+
num_local_kslicing,
457+
local_range_m,
458+
local_range_n);
457459
assert(local_range_m * local_range_n * num_local_kslicing <= 32);
458460
return cl::sycl::range<3>{num_local_kslicing, local_range_m, local_range_n};
459461
};
@@ -471,8 +473,11 @@ class gemm_universal_t<
471473
uint32_t group_range_m = (matrix_m + wg_tile_m - 1) / wg_tile_m;
472474
uint32_t group_range_n = (matrix_n + wg_tile_n - 1) / wg_tile_n;
473475
group_swizzle_t::update_group_range(group_range_m, group_range_n);
474-
// std::cout << "Group range: {" << num_global_kslicing << ", "
475-
// << group_range_m << ", " << group_range_n << "} \n";
476+
XETLA_PRINTF(
477+
"Group range: {%d, %d, %d}",
478+
num_global_kslicing,
479+
group_range_m,
480+
group_range_n);
476481
return cl::sycl::range<3>{
477482
num_global_kslicing, group_range_m, group_range_n};
478483
};

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace gpu::xetla::subgroup {
2828
namespace detail {
2929
template <typename tile_t, typename payload_t>
3030
struct check_load_type {
31+
static constexpr bool is_lsc_gather = true;
3132
static constexpr bool is_global_block_2d =
3233
(payload_t::memory_space == mem_space::global &&
3334
(payload_t::message_type == msg_type::block_2d) &&
@@ -444,6 +445,7 @@ template <
444445
typename payload_t>
445446
__XETLA_API typename std::enable_if_t<
446447
detail::check_load_type<tile_t, payload_t>::is_global_block_2d &&
448+
detail::check_load_type<tile_t, payload_t>::is_lsc_gather &&
447449
payload_t::arch_tag <= gpu_arch::XeHpg>
448450
tile_load(tile_t& tile, payload_t& payload) {
449451
using dtype = typename payload_t::dtype;
@@ -531,6 +533,77 @@ tile_load(tile_t& tile, payload_t& payload) {
531533
}
532534
}
533535

536+
/// @brief This function loads data from unaligned-2D memory surface.
537+
/// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
538+
/// registers. Each block will be loaded serially by its corresponding payload.
539+
/// @tparam tile_t Is the tile_t struct contains registers.
540+
/// These registers will be the destination of load operation.
541+
/// @tparam payload_t Is the mem_payload_t struct describing the memory
542+
/// information. Payload indicates the source of load operation.
543+
/// @tparam L1 Is the cache hint for L1 cache.
544+
/// @tparam L3 Is the cache hint for L3 cache.
545+
/// @param tile Is the tile object with type tile_t, holds the return data of
546+
/// the loads.
547+
/// @param payload Is the payload object with type payload_t. Contains all the
548+
/// information for loads.
549+
/// @return No return, update in place.
550+
template <
551+
cache_hint L1 = cache_hint::cached,
552+
cache_hint L3 = cache_hint::cached,
553+
typename tile_t,
554+
typename payload_t>
555+
__XETLA_API typename std::enable_if_t<
556+
detail::check_load_type<tile_t, payload_t>::is_global_block_2d &&
557+
!detail::check_load_type<tile_t, payload_t>::is_lsc_gather &&
558+
!arch_has_2d_load_store(payload_t::arch_tag)>
559+
tile_load(tile_t& tile, payload_t& payload) {
560+
using dtype = typename payload_t::dtype;
561+
using tile_desc = typename payload_t::tile_desc;
562+
using load_dtype = typename payload_t::mem_dtype;
563+
constexpr uint32_t load_elems = payload_t::simd_exec_size;
564+
constexpr uint32_t pack_factor = payload_t::pack_factor;
565+
566+
#pragma unroll
567+
for (uint32_t i = 0; i < tile_desc::num_block_y; i++) {
568+
uint32_t offset_y = i * tile_desc::block_size_y;
569+
#pragma unroll
570+
for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
571+
uint32_t offset_x = j * tile_desc::block_size_x;
572+
auto reg_sub = tile.reg.xetla_select<tile_desc::block_elems, 1>(
573+
(i * tile_desc::num_block_x + j) * tile_desc::block_elems);
574+
#pragma unroll
575+
for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y;
576+
sub_block_y += 1) {
577+
xetla_vector<load_dtype, load_elems> reg_tmp = 0;
578+
uint32_t address_offset = payload_t::trans
579+
? offset_x * payload.pitch_in_bytes +
580+
(offset_y + sub_block_y) * sizeof(dtype)
581+
: offset_x * sizeof(dtype) +
582+
(offset_y + sub_block_y) * payload.pitch_in_bytes;
583+
reg_tmp = xetla_load_global<
584+
load_dtype,
585+
payload_t::simd_exec_size,
586+
data_size::default_size,
587+
L1,
588+
L3>(payload.base_ptr, payload.base_offset + address_offset);
589+
590+
reg_sub
591+
.xetla_select<load_elems * pack_factor, 1>(
592+
sub_block_y * tile_desc::block_size_x)
593+
.xetla_format<load_dtype>() = reg_tmp;
594+
}
595+
}
596+
}
597+
598+
if constexpr (payload_t::trans) {
599+
SW_BARRIER();
600+
tile_transpose(tile);
601+
}
602+
if constexpr (payload_t::mem_transform) {
603+
SW_BARRIER();
604+
vnni_convert(tile);
605+
}
606+
}
534607
/// @brief This function loads data from unaligned-2D memory surface.
535608
/// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
536609
/// registers. Each block will be loaded serially by its corresponding payload.

include/subgroup/tile/impl/op_function.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,4 +676,35 @@ layout_convert(T_dst& dst, T_src& src) {
676676
}
677677
}
678678
}
679+
680+
template <typename T>
681+
void dump_mat(
682+
T mat,
683+
size_t tile_x = T::tile_size_x,
684+
size_t tile_y = T::tile_size_y) {
685+
#pragma unroll
686+
for (size_t row = 0; row < tile_y; row++) {
687+
#pragma unroll
688+
for (size_t col = 0; col < tile_x; col++) {
689+
sycl::ext::oneapi::experimental::printf(
690+
"%d ", (int)(sycl::half)mat.reg[row * tile_x + col]);
691+
}
692+
sycl::ext::oneapi::experimental::printf("\n ");
693+
}
694+
sycl::ext::oneapi::experimental::printf("\n ");
695+
}
696+
template <typename T>
697+
void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) {
698+
#pragma unroll
699+
for (size_t row = 0; row < tile_y; row++) {
700+
#pragma unroll
701+
for (size_t col = 0; col < tile_x; col++) {
702+
sycl::ext::oneapi::experimental::printf(
703+
"%d ", (int)(sycl::half)mat[row * tile_x + col]);
704+
}
705+
sycl::ext::oneapi::experimental::printf("\n ");
706+
}
707+
sycl::ext::oneapi::experimental::printf("\n ");
708+
}
709+
679710
} // namespace gpu::xetla::subgroup

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,15 @@ struct mem_payload_t<
404404
// for pvc, we can use simd16 or simd32
405405
static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype);
406406
static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype);
407-
static constexpr uint32_t num_channel =
407+
static constexpr uint32_t simd_channel =
408408
((tile_bytes % max_store_bytes) == 0 &&
409409
(block_bytes % max_store_bytes) == 0)
410410
? 32
411411
: 16;
412-
413-
static constexpr uint32_t num_channel_x = block_size_x;
414-
static constexpr uint32_t num_channel_y = num_channel / num_channel_x;
412+
static constexpr uint32_t num_channel =
413+
(simd_channel >= block_size_x) ? block_size_x : simd_channel;
414+
static constexpr uint32_t num_channel_x = block_size_x; // 16
415+
static constexpr uint32_t num_channel_y = num_channel / num_channel_x; // 1
415416
static constexpr uint32_t store_elems = num_channel_y * block_size_x;
416417

417418
xetla_vector<uint32_t, num_channel> channel_offset;

tests/integration/data_transformer/common.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class TestBase {
122122
using data_type_in = float;
123123
using data_type_out = bf16;
124124
using data_type_acc = float;
125+
static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc;
125126
};
126127

127128
class Test_fp32tobf16_128_64 : public TestBase {

0 commit comments

Comments
 (0)