Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit b13e02f

Browse files
[XETLA] Add dpas attr, refine mma, load, store attr (#242)
1 parent 4929d80 commit b13e02f

File tree

13 files changed

+178
-191
lines changed

13 files changed

+178
-191
lines changed

include/common/core/arch_config.hpp

Lines changed: 84 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@ namespace gpu::xetla {
2727
/// @{
2828

2929
template <msg_type message_type, gpu_arch arch_tag>
30-
struct load_store_attr_t {};
30+
struct load_store_attr_t {
31+
static constexpr bool has_hw_block_2d = false;
32+
};
3133

3234
template <>
3335
struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
3436
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
37+
static constexpr bool has_hw_block_2d = true;
3538
static constexpr uint32_t max_load_height_in_elem = 32;
3639
static constexpr uint32_t max_load_width_in_bytes = 64;
3740
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
@@ -53,6 +56,7 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
5356
template <msg_type message_type, gpu_arch arg_tag>
5457
struct client_load_store_attr_base_t {
5558
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
59+
static constexpr bool has_hw_block_2d = false;
5660
static constexpr uint32_t max_load_height_in_elem = 32;
5761
static constexpr uint32_t max_load_width_in_bytes = 64;
5862
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
@@ -83,74 +87,116 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeLpg>
8387
msg_type::block_2d,
8488
gpu_arch::XeLpg> {};
8589

90+
template <gpu_arch arch_tag>
91+
inline constexpr bool arch_has_2d_load_store =
92+
load_store_attr_t<msg_type::block_2d, arch_tag>::has_hw_block_2d;
93+
8694
template <gpu_arch arch_tag>
8795
struct load_store_attr_t<msg_type::block_1d, arch_tag> {
96+
static constexpr uint32_t max_load_vec_len = 32;
97+
static constexpr uint32_t max_store_vec_len = 32;
98+
static constexpr uint32_t max_prefetch_vec_len = 32;
99+
};
100+
101+
template <>
102+
struct load_store_attr_t<msg_type::block_1d, gpu_arch::XeHpc> {
88103
static constexpr uint32_t max_load_vec_len = 64;
89104
static constexpr uint32_t max_store_vec_len = 64;
105+
static constexpr uint32_t max_prefetch_vec_len = 64;
90106
};
91107

92-
template <gpu_arch arch_tag>
93-
struct mma_attr_t {};
108+
struct dpas_attr_base_t {
109+
static constexpr bool has_xmx = true;
110+
static constexpr uint32_t systolic_depth = 8;
111+
static constexpr uint32_t rcount_max = 8;
112+
static constexpr uint32_t op_per_channel_bits = 32;
113+
static constexpr uint32_t op_per_channel_bytes = (op_per_channel_bits >> 3);
114+
static constexpr uint32_t op_per_channel_max = 8;
115+
};
94116

95117
template <gpu_arch arch_tag>
96-
struct client_mma_atr_base_t {
97-
static constexpr uint32_t mma_m_in_elem = 8;
98-
static constexpr uint32_t mma_n_in_elem = 8;
99-
static constexpr uint32_t mma_k_in_bytes = 32;
118+
struct dpas_attr_t {
119+
static constexpr bool has_xmx = false;
100120
};
101121

102122
template <>
103-
struct mma_attr_t<gpu_arch::XeHpc> {
104-
static constexpr uint32_t mma_m_in_elem = 8;
105-
static constexpr uint32_t mma_n_in_elem = 16;
106-
static constexpr uint32_t mma_k_in_bytes = 32;
123+
struct dpas_attr_t<gpu_arch::XeHpc> : public dpas_attr_base_t {
124+
static constexpr uint32_t n_fixed_limit = 16;
107125
};
108126

109127
template <>
110-
struct mma_attr_t<gpu_arch::XeHpg>
111-
: public client_mma_atr_base_t<gpu_arch::XeHpg> {};
128+
struct dpas_attr_t<gpu_arch::XeHpg> : public dpas_attr_base_t {
129+
static constexpr uint32_t n_fixed_limit = 8;
130+
};
112131

113-
template <grf_mode grf_num_mode, gpu_arch arch_tag>
114-
struct register_attr_t {};
132+
template <gpu_arch arch_tag>
133+
inline constexpr bool arch_has_xmx = dpas_attr_t<arch_tag>::has_xmx;
115134

116-
template <grf_mode grf_num_mode, gpu_arch arch_tag>
117-
struct client_register_attr_base_t {
118-
static constexpr uint32_t acc_reg_in_bytes =
119-
(grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64;
120-
static constexpr uint32_t grf_in_bytes =
121-
(grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64;
122-
static constexpr uint32_t reg_in_bytes = 64;
135+
template <gpu_arch arch_tag>
136+
struct fpu_attr_t {
137+
static constexpr bool has_fpu = true;
123138
};
124139

140+
template <gpu_arch arch_tag>
141+
inline constexpr bool arch_has_fpu = fpu_attr_t<arch_tag>::has_fpu;
142+
125143
template <grf_mode grf_num_mode>
126-
struct register_attr_t<grf_num_mode, gpu_arch::XeHpc> {
127-
static constexpr uint32_t acc_reg_in_bytes =
128-
(grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64;
129-
static constexpr uint32_t grf_in_bytes =
130-
(grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64;
144+
struct register_nums_t {
145+
static constexpr uint32_t register_nums =
146+
(grf_num_mode == grf_mode::normal) ? 128 : 256;
147+
static constexpr uint32_t acc_register_nums =
148+
(grf_num_mode == grf_mode::normal) ? 4 : 8;
149+
};
150+
151+
template <gpu_arch arch_tag>
152+
struct register_bytes_t {
131153
static constexpr uint32_t reg_in_bytes = 64;
132154
};
133155

134-
template <grf_mode grf_num_mode>
135-
struct register_attr_t<grf_num_mode, gpu_arch::XeHpg>
136-
: public client_register_attr_base_t<grf_num_mode, gpu_arch::XeHpg> {};
156+
template <grf_mode grf_num_mode, gpu_arch arch_tag>
157+
struct register_attr_t {
158+
static constexpr uint32_t reg_in_bytes =
159+
register_bytes_t<arch_tag>::reg_in_bytes;
160+
static constexpr uint32_t register_nums =
161+
register_nums_t<grf_num_mode>::register_nums;
162+
static constexpr uint32_t acc_register_nums =
163+
register_nums_t<grf_num_mode>::acc_register_nums;
164+
static constexpr uint32_t acc_reg_in_bytes = acc_register_nums * reg_in_bytes;
165+
static constexpr uint32_t grf_in_bytes = register_nums * reg_in_bytes;
166+
};
137167

138-
template <grf_mode grf_num_mode>
139-
struct register_attr_t<grf_num_mode, gpu_arch::XeLpg>
140-
: public client_register_attr_base_t<grf_num_mode, gpu_arch::XeLpg> {};
168+
template <gpu_arch arch_tag, uint32_t m, class enable = void>
169+
struct mma_attr_t {};
170+
171+
template <gpu_arch arch_tag, uint32_t m>
172+
struct mma_attr_t<arch_tag, m, std::enable_if_t<arch_has_xmx<arch_tag>>> {
173+
using dpas_attr = dpas_attr_t<arch_tag>;
174+
static constexpr uint32_t mma_m_in_elem =
175+
(m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m;
176+
static constexpr uint32_t mma_n_in_elem = dpas_attr::n_fixed_limit;
177+
static constexpr uint32_t mma_k_in_bytes =
178+
dpas_attr::systolic_depth * dpas_attr::op_per_channel_bytes;
179+
};
180+
181+
template <gpu_arch arch_tag, uint32_t m>
182+
struct mma_attr_t<arch_tag, m, std::enable_if_t<!arch_has_xmx<arch_tag>>> {
183+
static constexpr uint32_t mma_m_in_elem = (m > 8) ? 8 : m;
184+
static constexpr uint32_t mma_n_in_elem = 16;
185+
static constexpr uint32_t mma_k_in_bytes = 32;
186+
};
141187

142188
template <gpu_arch arch_tag>
143189
struct arch_attr_t {};
144190

145191
template <gpu_arch arch_tag>
146192
struct client_arch_attr_base_t {
147193
template <msg_type message_type = msg_type::block_2d>
148-
using load_store_attr = load_store_attr_t<message_type, gpu_arch::XeHpg>;
194+
using load_store_attr = load_store_attr_t<message_type, arch_tag>;
149195

150-
template <grf_mode grf_num_mode = grf_mode::double_grf>
151-
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpg>;
196+
template <grf_mode grf_num_mode = grf_mode::normal>
197+
using register_attr = register_attr_t<grf_num_mode, arch_tag>;
152198

153-
using mma_attr = mma_attr_t<gpu_arch::XeHpg>;
199+
using dpas_attr = dpas_attr_t<arch_tag>;
154200

155201
static constexpr uint32_t max_wg_num = 64;
156202
static constexpr uint32_t local_mem_size = 64 * 1024;
@@ -164,7 +210,7 @@ struct arch_attr_t<gpu_arch::XeHpc> {
164210
template <grf_mode grf_num_mode = grf_mode::double_grf>
165211
using register_attr = register_attr_t<grf_num_mode, gpu_arch::XeHpc>;
166212

167-
using mma_attr = mma_attr_t<gpu_arch::XeHpc>;
213+
using dpas_attr = dpas_attr_t<gpu_arch::XeHpc>;
168214

169215
static constexpr uint32_t max_wg_num = 64;
170216
static constexpr uint32_t local_mem_size = 128 * 1024;

include/common/core/common_types.hpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@
2222

2323
namespace gpu::xetla {
2424
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 };
25-
inline constexpr bool arch_has_xmx(gpu_arch arch) {
26-
return arch >= gpu_arch::XeHpg;
27-
}
28-
inline constexpr bool arch_has_2d_load_store(gpu_arch arch) {
29-
return arch >= gpu_arch::XeHpc;
30-
}
3125

3226
enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
3327

include/common/utils/limitation.hpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ class block_2d {
9191
ret = ((block_width * block_height * element_size) <= (32 * bytes_per_grf));
9292
XETLA_ASSERT(
9393
ret,
94-
"2D Block Loads upto 32 GRFs are can be read but is %u:%u",
94+
"2D Block Loads upto 32 * %u bytes are can be read but is %u:%u",
95+
bytes_per_grf,
9596
block_width,
9697
block_height);
9798
if (!ret) {
@@ -318,7 +319,7 @@ class block_2d {
318319
static constexpr auto element_size = sizeof(T);
319320
static constexpr uint32_t max_24bit = 16 * 1024 * 1024; // 2 ^ 24
320321
static constexpr auto bytes_per_grf =
321-
register_attr_t<grf_mode::double_grf, gpu_arch::XeHpc>::reg_in_bytes;
322+
register_attr_t<grf_mode::double_grf, arch_tag>::reg_in_bytes;
322323

323324
static inline bool check_base_address(uint64_t base) {
324325
bool ret = ((base % 64) == 0);
@@ -746,11 +747,8 @@ struct check_store {
746747
} // namespace subgroup
747748

748749
namespace group {
749-
template <gpu_arch arch = gpu_arch::XeHpc, class enable = void>
750-
struct gemm {};
751-
752-
template <gpu_arch arch>
753-
struct gemm<arch, std::enable_if_t<(arch <= gpu_arch::XeHpc)>> {
750+
template <gpu_arch arch = gpu_arch::XeHpc>
751+
struct gemm {
754752
struct default_fpu {
755753
template <
756754
typename dtype_a,
@@ -802,7 +800,7 @@ struct gemm<arch, std::enable_if_t<(arch <= gpu_arch::XeHpc)>> {
802800
int block_size_y_b>
803801
struct check_tile_size_default {
804802
static constexpr uint32_t reg_in_bytes =
805-
register_attr_t<grf_mode::double_grf, gpu_arch::XeHpc>::reg_in_bytes;
803+
register_attr_t<grf_mode::double_grf, arch>::reg_in_bytes;
806804
static constexpr uint32_t simd_len = reg_in_bytes / sizeof(dtype_mma);
807805

808806
static_assert(
@@ -878,7 +876,7 @@ struct gemm<arch, std::enable_if_t<(arch <= gpu_arch::XeHpc)>> {
878876
int block_size_x_b,
879877
int block_size_y_b>
880878
struct check_tile_size_default {
881-
using mma_attr = mma_attr_t<gpu_arch::XeHpc>;
879+
using mma_attr = mma_attr_t<arch, block_size_y_a>;
882880
static constexpr int32_t mma_m = mma_attr::mma_m_in_elem;
883881
static constexpr int32_t mma_n = mma_attr::mma_n_in_elem;
884882
static constexpr int32_t mma_k =

include/experimental/group/gemm/compute_policy.hpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,33 +62,28 @@ struct compute_policy_int4_dequantize<
6262
arch_tag_,
6363
std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> {
6464
using compute_attr = compute_attr_;
65+
using dtype_mma_acc = typename compute_attr::dtype_acc;
66+
using dtype_mma_a = typename compute_attr::dtype_a;
67+
using dtype_mma_b = typename compute_attr::dtype_b;
68+
6569
using perf_tuning_knob = perf_tuning_knob_;
66-
static constexpr int k_stride = perf_tuning_knob::k_stride;
6770
static constexpr int stages = perf_tuning_knob::stages;
6871
static constexpr int sync_freq = perf_tuning_knob::sync_freq;
72+
static constexpr int k_stride = perf_tuning_knob::k_stride;
6973
static constexpr mma_engine mma_engine = mma_engine_;
7074
static constexpr gpu_arch arch_tag = arch_tag_;
7175

7276
static_assert(
7377
!(mma_engine == mma_engine::xmx && arch_tag == gpu_arch::XeLpg),
7478
"XeLpg does not support xmx");
7579

76-
using dtype_mma_acc = typename compute_attr::dtype_acc;
77-
using dtype_mma_a = typename compute_attr::dtype_a;
78-
using dtype_mma_b = typename compute_attr::dtype_b;
79-
80-
static constexpr uint32_t block_bytes_x_a = 32;
81-
static constexpr uint32_t block_size_y_a = 16;
82-
8380
static constexpr bool is_int4_matB_policy = true;
8481

85-
static constexpr uint32_t block_size_x_b = (mma_engine == mma_engine::xmx)
86-
? arch_attr_t<arch_tag>::mma_attr::mma_n_in_elem
87-
: 32;
88-
static constexpr uint32_t block_bytes_y_b = 32;
89-
static_assert(
90-
block_bytes_x_a == block_bytes_y_b,
91-
"mat_a x need to match with mat_b y");
82+
static constexpr uint32_t block_size_y_a = 16;
83+
using mma_attr = mma_attr_t<arch_tag_, block_size_y_a>;
84+
static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes;
85+
static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem;
86+
static constexpr uint32_t block_bytes_y_b = block_bytes_x_a;
9287

9388
static constexpr uint32_t dequant_s = dequant_s_;
9489
static_assert(

0 commit comments

Comments
 (0)