@@ -230,14 +230,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
230
230
}
231
231
int kv_group_update=1 ;
232
232
for (int h = 0 ; h < num_heads_q; h++) {
233
- cutlass::DeviceAllocation<ElementOutput > block_S;
233
+ cutlass::DeviceAllocation<ElementAccumulator > block_S;
234
234
block_S.reset (seq_len_qo * seq_len_kv);
235
235
236
236
cutlass::TensorRef ref_Q (block_Q_.get () + offset_q, LayoutQ::packed ({seq_len_qo, head_size_qk}));
237
237
cutlass::TensorRef ref_K (block_K_.get () + offset_k, LayoutK::packed ({head_size_qk, seq_len_kv}));
238
238
cutlass::TensorRef ref_V (block_V_.get () + offset_v, LayoutV::packed ({seq_len_kv, head_size_vo}));
239
239
cutlass::TensorRef ref_S (block_S.get (), LayoutQ::packed ({seq_len_qo, seq_len_kv}));
240
- cutlass::TensorRef ref_O (block_ref_O.get () + offset_o, LayoutO::packed ({seq_len_qo, head_size_vo}));
241
240
242
241
cutlass::reference::device::GemmComplex ({seq_len_qo, seq_len_kv, head_size_qk}, 1 .f , ref_Q,
243
242
cutlass::ComplexTransform::kNone , ref_K, cutlass::ComplexTransform::kNone ,
@@ -251,8 +250,8 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
251
250
252
251
syclcompat::wait ();
253
252
254
- std::vector<ElementOutput > host_S (block_S.size ());
255
- syclcompat::memcpy<ElementOutput >(host_S.data (), block_S.get (), host_S.size ());
253
+ std::vector<ElementAccumulator > host_S (block_S.size ());
254
+ syclcompat::memcpy<ElementAccumulator >(host_S.data (), block_S.get (), host_S.size ());
256
255
syclcompat::wait ();
257
256
258
257
// delete this memory as it is no longer needed
@@ -265,13 +264,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
265
264
for (int row = 0 ; row < seq_len_qo; row++) {
266
265
for (int col = 0 ; col < seq_len_kv; col++) {
267
266
if ((col - full_tile_offset) > (row - discard_seq_coord))
268
- host_S[col + row * seq_len_kv] = -INFINITY;
267
+ host_S[col + row * seq_len_kv] = ElementAccumulator{ -INFINITY} ;
269
268
}
270
269
}
271
270
}
272
271
273
272
// compute max element per row of S
274
- std::vector<ElementOutput > max_vec (seq_len_qo, -INFINITY);
273
+ std::vector<ElementAccumulator > max_vec (seq_len_qo, ElementAccumulator{ -INFINITY} );
275
274
for (int row = 0 ; row < seq_len_qo; row++) {
276
275
int idx = row * seq_len_kv;
277
276
int max_idx = row;
@@ -287,12 +286,12 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
287
286
int idx = row * seq_len_kv;
288
287
int max_idx = row;
289
288
for (int col = 0 ; col < seq_len_kv; col++, idx++) {
290
- host_S[idx] = expf ((host_S[idx] - max_vec[max_idx]) / sqrt (static_cast <ElementOutput >((head_size_qk))));
289
+ host_S[idx] = expf ((host_S[idx] - max_vec[max_idx]) / sqrt (static_cast <ElementAccumulator >((head_size_qk))));
291
290
}
292
291
}
293
292
294
293
// compute sum per row of S
295
- std::vector<ElementOutput > sum_vec (seq_len_qo, ElementOutput {0 });
294
+ std::vector<ElementAccumulator > sum_vec (seq_len_qo, ElementAccumulator {0 });
296
295
for (int row = 0 ; row < seq_len_qo; row++) {
297
296
int idx = row * seq_len_kv;
298
297
int sum_idx = row;
@@ -324,9 +323,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
324
323
325
324
cutlass::TensorRef ref_P (block_P.get (), LayoutQ::packed ({seq_len_qo, seq_len_kv}));
326
325
327
- cutlass::reference::device::GemmComplex ({seq_len_qo, head_size_vo, seq_len_kv}, 1 .f , ref_P,
326
+ cutlass::DeviceAllocation<ElementAccumulator> block_acc;
327
+ block_acc.reset (seq_len_qo * head_size_vo);
328
+ cutlass::TensorRef ref_acc (block_acc.get (), LayoutO::packed ({seq_len_qo, head_size_vo}));
329
+
330
+ cutlass::reference::device::GemmComplex ({seq_len_qo, head_size_vo, seq_len_kv}, ElementAccumulator{1 }, ref_P,
328
331
cutlass::ComplexTransform::kNone , ref_V, cutlass::ComplexTransform::kNone ,
329
- 0 . f , ref_O, ref_O , ElementAccumulator ( 0 ) ,
332
+ ElementAccumulator{ 0 }, ref_acc, ref_acc , ElementAccumulator{ 0 } ,
330
333
1 , // batch_count
331
334
seq_len_qo * seq_len_kv, // batch_stride_P
332
335
seq_len_kv * head_size_vo, // batch_stride_V
@@ -338,6 +341,19 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
338
341
// delete this memory as it is no longer needed
339
342
block_P.reset ();
340
343
344
+ std::vector<ElementAccumulator> vec_acc (block_acc.size ());
345
+ syclcompat::memcpy<ElementAccumulator>(vec_acc.data (), block_acc.get (), vec_acc.size ());
346
+ syclcompat::wait ();
347
+
348
+ // delete this memory as it is no longer needed
349
+ block_acc.reset ();
350
+ std::vector<ElementOutput> vec_out (vec_acc.size ());
351
+ for (int i = 0 ; i < vec_out.size (); i++) {
352
+ vec_out[i] = static_cast <ElementOutput>(vec_acc[i]);
353
+ }
354
+ syclcompat::memcpy<ElementOutput>(block_ref_O.get () + offset_o, vec_out.data (), vec_out.size ());
355
+ syclcompat::wait ();
356
+
341
357
offset_q += seq_len_qo * head_size_qk;
342
358
if (kv_group_update % q_group_size==0 ) {
343
359
offset_k += seq_len_kv * head_size_qk;
@@ -352,7 +368,7 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
352
368
353
369
// Check if output from CUTLASS kernel and reference kernel are equal or not
354
370
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual (block_ref_O.get (), block_O.get (),
355
- block_O.size (), 0 . 5f , 0 . 5f );
371
+ block_O.size (), ElementOutput{ 0.5 }, ElementOutput{ 0.5 } );
356
372
357
373
return passed;
358
374
}
@@ -619,7 +635,7 @@ template <bool Causal,
619
635
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
620
636
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
621
637
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillEpilogue<
622
- EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementAccumulator , cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
638
+ EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput , cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
623
639
GmemTiledCopyStore>;
624
640
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue<Causal, EpilogueDispatchPolicy, ElementAccumulator>;
625
641
0 commit comments