@@ -27,11 +27,14 @@ namespace gpu::xetla {
27
27
// / @{
28
28
29
29
template <msg_type message_type, gpu_arch arch_tag>
30
- struct load_store_attr_t {};
30
+ struct load_store_attr_t {
31
+ static constexpr bool has_hw_block_2d = false ;
32
+ };
31
33
32
34
template <>
33
35
struct load_store_attr_t <msg_type::block_2d, gpu_arch::XeHpc> {
34
36
// / HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
37
+ static constexpr bool has_hw_block_2d = true ;
35
38
static constexpr uint32_t max_load_height_in_elem = 32 ;
36
39
static constexpr uint32_t max_load_width_in_bytes = 64 ;
37
40
static constexpr uint32_t max_trans_load_width_in_bytes = 32 ;
@@ -53,6 +56,7 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
53
56
template <msg_type message_type, gpu_arch arg_tag>
54
57
struct client_load_store_attr_base_t {
55
58
// / HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
59
+ static constexpr bool has_hw_block_2d = false ;
56
60
static constexpr uint32_t max_load_height_in_elem = 32 ;
57
61
static constexpr uint32_t max_load_width_in_bytes = 64 ;
58
62
static constexpr uint32_t max_trans_load_width_in_bytes = 32 ;
@@ -83,74 +87,116 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeLpg>
83
87
msg_type::block_2d,
84
88
gpu_arch::XeLpg> {};
85
89
90
+ template <gpu_arch arch_tag>
91
+ inline constexpr bool arch_has_2d_load_store =
92
+ load_store_attr_t <msg_type::block_2d, arch_tag>::has_hw_block_2d;
93
+
86
94
template <gpu_arch arch_tag>
87
95
struct load_store_attr_t <msg_type::block_1d, arch_tag> {
96
+ static constexpr uint32_t max_load_vec_len = 32 ;
97
+ static constexpr uint32_t max_store_vec_len = 32 ;
98
+ static constexpr uint32_t max_prefetch_vec_len = 32 ;
99
+ };
100
+
101
+ template <>
102
+ struct load_store_attr_t <msg_type::block_1d, gpu_arch::XeHpc> {
88
103
static constexpr uint32_t max_load_vec_len = 64 ;
89
104
static constexpr uint32_t max_store_vec_len = 64 ;
105
+ static constexpr uint32_t max_prefetch_vec_len = 64 ;
90
106
};
91
107
92
- template <gpu_arch arch_tag>
93
- struct mma_attr_t {};
108
+ struct dpas_attr_base_t {
109
+ static constexpr bool has_xmx = true ;
110
+ static constexpr uint32_t systolic_depth = 8 ;
111
+ static constexpr uint32_t rcount_max = 8 ;
112
+ static constexpr uint32_t op_per_channel_bits = 32 ;
113
+ static constexpr uint32_t op_per_channel_bytes = (op_per_channel_bits >> 3 );
114
+ static constexpr uint32_t op_per_channel_max = 8 ;
115
+ };
94
116
95
117
template <gpu_arch arch_tag>
96
- struct client_mma_atr_base_t {
97
- static constexpr uint32_t mma_m_in_elem = 8 ;
98
- static constexpr uint32_t mma_n_in_elem = 8 ;
99
- static constexpr uint32_t mma_k_in_bytes = 32 ;
118
+ struct dpas_attr_t {
119
+ static constexpr bool has_xmx = false ;
100
120
};
101
121
102
122
template <>
103
- struct mma_attr_t <gpu_arch::XeHpc> {
104
- static constexpr uint32_t mma_m_in_elem = 8 ;
105
- static constexpr uint32_t mma_n_in_elem = 16 ;
106
- static constexpr uint32_t mma_k_in_bytes = 32 ;
123
+ struct dpas_attr_t <gpu_arch::XeHpc> : public dpas_attr_base_t {
124
+ static constexpr uint32_t n_fixed_limit = 16 ;
107
125
};
108
126
109
127
template <>
110
- struct mma_attr_t <gpu_arch::XeHpg>
111
- : public client_mma_atr_base_t <gpu_arch::XeHpg> {};
128
+ struct dpas_attr_t <gpu_arch::XeHpg> : public dpas_attr_base_t {
129
+ static constexpr uint32_t n_fixed_limit = 8 ;
130
+ };
112
131
113
- template <grf_mode grf_num_mode, gpu_arch arch_tag>
114
- struct register_attr_t {} ;
132
+ template <gpu_arch arch_tag>
133
+ inline constexpr bool arch_has_xmx = dpas_attr_t <arch_tag>::has_xmx ;
115
134
116
- template <grf_mode grf_num_mode, gpu_arch arch_tag>
117
- struct client_register_attr_base_t {
118
- static constexpr uint32_t acc_reg_in_bytes =
119
- (grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64 ;
120
- static constexpr uint32_t grf_in_bytes =
121
- (grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64 ;
122
- static constexpr uint32_t reg_in_bytes = 64 ;
135
+ template <gpu_arch arch_tag>
136
+ struct fpu_attr_t {
137
+ static constexpr bool has_fpu = true ;
123
138
};
124
139
140
+ template <gpu_arch arch_tag>
141
+ inline constexpr bool arch_has_fpu = fpu_attr_t <arch_tag>::has_fpu;
142
+
125
143
template <grf_mode grf_num_mode>
126
- struct register_attr_t <grf_num_mode, gpu_arch::XeHpc> {
127
- static constexpr uint32_t acc_reg_in_bytes =
128
- (grf_num_mode == grf_mode::normal) ? 4 * 64 : 8 * 64 ;
129
- static constexpr uint32_t grf_in_bytes =
130
- (grf_num_mode == grf_mode::normal) ? 128 * 64 : 256 * 64 ;
144
+ struct register_nums_t {
145
+ static constexpr uint32_t register_nums =
146
+ (grf_num_mode == grf_mode::normal) ? 128 : 256 ;
147
+ static constexpr uint32_t acc_register_nums =
148
+ (grf_num_mode == grf_mode::normal) ? 4 : 8 ;
149
+ };
150
+
151
+ template <gpu_arch arch_tag>
152
+ struct register_bytes_t {
131
153
static constexpr uint32_t reg_in_bytes = 64 ;
132
154
};
133
155
134
- template <grf_mode grf_num_mode>
135
- struct register_attr_t <grf_num_mode, gpu_arch::XeHpg>
136
- : public client_register_attr_base_t <grf_num_mode, gpu_arch::XeHpg> {};
156
+ template <grf_mode grf_num_mode, gpu_arch arch_tag>
157
+ struct register_attr_t {
158
+ static constexpr uint32_t reg_in_bytes =
159
+ register_bytes_t <arch_tag>::reg_in_bytes;
160
+ static constexpr uint32_t register_nums =
161
+ register_nums_t <grf_num_mode>::register_nums;
162
+ static constexpr uint32_t acc_register_nums =
163
+ register_nums_t <grf_num_mode>::acc_register_nums;
164
+ static constexpr uint32_t acc_reg_in_bytes = acc_register_nums * reg_in_bytes;
165
+ static constexpr uint32_t grf_in_bytes = register_nums * reg_in_bytes;
166
+ };
137
167
138
- template <grf_mode grf_num_mode>
139
- struct register_attr_t <grf_num_mode, gpu_arch::XeLpg>
140
- : public client_register_attr_base_t <grf_num_mode, gpu_arch::XeLpg> {};
168
+ template <gpu_arch arch_tag, uint32_t m, class enable = void >
169
+ struct mma_attr_t {};
170
+
171
+ template <gpu_arch arch_tag, uint32_t m>
172
+ struct mma_attr_t <arch_tag, m, std::enable_if_t <arch_has_xmx<arch_tag>>> {
173
+ using dpas_attr = dpas_attr_t <arch_tag>;
174
+ static constexpr uint32_t mma_m_in_elem =
175
+ (m > dpas_attr::rcount_max) ? dpas_attr::rcount_max : m;
176
+ static constexpr uint32_t mma_n_in_elem = dpas_attr::n_fixed_limit;
177
+ static constexpr uint32_t mma_k_in_bytes =
178
+ dpas_attr::systolic_depth * dpas_attr::op_per_channel_bytes;
179
+ };
180
+
181
+ template <gpu_arch arch_tag, uint32_t m>
182
+ struct mma_attr_t <arch_tag, m, std::enable_if_t <!arch_has_xmx<arch_tag>>> {
183
+ static constexpr uint32_t mma_m_in_elem = (m > 8 ) ? 8 : m;
184
+ static constexpr uint32_t mma_n_in_elem = 16 ;
185
+ static constexpr uint32_t mma_k_in_bytes = 32 ;
186
+ };
141
187
142
188
template <gpu_arch arch_tag>
143
189
struct arch_attr_t {};
144
190
145
191
template <gpu_arch arch_tag>
146
192
struct client_arch_attr_base_t {
147
193
template <msg_type message_type = msg_type::block_2d>
148
- using load_store_attr = load_store_attr_t <message_type, gpu_arch::XeHpg >;
194
+ using load_store_attr = load_store_attr_t <message_type, arch_tag >;
149
195
150
- template <grf_mode grf_num_mode = grf_mode::double_grf >
151
- using register_attr = register_attr_t <grf_num_mode, gpu_arch::XeHpg >;
196
+ template <grf_mode grf_num_mode = grf_mode::normal >
197
+ using register_attr = register_attr_t <grf_num_mode, arch_tag >;
152
198
153
- using mma_attr = mma_attr_t <gpu_arch::XeHpg >;
199
+ using dpas_attr = dpas_attr_t <arch_tag >;
154
200
155
201
static constexpr uint32_t max_wg_num = 64 ;
156
202
static constexpr uint32_t local_mem_size = 64 * 1024 ;
@@ -164,7 +210,7 @@ struct arch_attr_t<gpu_arch::XeHpc> {
164
210
template <grf_mode grf_num_mode = grf_mode::double_grf>
165
211
using register_attr = register_attr_t <grf_num_mode, gpu_arch::XeHpc>;
166
212
167
- using mma_attr = mma_attr_t <gpu_arch::XeHpc>;
213
+ using dpas_attr = dpas_attr_t <gpu_arch::XeHpc>;
168
214
169
215
static constexpr uint32_t max_wg_num = 64 ;
170
216
static constexpr uint32_t local_mem_size = 128 * 1024 ;
0 commit comments