diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp index 75ecdc9359..68a08e916f 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp @@ -31,6 +31,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/fp8_to_fp16.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cute/algorithm/functional.hpp" @@ -147,6 +148,8 @@ struct FlashPrefillCachedMma, ProblemShapeTyp using atom_load_V = Copy_Atom; using val_layout_load_V = decltype(make_layout(shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{}, Layout{}, val_layout_load_V{})); + template + static constexpr bool is_fp8_v = cute::is_any_of_v; // Host side kernel arguments struct Arguments { @@ -227,10 +230,11 @@ struct FlashPrefillCachedMma, ProblemShapeTyp Tensor tCgK = thread_mma_k.partition_B(gK); // Create fragments - // TODO(Codeplay): fix this, this is probably not general - Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape()))); - Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape()))); - + using TCrQ_Type = cute::conditional_t, uint8_t, ElementQ>; + using TCrK_Type = cute::conditional_t, uint8_t, ElementK>; + Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape()))); + Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape()))); + // Retile registers for copies Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); Tensor tKrK = thr_copy_K.retile_D(tCrK); @@ -270,7 +274,23 @@ struct FlashPrefillCachedMma, ProblemShapeTyp for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { copy(params.gmem_tiled_copy_q, tQgQ(_,_,_,k_tile), tQrQ); copy(gmem_tiled_copy_k, tKgK(_,_,_,k_tile), tKrK); - cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); + if constexpr (is_fp8_v && is_fp8_v) { + auto tCrQ_fp16 = make_fragment_like(tCrQ); + convert_FP8_to_FP16(tCrQ, tCrQ_fp16); + auto tCrK_fp16 = make_fragment_like(tCrK); + convert_FP8_to_FP16(tCrK, tCrK_fp16); + cute::gemm(tiled_mma, accum, tCrQ_fp16, tCrK_fp16, frag_src); + } else if constexpr (is_fp8_v && !is_fp8_v) { + auto tCrQ_fp16 = make_fragment_like(tCrQ); + convert_FP8_to_FP16(tCrQ, tCrQ_fp16); + cute::gemm(tiled_mma, accum, tCrQ_fp16 , tCrK, frag_src); + } else if constexpr (!is_fp8_v && is_fp8_v) { + auto tCrK_fp16 = make_fragment_like(tCrK); + convert_FP8_to_FP16(tCrK, tCrK_fp16); + cute::gemm(tiled_mma, accum, tCrQ , tCrK_fp16, frag_src); + } else { + cute::gemm(tiled_mma, accum, tCrQ , tCrK, frag_src); + } } } @@ -289,7 +309,8 @@ struct FlashPrefillCachedMma, ProblemShapeTyp auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); Tensor tCgV = thread_mma.partition_B(gV_); - Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0,3>(tCgV.shape()))); + using TCrV_Type = cute::conditional_t, uint8_t, ElementV>; + Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0,3>(tCgV.shape()))); // Partition the copying of A and B tiles across the threads auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx); @@ -321,7 +342,13 @@ struct FlashPrefillCachedMma, ProblemShapeTyp CUTLASS_PRAGMA_UNROLL for(int i = 0; i< tile_count; i++) { copy(gmem_tiled_copy_v, tVgV(_,_,_,i), tVrV); - cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV, frag_src(_,_,_,i)); + if constexpr (is_fp8_v) { + auto tCrV_fp16 = make_fragment_like(tCrV); + convert_FP8_to_FP16(tCrV, tCrV_fp16); + cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV_fp16, frag_src(_,_,_,i)); + } else { + cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV, frag_src(_,_,_,i)); + } } } diff --git a/examples/06_bmg_flash_attention/06_bmg_prefill_attention_prefill_cachedKV_fp8.cpp b/examples/06_bmg_flash_attention/06_bmg_prefill_attention_prefill_cachedKV_fp8.cpp new file mode 100644 index 0000000000..2e08537cab --- /dev/null +++ b/examples/06_bmg_flash_attention/06_bmg_prefill_attention_prefill_cachedKV_fp8.cpp @@ -0,0 +1,122 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Flash Attention V2 Prefill for Intel BMG + + This example constructs and executes a Flash Attention Prefill with KV cache on Intel BMG. The + definition of the GEMM, options etc for this example are defined in the associated + bmg_flash_attn_cachedKV_runner.hpp header file. + + See https://arxiv.org/pdf/2307.08691 for details of Flash Attention V2 algorithm + + To run this example: + $ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV --seq_len_qo=512 + --seq_len_kv=512 --seq_len_kv_cache=512 --head_size_vo=128 --head_size_qk=128 + + Causal masking of the first matrix multiplication is supported (`--is_causal`) + + To build & run this example (from your build dir): + + $ ninja 06_bmg_prefill_attention_cachedKV + $ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV + + Call with `--help` for information about available options +*/ + +#include "bmg_flash_attn_prefill_cachedKV_runner.hpp" + +int main(int argc, const char **argv) { + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // Define the work-group tile shape depending on the head-size of the second matmul + // Shape<_SequenceLenthOutputBLOCK, _HeadSizeout(NV), SequenceLengthKVBLOCK_KN/KV, HeadSizeQKBLOCK_KQK, HEADSIZEOutSlicerBlock> + // +#if !defined(HEAD_DIM) + std::cerr << "HEAD_DIM must be defined" << std::endl; + return -1; +#endif + if (options.head_size_vo != HEAD_DIM) { + std::cerr << "head_size_vo must be " << HEAD_DIM << ", but got " << options.head_size_vo << std::endl; + return -1; + } + + using ElementInputQ = cutlass::float_e5m2_t; // <- data type of elements in input matrix A + using ElementInputKV = cutlass::float_e5m2_t; // <- data type of elements in input matrix B + using MMAOperation = XE_8x16x16_F32F16F16F32_TT; + using GmemTiledCopyQ = XE_2D_U8x8x32_LD_N; + using GmemTiledCopyK = XE_2D_U8x16x16_LD_T; // _T designates a transposed block load operation + using GmemTiledCopyV = XE_2D_U8x32x32_LD_V; + constexpr int PipelineStages = 2; +#if HEAD_DIM == 64 + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 96 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _96, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 128 + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _128, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 192 + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutPut = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#endif + return options.is_causal ? FMHAConfig::run(options) + : FMHAConfig::run(options); + +} diff --git a/examples/06_bmg_flash_attention/CMakeLists.txt b/examples/06_bmg_flash_attention/CMakeLists.txt index 39752da4ed..fdd789ecf0 100644 --- a/examples/06_bmg_flash_attention/CMakeLists.txt +++ b/examples/06_bmg_flash_attention/CMakeLists.txt @@ -54,12 +54,21 @@ foreach(HEAD_DIM 64 96 128 192) TEST_NO_PAGED TEST_PAGED ) + cutlass_example_add_executable( 06_bmg_prefill_attention_fp8_hdim${HEAD_DIM} 06_bmg_prefill_attention_fp8.cpp TEST_COMMAND_OPTIONS ) + cutlass_example_add_executable( + 06_bmg_prefill_attention_cachedKV_fp8_hdim${HEAD_DIM} + 06_bmg_prefill_attention_prefill_cachedKV_fp8.cpp + TEST_COMMAND_OPTIONS + TEST_NO_PAGED + TEST_PAGED + ) + cutlass_example_add_executable( 06_bmg_decode_attention_fp8_hdim${HEAD_DIM} 06_bmg_decode_attention_fp8.cpp @@ -71,5 +80,6 @@ foreach(HEAD_DIM 64 96 128 192) target_compile_definitions(06_bmg_prefill_attention_cachedKV_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_bmg_decode_attention_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_bmg_prefill_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) + target_compile_definitions(06_bmg_prefill_attention_cachedKV_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_bmg_decode_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) endforeach() diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp index 0600b6ce0f..6a8334c3ff 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp @@ -203,6 +203,27 @@ template struct ExampleRunner { }; PagedKVParams paged_kv_cache; + template + void convert_fp8_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) { + syclcompat::get_default_queue().parallel_for(size, [=](auto indx) { + d_dst[indx] = static_cast(d_src[indx]); + }).wait(); + } + + template + static constexpr bool is_fp8_v = cute::is_any_of_v; + + template inline auto in_memory(cutlass::DeviceAllocation& in) { + using outType = cutlass::DeviceAllocation, half_t, Tin>>; + if constexpr(is_fp8_v) { + cutlass::DeviceAllocation out(in.size()); + convert_fp8_to_fp16(in.get(), out.get(), in.size()); + return out; + } else { + return in; + }; + } + // // Methods // @@ -221,6 +242,11 @@ template struct ExampleRunner { auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); int seq_len_qo, seq_len_kv, seq_len_kv_cache; + auto block_Q_ = in_memory(block_Q); + auto block_K_ = in_memory(block_K); + auto block_V_ = in_memory(block_V); + using ElementK_ = cute::conditional_t, half_t, ElementK>; + using ElementV_ = cute::conditional_t, half_t, ElementV>; int offset_q = 0; int offset_k = 0; int offset_v = 0; @@ -248,45 +274,47 @@ template struct ExampleRunner { cutlass::DeviceAllocation block_S; block_S.reset(seq_len_qo * seq_len_kv_total); - ElementK* k_ptr; - ElementV* v_ptr; + ElementK_* k_ptr; + ElementV_* v_ptr; if (use_kv_cache) { - cutlass::DeviceAllocation block_K_concat(head_size_qk * seq_len_kv_total); - cutlass::DeviceAllocation block_V_concat(seq_len_kv_total * head_size_vo); + auto block_K_cache_ = in_memory(block_K_cache); + auto block_V_cache_ = in_memory(block_V_cache); + cutlass::DeviceAllocation block_K_concat(head_size_qk * seq_len_kv_total); + cutlass::DeviceAllocation block_V_concat(seq_len_kv_total * head_size_vo); // Concatenate K_cache and K - syclcompat::memcpy( + syclcompat::memcpy( block_K_concat.get(), - block_K_cache.get() + offset_k_cache, + block_K_cache_.get() + offset_k_cache, seq_len_kv_cache * head_size_qk ); - syclcompat::memcpy( + syclcompat::memcpy( block_K_concat.get() + seq_len_kv_cache * head_size_qk, - block_K.get() + offset_k, + block_K_.get() + offset_k, seq_len_kv * head_size_qk ); // Concatenate V_cache and V - syclcompat::memcpy( + syclcompat::memcpy( block_V_concat.get(), - block_V_cache.get() + offset_v_cache, + block_V_cache_.get() + offset_v_cache, seq_len_kv_cache * head_size_vo ); - syclcompat::memcpy( + syclcompat::memcpy( block_V_concat.get() + seq_len_kv_cache * head_size_vo, - block_V.get() + offset_v, + block_V_.get() + offset_v, seq_len_kv * head_size_vo ); k_ptr = block_K_concat.get(); v_ptr = block_V_concat.get(); } else { - k_ptr = block_K.get() + offset_k; - v_ptr = block_V.get() + offset_v; + k_ptr = block_K_.get() + offset_k; + v_ptr = block_V_.get() + offset_v; } - cutlass::TensorRef ref_Q(block_Q.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); + cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total})); cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({seq_len_kv_total, head_size_vo})); cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); @@ -364,14 +392,14 @@ template struct ExampleRunner { } } - std::vector host_P(host_S.size()); + std::vector host_P(host_S.size()); for (int p = 0; p < host_P.size(); p++) - host_P[p] = static_cast(host_S[p]); + host_P[p] = static_cast(host_S[p]); - cutlass::DeviceAllocation block_P; + cutlass::DeviceAllocation block_P; block_P.reset(host_P.size()); - syclcompat::memcpy(block_P.get(), host_P.data(), host_P.size()); + syclcompat::memcpy(block_P.get(), host_P.data(), host_P.size()); cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total}));