Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit e5510c6

Browse files
sunjiweiswiftDDEle
andauthored
Xetla Bug fix (#256)
1 Restore first token configuration 2 Bug fixes for arch_config 3 Bug fixes for fmha 4 Synchronized part of the code with innersource 5 cmake compilation parameters are the same as ipex 6 FP16 UT bugfix dtype_mma_a and dtype_mma_b should be fp16 7 Updated policy for int4 and default FPU 8 FP16 gemm MatB col_major bugfix --------- Co-authored-by: Ding, Yi1 <[email protected]>
1 parent b13e02f commit e5510c6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2147
-1257
lines changed

CMakeLists.txt

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ if (NOT CMAKE_BUILD_TYPE)
88
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
99
endif()
1010
if(UNIX)
11+
set(CMAKE_C_COMPILER icx)
12+
set(CMAKE_CXX_COMPILER icpx)
1113
else() # Windows
1214
# Force CMake to use icx-cl rather than the default C++ compiler/linker
1315
# (needed on Windows only)
@@ -24,7 +26,7 @@ include(CTest)
2426
enable_testing()
2527

2628
if(UNIX)
27-
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/tools/cmake")
29+
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/tools/cmake")
2830
endif()
2931
find_package(MKL CONFIG REQUIRED)
3032
message(STATUS "MKL_VERSION=${MKL_VERSION}")
@@ -33,7 +35,7 @@ message(STATUS "MKL_IMPORTED_TARGETS=${MKL_IMPORTED_TARGETS}")
3335
# debug option
3436
message(STATUS "'DEBUG' is set to " ${DEBUG})
3537
if (${DEBUG})
36-
add_compile_options(-debug=minimal -Rno-debug-disables-optimization -DDEBUG=${DEBUG})
38+
add_compile_options(-debug=minimal -Rno-debug-disables-optimization -DDEBUG=${DEBUG})
3739
endif ()
3840

3941
# log message print
@@ -43,20 +45,41 @@ if (${LOG} STREQUAL "on")
4345
add_definitions(-DLOG_PRINT)
4446
endif ()
4547

48+
# For large registers mode, enable 256 registers for kernels
49+
set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
50+
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt")
51+
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-codegen")
52+
# Enable bank conflict reduction.
53+
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -enableBCR")
54+
# Optimization to reduce the tokens used for DPAS instruction.
55+
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -DPASTokenReduction")
56+
4657
# AOT device
47-
set(AOT_DEVICE "" CACHE STRING "Set device list for AOT build")
58+
set(USE_AOT_DEVLIST "" CACHE STRING "Set device list for AOT build")
59+
if (USE_AOT_DEVLIST)
60+
add_compile_options(-fsycl-targets=spir64_gen)
61+
add_link_options(-fsycl-targets=spir64_gen)
62+
# For registers usage verbose at AOT
63+
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -printregusage")
64+
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs "-options '${XETLA_OFFLINE_OPTIONS}' -device '${USE_AOT_DEVLIST}'")
65+
else()
66+
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs "${XETLA_OFFLINE_OPTIONS}")
67+
endif()
68+
69+
add_compile_options(-fsycl -fsycl-device-code-split=per_kernel)
70+
add_compile_options(-Wall -Wextra -Werror)
71+
72+
include(ProcessorCount)
73+
ProcessorCount(nproc)
74+
add_link_options(-fsycl -fsycl-device-code-split=per_kernel -fsycl-max-parallel-link-jobs=${nproc})
75+
add_link_options(${XETLA_KERNEL_FLAGS})
4876

49-
add_compile_options(-fsycl)
50-
add_link_options(-fsycl)
5177
if(UNIX)
52-
if (AOT_DEVICE)
53-
add_compile_options(-fsycl-targets=spir64_gen)
54-
add_link_options(-fsycl-targets=spir64_gen -Xs "-device ${AOT_DEVICE}") # MTL
55-
endif()
56-
add_compile_options(-fp-model=precise -Wall -Wextra -Werror)
78+
add_compile_options(-fp-model=precise)
5779
add_link_options(-lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lpthread -lm)
5880
link_libraries(-lgtest -lgtest_main)
5981
else() # Windows
82+
add_compile_options(/fp:precise)
6083
add_compile_options(/EHsc)
6184
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
6285
add_compile_options(/MDd)
Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,6 @@
11
set(TARGET stream_k_gemm)
22

3-
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -fsycl)
4-
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -fsycl-targets=spir64_gen)
5-
6-
# disable loop invariance optimization, this is for performance
7-
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt")
8-
# For large registers mode, enable 256 registers for kernels
9-
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -doubleGRF")
10-
# For registers usage verbose at AOT
11-
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -printregusage")
12-
# Enable bank conflict reduction.
13-
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -enableBCR")
14-
# Optimization to reduce the tokens used for DPAS instruction.
15-
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -Xfinalizer -DPASTokenReduction")
16-
17-
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs)
18-
set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} "-device pvc -options '${XETLA_OFFLINE_OPTIONS} ' ")
19-
203
#build test
214
add_executable(${TARGET} stream_k_gemm.cpp)
22-
target_link_options(${TARGET} PRIVATE ${XETLA_KERNEL_FLAGS})
235
# Disable vector combine, to remove redundant loads and stores
24-
#target_compile_options(${TARGET} PRIVATE -mllvm -disable-vector-combine -fsycl -fsycl-targets=spir64_gen)
25-
6+
# target_compile_options(${TARGET} PRIVATE -mllvm -disable-vector-combine -fsycl -fsycl-targets=spir64_gen)

include/common/core/arch_config.hpp

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeLpg>
8989

9090
template <gpu_arch arch_tag>
9191
inline constexpr bool arch_has_2d_load_store =
92-
load_store_attr_t<msg_type::block_2d, arch_tag>::has_hw_block_2d;
92+
load_store_attr_t<msg_type::block_2d, arch_tag>::has_hw_block_2d;
9393

9494
template <gpu_arch arch_tag>
9595
struct load_store_attr_t<msg_type::block_1d, arch_tag> {
@@ -149,9 +149,19 @@ struct register_nums_t {
149149
};
150150

151151
template <gpu_arch arch_tag>
152-
struct register_bytes_t {
152+
struct register_bytes_t;
153+
template <>
154+
struct register_bytes_t<gpu_arch::XeHpc> {
153155
static constexpr uint32_t reg_in_bytes = 64;
154156
};
157+
template <>
158+
struct register_bytes_t<gpu_arch::XeHpg> {
159+
static constexpr uint32_t reg_in_bytes = 32;
160+
};
161+
template <>
162+
struct register_bytes_t<gpu_arch::XeLpg> {
163+
static constexpr uint32_t reg_in_bytes = 32;
164+
};
155165

156166
template <grf_mode grf_num_mode, gpu_arch arch_tag>
157167
struct register_attr_t {
@@ -188,41 +198,47 @@ struct mma_attr_t<arch_tag, m, std::enable_if_t<!arch_has_xmx<arch_tag>>> {
188198
template <gpu_arch arch_tag>
189199
struct arch_attr_t {};
190200

191-
template <gpu_arch arch_tag>
192-
struct client_arch_attr_base_t {
201+
template <>
202+
struct arch_attr_t<gpu_arch::XeHpc> {
193203
template <msg_type message_type = msg_type::block_2d>
194-
using load_store_attr = load_store_attr_t<message_type, arch_tag>;
204+
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpc>;
195205

196-
template <grf_mode grf_num_mode = grf_mode::normal>
197-
using register_attr = register_attr_t<grf_num_mode, arch_tag>;
206+
template <grf_mode grf_num_mode = grf_mode::double_grf>
207+
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpc>;
198208

199-
using dpas_attr = dpas_attr_t<arch_tag>;
209+
using dpas_attr = dpas_attr_t<gpu_arch::XeHpc>;
200210

201211
static constexpr uint32_t max_wg_num = 64;
202-
static constexpr uint32_t local_mem_size = 64 * 1024;
212+
static constexpr uint32_t local_mem_size = 128 * 1024;
203213
};
204214

205215
template <>
206-
struct arch_attr_t<gpu_arch::XeHpc> {
216+
struct arch_attr_t<gpu_arch::XeHpg> {
207217
template <msg_type message_type = msg_type::block_2d>
208-
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpc>;
218+
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpg>;
209219

210220
template <grf_mode grf_num_mode = grf_mode::double_grf>
211-
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpc>;
221+
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpg>;
212222

213-
using dpas_attr = dpas_attr_t<gpu_arch::XeHpc>;
223+
using dpas_attr = dpas_attr_t<gpu_arch::XeHpg>;
214224

215225
static constexpr uint32_t max_wg_num = 64;
216-
static constexpr uint32_t local_mem_size = 128 * 1024;
226+
static constexpr uint32_t local_mem_size = 64 * 1024;
217227
};
218228

219229
template <>
220-
struct arch_attr_t<gpu_arch::XeHpg>
221-
: public client_arch_attr_base_t<gpu_arch::XeHpg> {};
230+
struct arch_attr_t<gpu_arch::XeLpg> {
231+
template <msg_type message_type = msg_type::block_2d>
232+
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeLpg>;
222233

223-
template <>
224-
struct arch_attr_t<gpu_arch::XeLpg>
225-
: public client_arch_attr_base_t<gpu_arch::XeLpg> {};
234+
template <grf_mode grf_num_mode = grf_mode::double_grf>
235+
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeLpg>;
236+
237+
using dpas_attr = dpas_attr_t<gpu_arch::XeLpg>;
238+
239+
static constexpr uint32_t max_wg_num = 64;
240+
static constexpr uint32_t local_mem_size = 64 * 1024;
241+
};
226242

227243
/// @} xetla_core_arch_config
228244

include/common/core/base_types.hpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,28 @@ using fp16 = sycl::half;
5555
///
5656
using tf32 = sycl::ext::intel::experimental::esimd::tfloat32;
5757

58+
/// @brief mx_fp4(E2M1) data packed as 8bits data type.
59+
struct mx_fp4 {
60+
uint8_t data;
61+
operator uint8_t() const {
62+
return data;
63+
}
64+
mx_fp4() = default;
65+
mx_fp4(uint8_t val) {
66+
data = val;
67+
}
68+
};
69+
70+
template <typename T>
71+
struct get_packed_num {
72+
static constexpr uint32_t value = 1;
73+
};
74+
75+
template <>
76+
struct get_packed_num<mx_fp4> {
77+
static constexpr uint32_t value = 2;
78+
};
79+
5880
template <typename T, typename = void>
5981
struct is_host_callable : std::false_type {};
6082
template <typename T>
@@ -66,7 +88,8 @@ struct is_host_callable<T, std::enable_if_t<T::host_callable == true>>
6688
template <typename T>
6789
struct is_internal_type {
6890
static constexpr bool value = std::is_same<remove_const_t<T>, bf16>::value ||
69-
std::is_same<remove_const_t<T>, tf32>::value;
91+
std::is_same<remove_const_t<T>, tf32>::value ||
92+
std::is_same<remove_const_t<T>, mx_fp4>::value;
7093
};
7194
template <typename T>
7295
inline constexpr bool is_internal_type_v = is_internal_type<T>::value;
@@ -108,6 +131,12 @@ struct native_type {
108131
using type = T;
109132
};
110133

134+
/// @brief Set uint8_t as the native data type of mx_fp4.
135+
template <>
136+
struct native_type<mx_fp4> {
137+
using type = uint8_t;
138+
};
139+
111140
/// @brief Return the native data type of T
112141
template <typename T>
113142
using native_type_t = typename native_type<T>::type;

include/common/core/common.hpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,44 @@ enum class msg_type : uint8_t {
8282
// prefetch_1d = 5
8383
};
8484

85-
/// L1 or L2 cache hint kinds.
85+
/// L1, L2 or L3 cache hints.
8686
enum class cache_hint : uint8_t {
8787
none = 0,
88+
/// load/store/atomic: do not cache data to cache;
8889
uncached = 1,
90+
91+
// load: cache data to cache;
8992
cached = 2,
93+
94+
/// store: write data into cache level and mark the cache line as "dirty".
95+
/// Upon eviction, the "dirty" data will be written into the furthest
96+
/// subsequent cache;
9097
write_back = 3,
98+
99+
/// store: immediately write data to the subsequent furthest cache, marking
100+
/// the cache line in the current cache as "not dirty";
91101
write_through = 4,
102+
103+
/// load: cache data to cache using the evict-first policy to minimize cache
104+
/// pollution caused by temporary streaming data that may only be accessed
105+
/// once or twice;
106+
/// store/atomic: same as write-through, but use the evict-first policy
107+
/// to limit cache pollution by streaming;
92108
streaming = 5,
93-
read_invalidate = 6
109+
110+
/// load: asserts that the cache line containing the data will not be read
111+
/// again until it’s overwritten, therefore the load operation can invalidate
112+
/// the cache line and discard "dirty" data. If the assertion is violated
113+
/// (the cache line is read again) then behavior is undefined.
114+
read_invalidate = 6,
115+
116+
// TODO: Implement the verification of this enum in check_cache_hint().
117+
/// load, L2 cache only, next gen GPU after Xe required: asserts that
118+
/// the L2 cache line containing the data will not be written until all
119+
/// invocations of the shader or kernel execution are finished.
120+
/// If the assertion is violated (the cache line is written), the behavior
121+
/// is undefined.
122+
const_cached = 7
94123
};
95124

96125
/// Data size or format to read or store

0 commit comments

Comments
 (0)