Skip to content

Commit 5377d14

Browse files
Separate output and accumulator type for Flash Attention Prefill Cached (#448)
This PR separates the output type and accumulator type for Flash Attention Prefill Cached. Combinations supported are: * bf16 inputs, fp32 accumulator, bf16 | fp32 output * fp16 inputs, fp32 accumulator, fp16 | fp32 output It also fixes the PagedKV cache support when used with Variable Length sequences. --------- Co-authored-by: Alejandro Acosta <[email protected]>
1 parent d5f1886 commit 5377d14

File tree

7 files changed

+309
-212
lines changed

7 files changed

+309
-212
lines changed

applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue_cachedKV.hpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,14 @@ template <class DispatchPolicy, class MMAOperation_, class TileShapeOutput_, cla
5353
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
5454
};
5555

56-
template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
57-
class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
56+
template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementCompute_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
57+
class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementCompute_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
5858
public:
5959
//
6060
// Type Aliases
6161
//
6262
using DispatchPolicy = epilogue::IntelXeXMX16;
6363
using ElementO = ElementO_;
64-
using ElementAccumulator = ElementO_;
6564
using StrideO = StrideO_;
6665
using ElementLSE = ElementLSE_;
6766
using CopyOpO = CopyOpO_;
@@ -70,7 +69,8 @@ class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileSha
7069
using TiledMmaOutput = typename TiledMMAHelper<MMA_Atom<MMAOperation_>, Layout<TileShapeOutput>, SubgroupLayout>::TiledMMA;
7170
using GmemTiledCopyO = CopyOpO;
7271
using ElementOutput = ElementO_;
73-
using ElementCompute = ElementO_;
72+
using ElementCompute = ElementCompute_;
73+
using ElementAccumulator = ElementCompute_;
7474
using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape())));
7575

7676
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
@@ -197,7 +197,17 @@ class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileSha
197197
auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX());
198198
Tensor tOgO = thread_xe_store_o.partition_D(gO);
199199

200-
copy(params.xe_store_o, out_reg, tOgO);
200+
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg);
201+
// iff ElementOutput == ElementAccumulator, then convert_type doesn't do the right conversion
202+
// so we call copy() which internally performs a static_cast op on the data.
203+
// for ElementOutput == bf16 | fp16, convert_type calls relevant NumericConverter specialization.
204+
if constexpr (cute::is_same_v<ElementOutput, ElementCompute>) {
205+
copy(out_reg, final_out_reg);
206+
} else {
207+
Tensor temp = convert_type<ElementOutput>(out_reg);
208+
copy(temp, final_out_reg);
209+
}
210+
copy(params.xe_store_o, final_out_reg, tOgO);
201211
}
202212

203213
// SequenceLengthShapeType = Shape<int, int, int>

applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
164164
// Paged KV Cache
165165
int const* ptr_page_table;
166166
int page_size;
167-
int num_pages_per_seq;
167+
int const* num_pages_per_seq;
168168
};
169169

170170
struct Params {
@@ -176,7 +176,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
176176
// Paged KV Cache
177177
int const* ptr_page_table;
178178
int page_size;
179-
int num_pages_per_seq;
179+
int const* num_pages_per_seq;
180180
};
181181

182182
//

applications/flash_attention_v2/kernel/xe_flash_attn_prefill_cachedKV.hpp

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ class FMHAPrefillCached {
178178
static bool can_implement(Arguments const &args) {
179179
bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or
180180
(args.mode == gemm::GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
181-
return mode_implementable;
181+
bool valid_page_size = !PagedKV || (args.mainloop.page_size >= QK_BLK_N && args.mainloop.page_size % QK_BLK_N == 0);
182+
return mode_implementable && valid_page_size;
182183
}
183184

184185
static int get_workspace_size(Arguments const &args) { return 0; }
@@ -314,10 +315,22 @@ class FMHAPrefillCached {
314315
}
315316
auto& prefetch_K = (seq_len_kv_cache == 0) ? tiled_prefetch_k: tiled_prefetch_k_cache;
316317
auto& pKgK1_ = (seq_len_kv_cache == 0) ? pKgK: pKgK_cache;
318+
319+
int cached_nblock = 0;
320+
if constexpr (PagedKV) {
321+
if (seq_len_kv_cache != 0) {
322+
int curr_batch_pages = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord + 1] - mainloop_params.num_pages_per_seq[batch_coord]
323+
: ceil_div(seq_len_kv_cache, mainloop_params.page_size);
324+
int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages;
325+
cached_nblock = mainloop_params.ptr_page_table[
326+
batch_offset // page table for this batch
327+
] * tiles_per_page; // base block idx of physical page
328+
}
329+
}
317330
// The headsize for both cached and non-cached version is the same
318331
for (int j = 0; j < size<4>(pKgK1_); j++) {
319332
CUTLASS_PRAGMA_UNROLL
320-
for (int i = 0; i < DispatchPolicy::Stages; i++) {
333+
for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; i++) {
321334
prefetch(prefetch_K, pKgK1_(_, _, _ , i, j));
322335
}
323336
}
@@ -345,18 +358,6 @@ class FMHAPrefillCached {
345358

346359
bool is_KV_cache = nblock < nblock_cache;
347360

348-
int cached_nblock = nblock;
349-
if constexpr (PagedKV) {
350-
if (is_KV_cache) {
351-
// get physical page idx from page table
352-
cached_nblock = params.mainloop.ptr_page_table[
353-
batch_coord * params.mainloop.num_pages_per_seq + // page table for this batch
354-
nblock * QK_BLK_N / params.mainloop.page_size // nblock (tile idx) to logical page idx
355-
] * tiles_per_page + // base block idx of physical page
356-
nblock % tiles_per_page; // offset within page
357-
}
358-
}
359-
360361
// 1) Load KV (performed inside mmaQK)
361362
auto gK_ = is_KV_cache ? gK_cache(_, _, cached_nblock, _) : gK(_, _, nblock - nblock_cache, _);
362363
auto gV_ = is_KV_cache ? gV_cache(_, _, cached_nblock) : gV(_, _, nblock - nblock_cache);
@@ -372,8 +373,32 @@ class FMHAPrefillCached {
372373
// prefetching it the same way as cutlass K matrix does not make sense
373374
auto& tiled_prefetch_v_ = is_KV_cache ? tiled_prefetch_v_cache : tiled_prefetch_v;
374375
auto& pVgV_ = is_KV_cache ? pVgV_cache : pVgV;
375-
for(int i=0; i < size<1>(pVgV); i++) {
376-
prefetch(tiled_prefetch_v_, pVgV_cache(_, i, _ , nblock - (!is_KV_cache) * nblock_cache));
376+
int v_prefetch_idx = is_KV_cache ? PagedKV ? cached_nblock : nblock
377+
: nblock - nblock_cache;
378+
for(int i = 0; i < size<1>(pVgV_); i++) {
379+
prefetch(tiled_prefetch_v_, pVgV_(_, i, _ , v_prefetch_idx));
380+
}
381+
382+
int next_cached_nblock = nblock + 1;
383+
bool is_next_KV_cache = next_cached_nblock < nblock_cache;
384+
if constexpr (PagedKV) {
385+
if (is_next_KV_cache) {
386+
int curr_batch_pages = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord + 1] - mainloop_params.num_pages_per_seq[batch_coord]
387+
: ceil_div(seq_len_kv_cache, mainloop_params.page_size);
388+
int next_page_logical_idx = next_cached_nblock * QK_BLK_N / params.mainloop.page_size;
389+
int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages;
390+
bool valid_page = next_page_logical_idx < curr_batch_pages;
391+
// get physical page idx from page table
392+
if (valid_page) {
393+
next_cached_nblock = params.mainloop.ptr_page_table[
394+
batch_offset + // page table for this batch
395+
next_page_logical_idx // nblock (tile idx) to logical page idx
396+
] * tiles_per_page + // base block idx of physical page
397+
next_cached_nblock % tiles_per_page; // offset within page
398+
} else {
399+
next_cached_nblock = curr_batch_pages * tiles_per_page; // push idx out of bounds to respect the boundary between batches
400+
}
401+
}
377402
}
378403

379404
// 4) Fused softmax
@@ -382,16 +407,26 @@ class FMHAPrefillCached {
382407

383408
// 5) Perform GEMM O = S*V
384409
collective_mma.template mmaPV<VSlicer>(out_reg, tSr, gV_, out_reg, mainloop_params, is_KV_cache);
385-
410+
411+
// Prefetch the next Q tile
412+
CUTLASS_PRAGMA_UNROLL
413+
for (int i = 0; i < size<3>(pQgQ); i++) {
414+
prefetch(tiled_prefetch_q, pQgQ(_, _, _, i));
415+
}
416+
417+
is_KV_cache = is_next_KV_cache;
418+
cached_nblock = next_cached_nblock;
386419
// Prefetch the next K tile
387420
// there is no need to gaurd it with if statememt as prefetch will ignore out of bound reading
388421

389422
bool sel_prefetch_k = (nblock + DispatchPolicy::Stages) < nblock_cache;
390423
auto& prefetch_k_selector = sel_prefetch_k ? tiled_prefetch_k_cache: tiled_prefetch_k;
391424
auto& pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK;
425+
int k_prefetch_idx = sel_prefetch_k ? PagedKV ? cached_nblock : nblock + DispatchPolicy::Stages
426+
: nblock + DispatchPolicy::Stages - nblock_cache;
392427
CUTLASS_PRAGMA_UNROLL
393428
for (int j = 0; j < size<4>(pKgK_); j++) {
394-
prefetch(prefetch_k_selector, pKgK_(_, _, _, (nblock + DispatchPolicy::Stages) - (!sel_prefetch_k) * nblock_cache , j));
429+
prefetch(prefetch_k_selector, pKgK_(_, _, _, k_prefetch_idx , j));
395430
}
396431
barrier_wait(barrier_scope);
397432
}
@@ -406,8 +441,8 @@ class FMHAPrefillCached {
406441
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock_new - 1, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, false);
407442
// we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big,
408443
// prefetching it the same way as cutlass K matrix does not make sense
409-
for(int i=0; i< size<1>(pVgV); i++) {
410-
prefetch(tiled_prefetch_v, pVgV(_, i, _ , nblock_new - 1));
444+
for(int i = 0; i< size<1>(pVgV); i++) {
445+
prefetch(tiled_prefetch_v, pVgV(_, i, _ , nblock_new - 1));
411446
}
412447
// mask the elements of each tile where j > i
413448
const int item_id = thread_idx % SubgroupSize;
@@ -420,7 +455,7 @@ class FMHAPrefillCached {
420455
CUTLASS_PRAGMA_UNROLL
421456
for (int row = 0; row < Vec; row++, row_idx++) { // 8
422457
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
423-
tSr(row, m, n) = -INFINITY;
458+
tSr(row, m, n) = ElementAccumulator{-INFINITY};
424459
}
425460
}
426461
}

benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
207207
int kv_group_update=1;
208208

209209
for (int h = 0; h < num_heads_q; h++) {
210-
cutlass::DeviceAllocation<ElementOutput> block_S;
210+
cutlass::DeviceAllocation<ElementAccumulator> block_S;
211211
block_S.reset(seq_len_qo * seq_len_kv_total);
212212

213213
ElementK* k_ptr;
@@ -254,11 +254,10 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
254254
cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total}));
255255
cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({seq_len_kv_total, head_size_vo}));
256256
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total}));
257-
cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo}));
258257

259-
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, 1.f, ref_Q,
258+
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, ElementAccumulator{1}, ref_Q,
260259
cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone,
261-
0.f, ref_S, ref_S, ElementAccumulator(0),
260+
ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0},
262261
1, // batch_count
263262
seq_len_qo * head_size_qk, // batch_stride_Q
264263
seq_len_kv_total * head_size_qk, // batch_stride_K
@@ -268,8 +267,8 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
268267

269268
syclcompat::wait();
270269

271-
std::vector<ElementOutput> host_S(block_S.size());
272-
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
270+
std::vector<ElementAccumulator> host_S(block_S.size());
271+
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());
273272
syclcompat::wait();
274273

275274
// delete this memory as it is no longer needed
@@ -283,13 +282,13 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
283282
for (int row = 0; row < seq_len_qo; row++) {
284283
for (int col = start_col; col < seq_len_kv_total; col++) {
285284
if (col - full_tile_offset > row + start_col - discard_seq_coord)
286-
host_S[col + row * seq_len_kv_total] = -INFINITY;
285+
host_S[col + row * seq_len_kv_total] = ElementAccumulator{-INFINITY};
287286
}
288287
}
289288
}
290289

291290
// compute max element per row of S
292-
std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY);
291+
std::vector<ElementAccumulator> max_vec(seq_len_qo, ElementAccumulator{-INFINITY});
293292
for (int row = 0; row < seq_len_qo; row++) {
294293
int idx = row * seq_len_kv_total;
295294
int max_idx = row;
@@ -305,12 +304,12 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
305304
int idx = row * seq_len_kv_total;
306305
int max_idx = row;
307306
for (int col = 0; col < seq_len_kv_total; col++, idx++) {
308-
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementOutput>((head_size_qk))));
307+
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementAccumulator>((head_size_qk))));
309308
}
310309
}
311310

312311
// compute sum per row of S
313-
std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0});
312+
std::vector<ElementAccumulator> sum_vec(seq_len_qo, ElementAccumulator{0});
314313
for (int row = 0; row < seq_len_qo; row++) {
315314
int idx = row * seq_len_kv_total;
316315
int sum_idx = row;
@@ -342,9 +341,13 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
342341

343342
cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total}));
344343

345-
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, 1.f, ref_P,
344+
cutlass::DeviceAllocation<ElementAccumulator> block_acc;
345+
block_acc.reset(seq_len_qo * head_size_vo);
346+
cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo}));
347+
348+
cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, ElementAccumulator{1}, ref_P,
346349
cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone,
347-
0.f, ref_O, ref_O, ElementAccumulator(0),
350+
ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0},
348351
1, // batch_count
349352
seq_len_qo * seq_len_kv_total, // batch_stride_P
350353
seq_len_kv_total * head_size_vo, // batch_stride_V
@@ -356,6 +359,19 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
356359
// delete this memory as it is no longer needed
357360
block_P.reset();
358361

362+
std::vector<ElementAccumulator> vec_acc(block_acc.size());
363+
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
364+
syclcompat::wait();
365+
366+
// delete this memory as it is no longer needed
367+
block_acc.reset();
368+
std::vector<ElementOutput> vec_out(vec_acc.size());
369+
for(int i = 0; i < vec_out.size(); i++) {
370+
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
371+
}
372+
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());
373+
syclcompat::wait();
374+
359375
offset_q += seq_len_qo * head_size_qk;
360376
if(kv_group_update % q_group_size==0) {
361377
offset_k += seq_len_kv * head_size_qk;
@@ -372,7 +388,7 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
372388

373389
// Check if output from CUTLASS kernel and reference kernel are equal or not
374390
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
375-
block_O.size(), 0.5f, 0.5f);
391+
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});
376392

377393
return passed;
378394
}

benchmarks/flash_attention/flash_attention_prefill_cachedKV/fmha_prefill_configuration.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ struct FMHAPrefillConfig {
6969
using MMAOperation = typename MMAOP<GEMMDispatchPolicy, ElementInputType,ElementAccumulator>::Type;
7070
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillCachedEpilogue<
7171
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput,
72-
SubgroupLayout, ElementAccumulator,
72+
SubgroupLayout, ElementAccumulator, ElementOutput,
7373
cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
7474
GmemTiledCopyO>;
7575

0 commit comments

Comments
 (0)