Skip to content

Commit e948f6d

Browse files
Separate out type and accum type
1 parent ac786ea commit e948f6d

File tree

7 files changed

+103
-46
lines changed

7 files changed

+103
-46
lines changed

applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue.hpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,14 @@ template <class DispatchPolicy, class MMAOperation_, class TileShapeOutput_, cla
5353
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
5454
};
5555

56-
template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
57-
class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
56+
template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementCompute_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
57+
class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementCompute_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
5858
public:
5959
//
6060
// Type Aliases
6161
//
6262
using DispatchPolicy = epilogue::IntelXeXMX16;
6363
using ElementO = ElementO_;
64-
using ElementAccumulator = ElementO_;
6564
using StrideO = StrideO_;
6665
using ElementLSE = ElementLSE_;
6766
using CopyOpO = CopyOpO_;
@@ -70,7 +69,8 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutp
7069
using TiledMmaOutput = typename TiledMMAHelper<MMA_Atom<MMAOperation_>, Layout<TileShapeOutput>, SubgroupLayout>::TiledMMA;
7170
using GmemTiledCopyO = CopyOpO;
7271
using ElementOutput = ElementO_;
73-
using ElementCompute = ElementO_;
72+
using ElementCompute = ElementCompute_;
73+
using ElementAccumulator = ElementCompute_;
7474
using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape())));
7575

7676
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
@@ -196,7 +196,14 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutp
196196
auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX());
197197
Tensor tOgO = thread_xe_store_o.partition_D(gO);
198198

199-
copy(params.xe_store_o, out_reg, tOgO);
199+
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg);
200+
if constexpr (cute::is_any_of_v<ElementOutput, cute::float_e5m2_t, cute::float_e4m3_t>) {
201+
copy(out_reg, final_out_reg);
202+
} else {
203+
Tensor temp = convert_type<ElementOutput>(out_reg);
204+
copy(temp, final_out_reg);
205+
}
206+
copy(params.xe_store_o, final_out_reg, tOgO);
200207
}
201208

202209
// SequenceLengthShapeType = Shape<int, int>

applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ class FMHAPrefill {
370370
CUTLASS_PRAGMA_UNROLL
371371
for (int row = 0; row < Vec; row++, row_idx++) { // 8
372372
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
373-
tSr(row, m, n) = -INFINITY;
373+
tSr(row, m, n) = ElementAccumulator{-INFINITY};
374374
}
375375
}
376376
}

benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -189,18 +189,17 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
189189
}
190190
int kv_group_update=1;
191191
for (int h = 0; h < num_heads_q; h++) {
192-
cutlass::DeviceAllocation<ElementOutput> block_S;
192+
cutlass::DeviceAllocation<ElementAccumulator> block_S;
193193
block_S.reset(seq_len_qo * seq_len_kv);
194194

195195
cutlass::TensorRef ref_Q(block_Q[0].get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk}));
196196
cutlass::TensorRef ref_K(block_K[0].get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv}));
197197
cutlass::TensorRef ref_V(block_V[0].get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo}));
198198
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
199-
cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo}));
200199

201-
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q,
200+
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, ElementAccumulator{1.f}, ref_Q,
202201
cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone,
203-
0.f, ref_S, ref_S, ElementAccumulator(0),
202+
ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0},
204203
1, // batch_count
205204
seq_len_qo * head_size_qk, // batch_stride_Q
206205
seq_len_kv * head_size_qk, // batch_stride_K
@@ -210,8 +209,8 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
210209

211210
syclcompat::wait();
212211

213-
std::vector<ElementOutput> host_S(block_S.size());
214-
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
212+
std::vector<ElementAccumulator> host_S(block_S.size());
213+
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());
215214
syclcompat::wait();
216215

217216
// delete this memory as it is no longer needed
@@ -224,13 +223,13 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
224223
for (int row = 0; row < seq_len_qo; row++) {
225224
for (int col = 0; col < seq_len_kv; col++) {
226225
if ((col - full_tile_offset) > (row - discard_seq_coord))
227-
host_S[col + row * seq_len_kv] = -INFINITY;
226+
host_S[col + row * seq_len_kv] = ElementAccumulator{-INFINITY};
228227
}
229228
}
230229
}
231230

232231
// compute max element per row of S
233-
std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY);
232+
std::vector<ElementAccumulator> max_vec(seq_len_qo, ElementAccumulator{-INFINITY});
234233
for (int row = 0; row < seq_len_qo; row++) {
235234
int idx = row * seq_len_kv;
236235
int max_idx = row;
@@ -246,12 +245,12 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
246245
int idx = row * seq_len_kv;
247246
int max_idx = row;
248247
for (int col = 0; col < seq_len_kv; col++, idx++) {
249-
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementOutput>((head_size_qk))));
248+
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementAccumulator>((head_size_qk))));
250249
}
251250
}
252251

253252
// compute sum per row of S
254-
std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0});
253+
std::vector<ElementAccumulator> sum_vec(seq_len_qo, ElementAccumulator{0});
255254
for (int row = 0; row < seq_len_qo; row++) {
256255
int idx = row * seq_len_kv;
257256
int sum_idx = row;
@@ -283,9 +282,13 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
283282

284283
cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
285284

286-
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, 1.f, ref_P,
285+
cutlass::DeviceAllocation<ElementAccumulator> block_acc;
286+
block_acc.reset(seq_len_qo * head_size_vo);
287+
cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo}));
288+
289+
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementAccumulator{1}, ref_P,
287290
cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone,
288-
0.f, ref_O, ref_O, ElementAccumulator(0),
291+
ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0},
289292
1, // batch_count
290293
seq_len_qo * seq_len_kv, // batch_stride_P
291294
seq_len_kv * head_size_vo, // batch_stride_V
@@ -297,6 +300,19 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
297300
// delete this memory as it is no longer needed
298301
block_P.reset();
299302

303+
std::vector<ElementAccumulator> vec_acc(block_acc.size());
304+
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
305+
syclcompat::wait();
306+
307+
// delete this memory as it is no longer needed
308+
block_acc.reset();
309+
std::vector<ElementOutput> vec_out(vec_acc.size());
310+
for(int i = 0; i < vec_out.size(); i++) {
311+
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
312+
}
313+
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());
314+
syclcompat::wait();
315+
300316
offset_q += seq_len_qo * head_size_qk;
301317
if(kv_group_update % q_group_size==0) {
302318
offset_k += seq_len_kv * head_size_qk;
@@ -311,7 +327,7 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
311327

312328
// Check if output from CUTLASS kernel and reference kernel are equal or not
313329
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
314-
block_O.size(), 0.5f, 0.5f);
330+
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});
315331

316332
return passed;
317333
}

benchmarks/flash_attention/flash_attention_prefill/fmha_prefill_configuration.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct FMHAPrefillConfig {
6767
using MMAOperation = typename MMAOP<GEMMDispatchPolicy, ElementInputType,ElementAccumulator>::Type;
6868
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillEpilogue<
6969
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput,
70-
SubgroupLayout, ElementAccumulator,
70+
SubgroupLayout, ElementAccumulator, ElementOutputType,
7171
cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
7272
GmemTiledCopyO>;
7373

examples/sycl/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
230230
}
231231
int kv_group_update=1;
232232
for (int h = 0; h < num_heads_q; h++) {
233-
cutlass::DeviceAllocation<ElementOutput> block_S;
233+
cutlass::DeviceAllocation<ElementAccumulator> block_S;
234234
block_S.reset(seq_len_qo * seq_len_kv);
235235

236236
cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk}));
237237
cutlass::TensorRef ref_K(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv}));
238238
cutlass::TensorRef ref_V(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo}));
239239
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
240-
cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo}));
241240

242241
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q,
243242
cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone,
@@ -251,8 +250,8 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
251250

252251
syclcompat::wait();
253252

254-
std::vector<ElementOutput> host_S(block_S.size());
255-
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
253+
std::vector<ElementAccumulator> host_S(block_S.size());
254+
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());
256255
syclcompat::wait();
257256

258257
// delete this memory as it is no longer needed
@@ -265,13 +264,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
265264
for (int row = 0; row < seq_len_qo; row++) {
266265
for (int col = 0; col < seq_len_kv; col++) {
267266
if ((col - full_tile_offset) > (row - discard_seq_coord))
268-
host_S[col + row * seq_len_kv] = -INFINITY;
267+
host_S[col + row * seq_len_kv] = ElementAccumulator{-INFINITY};
269268
}
270269
}
271270
}
272271

273272
// compute max element per row of S
274-
std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY);
273+
std::vector<ElementAccumulator> max_vec(seq_len_qo, ElementAccumulator{-INFINITY});
275274
for (int row = 0; row < seq_len_qo; row++) {
276275
int idx = row * seq_len_kv;
277276
int max_idx = row;
@@ -287,12 +286,12 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
287286
int idx = row * seq_len_kv;
288287
int max_idx = row;
289288
for (int col = 0; col < seq_len_kv; col++, idx++) {
290-
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementOutput>((head_size_qk))));
289+
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementAccumulator>((head_size_qk))));
291290
}
292291
}
293292

294293
// compute sum per row of S
295-
std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0});
294+
std::vector<ElementAccumulator> sum_vec(seq_len_qo, ElementAccumulator{0});
296295
for (int row = 0; row < seq_len_qo; row++) {
297296
int idx = row * seq_len_kv;
298297
int sum_idx = row;
@@ -324,9 +323,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
324323

325324
cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
326325

327-
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, 1.f, ref_P,
326+
cutlass::DeviceAllocation<ElementAccumulator> block_acc;
327+
block_acc.reset(seq_len_qo * head_size_vo);
328+
cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo}));
329+
330+
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementAccumulator{1}, ref_P,
328331
cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone,
329-
0.f, ref_O, ref_O, ElementAccumulator(0),
332+
ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0},
330333
1, // batch_count
331334
seq_len_qo * seq_len_kv, // batch_stride_P
332335
seq_len_kv * head_size_vo, // batch_stride_V
@@ -338,6 +341,19 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
338341
// delete this memory as it is no longer needed
339342
block_P.reset();
340343

344+
std::vector<ElementAccumulator> vec_acc(block_acc.size());
345+
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
346+
syclcompat::wait();
347+
348+
// delete this memory as it is no longer needed
349+
block_acc.reset();
350+
std::vector<ElementOutput> vec_out(vec_acc.size());
351+
for(int i = 0; i < vec_out.size(); i++) {
352+
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
353+
}
354+
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());
355+
syclcompat::wait();
356+
341357
offset_q += seq_len_qo * head_size_qk;
342358
if(kv_group_update % q_group_size==0) {
343359
offset_k += seq_len_kv * head_size_qk;
@@ -352,7 +368,7 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
352368

353369
// Check if output from CUTLASS kernel and reference kernel are equal or not
354370
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
355-
block_O.size(), 0.5f, 0.5f);
371+
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});
356372

357373
return passed;
358374
}
@@ -619,7 +635,7 @@ template <bool Causal,
619635
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
620636
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
621637
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillEpilogue<
622-
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementAccumulator, cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
638+
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
623639
GmemTiledCopyStore>;
624640
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue<Causal, EpilogueDispatchPolicy, ElementAccumulator>;
625641

include/cute/arch/copy_xe_U8.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ struct XE_2D_U8x4x16_ST_N {
682682
};
683683

684684
struct XE_2D_U8x8x16_ST_N {
685+
using BlockShape = Shape<_8, _16>;
685686
template <class T>
686687
CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height,
687688
int pitch, intel::coord_t coord,

0 commit comments

Comments
 (0)