diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue_cachedKV.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue_cachedKV.hpp index 9e26e1c94f..f115bf6005 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue_cachedKV.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue_cachedKV.hpp @@ -53,15 +53,14 @@ template , "Could not find an epilogue specialization."); }; -template -class FlashPrefillCachedEpilogue { +template +class FlashPrefillCachedEpilogue { public: // // Type Aliases // using DispatchPolicy = epilogue::IntelXeXMX16; using ElementO = ElementO_; - using ElementAccumulator = ElementO_; using StrideO = StrideO_; using ElementLSE = ElementLSE_; using CopyOpO = CopyOpO_; @@ -70,7 +69,8 @@ class FlashPrefillCachedEpilogue, Layout, SubgroupLayout>::TiledMMA; using GmemTiledCopyO = CopyOpO; using ElementOutput = ElementO_; - using ElementCompute = ElementO_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementCompute_; using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape()))); static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -197,7 +197,17 @@ class FlashPrefillCachedEpilogue(out_reg); + // iff ElementOutput == ElementAccumulator, then convert_type doesn't do the right conversion + // so we call copy() which internally performs a static_cast op on the data. + // for ElementOutput == bf16 | fp16, convert_type calls relevant NumericConverter specialization. + if constexpr (cute::is_same_v) { + copy(out_reg, final_out_reg); + } else { + Tensor temp = convert_type(out_reg); + copy(temp, final_out_reg); + } + copy(params.xe_store_o, final_out_reg, tOgO); } // SequenceLengthShapeType = Shape diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp index 8e34a80c96..b7ebd85a45 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp @@ -164,7 +164,7 @@ struct FlashPrefillCachedMma, ProblemShapeTyp // Paged KV Cache int const* ptr_page_table; int page_size; - int num_pages_per_seq; + int const* num_pages_per_seq; }; struct Params { @@ -176,7 +176,7 @@ struct FlashPrefillCachedMma, ProblemShapeTyp // Paged KV Cache int const* ptr_page_table; int page_size; - int num_pages_per_seq; + int const* num_pages_per_seq; }; // diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_prefill_cachedKV.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_prefill_cachedKV.hpp index 693996c0b7..56e3dc242c 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_prefill_cachedKV.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_prefill_cachedKV.hpp @@ -178,7 +178,8 @@ class FMHAPrefillCached { static bool can_implement(Arguments const &args) { bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or (args.mode == gemm::GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); - return mode_implementable; + bool valid_page_size = !PagedKV || (args.mainloop.page_size >= QK_BLK_N && args.mainloop.page_size % QK_BLK_N == 0); + return mode_implementable && valid_page_size; } static int get_workspace_size(Arguments const &args) { return 0; } @@ -314,10 +315,22 @@ class FMHAPrefillCached { } auto& prefetch_K = (seq_len_kv_cache == 0) ? tiled_prefetch_k: tiled_prefetch_k_cache; auto& pKgK1_ = (seq_len_kv_cache == 0) ? pKgK: pKgK_cache; + + int cached_nblock = 0; + if constexpr (PagedKV) { + if (seq_len_kv_cache != 0) { + int curr_batch_pages = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord + 1] - mainloop_params.num_pages_per_seq[batch_coord] + : ceil_div(seq_len_kv_cache, mainloop_params.page_size); + int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages; + cached_nblock = mainloop_params.ptr_page_table[ + batch_offset // page table for this batch + ] * tiles_per_page; // base block idx of physical page + } + } // The headsize for both cached and non-cached version is the same for (int j = 0; j < size<4>(pKgK1_); j++) { CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < DispatchPolicy::Stages; i++) { + for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; i++) { prefetch(prefetch_K, pKgK1_(_, _, _ , i, j)); } } @@ -345,18 +358,6 @@ class FMHAPrefillCached { bool is_KV_cache = nblock < nblock_cache; - int cached_nblock = nblock; - if constexpr (PagedKV) { - if (is_KV_cache) { - // get physical page idx from page table - cached_nblock = params.mainloop.ptr_page_table[ - batch_coord * params.mainloop.num_pages_per_seq + // page table for this batch - nblock * QK_BLK_N / params.mainloop.page_size // nblock (tile idx) to logical page idx - ] * tiles_per_page + // base block idx of physical page - nblock % tiles_per_page; // offset within page - } - } - // 1) Load KV (performed inside mmaQK) auto gK_ = is_KV_cache ? gK_cache(_, _, cached_nblock, _) : gK(_, _, nblock - nblock_cache, _); auto gV_ = is_KV_cache ? gV_cache(_, _, cached_nblock) : gV(_, _, nblock - nblock_cache); @@ -372,8 +373,32 @@ class FMHAPrefillCached { // prefetching it the same way as cutlass K matrix does not make sense auto& tiled_prefetch_v_ = is_KV_cache ? tiled_prefetch_v_cache : tiled_prefetch_v; auto& pVgV_ = is_KV_cache ? pVgV_cache : pVgV; - for(int i=0; i < size<1>(pVgV); i++) { - prefetch(tiled_prefetch_v_, pVgV_cache(_, i, _ , nblock - (!is_KV_cache) * nblock_cache)); + int v_prefetch_idx = is_KV_cache ? PagedKV ? cached_nblock : nblock + : nblock - nblock_cache; + for(int i = 0; i < size<1>(pVgV_); i++) { + prefetch(tiled_prefetch_v_, pVgV_(_, i, _ , v_prefetch_idx)); + } + + int next_cached_nblock = nblock + 1; + bool is_next_KV_cache = next_cached_nblock < nblock_cache; + if constexpr (PagedKV) { + if (is_next_KV_cache) { + int curr_batch_pages = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord + 1] - mainloop_params.num_pages_per_seq[batch_coord] + : ceil_div(seq_len_kv_cache, mainloop_params.page_size); + int next_page_logical_idx = next_cached_nblock * QK_BLK_N / params.mainloop.page_size; + int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages; + bool valid_page = next_page_logical_idx < curr_batch_pages; + // get physical page idx from page table + if (valid_page) { + next_cached_nblock = params.mainloop.ptr_page_table[ + batch_offset + // page table for this batch + next_page_logical_idx // nblock (tile idx) to logical page idx + ] * tiles_per_page + // base block idx of physical page + next_cached_nblock % tiles_per_page; // offset within page + } else { + next_cached_nblock = curr_batch_pages * tiles_per_page; // push idx out of bounds to respect the boundary between batches + } + } } // 4) Fused softmax @@ -382,16 +407,26 @@ class FMHAPrefillCached { // 5) Perform GEMM O = S*V collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, mainloop_params, is_KV_cache); - + + // Prefetch the next Q tile + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + + is_KV_cache = is_next_KV_cache; + cached_nblock = next_cached_nblock; // Prefetch the next K tile // there is no need to gaurd it with if statememt as prefetch will ignore out of bound reading bool sel_prefetch_k = (nblock + DispatchPolicy::Stages) < nblock_cache; auto& prefetch_k_selector = sel_prefetch_k ? tiled_prefetch_k_cache: tiled_prefetch_k; auto& pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK; + int k_prefetch_idx = sel_prefetch_k ? PagedKV ? cached_nblock : nblock + DispatchPolicy::Stages + : nblock + DispatchPolicy::Stages - nblock_cache; CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size<4>(pKgK_); j++) { - prefetch(prefetch_k_selector, pKgK_(_, _, _, (nblock + DispatchPolicy::Stages) - (!sel_prefetch_k) * nblock_cache , j)); + prefetch(prefetch_k_selector, pKgK_(_, _, _, k_prefetch_idx , j)); } barrier_wait(barrier_scope); } @@ -406,8 +441,8 @@ class FMHAPrefillCached { collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock_new - 1, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, false); // we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big, // prefetching it the same way as cutlass K matrix does not make sense - for(int i=0; i< size<1>(pVgV); i++) { - prefetch(tiled_prefetch_v, pVgV(_, i, _ , nblock_new - 1)); + for(int i = 0; i< size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _ , nblock_new - 1)); } // mask the elements of each tile where j > i const int item_id = thread_idx % SubgroupSize; @@ -420,7 +455,7 @@ class FMHAPrefillCached { CUTLASS_PRAGMA_UNROLL for (int row = 0; row < Vec; row++, row_idx++) { // 8 if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { - tSr(row, m, n) = -INFINITY; + tSr(row, m, n) = ElementAccumulator{-INFINITY}; } } } diff --git a/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp b/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp index d73151a035..dc81c062bc 100644 --- a/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp +++ b/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp @@ -207,7 +207,7 @@ template struct BenchmarkRunnerFMHA { int kv_group_update=1; for (int h = 0; h < num_heads_q; h++) { - cutlass::DeviceAllocation block_S; + cutlass::DeviceAllocation block_S; block_S.reset(seq_len_qo * seq_len_kv_total); ElementK* k_ptr; @@ -254,11 +254,10 @@ template struct BenchmarkRunnerFMHA { cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total})); cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({seq_len_kv_total, head_size_vo})); cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); - cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo})); - cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, 1.f, ref_Q, + cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, ElementAccumulator{1}, ref_Q, cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, - 0.f, ref_S, ref_S, ElementAccumulator(0), + ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0}, 1, // batch_count seq_len_qo * head_size_qk, // batch_stride_Q seq_len_kv_total * head_size_qk, // batch_stride_K @@ -268,8 +267,8 @@ template struct BenchmarkRunnerFMHA { syclcompat::wait(); - std::vector host_S(block_S.size()); - syclcompat::memcpy(host_S.data(), block_S.get(), host_S.size()); + std::vector host_S(block_S.size()); + syclcompat::memcpy(host_S.data(), block_S.get(), host_S.size()); syclcompat::wait(); // delete this memory as it is no longer needed @@ -283,13 +282,13 @@ template struct BenchmarkRunnerFMHA { for (int row = 0; row < seq_len_qo; row++) { for (int col = start_col; col < seq_len_kv_total; col++) { if (col - full_tile_offset > row + start_col - discard_seq_coord) - host_S[col + row * seq_len_kv_total] = -INFINITY; + host_S[col + row * seq_len_kv_total] = ElementAccumulator{-INFINITY}; } } } // compute max element per row of S - std::vector max_vec(seq_len_qo, -INFINITY); + std::vector max_vec(seq_len_qo, ElementAccumulator{-INFINITY}); for (int row = 0; row < seq_len_qo; row++) { int idx = row * seq_len_kv_total; int max_idx = row; @@ -305,12 +304,12 @@ template struct BenchmarkRunnerFMHA { int idx = row * seq_len_kv_total; int max_idx = row; for (int col = 0; col < seq_len_kv_total; col++, idx++) { - host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast((head_size_qk)))); + host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast((head_size_qk)))); } } // compute sum per row of S - std::vector sum_vec(seq_len_qo, ElementOutput{0}); + std::vector sum_vec(seq_len_qo, ElementAccumulator{0}); for (int row = 0; row < seq_len_qo; row++) { int idx = row * seq_len_kv_total; int sum_idx = row; @@ -342,9 +341,13 @@ template struct BenchmarkRunnerFMHA { cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); - cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, 1.f, ref_P, + cutlass::DeviceAllocation block_acc; + block_acc.reset(seq_len_qo * head_size_vo); + cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); + + cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, ElementAccumulator{1}, ref_P, cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, - 0.f, ref_O, ref_O, ElementAccumulator(0), + ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0}, 1, // batch_count seq_len_qo * seq_len_kv_total, // batch_stride_P seq_len_kv_total * head_size_vo, // batch_stride_V @@ -356,6 +359,19 @@ template struct BenchmarkRunnerFMHA { // delete this memory as it is no longer needed block_P.reset(); + std::vector vec_acc(block_acc.size()); + syclcompat::memcpy(vec_acc.data(), block_acc.get(), vec_acc.size()); + syclcompat::wait(); + + // delete this memory as it is no longer needed + block_acc.reset(); + std::vector vec_out(vec_acc.size()); + for(int i = 0; i < vec_out.size(); i++) { + vec_out[i] = static_cast(vec_acc[i]); + } + syclcompat::memcpy(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size()); + syclcompat::wait(); + offset_q += seq_len_qo * head_size_qk; if(kv_group_update % q_group_size==0) { offset_k += seq_len_kv * head_size_qk; @@ -372,7 +388,7 @@ template struct BenchmarkRunnerFMHA { // Check if output from CUTLASS kernel and reference kernel are equal or not bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), - block_O.size(), 0.5f, 0.5f); + block_O.size(), ElementOutput{0.5}, ElementOutput{0.5}); return passed; } diff --git a/benchmarks/flash_attention/flash_attention_prefill_cachedKV/fmha_prefill_configuration.hpp b/benchmarks/flash_attention/flash_attention_prefill_cachedKV/fmha_prefill_configuration.hpp index be437c3239..9f1ab85a3e 100644 --- a/benchmarks/flash_attention/flash_attention_prefill_cachedKV/fmha_prefill_configuration.hpp +++ b/benchmarks/flash_attention/flash_attention_prefill_cachedKV/fmha_prefill_configuration.hpp @@ -69,7 +69,7 @@ struct FMHAPrefillConfig { using MMAOperation = typename MMAOP::Type; using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillCachedEpilogue< EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, - SubgroupLayout, ElementAccumulator, + SubgroupLayout, ElementAccumulator, ElementOutput, cutlass::gemm::TagToStrideC_t, ElementOutput, GmemTiledCopyO>; diff --git a/examples/sycl/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp b/examples/sycl/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp index be2af5e2bb..0600b6ce0f 100644 --- a/examples/sycl/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp +++ b/examples/sycl/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp @@ -199,7 +199,7 @@ template struct ExampleRunner { struct PagedKVParams { cutlass::DeviceAllocation page_table; int page_size = 0; - int num_pages_per_seq = 0; + cutlass::DeviceAllocation num_pages_per_seq; }; PagedKVParams paged_kv_cache; @@ -245,58 +245,55 @@ template struct ExampleRunner { int kv_group_update = 1; for (int h = 0; h < num_heads_q; h++) { - cutlass::DeviceAllocation block_S; + cutlass::DeviceAllocation block_S; block_S.reset(seq_len_qo * seq_len_kv_total); ElementK* k_ptr; ElementV* v_ptr; if (use_kv_cache) { - cutlass::DeviceAllocation block_K_concat(head_size_qk * seq_len_kv_total); - cutlass::DeviceAllocation block_V_concat(seq_len_kv_total * head_size_vo); - - // Concatenate K_cache and K - syclcompat::memcpy( - block_K_concat.get(), - block_K_cache.get() + offset_k_cache, - seq_len_kv_cache * head_size_qk - ); - syclcompat::memcpy( - block_K_concat.get() + seq_len_kv_cache * head_size_qk, - block_K.get() + offset_k, - seq_len_kv * head_size_qk - ); - - // Concatenate V_cache and V - syclcompat::memcpy( - block_V_concat.get(), - block_V_cache.get() + offset_v_cache, - seq_len_kv_cache * head_size_vo - ); - syclcompat::memcpy( - block_V_concat.get() + seq_len_kv_cache * head_size_vo, - block_V.get() + offset_v, - seq_len_kv * head_size_vo - ); - syclcompat::wait(); - - k_ptr = block_K_concat.get(); - v_ptr = block_V_concat.get(); - } - else { - k_ptr = block_K.get() + offset_k; - v_ptr = block_V.get() + offset_v; + cutlass::DeviceAllocation block_K_concat(head_size_qk * seq_len_kv_total); + cutlass::DeviceAllocation block_V_concat(seq_len_kv_total * head_size_vo); + + // Concatenate K_cache and K + syclcompat::memcpy( + block_K_concat.get(), + block_K_cache.get() + offset_k_cache, + seq_len_kv_cache * head_size_qk + ); + syclcompat::memcpy( + block_K_concat.get() + seq_len_kv_cache * head_size_qk, + block_K.get() + offset_k, + seq_len_kv * head_size_qk + ); + + // Concatenate V_cache and V + syclcompat::memcpy( + block_V_concat.get(), + block_V_cache.get() + offset_v_cache, + seq_len_kv_cache * head_size_vo + ); + syclcompat::memcpy( + block_V_concat.get() + seq_len_kv_cache * head_size_vo, + block_V.get() + offset_v, + seq_len_kv * head_size_vo + ); + + k_ptr = block_K_concat.get(); + v_ptr = block_V_concat.get(); + } else { + k_ptr = block_K.get() + offset_k; + v_ptr = block_V.get() + offset_v; } cutlass::TensorRef ref_Q(block_Q.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total})); cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({seq_len_kv_total, head_size_vo})); cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); - cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo})); - cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, 1.f, ref_Q, + cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, ElementAccumulator{1}, ref_Q, cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, - 0.f, ref_S, ref_S, ElementAccumulator(0), + ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0}, 1, // batch_count seq_len_qo * head_size_qk, // batch_stride_Q seq_len_kv_total * head_size_qk, // batch_stride_K @@ -306,13 +303,11 @@ template struct ExampleRunner { syclcompat::wait(); - std::vector host_S(block_S.size()); - syclcompat::memcpy(host_S.data(), block_S.get(), host_S.size()); - syclcompat::wait(); + std::vector host_S(block_S.size()); + syclcompat::memcpy(host_S.data(), block_S.get(), host_S.size()); // delete this memory as it is no longer needed block_S.reset(); - auto offset = cute::min(seq_len_qo, seq_len_kv); auto discard_seq_coord = seq_len_qo - offset; auto full_tile_offset = seq_len_kv - offset; @@ -322,13 +317,13 @@ template struct ExampleRunner { for (int row = 0; row < seq_len_qo; row++) { for (int col = start_col; col < seq_len_kv_total; col++) { if (col - full_tile_offset > row + start_col - discard_seq_coord) - host_S[col + row * seq_len_kv_total] = -INFINITY; + host_S[col + row * seq_len_kv_total] = ElementAccumulator{-INFINITY}; } } } // compute max element per row of S - std::vector max_vec(seq_len_qo, -INFINITY); + std::vector max_vec(seq_len_qo, ElementAccumulator{-INFINITY}); for (int row = 0; row < seq_len_qo; row++) { int idx = row * seq_len_kv_total; int max_idx = row; @@ -344,12 +339,12 @@ template struct ExampleRunner { int idx = row * seq_len_kv_total; int max_idx = row; for (int col = 0; col < seq_len_kv_total; col++, idx++) { - host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast((head_size_qk)))); + host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast((head_size_qk)))); } } // compute sum per row of S - std::vector sum_vec(seq_len_qo, ElementOutput{0}); + std::vector sum_vec(seq_len_qo, ElementAccumulator{0}); for (int row = 0; row < seq_len_qo; row++) { int idx = row * seq_len_kv_total; int sum_idx = row; @@ -361,7 +356,7 @@ template struct ExampleRunner { idx = row * seq_len_kv_total; sum_idx = row; for (int col = 0; col < seq_len_kv_total; col++, idx++) { - if(is_causal && row < discard_seq_coord) { + if(is_causal && row < discard_seq_coord) { host_S[idx] = 0; } else { host_S[idx] /= sum_vec[sum_idx]; @@ -377,13 +372,16 @@ template struct ExampleRunner { block_P.reset(host_P.size()); syclcompat::memcpy(block_P.get(), host_P.data(), host_P.size()); - syclcompat::wait(); cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); - cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, 1.f, ref_P, + cutlass::DeviceAllocation block_acc; + block_acc.reset(seq_len_qo * head_size_vo); + cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); + + cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, ElementAccumulator{1}, ref_P, cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, - 0.f, ref_O, ref_O, ElementAccumulator(0), + ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0}, 1, // batch_count seq_len_qo * seq_len_kv_total, // batch_stride_P seq_len_kv_total * head_size_vo, // batch_stride_V @@ -395,6 +393,17 @@ template struct ExampleRunner { // delete this memory as it is no longer needed block_P.reset(); + std::vector vec_acc(block_acc.size()); + syclcompat::memcpy(vec_acc.data(), block_acc.get(), vec_acc.size()); + + // delete this memory as it is no longer needed + block_acc.reset(); + std::vector vec_out(vec_acc.size()); + for(int i = 0; i < vec_out.size(); i++) { + vec_out[i] = static_cast(vec_acc[i]); + } + syclcompat::memcpy(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size()); + offset_q += seq_len_qo * head_size_qk; if(kv_group_update % q_group_size==0) { offset_k += seq_len_kv * head_size_qk; @@ -411,7 +420,7 @@ template struct ExampleRunner { // Check if output from CUTLASS kernel and reference kernel are equal or not bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), - block_O.size(), 0.5f, 0.5f); + block_O.size(), ElementOutput{0.5}, ElementOutput{0.5}); return passed; } @@ -528,27 +537,33 @@ template struct ExampleRunner { block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); if (options.use_paged_kv) { - paged_kv_cache.page_size = options.page_size; - int num_pages_per_seq = seq_len_kv_cache / paged_kv_cache.page_size; - paged_kv_cache.num_pages_per_seq = num_pages_per_seq; - - int num_pages = num_pages_per_seq * batch; - paged_kv_cache.page_table.reset(batch * num_pages_per_seq); - - // initialize block table with random mapping for non-contiguous layout - std::vector page_mapping(batch * num_pages_per_seq); - for (int b = 0; b < batch; ++b) { - std::vector physical_pages(num_pages_per_seq); - std::iota(physical_pages.begin(), physical_pages.end(), 0); - // shuffle physical pages - std::shuffle(physical_pages.begin(), physical_pages.end(), std::mt19937{ std::random_device{}() }); - for (int blk = 0; blk < num_pages_per_seq; ++blk) { - int logical_idx = b * num_pages_per_seq + blk; - page_mapping[logical_idx] = physical_pages[blk]; - } + paged_kv_cache.page_size = options.page_size; + std::vector num_pages_per_seq{0}; + int num_pages = 0; + for(int b = 0; b < cute::get<0>(problem_shape); b++) { + int seq_len_cache = isVarLen ? cumulative_seqlen_kv_cache[b + 1] - cumulative_seqlen_kv_cache[b] : seq_len_kv_cache; + int pages_per_seq = ceil_div(seq_len_cache, paged_kv_cache.page_size); + num_pages_per_seq.push_back(num_pages_per_seq.back() + pages_per_seq); + num_pages += pages_per_seq; + } + paged_kv_cache.page_table.reset(num_pages); + + // initialize block table with random mapping for non-contiguous layout + std::vector page_mapping(num_pages); + for (int b = 0; b < cute::get<0>(problem_shape); ++b) { + std::vector physical_pages(num_pages_per_seq[b + 1] - num_pages_per_seq[b]); + std::iota(physical_pages.begin(), physical_pages.end(), 0); + // shuffle physical pages + std::shuffle(physical_pages.begin(), physical_pages.end(), std::mt19937{ std::random_device{}() }); + for (int blk = 0; blk < physical_pages.size(); ++blk) { + int logical_idx = num_pages_per_seq[b] + blk; + page_mapping[logical_idx] = physical_pages[blk]; } - syclcompat::memcpy(paged_kv_cache.page_table.get(), page_mapping.data(), page_mapping.size() * sizeof(int)); - syclcompat::wait(); + } + syclcompat::memcpy(paged_kv_cache.page_table.get(), page_mapping.data(), page_mapping.size() * sizeof(int)); + + paged_kv_cache.num_pages_per_seq.reset(num_pages_per_seq.size()); + syclcompat::memcpy(paged_kv_cache.num_pages_per_seq.get(), num_pages_per_seq.data(), num_pages_per_seq.size() * sizeof(int)); } initialize_block(block_Q, seed + 2023); @@ -631,7 +646,7 @@ template struct ExampleRunner { block_V_cache.get(), stride_V_cache, options.use_paged_kv ? paged_kv_cache.page_table.get() : nullptr, options.use_paged_kv ? paged_kv_cache.page_size : 0, - options.use_paged_kv ? paged_kv_cache.num_pages_per_seq : 0}, + options.use_paged_kv ? paged_kv_cache.num_pages_per_seq.get() : nullptr}, {options.softmax_scale}, {block_O.get(), stride_O}, hw_info}; @@ -694,7 +709,7 @@ template struct ExampleRunner { << "\tSeq Length KV: " << options.seq_len_kv << "\tSeq Length KV Cache: " << options.seq_len_kv_cache << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") - << "\t Scheduler: " << options.scheduler; + << "\t Scheduler: " << options.scheduler << "\t Paged KV cache: " << (options.use_paged_kv ? "true" : "false"); printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000); } @@ -702,7 +717,23 @@ template struct ExampleRunner { } }; -template struct FMHAConfig { +// the default value used for the case BF16 +template struct FMHAConfig { template static int run(const Options &options) { @@ -714,22 +745,10 @@ template ; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using MMAOperation = XE_8x16x16_F32BF16BF16F32_TT; - using GmemTiledCopyQ = XE_2D_U16x8x32_LD_N; - using GmemTiledCopyK = XE_2D_U16x16x16_LD_T; // _T designates a transposed block load operation - using GmemTiledCopyV = XE_2D_U16x16x32_LD_V; - using GmemTiledCopyStore = XE_2D_U32x8x16_ST_N; using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillCachedEpilogue< - EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, + EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t, ElementOutput, GmemTiledCopyStore>; using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue; @@ -758,12 +777,14 @@ template (options); - } - if(options.varlen) { - return run(options); + if (options.use_paged_kv && !options.varlen) { + return run(options); + } else if(!options.use_paged_kv && options.varlen) { + return run(options); + } else if(!options.use_paged_kv && !options.varlen) { + return run(options); + } else { + return run(options); } - return run(options); } }; diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp index 5298490090..287302745d 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp @@ -91,7 +91,7 @@ struct XE_Flash_Attention_Prefill_CachedKV { using GmemTiledCopyV = cute::XE_2D_U16x16x32_LD_V; using GmemTiledCopyStore = cute::XE_2D_U32x8x16_ST_N; using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillCachedEpilogue< - EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, + EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementAccumulator, ElementOutput, cutlass::gemm::TagToStrideC_t, ElementOutput, GmemTiledCopyStore>; using FlashPrefillSoftmaxEpilogue = cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue< HasCausalMask, EpilogueDispatchPolicy, ElementAccumulator>; @@ -196,7 +196,7 @@ struct TestbedImpl { struct PagedKVParams { cutlass::DeviceAllocation page_table; int page_size = 128; - int num_pages_per_seq = 0; + cutlass::DeviceAllocation num_pages_per_seq; }; PagedKVParams paged_kv_cache; @@ -246,27 +246,33 @@ struct TestbedImpl { block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); - if (use_kv_cache && UsePagedKV) { - int num_pages_per_seq = seq_len_kv_cache / paged_kv_cache.page_size; - paged_kv_cache.num_pages_per_seq = num_pages_per_seq; - - int num_pages = num_pages_per_seq * batch; - paged_kv_cache.page_table.reset(batch * num_pages_per_seq); - - // initialize block table with random mapping for non-contiguous layout - std::vector page_mapping(batch * num_pages_per_seq); - for (int b = 0; b < batch; ++b) { - std::vector physical_pages(num_pages_per_seq); - std::iota(physical_pages.begin(), physical_pages.end(), 0); - // shuffle physical pages - std::shuffle(physical_pages.begin(), physical_pages.end(), std::mt19937{ std::random_device{}() }); - for (int blk = 0; blk < num_pages_per_seq; ++blk) { - int logical_idx = b * num_pages_per_seq + blk; - page_mapping[logical_idx] = physical_pages[blk]; - } + if constexpr (UsePagedKV) { + std::vector num_pages_per_seq{0}; + int num_pages = 0; + for(int b = 0; b < cute::get<0>(problem_shape); b++) { + int seq_len_cache = isVarLen ? cumulative_seqlen_kv_cache[b + 1] - cumulative_seqlen_kv_cache[b] : seq_len_kv_cache; + int pages_per_seq = ceil_div(seq_len_cache, paged_kv_cache.page_size); + num_pages_per_seq.push_back(num_pages_per_seq.back() + pages_per_seq); + num_pages += pages_per_seq; + } + paged_kv_cache.page_table.reset(num_pages); + + // initialize block table with random mapping for non-contiguous layout + std::vector page_mapping(num_pages); + for (int b = 0; b < cute::get<0>(problem_shape); ++b) { + std::vector physical_pages(num_pages_per_seq[b + 1] - num_pages_per_seq[b]); + std::iota(physical_pages.begin(), physical_pages.end(), 0); + // shuffle physical pages + std::shuffle(physical_pages.begin(), physical_pages.end(), std::mt19937{ std::random_device{}() }); + for (int blk = 0; blk < physical_pages.size(); ++blk) { + int logical_idx = num_pages_per_seq[b] + blk; + page_mapping[logical_idx] = physical_pages[blk]; } - syclcompat::memcpy(paged_kv_cache.page_table.get(), page_mapping.data(), page_mapping.size() * sizeof(int)); - syclcompat::wait(); + } + syclcompat::memcpy(paged_kv_cache.page_table.get(), page_mapping.data(), page_mapping.size() * sizeof(int)); + + paged_kv_cache.num_pages_per_seq.reset(num_pages_per_seq.size()); + syclcompat::memcpy(paged_kv_cache.num_pages_per_seq.get(), num_pages_per_seq.data(), num_pages_per_seq.size() * sizeof(int)); } initialize_block(block_Q, seed + 2023); @@ -415,58 +421,55 @@ struct TestbedImpl { int kv_group_update = 1; for (int h = 0; h < num_heads_q; h++) { - cutlass::DeviceAllocation block_S; + cutlass::DeviceAllocation block_S; block_S.reset(seq_len_qo * seq_len_kv_total); ElementK* k_ptr; ElementV* v_ptr; if (use_kv_cache) { - cutlass::DeviceAllocation block_K_concat(head_size_qk * seq_len_kv_total); - cutlass::DeviceAllocation block_V_concat(seq_len_kv_total * head_size_vo); - - // Concatenate K_cache and K - syclcompat::memcpy( - block_K_concat.get(), - block_K_cache.get() + offset_k_cache, - seq_len_kv_cache * head_size_qk - ); - syclcompat::memcpy( - block_K_concat.get() + seq_len_kv_cache * head_size_qk, - block_K.get() + offset_k, - seq_len_kv * head_size_qk - ); - - // Concatenate V_cache and V - syclcompat::memcpy( - block_V_concat.get(), - block_V_cache.get() + offset_v_cache, - seq_len_kv_cache * head_size_vo - ); - syclcompat::memcpy( - block_V_concat.get() + seq_len_kv_cache * head_size_vo, - block_V.get() + offset_v, - seq_len_kv * head_size_vo - ); - syclcompat::wait(); - - k_ptr = block_K_concat.get(); - v_ptr = block_V_concat.get(); - } - else { - k_ptr = block_K.get() + offset_k; - v_ptr = block_V.get() + offset_v; + cutlass::DeviceAllocation block_K_concat(head_size_qk * seq_len_kv_total); + cutlass::DeviceAllocation block_V_concat(seq_len_kv_total * head_size_vo); + + // Concatenate K_cache and K + syclcompat::memcpy( + block_K_concat.get(), + block_K_cache.get() + offset_k_cache, + seq_len_kv_cache * head_size_qk + ); + syclcompat::memcpy( + block_K_concat.get() + seq_len_kv_cache * head_size_qk, + block_K.get() + offset_k, + seq_len_kv * head_size_qk + ); + + // Concatenate V_cache and V + syclcompat::memcpy( + block_V_concat.get(), + block_V_cache.get() + offset_v_cache, + seq_len_kv_cache * head_size_vo + ); + syclcompat::memcpy( + block_V_concat.get() + seq_len_kv_cache * head_size_vo, + block_V.get() + offset_v, + seq_len_kv * head_size_vo + ); + + k_ptr = block_K_concat.get(); + v_ptr = block_V_concat.get(); + } else { + k_ptr = block_K.get() + offset_k; + v_ptr = block_V.get() + offset_v; } cutlass::TensorRef ref_Q(block_Q.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total})); cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({seq_len_kv_total, head_size_vo})); cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); - cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo})); - cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, 1.f, ref_Q, + cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, ElementAccumulator{1}, ref_Q, cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, - 0.f, ref_S, ref_S, ElementAccumulator(0), + ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0}, 1, // batch_count seq_len_qo * head_size_qk, // batch_stride_Q seq_len_kv_total * head_size_qk, // batch_stride_K @@ -476,13 +479,11 @@ struct TestbedImpl { syclcompat::wait(); - std::vector host_S(block_S.size()); - syclcompat::memcpy(host_S.data(), block_S.get(), host_S.size()); - syclcompat::wait(); + std::vector host_S(block_S.size()); + syclcompat::memcpy(host_S.data(), block_S.get(), host_S.size()); // delete this memory as it is no longer needed block_S.reset(); - auto offset = cute::min(seq_len_qo, seq_len_kv); auto discard_seq_coord = seq_len_qo - offset; auto full_tile_offset = seq_len_kv - offset; @@ -492,13 +493,13 @@ struct TestbedImpl { for (int row = 0; row < seq_len_qo; row++) { for (int col = start_col; col < seq_len_kv_total; col++) { if (col - full_tile_offset > row + start_col - discard_seq_coord) - host_S[col + row * seq_len_kv_total] = -INFINITY; + host_S[col + row * seq_len_kv_total] = ElementAccumulator{-INFINITY}; } } } // compute max element per row of S - std::vector max_vec(seq_len_qo, -INFINITY); + std::vector max_vec(seq_len_qo, ElementAccumulator{-INFINITY}); for (int row = 0; row < seq_len_qo; row++) { int idx = row * seq_len_kv_total; int max_idx = row; @@ -514,12 +515,12 @@ struct TestbedImpl { int idx = row * seq_len_kv_total; int max_idx = row; for (int col = 0; col < seq_len_kv_total; col++, idx++) { - host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) * softmax_scale); + host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) * static_cast(softmax_scale)); } } // compute sum per row of S - std::vector sum_vec(seq_len_qo, ElementOutput{0}); + std::vector sum_vec(seq_len_qo, ElementAccumulator{0}); for (int row = 0; row < seq_len_qo; row++) { int idx = row * seq_len_kv_total; int sum_idx = row; @@ -531,7 +532,7 @@ struct TestbedImpl { idx = row * seq_len_kv_total; sum_idx = row; for (int col = 0; col < seq_len_kv_total; col++, idx++) { - if(HasCausalMask && row < discard_seq_coord) { + if(HasCausalMask && row < discard_seq_coord) { host_S[idx] = 0; } else { host_S[idx] /= sum_vec[sum_idx]; @@ -547,13 +548,16 @@ struct TestbedImpl { block_P.reset(host_P.size()); syclcompat::memcpy(block_P.get(), host_P.data(), host_P.size()); - syclcompat::wait(); cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); - cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, 1.f, ref_P, + cutlass::DeviceAllocation block_acc; + block_acc.reset(seq_len_qo * head_size_vo); + cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); + + cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, ElementAccumulator{1}, ref_P, cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, - 0.f, ref_O, ref_O, ElementAccumulator(0), + ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0}, 1, // batch_count seq_len_qo * seq_len_kv_total, // batch_stride_P seq_len_kv_total * head_size_vo, // batch_stride_V @@ -565,8 +569,19 @@ struct TestbedImpl { // delete this memory as it is no longer needed block_P.reset(); + std::vector vec_acc(block_acc.size()); + syclcompat::memcpy(vec_acc.data(), block_acc.get(), vec_acc.size()); + + // delete this memory as it is no longer needed + block_acc.reset(); + std::vector vec_out(vec_acc.size()); + for(int i = 0; i < vec_out.size(); i++) { + vec_out[i] = static_cast(vec_acc[i]); + } + syclcompat::memcpy(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size()); + offset_q += seq_len_qo * head_size_qk; - if(kv_group_update % q_group_size == 0) { + if(kv_group_update % q_group_size==0) { offset_k += seq_len_kv * head_size_qk; offset_v += seq_len_kv * head_size_vo; offset_k_cache += seq_len_kv_cache * head_size_qk; @@ -581,7 +596,7 @@ struct TestbedImpl { // Check if output from CUTLASS kernel and reference kernel are equal or not bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), - block_O.size(), 0.5f, 0.5f); + block_O.size(), ElementOutput{0.5}, ElementOutput{0.5}); return passed; } @@ -629,7 +644,7 @@ struct TestbedImpl { block_V_cache.get(), stride_V_cache, UsePagedKV ? paged_kv_cache.page_table.get() : nullptr, UsePagedKV ? paged_kv_cache.page_size : 0, - UsePagedKV ? paged_kv_cache.num_pages_per_seq : 0}, + UsePagedKV ? paged_kv_cache.num_pages_per_seq.get() : nullptr}, {softmax_scale}, {block_O.get(), stride_O}, hw_info};