From a445e4875c0fb6a0cb3dad2297b0769351269485 Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Fri, 17 Oct 2025 09:26:21 +0000 Subject: [PATCH 01/11] Support of FP8 chunk prefill Signed-off-by: Aditya Chatterjee --- src/sycl/chunked_prefill.cpp | 129 +++- src/sycl/kernels/chunk_prefill/fp8_descale.h | 140 +++++ .../chunk_prefill/xe_chunk_prefill.hpp | 570 ++++++++++-------- .../xe_flash_attn_chunk_prefill_mma.hpp | 491 +++++++++------ 4 files changed, 897 insertions(+), 433 deletions(-) create mode 100644 src/sycl/kernels/chunk_prefill/fp8_descale.h diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index d29733f..294c073 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -64,6 +64,9 @@ struct Flash_fwd_params { // The scaling factors for the kernel. float scale_softmax; float softcap; + float* __restrict__ q_scale_ptr; + float* __restrict__ k_scale_ptr; + float* __restrict__ v_scale_ptr; // array of length b+1 holding starting offset of each sequence. int* __restrict__ cu_seqlens_q; @@ -113,6 +116,7 @@ struct Flash_fwd_params { // Paged KV cache int* __restrict__ page_table; + int* __restrict__ num_pages_per_seq_ptr; int max_num_pages_per_seq; index_t page_table_batch_stride; int page_size; @@ -136,7 +140,7 @@ struct Flash_fwd_params { bool is_bf16; bool is_fp32; - bool is_e4m3; + bool is_fp8; bool is_causal; bool is_local; @@ -311,24 +315,26 @@ struct KernelRunner { typename FMHAChunkPrefillKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {// static_cast(params.q_ptr), - static_cast(params.q_ptr), + {static_cast(params.q_ptr), stride_Q, - // static_cast(params.knew_ptr), - // stride_K, - // static_cast(params.vnew_ptr), - // stride_V, - static_cast(params.k_ptr), + static_cast(params.knew_ptr), + stride_K, + static_cast(params.vnew_ptr), + stride_V, + params.q_scale_ptr, + params.k_scale_ptr, + params.v_scale_ptr, + static_cast(params.k_ptr), stride_K_cache, static_cast(params.v_ptr), stride_V_cache, params.page_table, params.page_size, - params.max_num_pages_per_seq, + params.num_pages_per_seq_ptr, -1, -1}, - {(ElementQ)params.scale_softmax}, - {static_cast(params.o_ptr), stride_O}, + {params.scale_softmax}, + {static_cast(params.o_ptr), stride_O}, hw_info}; // Define device-global scratch memory @@ -496,8 +502,9 @@ std::vector mha_fwd( auto q_type = q.scalar_type(); TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "SGL Kernel XPU only supports fp16 and bf16 type"); + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn || + q_type == at::ScalarType::Float8_e5m2, + "SGL Kernel XPU only supports fp16, bf16 and fp8 types"); TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); @@ -639,6 +646,7 @@ std::vector mha_fwd( // align with FA3 Flash_fwd_params params; params.is_bf16 = q.dtype() == torch::kBFloat16; + params.is_fp8 = q.dtype() == at::ScalarType::Float8_e4m3fn || q.dtype() == at::ScalarType::Float8_e5m2; // Set the pointers and strides. params.q_ptr = q.data_ptr(); @@ -656,6 +664,12 @@ std::vector mha_fwd( params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); + if (params.is_fp8) { + params.q_scale_ptr = static_cast(q_descale_.value().data_ptr()); + params.k_scale_ptr = static_cast(k_descale_.value().data_ptr()); + params.v_scale_ptr = static_cast(v_descale_.value().data_ptr()); + } + if (!is_varlen_q) { params.q_batch_stride = q.stride(0); params.o_batch_stride = out.stride(0); @@ -708,12 +722,15 @@ std::vector mha_fwd( params.total_k = total_k; params.b_k = batch_size_k; params.dv = head_size_v; + at::Tensor num_pages_per_seq; if (paged_KV) { params.page_table = page_table.data_ptr(); params.page_table_batch_stride = page_table.stride(0); params.max_num_pages_per_seq = max_num_pages_per_seq; params.page_size = page_size; params.num_pages = num_pages; + num_pages_per_seq = torch::zeros({batch_size_k + 1}, torch::dtype(torch::kInt32).device(q.device())); + params.num_pages_per_seq_ptr = num_pages_per_seq.data_ptr(); } if (q_v_.has_value()) { @@ -787,7 +804,91 @@ std::vector mha_fwd( auto outaccum_type = at::ScalarType::Float; constexpr int PipelineStages = 2; - if (params.is_causal) { + + if (params.is_fp8) { + using ElementInputQ = cutlass::float_e4m3_t; + using ElementInputKV = cutlass::float_e4m3_t; + using MMAOperation = XE_8x16x16_F32BF16BF16F32_TT; + using GmemTiledCopyQ = XE_2D_U8x8x32_LD_N; + using GmemTiledCopyK = XE_2D_U8x16x16_LD_T; + using GmemTiledCopyV = XE_2D_U8x32x32_LD_V; + + if (params.is_causal) { + switch (params.d) { + case 64: + FMHAConfig< + true, + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + ElementInputQ, + ElementInputKV, + MMAOperation, + GmemTiledCopyQ, + GmemTiledCopyK, + GmemTiledCopyV>::run(params); + break; + case 128: + FMHAConfig< + true, + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + ElementInputQ, + ElementInputKV, + MMAOperation, + GmemTiledCopyQ, + GmemTiledCopyK, + GmemTiledCopyV>::run(params); + break; + default: + TORCH_CHECK(false, "Unsupported head size for FP8 causal attention"); + } + } else { + switch (params.d) { + case 64: + FMHAConfig< + false, + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + ElementInputQ, + ElementInputKV, + MMAOperation, + GmemTiledCopyQ, + GmemTiledCopyK, + GmemTiledCopyV>::run(params); + break; + case 128: + FMHAConfig< + false, + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + ElementInputQ, + ElementInputKV, + MMAOperation, + GmemTiledCopyQ, + GmemTiledCopyK, + GmemTiledCopyV>::run(params); + break; + default: + TORCH_CHECK(false, "Unsupported head size for FP8 attention"); + } + } + } else if (params.is_causal) { switch (params.d) { case 64: FMHAConfig< diff --git a/src/sycl/kernels/chunk_prefill/fp8_descale.h b/src/sycl/kernels/chunk_prefill/fp8_descale.h new file mode 100644 index 0000000..b53ff2f --- /dev/null +++ b/src/sycl/kernels/chunk_prefill/fp8_descale.h @@ -0,0 +1,140 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Helper device function for E4M3 -> BFLOAT16 bitwise conversion +CUTLASS_DEVICE uint16_t +fp8_e4m3_to_fp16_bitwise(uint8_t const& src) { + // E4M3 (1-4-3) constants + constexpr uint32_t e4m3_exp_bias = 7; + // BFLOAT16 (1-8-7) constants + constexpr uint32_t bf16_exp_bias = 127; + + // Unpack FP8 bits + uint16_t sign = static_cast(src & 0x80); + uint16_t exponent = static_cast(src & 0x78) >> 3; + uint16_t mantissa = static_cast(src & 0x07); + + // Reconstruct BFLOAT16 bits + uint16_t bf16_sign = sign << 8; + // Re-bias exponent and shift to BFLOAT16 position + uint16_t bf16_exponent = (exponent - e4m3_exp_bias + bf16_exp_bias) << 7; + // Shift mantissa to BFLOAT16 position + uint16_t bf16_mantissa = mantissa << 4; + + return bf16_sign | bf16_exponent | bf16_mantissa; +} + +// Helper device function for E5M2 -> BFLOAT16 bitwise conversion +CUTLASS_DEVICE uint16_t +fp8_e5m2_to_fp16_bitwise(uint8_t const& src) { + // E5M2 (1-5-2) constants + constexpr uint32_t e5m2_exp_bias = 15; + // BFLOAT16 (1-8-7) constants + constexpr uint32_t bf16_exp_bias = 127; + + // Unpack FP8 bits + uint16_t sign = static_cast(src & 0x80); + uint16_t exponent = static_cast(src & 0x7C) >> 2; + uint16_t mantissa = static_cast(src & 0x03); + + // Reconstruct BFLOAT16 bits + uint16_t bf16_sign = sign << 8; + // Re-bias exponent and shift to BFLOAT16 position + uint16_t bf16_exponent = (exponent - e5m2_exp_bias + bf16_exp_bias) << 7; + // Shift mantissa to BFLOAT16 position + uint16_t bf16_mantissa = mantissa << 5; + + return bf16_sign | bf16_exponent | bf16_mantissa; +} + + +template < + typename Encoding, + int VectorizeSize = 8, + typename SrcTensor, + typename DstTensor +> +CUTLASS_DEVICE void +convert_and_descale( + SrcTensor const& src, + DstTensor& dst, + float scale) { + + using SrcVec_u8 = sycl::vec; + using DstVec_u16 = sycl::vec; + + auto src_ptr = reinterpret_cast(src.data()); + auto dst_ptr = reinterpret_cast(dst.data()); + + // Create a SCALAR bfloat16_t for scaling + const cutlass::bfloat16_t scale_bf16 = static_cast(scale); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cute::size(src) / VectorizeSize; ++i) { + SrcVec_u8 const src_vec_u8 = src_ptr[i]; + DstVec_u16 result_vec_u16; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < VectorizeSize; ++j) { + // 1. Convert FP8 bits to BFLOAT16 bits + uint16_t val_bf16_bits; + if constexpr (std::is_same_v) { + val_bf16_bits = fp8_e4m3_to_fp16_bitwise(src_vec_u8[j]); + } else { + val_bf16_bits = fp8_e5m2_to_fp16_bitwise(src_vec_u8[j]); + } + + // 2. Reinterpret bits as bfloat16_t to perform math + cutlass::bfloat16_t val_bf16 = reinterpret_cast(val_bf16_bits); + + // 3. Apply scaling + val_bf16 *= scale_bf16; + + // 4. Reinterpret back to bits for storage + result_vec_u16[j] = reinterpret_cast(val_bf16); + } + + // 5. Store the final vector of bits + dst_ptr[i] = result_vec_u16; + } +} diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp index 1384bca..02d5584 100644 --- a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -38,22 +38,17 @@ namespace cutlass::flash_attention::kernel { -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveSoftmaxEpilogue_, - class CollectiveEpilogue_, - class TileScheduler_ = void> +template class FMHAPrefillChunk; /////////////////////////////////////////////////////////////////////////////// -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveSoftmaxEpilogue_, - class CollectiveEpilogue_, - class TileScheduler_> +template class FMHAPrefillChunk { - public: + +public: // // Type Aliases // @@ -87,12 +82,14 @@ class FMHAPrefillChunk { using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params; - static_assert( - cute::is_void_v or cute::is_same_v or - cute::is_same_v, - "Unsupported TileScheduler for Intel Xe."); + static_assert(cute::is_void_v or + cute::is_same_v or + cute::is_same_v, + "Unsupported TileScheduler for Intel Xe."); using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector::Scheduler; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; using TileSchedulerParams = typename TileScheduler::Params; // Epilogue derived types @@ -106,7 +103,8 @@ class FMHAPrefillChunk { using TiledMmaOutput = typename CollectiveEpilogue::TiledMmaOutput; static_assert( - cute::is_same_v, + cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); // MSVC requires the cast to fix a warning-as-error. static constexpr int SharedStorageSize = 0; @@ -117,9 +115,12 @@ class FMHAPrefillChunk { static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local"); static constexpr bool PagedKV = CollectiveMainloop::PagedKV; - static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size - static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; - using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 + + static constexpr int SubgroupSize = + CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = + CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 static constexpr int QK_BLK_M = CollectiveMainloop::QK_BLK_M; static constexpr int QK_BLK_N = CollectiveMainloop::QK_BLK_N; @@ -145,9 +146,11 @@ class FMHAPrefillChunk { static constexpr int FragsN = CollectiveMainloop::FragsNS; static constexpr int VSlicer = - get<1>(TileShapeOutput{}) / (get<1>(TileShapePV{}) * PV_ATOM_N); // ceil_div(FragsNOut,FragsNS); - using AccumeShape = - decltype(make_shape(Int{}, Int{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), Int{})); + get<1>(TileShapeOutput{}) / + (get<1>(TileShapePV{}) * PV_ATOM_N); // ceil_div(FragsNOut,FragsNS); + using AccumeShape = decltype(make_shape( + Int{}, Int{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), + Int{})); static constexpr bool is_var_len = CollectiveMainloop::is_var_len; // Kernel level shared memory storage @@ -182,55 +185,58 @@ class FMHAPrefillChunk { // Convert to underlying arguments. In this case, a simple copy for the // aliased type. - static Params to_underlying_arguments(Arguments const& args, void* workspace) { + static Params to_underlying_arguments(Arguments const &args, + void *workspace) { (void)workspace; - return { - args.mode, - args.problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), - TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, TileShapeOutput{})}; + return {args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, workspace), + CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments( + args.problem_shape, args.hw_info, TileShapeOutput{})}; } - static bool can_implement(Arguments const& args) { + static bool can_implement(Arguments const &args) { bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or - (args.mode == gemm::GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + (args.mode == gemm::GemmUniversalMode::kBatched && + rank(ProblemShape{}) == 4); return mode_implementable; } - static int get_workspace_size(Arguments const& args) { - return 0; - } + static int get_workspace_size(Arguments const &args) { return 0; } - static cutlass::Status initialize_workspace( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { + static cutlass::Status + initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { return Status::kSuccess; } - static dim3 get_grid_shape(Params const& params) { + static dim3 get_grid_shape(Params const ¶ms) { return TileScheduler::template get_grid_shape(params.scheduler); } - static dim3 get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); - } + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } CUTLASS_DEVICE - Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + Shape + get_sequence_length_shape(ProblemShape const &problem_shape, + int const &batch) { if constexpr (is_var_len) { - return cutlass::fmha::collective::apply_variable_length(select<3, 5>(problem_shape), batch); + return cutlass::fmha::collective::apply_variable_length( + select<3, 4, 5>(problem_shape), batch); } else { - return select<3, 5>(problem_shape); + return select<3, 4, 5>(problem_shape); } } CUTLASS_DEVICE - void operator()(Params const& params, char* smem_buf) { - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + void operator()(Params const ¶ms, char *smem_buf) { + SharedStorage &shared_storage = + *reinterpret_cast(smem_buf); // Preconditions CUTE_STATIC_ASSERT(is_static::value); CUTE_STATIC_ASSERT(is_static::value); @@ -242,37 +248,34 @@ class FMHAPrefillChunk { auto num_heads_q = get<1>(params.problem_shape); auto num_heads_kv = get<2>(params.problem_shape); - auto& head_size_qk = get<6>(params.problem_shape); - auto& head_size_vo = get<7>(params.problem_shape); + auto &head_size_qk = get<6>(params.problem_shape); + auto &head_size_vo = get<7>(params.problem_shape); // Preconditions - static_assert( - cute::rank(StrideQ{}) == 3, - "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * " - "num_heads_q]."); - static_assert( - cute::rank(StrideK{}) == 3, - "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * " - "num_heads_kv]."); - static_assert( - cute::rank(StrideV{}) == 3, - "StrideV must be rank-3: [seq_len_kv, head_size_vo, batch * " - "num_heads_kv]."); + static_assert(cute::rank(StrideQ{}) == 3, + "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * " + "num_heads_q]."); + static_assert(cute::rank(StrideK{}) == 3, + "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * " + "num_heads_kv]."); + static_assert(cute::rank(StrideV{}) == 3, + "StrideV must be rank-3: [seq_len_kv, head_size_vo, batch * " + "num_heads_kv]."); int thread_idx = int(ThreadIdxX()); - // int sub_group_id = thread_idx / SubgroupSize; - auto sub_group_id = get_sub_group_id(); - auto local_id = get_sub_group_local_id(); + int sub_group_id = thread_idx / SubgroupSize; TileScheduler tile_scheduler{params.scheduler}; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { - auto blk_coord = tile_scheduler.get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, - // batch_blk_idx, num_heads_blk_idx + auto blk_coord = + tile_scheduler + .get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, + // batch_blk_idx, num_heads_blk_idx - auto blk_m_coord = get<0>(blk_coord); // seq_len_blk_idx + auto blk_m_coord = get<0>(blk_coord); // seq_len_blk_idx auto blk_n_coord = 0; // nums_head_blk_idx - auto q_head_coord = get<1>(blk_coord); // q_heads_idx - auto batch_coord = get<2>(blk_coord); // batch_blk_idx + auto q_head_coord = get<1>(blk_coord); // q_heads_idx + auto batch_coord = get<2>(blk_coord); // batch_blk_idx // For variable sequence length case, batch is considered to be 1 (same // as group gemm). For fixed sequence length case, the l_coord is the @@ -290,9 +293,10 @@ class FMHAPrefillChunk { // 5>(params.problem_shape). sequence_length_shape = [batch, // num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, // head_size_qk, head_size_vo] - auto sequence_length_shape = get_sequence_length_shape(params.problem_shape, batch_coord); + auto sequence_length_shape = + get_sequence_length_shape(params.problem_shape, batch_coord); - auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape; + auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = sequence_length_shape; // int seq_len_kv_total = seq_len_kv_cache + seq_len_kv; // For variable sequence length case, batch is considered to be 1 (same // as group gemm). For fixed sequence length case, the l_coord is the @@ -305,91 +309,146 @@ class FMHAPrefillChunk { // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) // and check if it is still within bounds of the actual seq_len_qo // (get<0>(sequence_length_shape)). - if (blk_m_coord * get<0>(TileShapeOutput{}) >= seq_len_qo) { + if (blk_m_coord * get<0>(TileShapeOutput{}) >= + seq_len_qo) { continue; } const int seq_coord = - cute::min(seq_len_qo, (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % seq_len_qo); - // auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024) - // auto discard_seq_coord = seq_len_qo - offset; // 1024 - // auto full_tile_offset = seq_len_kv - offset; // 0 - - // const int seq_len = seq_len_kv; - // CausalMask - // ? full_tile_offset + - // cute::min(seq_len_kv, seq_coord - discard_seq_coord) + - // QK_SG_M - // : seq_len_kv; - + cute::min(seq_len_qo, (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % + seq_len_qo); + auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024) + auto discard_seq_coord = seq_len_qo - offset; // 1024 + auto full_tile_offset = seq_len_kv - offset; // 0 + + const int seq_len = + CausalMask + ? full_tile_offset + + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + + QK_SG_M + : seq_len_kv; + + const int kv_splits_new = cute::ceil_div(seq_len, QK_BLK_N); const int kv_splits_cache = cute::ceil_div(seq_len_kv_cache, QK_BLK_N); - const int kv_splits = kv_splits_cache; + const int kv_splits = kv_splits_cache + kv_splits_new; int tiles_per_page = params.mainloop.page_size / QK_BLK_N; - Tensor mQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + if (CausalMask && seq_coord < discard_seq_coord) { // 1024 =0 + continue; + } - Tensor mK_cache_nkl = cute::get_xe_tensor(make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) - Tensor mV_cache_nkl = cute::get_xe_tensor(make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) + auto q_group_size = num_heads_q / num_heads_kv; + auto kv_head_coord = q_head_coord / q_group_size; + + // Descale tensors are shaped (batch size * # heads) + // Each head has a seperate scale factor + // Q, K, V tensors have seperate scaling factors + const float q_scale_val = params.mainloop.ptr_q_scale == nullptr + ? 1.f + : params.mainloop.ptr_q_scale[batch_coord * num_heads_q + q_head_coord]; + const float k_scale_val = params.mainloop.ptr_k_scale == nullptr + ? 1.f + : params.mainloop.ptr_k_scale[batch_coord * num_heads_kv + kv_head_coord]; + const float v_scale_val = params.mainloop.ptr_v_scale == nullptr + ? 1.f + : params.mainloop.ptr_v_scale[batch_coord * num_heads_kv + kv_head_coord]; + + Tensor mQ_mkl = cute::get_xe_tensor( + make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + + Tensor mK_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) + Tensor mV_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv, 1)); //(n,k,l) + Tensor mK_cache_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) + Tensor mV_cache_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) // block_size and head_size are the same size. So no coord is needed. Tensor mQ_mk = mQ_mkl(_, _, 0); - Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k) - Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k) + Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) + Tensor mV_nk = mV_nkl(_, _, 0); - auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); + Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k) + Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k) - auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); - auto gV_cache = local_tile(mV_cache_nk, TileShapeOutput{}, make_coord(_, blk_n_coord, _), Step{}); + auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), + Step<_1, X, _1>{}); + auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _, _), + Step{}); + + auto gV = local_tile(mV_nk, TileShapeOutput{}, + make_coord(_, blk_n_coord, _), Step{}); + auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, + make_coord(_, _, _), Step{}); + auto gV_cache = + local_tile(mV_cache_nk, TileShapeOutput{}, + make_coord(_, blk_n_coord, _), Step{}); auto mainloop_params = CollectiveMainloop::get_updated_copies( - params.mainloop, params.problem_shape, sequence_length_shape, batch_coord, q_head_coord); + params.mainloop, params.problem_shape, sequence_length_shape, + batch_coord, q_head_coord); + - // we limit the horizontal size to two subgroup, the empirical results + // we limit the horisontal size to two subgroup, the empirical resutls // show that reading the two cacheline side by side in gives better // performance and anything after that does not have an effect on // performance. // (64 here for float b float when possible and loop over // to cover all the data needed) - auto tiled_prefetch_q = - cute::prefetch_selector, Int>, Num_SGs>( - mainloop_params.gmem_tiled_copy_q); - - auto tiled_prefetch_k_cache = - cute::prefetch_selector, Int>, Num_SGs>( - mainloop_params.gmem_tiled_copy_k_cache); - auto tiled_prefetch_v_cache = cute:: - prefetch_selector, Int>, Num_SGs>( - mainloop_params.gmem_tiled_copy_v_cache); + auto tiled_prefetch_q = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_q); + auto tiled_prefetch_k = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k); + auto tiled_prefetch_v = cute::prefetch_selector< + Shape, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v); + auto tiled_prefetch_k_cache = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k_cache); + auto tiled_prefetch_v_cache = cute::prefetch_selector< + Shape, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v_cache); auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); - auto thr_prefetch_K_cache = tiled_prefetch_k_cache.get_slice(thread_idx); - auto thr_prefetch_V_cache = tiled_prefetch_v_cache.get_slice(thread_idx); + auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); + auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); auto pQgQ = thr_prefetch_Q.partition_S(gQ); - + auto pKgK = thr_prefetch_K.partition_S(gK); + auto pVgV = thr_prefetch_V.partition_S(gV); // assuming the copy function is the same otherwise this need to have its // own tile_prefetch - auto pKgK_cache = thr_prefetch_K_cache.partition_S(gK_cache); - auto pVgV_cache = thr_prefetch_V_cache.partition_S(gV_cache); + auto pKgK_cache = thr_prefetch_K.partition_S(gK_cache); + auto pVgV_cache = thr_prefetch_V.partition_S(gV_cache); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<3>(pQgQ); i++) { prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); } - auto& prefetch_K = tiled_prefetch_k_cache; - auto& pKgK1_ = pKgK_cache; + auto &prefetch_K = + (seq_len_kv_cache == 0) ? tiled_prefetch_k : tiled_prefetch_k_cache; + auto &pKgK1_ = (seq_len_kv_cache == 0) ? pKgK : pKgK_cache; int cached_nblock = 0; if constexpr (PagedKV) { - // int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size);// max_page_size_per_seq - // int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * - // curr_batch_pages; - int batch_offset = batch_coord * mainloop_params.max_num_pages_per_seq; - cached_nblock = mainloop_params.ptr_page_table[batch_offset // page table for this batch - ] * tiles_per_page; // base block idx of physical page + int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); + int batch_offset = + is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] + : batch_coord * curr_batch_pages; + cached_nblock = + mainloop_params + .ptr_page_table[batch_offset // page table for this batch + ] * tiles_per_page; // base block idx of physical page } // The headsize for both cached and non-cached version is the same for (int j = 0; j < size<4>(pKgK1_); j++) { CUTLASS_PRAGMA_UNROLL - for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; i++) { + for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; + i++) { prefetch(prefetch_K, pKgK1_(_, _, _, i, j)); } } @@ -398,12 +457,13 @@ class FMHAPrefillChunk { // workgroup_shape Tensor out_reg = make_tensor(AccumeShape{}); - // There are 16 workitem and 16 max per subgroup, each worktime contains 1 + // There are 16 workitem and 16 max per subgroup, each worktime containt 1 // max and cumulatively, they calculate the max per subgroup ElementAccumulator max_reg{-INFINITY}; - // The sum reg each contains a 2d tensor for 8 x 2 This is number of - // sequence length process per subgroup - Tensor sum_reg = make_tensor(Shape, Int>{}); + // The sum reg each contains a 2d tesnor for 8 x 2 This is number of + // sequence lenght process per subgroup + Tensor sum_reg = + make_tensor(Shape, Int>{}); clear(sum_reg); clear(out_reg); @@ -414,145 +474,111 @@ class FMHAPrefillChunk { // different for each subgroup due to triangular nature of causal based // operation static constexpr int barrier_scope = CausalMask ? 3 : 2; - - int q_start_coord = blk_m_coord * QK_BLK_M; - int q_end_coord = cute::min(q_start_coord + QK_BLK_M, seq_len_qo); - int seq_diff = seq_len_kv_cache - seq_len_qo; - CUTLASS_PRAGMA_UNROLL - for (int split = 0; split < kv_splits; split++) { + for (int split = 0; split < kv_splits - static_cast(CausalMask); split++) { barrier_arrive(barrier_scope); - int kv_start_coord = split * QK_BLK_N; - - if constexpr (CausalMask) { - if (kv_start_coord >= q_end_coord + seq_diff) break; - } - - // // = 0, all KV is kv_cache + bool is_KV_cache = split < kv_splits_cache; // 1) Load KV (performed inside mmaQK) - auto gK_ = gK_cache(_, _, cached_nblock, _); - auto gV_ = gV_cache(_, _, cached_nblock); + auto gK_ = is_KV_cache ? gK_cache(_, _, cached_nblock, _) + : gK(_, _, split - kv_splits_cache, _); + auto gV_ = is_KV_cache ? gV_cache(_, _, cached_nblock) + : gV(_, _, split - kv_splits_cache); // 2) Create Tensor S - Tensor tSr = make_tensor(Shape, Int, Int>{}); + Tensor tSr = make_tensor( + Shape, Int, Int>{}); clear(tSr); // 3) Perform GEMM S = Q*K // Then modify layout to LayoutQ = ((seq_leq_q, group_head_q), // head_size_qk, batch* num_heads_q / group_head_q), which can be merged // into one gemm for (int i = 0; i < q_group_size; ++i) { - collective_mma.mmaQK(tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params); + collective_mma.mmaQK(tSr, gQ, gK_, tSr, + ceil_div(head_size_qk, QK_BLK_K), mainloop_params, + is_KV_cache, q_scale_val, k_scale_val); if constexpr (LocalMask) { // Sliding windows // mask the elements of each tile where j - left > i || j + right < i const int item_id = thread_idx % SubgroupSize; - int col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache); + int col_idx; + if (split < kv_splits_cache) { + col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache) ; + } else { + col_idx = item_id + seq_len_kv_cache + (split - kv_splits_cache) * cute::min(QK_BLK_N, seq_len_kv); + } CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < FragsM; m++) { // 2 - int row_idx = m * Vec + seq_coord; - int col_ref = seq_len_kv_cache - seq_len_qo; - // int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < Vec; row++) { // 8 - bool left_mask = col_idx < cute::max(0, row + row_idx + col_ref - mainloop_params.window_left); - bool right_mask = - col_idx > cute::min(seq_len_kv_cache, row + row_idx + col_ref + mainloop_params.window_right); - if (left_mask || right_mask) { - tSr(row, m, n) = ElementAccumulator{-INFINITY}; + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + bool left_mask = col_idx < cute::max(0, row + row_idx + col_ref - mainloop_params.window_left); + bool right_mask = col_idx > cute::min(seq_len_kv_cache + seq_len_kv, row + row_idx + col_ref + mainloop_params.window_right); + if (left_mask || right_mask) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } } - } } } } - if constexpr (PagedKV) { - // // if constexpr(!(CausalMask || LocalMask) && PagedKV) { - // // Processing Not divisible, mask padding - // const int item_id = thread_idx % SubgroupSize; - // int col_idx = item_id + split * cute::min(QK_BLK_N, - // seq_len_kv_cache + seq_len_kv); - // CUTLASS_PRAGMA_UNROLL - // for (int n = 0; n < FragsN; n++, col_idx += - // get<1>(MmaAtomShape())) { // 4 - // CUTLASS_PRAGMA_UNROLL - // for (int m = 0; m < FragsM; m++) { // 2 - // int row_idx = m * Vec + seq_coord; - // CUTLASS_PRAGMA_UNROLL - // for (int row = 0; row < Vec; row++) { // 8 - // if (col_idx >= seq_len_kv_cache + seq_len_kv || row_idx + - // row >= seq_len_qo) { - // tSr(row, m, n) = ElementAccumulator{-INFINITY}; - // } - // } - // } - // } - - int col_start = local_id + kv_start_coord; - int col_end = col_start + (FragsN - 1) * get<1>(MmaAtomShape()); - if (col_end >= seq_len_kv_cache) { - int col_idx = col_start; + if constexpr(!(CausalMask || LocalMask) && PagedKV) { + // Processing Not divisible, mask padding + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache + seq_len_kv); CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 - if (col_idx >= seq_len_kv_cache) { + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < FragsM; m++) { // 2 - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < Vec; row++) { // 8 + for (int row = 0; row < Vec; row++) { // 8 + if (col_idx >= seq_len_kv_cache + seq_len_kv || row_idx + row >= seq_len_qo) { tSr(row, m, n) = ElementAccumulator{-INFINITY}; - } - } - } - } - } - if constexpr (CausalMask) { - int row_start = q_start_coord + sub_group_id * QK_SG_M; - if (row_start + seq_diff < col_end) { - int col_idx = col_start; - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 - if (col_idx > row_start + seq_diff) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < FragsM; m++) { // 2 - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < Vec; row++) { // 8 - int row_idx = row_start + m * Vec + row; - if (row_idx + seq_diff < col_idx) tSr(row, m, n) = ElementAccumulator{-INFINITY}; - } - } } } } } } - auto& tiled_prefetch_v_ = tiled_prefetch_v_cache; - auto& pVgV_ = pVgV_cache; - int v_prefetch_idx = cached_nblock; + auto &tiled_prefetch_v_ = + is_KV_cache ? tiled_prefetch_v_cache + : tiled_prefetch_v; + auto &pVgV_ = is_KV_cache ? pVgV_cache : pVgV; + int v_prefetch_idx = is_KV_cache ? PagedKV ? cached_nblock : split + : split - kv_splits_cache; for (int i = 0; i < size<1>(pVgV_); i++) { prefetch(tiled_prefetch_v_, pVgV_(_, i, _, v_prefetch_idx)); } int next_cached_nblock = split + 1; + bool is_next_KV_cache = next_cached_nblock < kv_splits_cache; if constexpr (PagedKV) { - // int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); - // int batch_offset = - // is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages; - int curr_batch_pages = mainloop_params.max_num_pages_per_seq; // max_page_size_per_seq - int batch_offset = batch_coord * curr_batch_pages; - int next_page_logical_idx = next_cached_nblock * QK_BLK_N / params.mainloop.page_size; - bool valid_page = next_page_logical_idx < curr_batch_pages; - // get physical page idx from page table - if (valid_page) { - next_cached_nblock = params.mainloop.ptr_page_table - [batch_offset + // page table for this batch - next_page_logical_idx // split (tile idx) to logical - // page idx - ] * tiles_per_page + // base block idx of physical page - next_cached_nblock % tiles_per_page; // offset within page - } else { - next_cached_nblock = curr_batch_pages * tiles_per_page; // push idx out of bounds to respect the - // boundary between batches + if (is_next_KV_cache) { + int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); + int next_page_logical_idx = + next_cached_nblock * QK_BLK_N / params.mainloop.page_size; + int batch_offset = + is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] + : batch_coord * curr_batch_pages; + bool valid_page = next_page_logical_idx < curr_batch_pages; + // get physical page idx from page table + if (valid_page) { + next_cached_nblock = + params.mainloop.ptr_page_table + [batch_offset + // page table for this batch + next_page_logical_idx // split (tile idx) to logical + // page idx + ] * tiles_per_page + // base block idx of physical page + next_cached_nblock % tiles_per_page; // offset within page + } else { + next_cached_nblock = + curr_batch_pages * + tiles_per_page; // push idx out of bounds to respect the + // boundary between batches + } } } @@ -561,7 +587,8 @@ class FMHAPrefillChunk { softmax(split == 0, tSr, max_reg, sum_reg, out_reg); // 5) Perform GEMM O = S*V - collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, mainloop_params); + collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, + mainloop_params, is_KV_cache, v_scale_val); // ... prefetch next tile ... // Prefetch the next Q tile CUTLASS_PRAGMA_UNROLL @@ -569,27 +596,90 @@ class FMHAPrefillChunk { prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); } + is_KV_cache = is_next_KV_cache; cached_nblock = next_cached_nblock; // Prefetch the next K tile - // there is no need to guard it with if statement as prefetch will + // there is no need to gaurd it with if statememt as prefetch will // ignore out of bound reading - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<4>(pKgK_cache); j++) { - prefetch(tiled_prefetch_k_cache, pKgK_cache(_, _, _, cached_nblock, j)); + if constexpr (PagedKV) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<4>(pKgK_cache); j++) { + prefetch(tiled_prefetch_k_cache, pKgK_cache(_, _, _, cached_nblock, j)); + } + } else { + bool sel_prefetch_k = + (split + DispatchPolicy::Stages) < kv_splits_cache; + auto &prefetch_k_selector = + sel_prefetch_k ? tiled_prefetch_k_cache : tiled_prefetch_k; + auto &pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK; + int k_prefetch_idx = + sel_prefetch_k + ? PagedKV ? cached_nblock : split + DispatchPolicy::Stages + : split + DispatchPolicy::Stages - kv_splits_cache; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<4>(pKgK_); j++) { + prefetch(prefetch_k_selector, pKgK_(_, _, _, k_prefetch_idx, j)); + } } barrier_wait(barrier_scope); } + if constexpr (CausalMask) { + // BAND Matrix + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK(tSr, gQ, gK(_, _, kv_splits_new - 1, _), tSr, + ceil_div(head_size_qk, QK_BLK_K), mainloop_params, + false, q_scale_val, k_scale_val); + // we only need one block ahead, there is enough gap to prefetch it + // while doing softmax. because the gap between the two MMA is big, + // prefetching it the same way as cutlass K matrix does not make sense + for (int i = 0; i < size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _, kv_splits_new - 1)); + } + // mask the elements of each tile where j > i + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + (kv_splits_new - 1) * QK_BLK_N; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++, row_idx++) { // 8 + if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax((kv_splits - 1) == 0, tSr, max_reg, sum_reg, out_reg); + collective_mma.template mmaPV(out_reg, tSr, + gV(_, _, kv_splits_new - 1), + out_reg, mainloop_params, false, v_scale_val); + } + + // Epilogue - auto epilogue_params = CollectiveEpilogue::template get_updated_copies( - params.epilogue, params.problem_shape, sequence_length_shape, batch_coord, q_head_coord); + auto epilogue_params = + CollectiveEpilogue::template get_updated_copies( + params.epilogue, params.problem_shape, sequence_length_shape, + batch_coord, q_head_coord); CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; auto blk_coord_mnkl = make_coord(blk_m_coord, blk_n_coord, _, 0); - epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl, out_reg, max_reg, sum_reg); + epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl, + out_reg, max_reg, sum_reg); } } }; /////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::flash_attention::kernel +} // namespace cutlass::flash_attention::kernel diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index 4c21c3b..2002dfc 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -36,9 +36,12 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "fmha_fusion.hpp" +#include "fp8_descale.h" //////////////////////////////////////////////////////////// -namespace {} +namespace { + +} ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,69 +51,29 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// -template < - class DispatchPolicy, - class ProblemShapeType_, - class ElementQ_, - class StrideQ_, - class ElementK_, - class StrideK_, - class ElementV_, - class StrideV_, - class MMAOperation_, - class TileShapeQK_, - class TileShapePV_, - class SubgroupLayout_, - class GmemTiledCopyQ_, - class GmemTiledCopyK_, - class GmemTiledCopyV_, - bool CausalMask_, - bool LocalMask_, - bool PagedKV_> +template struct FlashChunkPrefillMma { - static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); + static_assert(cutlass::detail::dependent_false, + "Could not find a mainloop specialization."); }; ///////////////////////////////////////////////////////////////////////////////////////////////// -template < - int Stages, - class ProblemShapeType_, - class ElementQ_, - class StrideQ_, - class ElementK_, - class StrideK_, - class ElementV_, - class StrideV_, - class MMAOperation_, - class TileShapeQK_, - class TileShapePV_, - class SubgroupLayout_, - class GmemTiledCopyQ_, - class GmemTiledCopyK_, - class GmemTiledCopyV_, - bool CausalMask_, - bool LocalMask_, - bool PagedKV_> +template struct FlashChunkPrefillMma< - gemm::MainloopIntelXeXMX16, - ProblemShapeType_, - ElementQ_, - StrideQ_, - ElementK_, - StrideK_, - ElementV_, - StrideV_, - MMAOperation_, - TileShapeQK_, - TileShapePV_, - SubgroupLayout_, - GmemTiledCopyQ_, - GmemTiledCopyK_, - GmemTiledCopyV_, - CausalMask_, - LocalMask_, - PagedKV_> { + gemm::MainloopIntelXeXMX16, ProblemShapeType_, ElementQ_, StrideQ_, + ElementK_, StrideK_, ElementV_, StrideV_, MMAOperation_, TileShapeQK_, + TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, + GmemTiledCopyV_, CausalMask_, LocalMask_, PagedKV_> { // // Type Aliases // @@ -131,9 +94,11 @@ struct FlashChunkPrefillMma< using ArchTag = typename DispatchPolicy::ArchTag; using MmaAtom = MMA_Atom; - using TiledMmaQK = typename TiledMMAHelper, SubgroupLayout>::TiledMMA; + using TiledMmaQK = typename TiledMMAHelper, + SubgroupLayout>::TiledMMA; - using TiledMmaPV = typename TiledMMAHelper, SubgroupLayout>::TiledMMA; + using TiledMmaPV = typename TiledMMAHelper, + SubgroupLayout>::TiledMMA; using ElementAccumulator = typename TiledMmaQK::ValTypeC; static constexpr bool CausalMask = CausalMask_; static constexpr bool LocalMask = LocalMask_; @@ -143,11 +108,15 @@ struct FlashChunkPrefillMma< using MmaAtomShape = typename MmaAtom::Shape_MNK; - static constexpr auto PV_ATOM_M = decltype(get<0>(SubgroupLayout{}.shape()))::value; - static constexpr auto PV_ATOM_N = decltype(get<1>(SubgroupLayout{}.shape()))::value; - static constexpr auto PV_ATOM_K = decltype(get<2>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_M = + decltype(get<0>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_N = + decltype(get<1>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_K = + decltype(get<2>(SubgroupLayout{}.shape()))::value; - using SubgroupTileShapePV = decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape()))); + using SubgroupTileShapePV = + decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape()))); static constexpr auto QK_BLK_M = get<0>(TileShapeQK{}); static constexpr auto QK_BLK_N = get<1>(TileShapeQK{}); static constexpr auto QK_BLK_K = get<2>(TileShapeQK{}); @@ -155,67 +124,94 @@ struct FlashChunkPrefillMma< // This TiledMma is only required to serve the specific tiling requirements // for matrix K. This is due to the consumption of matrix K by all subgroups // within a workgroup. - static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8 - static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1 - static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1 + static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8 + static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1 + static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1 - using SubgroupTileShapeQK = - decltype(cute::shape_div(TileShapeQK{}, SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 ) + using SubgroupTileShapeQK = decltype(cute::shape_div( + TileShapeQK{}, + SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 ) static constexpr auto QK_SG_M = get<0>(SubgroupTileShapeQK{}); static constexpr auto QK_SG_N = get<1>(SubgroupTileShapeQK{}); static constexpr auto QK_SG_K = get<2>(SubgroupTileShapeQK{}); static constexpr bool is_var_len = - cutlass::fmha::collective::is_variable_length_v>; + cutlass::fmha::collective::is_variable_length_v< + tuple_element_t<3, ProblemShapeType>>; using FragsShapeS = decltype(cute::shape_div( - take<0, 2>(SubgroupTileShapeQK{}), take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4) - static constexpr int Vec = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8 + take<0, 2>(SubgroupTileShapeQK{}), + take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4) + static constexpr int Vec = + (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8 static constexpr int FragsM = get<0>(FragsShapeS{}); - static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4 + static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4 - static constexpr uint32_t MaxThreadsPerBlock = size(SubgroupLayout{}) * SubgroupSize; + static constexpr uint32_t MaxThreadsPerBlock = + size(SubgroupLayout{}) * SubgroupSize; using CopyThreadShape = Shape<_1, Int>; using traits_load_Q = Copy_Traits; using atom_load_Q = Copy_Atom; - using val_layout_load_Q = decltype(make_layout(shape_div(typename traits_load_Q::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_Q = decltype(make_tiled_copy(atom_load_Q{}, Layout{}, val_layout_load_Q{})); + using val_layout_load_Q = decltype(make_layout( + shape_div(typename traits_load_Q::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_Q = decltype(make_tiled_copy( + atom_load_Q{}, Layout{}, val_layout_load_Q{})); using traits_load_K = Copy_Traits; using atom_load_K = Copy_Atom; - using val_layout_load_K = decltype(make_layout(shape_div(typename traits_load_K::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_K = decltype(make_tiled_copy(atom_load_K{}, Layout{}, val_layout_load_K{})); + using val_layout_load_K = decltype(make_layout( + shape_div(typename traits_load_K::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_K = decltype(make_tiled_copy( + atom_load_K{}, Layout{}, val_layout_load_K{})); using traits_load_V = Copy_Traits; using atom_load_V = Copy_Atom; - using val_layout_load_V = decltype(make_layout(shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{}, Layout{}, val_layout_load_V{})); + using val_layout_load_V = decltype(make_layout( + shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_V = decltype(make_tiled_copy( + atom_load_V{}, Layout{}, val_layout_load_V{})); + + template + static constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; // Host side kernel arguments struct Arguments { - ElementQ const* ptr_Q; + ElementQ const *ptr_Q; StrideQ dQ; - ElementK const* ptr_K_cache; + ElementK const *ptr_K; + StrideK dK; + ElementV const *ptr_V; + StrideV dV; + float const *ptr_q_scale; + float const *ptr_k_scale; + float const *ptr_v_scale; + ElementK const *ptr_K_cache; StrideK dK_cache; - ElementV const* ptr_V_cache; + ElementV const *ptr_V_cache; StrideV dV_cache; // Paged KV Cache - int const* ptr_page_table; + int const *ptr_page_table; int page_size; - int max_num_pages_per_seq; + int const *num_pages_per_seq; int window_left; int window_right; }; struct Params { XE_Copy_Q gmem_tiled_copy_q; + XE_Copy_K gmem_tiled_copy_k; + XE_Copy_V gmem_tiled_copy_v; + float const *ptr_q_scale; + float const *ptr_k_scale; + float const *ptr_v_scale; XE_Copy_K gmem_tiled_copy_k_cache; XE_Copy_V gmem_tiled_copy_v_cache; - int const* ptr_page_table; + // Paged KV Cache + int const *ptr_page_table; int page_size; - int max_num_pages_per_seq; + int const *num_pages_per_seq; int window_left; int window_right; }; @@ -227,45 +223,67 @@ struct FlashChunkPrefillMma< FlashChunkPrefillMma() = default; static constexpr Params - to_underlying_arguments(ProblemShapeType const& problem_shape, Arguments const& args, void* workspace) { + to_underlying_arguments(ProblemShapeType const &problem_shape, + Arguments const &args, void *workspace) { (void)workspace; - auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = - problem_shape; + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, + seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; auto tensorQ = make_tensor( - make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), args.dQ)); - auto tensorK_cache = make_tensor( - make_gmem_ptr(args.ptr_K_cache), - make_layout(make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch), args.dK_cache)); + make_gmem_ptr(args.ptr_Q), + make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), + args.dQ)); + auto tensorK = make_tensor( + make_gmem_ptr(args.ptr_K), + make_layout(make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch), + args.dK)); + auto tensorV = make_tensor( + make_gmem_ptr(args.ptr_V), + make_layout(make_shape(num_heads_kv * head_size_vo, seq_len_kv, batch), + args.dV)); + auto tensorK_cache = + make_tensor(make_gmem_ptr(args.ptr_K_cache), + make_layout(make_shape(seq_len_kv_cache, + num_heads_kv * head_size_qk, batch), + args.dK_cache)); auto tensorV_cache = make_tensor( make_gmem_ptr(args.ptr_V_cache), - make_layout(make_shape(num_heads_kv * head_size_vo, seq_len_kv_cache, batch), args.dV_cache)); + make_layout( + make_shape(num_heads_kv * head_size_vo, seq_len_kv_cache, batch), + args.dV_cache)); XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; + XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{ - copyQ, - copyK_cache, - copyV_cache, - args.ptr_page_table, - args.page_size, - args.max_num_pages_per_seq, - args.window_left, - args.window_right}; + return Params{copyQ, + copyK, + copyV, + args.ptr_q_scale, + args.ptr_k_scale, + args.ptr_v_scale, + copyK_cache, + copyV_cache, + args.ptr_page_table, + args.page_size, + args.num_pages_per_seq, + args.window_left, + args.window_right}; } + // FP8 Q and FP8 K tensors are converted to BF16 tensors using descale factors + // GEMM is computed in BF16 precision (FP8 not supported in BMG) template - CUTLASS_DEVICE void mmaQK( - FragQccum& accum, - TensorQ gQ, - TensorK gK, - FragSrc const& frag_src, - int const& k_tile_count, - Params const& params) { - auto& gmem_tiled_copy_k = params.gmem_tiled_copy_k_cache; + CUTLASS_DEVICE void mmaQK(FragQccum &accum, TensorQ gQ, TensorK gK, + FragSrc const &frag_src, int const &k_tile_count, + Params const ¶ms, bool is_KV_cache, + float q_scale, float k_scale) { + + auto &gmem_tiled_copy_k = + is_KV_cache ? params.gmem_tiled_copy_k_cache : params.gmem_tiled_copy_k; int thread_idx = static_cast(ThreadIdxX()); auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx); @@ -275,7 +293,8 @@ struct FlashChunkPrefillMma< // To make all threads in a warp have the same global tensors pass in the // index of thread 0 in each warp auto sg = compat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx); auto thread_mma_k = tiled_mma.get_slice(0); @@ -284,8 +303,10 @@ struct FlashChunkPrefillMma< // Create fragments // TODO(Codeplay): fix this, this is probably not general - Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0, 3>(tCgQ.shape()))); - Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0, 3>(tCgK.shape()))); + using TCrQ_Type = cute::conditional_t, uint8_t, ElementQ>; + using TCrK_Type = cute::conditional_t, uint8_t, ElementK>; + Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape()))); + Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape()))); // Retile registers for copies Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); @@ -298,15 +319,39 @@ struct FlashChunkPrefillMma< // // Mainloop // - for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ); copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); - cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); + + // FP8 path: Convert FP8 fragments to BF16 + if constexpr (is_fp8_v || is_fp8_v) { + auto tCrQ_fp16 = make_fragment_like(tCrQ); + auto tCrK_fp16 = make_fragment_like(tCrK); + + if constexpr (is_fp8_v) { + convert_and_descale(tCrQ, tCrQ_fp16, q_scale); + } else { + // If Q is already FP16, copy it. + copy(tCrQ, tCrQ_fp16); + } + + if constexpr (is_fp8_v) { + convert_and_descale(tCrK, tCrK_fp16, k_scale); + } else { + copy(tCrK, tCrK_fp16); + } + + // GEMM is computed on the BF16 tensors + cute::gemm(tiled_mma, accum, tCrQ_fp16, tCrK_fp16, frag_src); + } else { + // BF16 path + cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); + } + #if 0 -#define PRINT(x) \ - print(#x ": "); \ - print(x); \ +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ print("\n"); if (cute::thread(0, 0)) { print("======================= Q: \n"); @@ -333,30 +378,42 @@ struct FlashChunkPrefillMma< } template - CUTLASS_DEVICE auto convert_type(Tensor const& tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - auto frag = convert_op(*reinterpret_cast*>(tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast *>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } - template - CUTLASS_DEVICE void - mmaPV(FragQccum& accum, FragS const& tSr, TensorV gV, FragSrc const& frag_src, Params const& params) { - auto& gmem_tiled_copy_v = params.gmem_tiled_copy_v_cache; + // FP8 V tensor is converted to BF16 tensor using descale factor + // P tensor (softmax output) is in FP32 precision (converted to BF16) + // GEMM is computed in BF16 precision (FP8 not supported in BMG) + template + CUTLASS_DEVICE void mmaPV(FragQccum &accum, FragS const &tSr, TensorV gV, + FragSrc const &frag_src, Params const ¶ms, + bool is_KV_cache, float v_scale) { + + auto &gmem_tiled_copy_v = + is_KV_cache ? params.gmem_tiled_copy_v_cache : params.gmem_tiled_copy_v; int thread_idx = static_cast(ThreadIdxX()); // Instantiate the MMA object TiledMmaPV tiled_mma; // Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid // Register spill - Tensor gV_ = take<0, 3>(local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _))); + Tensor gV_ = take<0, 3>( + local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _))); auto sg = compat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); Tensor tCgV = thread_mma.partition_B(gV_); - Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0, 3>(tCgV.shape()))); + using TCrV_Type = cute::conditional_t, uint8_t, ElementV>; + Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0,3>(tCgV.shape()))); // Partition the copying of A and B tiles across the threads auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx); @@ -364,9 +421,9 @@ struct FlashChunkPrefillMma< Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV); #if CUTLASS_ENABLE_DEBUG_PRINTS -#define PRINT(x) \ - print(#x ": "); \ - print(x); \ +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ print("\n"); if (cute::thread(LOG_THREAD, LOG_GROUP)) { print("===================== V :\n"); @@ -391,7 +448,15 @@ struct FlashChunkPrefillMma< CUTLASS_PRAGMA_UNROLL for (int i = 0; i < tile_count; i++) { copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); - cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i)); + + if constexpr (is_fp8_v) { + auto tCrV_fp16 = make_fragment_like(tCrV); + convert_and_descale(tCrV, tCrV_fp16, v_scale); + + cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV_fp16, frag_src(_,_,_,i)); + } else { + cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV, frag_src(_,_,_,i)); + } } } @@ -400,67 +465,135 @@ struct FlashChunkPrefillMma< // int, int, int> For Variable Sequence Length, ProblemShape = Shape template - CUTLASS_DEVICE static constexpr Params get_updated_copies( - Params const& params, - ProblemShape const& problem_shape, - SequenceLengthShape const& sequence_length_shape, - int const& l_coord, - int const& q_head_coord = 0) { - auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = select<0, 1, 2, 6, 7>(problem_shape); - auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape; + CUTLASS_DEVICE static constexpr Params + get_updated_copies(Params const ¶ms, ProblemShape const &problem_shape, + SequenceLengthShape const &sequence_length_shape, + int const &l_coord, int const &q_head_coord = 0) { + auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = + select<0, 1, 2, 6, 7>(problem_shape); + auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = sequence_length_shape; auto q_group_size = num_heads_q / num_heads_kv; auto kv_head_coord = q_head_coord / q_group_size; - int offset_q = 0, offset_k = 0, offset_v = 0, offset_k_cache = 0, offset_v_cache = 0; + int offset_q = 0, offset_k = 0, offset_v = 0, offset_k_cache = 0, + offset_v_cache = 0; int total_seq_len_kv_cache = 0; if constexpr (is_var_len) { auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; - auto kv_cached_cumulative_length = get<5>(problem_shape).cumulative_length; - - offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord] + q_head_coord * head_size_qk; - - offset_k_cache = kv_head_coord * head_size_qk; - offset_v_cache = kv_head_coord * head_size_vo; + auto kv_cumulative_length = get<4>(problem_shape).cumulative_length; + auto kv_cached_cumulative_length = + get<5>(problem_shape).cumulative_length; + + offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord] + + q_head_coord * head_size_qk; + + offset_k = num_heads_kv * head_size_qk * kv_cumulative_length[l_coord] + + kv_head_coord * head_size_qk; + offset_v = num_heads_kv * head_size_vo * kv_cumulative_length[l_coord] + + kv_head_coord * head_size_vo; + offset_k_cache = seq_len_kv_cache == 0 + ? 0 + : PagedKV? // For page_kv, there is no batch dimension. + kv_head_coord * head_size_qk + : num_heads_kv * head_size_qk * kv_cached_cumulative_length[l_coord] + kv_head_coord * head_size_qk; + offset_v_cache = seq_len_kv_cache == 0 + ? 0 + : PagedKV? // For page_kv, there is no batch dimension. + kv_head_coord * head_size_vo + : num_heads_kv * head_size_vo * kv_cached_cumulative_length[l_coord] + kv_head_coord * head_size_vo; total_seq_len_kv_cache = get<5>(problem_shape).total_length; } else { + offset_q = num_heads_q * head_size_qk * seq_len_qo * l_coord + + q_head_coord * head_size_qk; + + offset_k = num_heads_kv * head_size_qk * seq_len_kv * l_coord + + kv_head_coord * head_size_qk; + offset_v = num_heads_kv * head_size_vo * seq_len_kv * l_coord + + kv_head_coord * head_size_vo; + offset_k_cache = + seq_len_kv_cache == 0 + ? 0 : + PagedKV? + kv_head_coord * head_size_qk + : num_heads_kv * head_size_qk * seq_len_kv_cache * l_coord + kv_head_coord * head_size_qk; + offset_v_cache = + seq_len_kv_cache == 0 + ? 0 : + PagedKV? + kv_head_coord * head_size_vo + : num_heads_kv * head_size_vo * seq_len_kv_cache * l_coord + kv_head_coord * head_size_vo; + total_seq_len_kv_cache = batch * seq_len_kv_cache; } - auto q_traits = static_cast(params.gmem_tiled_copy_q); - const ElementQ* q_ptr = (const ElementQ*)q_traits.base_ptr; - auto k_traits_cache = static_cast(params.gmem_tiled_copy_k_cache); - const ElementK* k_cache_ptr = (const ElementK*)k_traits_cache.base_ptr; - auto v_traits_cache = static_cast(params.gmem_tiled_copy_v_cache); - const ElementV* v_cache_ptr = (const ElementV*)v_traits_cache.base_ptr; + auto q_traits = + static_cast(params.gmem_tiled_copy_q); + const ElementQ *q_ptr = (const ElementQ *)q_traits.base_ptr; + auto k_traits = + static_cast(params.gmem_tiled_copy_k); + const ElementK *k_ptr = (const ElementK *)k_traits.base_ptr; + auto v_traits = + static_cast(params.gmem_tiled_copy_v); + const ElementV *v_ptr = (const ElementV *)v_traits.base_ptr; + auto k_traits_cache = + static_cast(params.gmem_tiled_copy_k_cache); + const ElementK *k_cache_ptr = (const ElementK *)k_traits_cache.base_ptr; + auto v_traits_cache = + static_cast(params.gmem_tiled_copy_v_cache); + const ElementV *v_cache_ptr = (const ElementV *)v_traits_cache.base_ptr; // NHD format{batch, seq_len, head, dim_head} // stride {seq_len*head*dim_head, head*dim_head, dim_head, 1} - auto shape_q = make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); + auto shape_q = + make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); - - auto shape_k_cache = make_shape( - static_cast(PagedKV ? total_seq_len_kv_cache : seq_len_kv_cache), head_size_qk * num_heads_kv, 1); - StrideK stride_k_cache = cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); - auto shape_v_cache = make_shape( - head_size_vo * num_heads_kv, static_cast(PagedKV ? total_seq_len_kv_cache : seq_len_kv_cache), 1); - StrideV stride_v_cache = cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache); - auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), make_layout(shape_q, stride_q)); + auto shape_k = make_shape(static_cast(seq_len_kv), + num_heads_kv * head_size_qk, 1); + StrideK stride_k = cutlass::make_cute_packed_stride(StrideK{}, shape_k); + + auto shape_v = make_shape(head_size_vo * num_heads_kv, + static_cast(seq_len_kv), 1); + StrideV stride_v = cutlass::make_cute_packed_stride(StrideV{}, shape_v); + + auto shape_k_cache = make_shape(static_cast(PagedKV? total_seq_len_kv_cache : seq_len_kv_cache), + head_size_qk * num_heads_kv, 1); + StrideK stride_k_cache = + cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); + auto shape_v_cache = make_shape(head_size_vo * num_heads_kv, + static_cast(PagedKV? total_seq_len_kv_cache : seq_len_kv_cache), 1); + StrideV stride_v_cache = + cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache); + auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), + make_layout(shape_q, stride_q)); + auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), + make_layout(shape_k, stride_k)); + auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), + make_layout(shape_v, stride_v)); auto tensorK_cache = - make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), make_layout(shape_k_cache, stride_k_cache)); + make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), + make_layout(shape_k_cache, stride_k_cache)); auto tensorV_cache = - make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), make_layout(shape_v_cache, stride_v_cache)); + make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), + make_layout(shape_v_cache, stride_v_cache)); XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; + XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{ - copyQ, - copyK_cache, - copyV_cache, - params.ptr_page_table, - params.page_size, - params.max_num_pages_per_seq, - params.window_left, - params.window_right}; + + return Params{copyQ, + copyK, + copyV, + params.ptr_q_scale, + params.ptr_k_scale, + params.ptr_v_scale, + copyK_cache, + copyV_cache, + params.ptr_page_table, + params.page_size, + params.num_pages_per_seq, + params.window_left, + params.window_right}; } }; -} // namespace cutlass::flash_attention::collective +} // namespace cutlass::flash_attention::collective ///////////////////////////////////////////////////////////////////////////////////////////////// From 06ae0d85987df3d7bfa8504cc2b456d5af5a697c Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Mon, 27 Oct 2025 07:05:41 +0000 Subject: [PATCH 02/11] Rebased restructured code Signed-off-by: Aditya Chatterjee --- src/sycl/chunked_prefill.cpp | 27 +- .../chunk_prefill/xe_chunk_prefill.hpp | 521 +++++++----------- .../xe_flash_attn_chunk_prefill_mma.hpp | 447 +++++++-------- 3 files changed, 411 insertions(+), 584 deletions(-) diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 878dcdc..ac5bee0 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -66,6 +66,7 @@ struct Flash_fwd_params { float scale_softmax; void* sink_softmax; float softcap; + float* __restrict__ q_scale_ptr; float* __restrict__ k_scale_ptr; float* __restrict__ v_scale_ptr; @@ -118,7 +119,6 @@ struct Flash_fwd_params { // Paged KV cache int* __restrict__ page_table; - int* __restrict__ num_pages_per_seq_ptr; int max_num_pages_per_seq; index_t page_table_batch_stride; int page_size; @@ -318,19 +318,20 @@ struct KernelRunner { typename FMHAChunkPrefillKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {static_cast(params.q_ptr), + {// static_cast(params.q_ptr), + static_cast(params.q_ptr), stride_Q, - static_cast(params.knew_ptr), - stride_K, - static_cast(params.vnew_ptr), - stride_V, - params.q_scale_ptr, - params.k_scale_ptr, - params.v_scale_ptr, - static_cast(params.k_ptr), + // static_cast(params.knew_ptr), + // stride_K, + // static_cast(params.vnew_ptr), + // stride_V, + static_cast(params.k_ptr), stride_K_cache, static_cast(params.v_ptr), stride_V_cache, + params.q_scale_ptr, + params.k_scale_ptr, + params.v_scale_ptr, params.page_table, params.page_size, params.max_num_pages_per_seq, @@ -638,15 +639,15 @@ std::vector mha_fwd( params.v_scale_ptr = static_cast(v_descale_.value().data_ptr()); } - if (!is_varlen_q) { + /*if (!is_varlen_q) { params.q_batch_stride = q.stride(0); params.o_batch_stride = out.stride(0); } if (!is_varlen_k) { params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); - } - + }*/ + params.cu_seqlens_q = cu_seqlens_q.data_ptr(); params.cu_seqlens_k = cu_seqlens_k.data_ptr(); diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp index 4a40147..b50e2be 100644 --- a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -38,17 +38,22 @@ namespace cutlass::flash_attention::kernel { -template +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveSoftmaxEpilogue_, + class CollectiveEpilogue_, + class TileScheduler_ = void> class FMHAPrefillChunk; /////////////////////////////////////////////////////////////////////////////// -template +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveSoftmaxEpilogue_, + class CollectiveEpilogue_, + class TileScheduler_> class FMHAPrefillChunk { - -public: + public: // // Type Aliases // @@ -82,14 +87,12 @@ class FMHAPrefillChunk { using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params; - static_assert(cute::is_void_v or - cute::is_same_v or - cute::is_same_v, - "Unsupported TileScheduler for Intel Xe."); + static_assert( + cute::is_void_v or cute::is_same_v or + cute::is_same_v, + "Unsupported TileScheduler for Intel Xe."); using TileSchedulerTag = TileScheduler_; - using TileScheduler = - typename detail::TileSchedulerSelector::Scheduler; + using TileScheduler = typename detail::TileSchedulerSelector::Scheduler; using TileSchedulerParams = typename TileScheduler::Params; // Epilogue derived types @@ -106,8 +109,7 @@ class FMHAPrefillChunk { static constexpr bool Sink = CollectiveEpilogue::Sink; static_assert( - cute::is_same_v, + cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); // MSVC requires the cast to fix a warning-as-error. static constexpr int SharedStorageSize = 0; @@ -118,12 +120,9 @@ class FMHAPrefillChunk { static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local"); static constexpr bool PagedKV = CollectiveMainloop::PagedKV; - - static constexpr int SubgroupSize = - CollectiveMainloop::SubgroupSize; // sub_group size - static constexpr uint32_t MaxThreadsPerBlock = - CollectiveMainloop::MaxThreadsPerBlock; - using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 + static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 static constexpr int QK_BLK_M = CollectiveMainloop::QK_BLK_M; static constexpr int QK_BLK_N = CollectiveMainloop::QK_BLK_N; @@ -149,11 +148,9 @@ class FMHAPrefillChunk { static constexpr int FragsN = CollectiveMainloop::FragsNS; static constexpr int VSlicer = - get<1>(TileShapeOutput{}) / - (get<1>(TileShapePV{}) * PV_ATOM_N); // ceil_div(FragsNOut,FragsNS); - using AccumeShape = decltype(make_shape( - Int{}, Int{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), - Int{})); + get<1>(TileShapeOutput{}) / (get<1>(TileShapePV{}) * PV_ATOM_N); // ceil_div(FragsNOut,FragsNS); + using AccumeShape = + decltype(make_shape(Int{}, Int{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), Int{})); static constexpr bool is_var_len = CollectiveMainloop::is_var_len; // Kernel level shared memory storage @@ -188,58 +185,55 @@ class FMHAPrefillChunk { // Convert to underlying arguments. In this case, a simple copy for the // aliased type. - static Params to_underlying_arguments(Arguments const &args, - void *workspace) { + static Params to_underlying_arguments(Arguments const& args, void* workspace) { (void)workspace; - return {args.mode, - args.problem_shape, - CollectiveMainloop::to_underlying_arguments( - args.problem_shape, args.mainloop, workspace), - CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), - CollectiveEpilogue::to_underlying_arguments( - args.problem_shape, args.epilogue, workspace), - TileScheduler::to_underlying_arguments( - args.problem_shape, args.hw_info, TileShapeOutput{})}; + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, TileShapeOutput{})}; } - static bool can_implement(Arguments const &args) { + static bool can_implement(Arguments const& args) { bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or - (args.mode == gemm::GemmUniversalMode::kBatched && - rank(ProblemShape{}) == 4); + (args.mode == gemm::GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); return mode_implementable; } - static int get_workspace_size(Arguments const &args) { return 0; } + static int get_workspace_size(Arguments const& args) { + return 0; + } - static cutlass::Status - initialize_workspace(Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { + static cutlass::Status initialize_workspace( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { return Status::kSuccess; } - static dim3 get_grid_shape(Params const ¶ms) { + static dim3 get_grid_shape(Params const& params) { return TileScheduler::template get_grid_shape(params.scheduler); } - static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } CUTLASS_DEVICE - Shape - get_sequence_length_shape(ProblemShape const &problem_shape, - int const &batch) { + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { if constexpr (is_var_len) { - return cutlass::fmha::collective::apply_variable_length( - select<3, 4, 5>(problem_shape), batch); + return cutlass::fmha::collective::apply_variable_length(select<3, 5>(problem_shape), batch); } else { - return select<3, 4, 5>(problem_shape); + return select<3, 5>(problem_shape); } } CUTLASS_DEVICE - void operator()(Params const ¶ms, char *smem_buf) { - SharedStorage &shared_storage = - *reinterpret_cast(smem_buf); + void operator()(Params const& params, char* smem_buf) { + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); // Preconditions CUTE_STATIC_ASSERT(is_static::value); CUTE_STATIC_ASSERT(is_static::value); @@ -251,34 +245,37 @@ class FMHAPrefillChunk { auto num_heads_q = get<1>(params.problem_shape); auto num_heads_kv = get<2>(params.problem_shape); - auto &head_size_qk = get<6>(params.problem_shape); - auto &head_size_vo = get<7>(params.problem_shape); + auto& head_size_qk = get<6>(params.problem_shape); + auto& head_size_vo = get<7>(params.problem_shape); // Preconditions - static_assert(cute::rank(StrideQ{}) == 3, - "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * " - "num_heads_q]."); - static_assert(cute::rank(StrideK{}) == 3, - "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * " - "num_heads_kv]."); - static_assert(cute::rank(StrideV{}) == 3, - "StrideV must be rank-3: [seq_len_kv, head_size_vo, batch * " - "num_heads_kv]."); + static_assert( + cute::rank(StrideQ{}) == 3, + "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * " + "num_heads_q]."); + static_assert( + cute::rank(StrideK{}) == 3, + "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * " + "num_heads_kv]."); + static_assert( + cute::rank(StrideV{}) == 3, + "StrideV must be rank-3: [seq_len_kv, head_size_vo, batch * " + "num_heads_kv]."); int thread_idx = int(ThreadIdxX()); - int sub_group_id = thread_idx / SubgroupSize; + // int sub_group_id = thread_idx / SubgroupSize; + auto sub_group_id = get_sub_group_id(); + auto local_id = get_sub_group_local_id(); TileScheduler tile_scheduler{params.scheduler}; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { - auto blk_coord = - tile_scheduler - .get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, - // batch_blk_idx, num_heads_blk_idx + auto blk_coord = tile_scheduler.get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, + // batch_blk_idx, num_heads_blk_idx - auto blk_m_coord = get<0>(blk_coord); // seq_len_blk_idx + auto blk_m_coord = get<0>(blk_coord); // seq_len_blk_idx auto blk_n_coord = 0; // nums_head_blk_idx - auto q_head_coord = get<1>(blk_coord); // q_heads_idx - auto batch_coord = get<2>(blk_coord); // batch_blk_idx + auto q_head_coord = get<1>(blk_coord); // q_heads_idx + auto batch_coord = get<2>(blk_coord); // batch_blk_idx // For variable sequence length case, batch is considered to be 1 (same // as group gemm). For fixed sequence length case, the l_coord is the @@ -296,10 +293,9 @@ class FMHAPrefillChunk { // 5>(params.problem_shape). sequence_length_shape = [batch, // num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, // head_size_qk, head_size_vo] - auto sequence_length_shape = - get_sequence_length_shape(params.problem_shape, batch_coord); + auto sequence_length_shape = get_sequence_length_shape(params.problem_shape, batch_coord); - auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = sequence_length_shape; + auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape; // int seq_len_kv_total = seq_len_kv_cache + seq_len_kv; // For variable sequence length case, batch is considered to be 1 (same // as group gemm). For fixed sequence length case, the l_coord is the @@ -312,35 +308,28 @@ class FMHAPrefillChunk { // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) // and check if it is still within bounds of the actual seq_len_qo // (get<0>(sequence_length_shape)). - if (blk_m_coord * get<0>(TileShapeOutput{}) >= - seq_len_qo) { + if (blk_m_coord * get<0>(TileShapeOutput{}) >= seq_len_qo) { continue; } const int seq_coord = - cute::min(seq_len_qo, (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % - seq_len_qo); - auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024) - auto discard_seq_coord = seq_len_qo - offset; // 1024 - auto full_tile_offset = seq_len_kv - offset; // 0 - - const int seq_len = - CausalMask - ? full_tile_offset + - cute::min(seq_len_kv, seq_coord - discard_seq_coord) + - QK_SG_M - : seq_len_kv; - - const int kv_splits_new = cute::ceil_div(seq_len, QK_BLK_N); + cute::min(seq_len_qo, (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % seq_len_qo); + // auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024) + // auto discard_seq_coord = seq_len_qo - offset; // 1024 + // auto full_tile_offset = seq_len_kv - offset; // 0 + + // const int seq_len = seq_len_kv; + // CausalMask + // ? full_tile_offset + + // cute::min(seq_len_kv, seq_coord - discard_seq_coord) + + // QK_SG_M + // : seq_len_kv; + const int kv_splits_cache = cute::ceil_div(seq_len_kv_cache, QK_BLK_N); - const int kv_splits = kv_splits_cache + kv_splits_new; + const int kv_splits = kv_splits_cache; int tiles_per_page = params.mainloop.page_size / QK_BLK_N; - if (CausalMask && seq_coord < discard_seq_coord) { // 1024 =0 - continue; - } - auto q_group_size = num_heads_q / num_heads_kv; auto kv_head_coord = q_head_coord / q_group_size; @@ -357,101 +346,69 @@ class FMHAPrefillChunk { ? 1.f : params.mainloop.ptr_v_scale[batch_coord * num_heads_kv + kv_head_coord]; - Tensor mQ_mkl = cute::get_xe_tensor( - make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + Tensor mQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) - Tensor mK_nkl = cute::get_xe_tensor( - make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) - Tensor mV_nkl = cute::get_xe_tensor( - make_shape(head_size_vo, seq_len_kv, 1)); //(n,k,l) - Tensor mK_cache_nkl = cute::get_xe_tensor( - make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) - Tensor mV_cache_nkl = cute::get_xe_tensor( - make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) + Tensor mK_cache_nkl = cute::get_xe_tensor(make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) + Tensor mV_cache_nkl = cute::get_xe_tensor(make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) // block_size and head_size are the same size. So no coord is needed. Tensor mQ_mk = mQ_mkl(_, _, 0); - Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) - Tensor mV_nk = mV_nkl(_, _, 0); + Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k) + Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k) - Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k) - Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k) + auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); - auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), - Step<_1, X, _1>{}); - auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _, _), - Step{}); - - auto gV = local_tile(mV_nk, TileShapeOutput{}, - make_coord(_, blk_n_coord, _), Step{}); - auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, - make_coord(_, _, _), Step{}); - auto gV_cache = - local_tile(mV_cache_nk, TileShapeOutput{}, - make_coord(_, blk_n_coord, _), Step{}); + auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gV_cache = local_tile(mV_cache_nk, TileShapeOutput{}, make_coord(_, blk_n_coord, _), Step{}); auto mainloop_params = CollectiveMainloop::get_updated_copies( - params.mainloop, params.problem_shape, sequence_length_shape, - batch_coord, q_head_coord); - + params.mainloop, params.problem_shape, sequence_length_shape, batch_coord, q_head_coord); - // we limit the horisontal size to two subgroup, the empirical resutls + // we limit the horizontal size to two subgroup, the empirical results // show that reading the two cacheline side by side in gives better // performance and anything after that does not have an effect on // performance. // (64 here for float b float when possible and loop over // to cover all the data needed) - auto tiled_prefetch_q = cute::prefetch_selector< - Shape, Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_q); - auto tiled_prefetch_k = cute::prefetch_selector< - Shape, Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_k); - auto tiled_prefetch_v = cute::prefetch_selector< - Shape, - Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_v); - auto tiled_prefetch_k_cache = cute::prefetch_selector< - Shape, Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_k_cache); - auto tiled_prefetch_v_cache = cute::prefetch_selector< - Shape, - Int>, - Num_SGs>(mainloop_params.gmem_tiled_copy_v_cache); + auto tiled_prefetch_q = + cute::prefetch_selector, Int>, Num_SGs>( + mainloop_params.gmem_tiled_copy_q); + + auto tiled_prefetch_k_cache = + cute::prefetch_selector, Int>, Num_SGs>( + mainloop_params.gmem_tiled_copy_k_cache); + auto tiled_prefetch_v_cache = cute:: + prefetch_selector, Int>, Num_SGs>( + mainloop_params.gmem_tiled_copy_v_cache); auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); - auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); - auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); + auto thr_prefetch_K_cache = tiled_prefetch_k_cache.get_slice(thread_idx); + auto thr_prefetch_V_cache = tiled_prefetch_v_cache.get_slice(thread_idx); auto pQgQ = thr_prefetch_Q.partition_S(gQ); - auto pKgK = thr_prefetch_K.partition_S(gK); - auto pVgV = thr_prefetch_V.partition_S(gV); + // assuming the copy function is the same otherwise this need to have its // own tile_prefetch - auto pKgK_cache = thr_prefetch_K.partition_S(gK_cache); - auto pVgV_cache = thr_prefetch_V.partition_S(gV_cache); + auto pKgK_cache = thr_prefetch_K_cache.partition_S(gK_cache); + auto pVgV_cache = thr_prefetch_V_cache.partition_S(gV_cache); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<3>(pQgQ); i++) { prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); } - auto &prefetch_K = - (seq_len_kv_cache == 0) ? tiled_prefetch_k : tiled_prefetch_k_cache; - auto &pKgK1_ = (seq_len_kv_cache == 0) ? pKgK : pKgK_cache; + auto& prefetch_K = tiled_prefetch_k_cache; + auto& pKgK1_ = pKgK_cache; int cached_nblock = 0; if constexpr (PagedKV) { - int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); - int batch_offset = - is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] - : batch_coord * curr_batch_pages; - cached_nblock = - mainloop_params - .ptr_page_table[batch_offset // page table for this batch - ] * tiles_per_page; // base block idx of physical page + // int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size);// max_page_size_per_seq + // int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * + // curr_batch_pages; + int batch_offset = batch_coord * mainloop_params.max_num_pages_per_seq; + cached_nblock = mainloop_params.ptr_page_table[batch_offset // page table for this batch + ] * tiles_per_page; // base block idx of physical page } // The headsize for both cached and non-cached version is the same for (int j = 0; j < size<4>(pKgK1_); j++) { CUTLASS_PRAGMA_UNROLL - for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; - i++) { + for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; i++) { prefetch(prefetch_K, pKgK1_(_, _, _, i, j)); } } @@ -460,13 +417,12 @@ class FMHAPrefillChunk { // workgroup_shape Tensor out_reg = make_tensor(AccumeShape{}); - // There are 16 workitem and 16 max per subgroup, each worktime containt 1 + // There are 16 workitem and 16 max per subgroup, each worktime contains 1 // max and cumulatively, they calculate the max per subgroup ElementAccumulator max_reg{-INFINITY}; - // The sum reg each contains a 2d tesnor for 8 x 2 This is number of - // sequence lenght process per subgroup - Tensor sum_reg = - make_tensor(Shape, Int>{}); + // The sum reg each contains a 2d tensor for 8 x 2 This is number of + // sequence length process per subgroup + Tensor sum_reg = make_tensor(Shape, Int>{}); clear(sum_reg); clear(out_reg); @@ -477,54 +433,56 @@ class FMHAPrefillChunk { // different for each subgroup due to triangular nature of causal based // operation static constexpr int barrier_scope = CausalMask ? 3 : 2; + + int q_start_coord = blk_m_coord * QK_BLK_M; + int q_end_coord = cute::min(q_start_coord + QK_BLK_M, seq_len_qo); + int seq_diff = seq_len_kv_cache - seq_len_qo; + CUTLASS_PRAGMA_UNROLL - for (int split = 0; split < kv_splits - static_cast(CausalMask); split++) { + for (int split = 0; split < kv_splits; split++) { barrier_arrive(barrier_scope); - bool is_KV_cache = split < kv_splits_cache; + int kv_start_coord = split * QK_BLK_N; + + if constexpr (CausalMask) { + if (kv_start_coord >= q_end_coord + seq_diff) break; + } + + // // = 0, all KV is kv_cache // 1) Load KV (performed inside mmaQK) - auto gK_ = is_KV_cache ? gK_cache(_, _, cached_nblock, _) - : gK(_, _, split - kv_splits_cache, _); - auto gV_ = is_KV_cache ? gV_cache(_, _, cached_nblock) - : gV(_, _, split - kv_splits_cache); + auto gK_ = gK_cache(_, _, cached_nblock, _); + auto gV_ = gV_cache(_, _, cached_nblock); // 2) Create Tensor S - Tensor tSr = make_tensor( - Shape, Int, Int>{}); + Tensor tSr = make_tensor(Shape, Int, Int>{}); clear(tSr); // 3) Perform GEMM S = Q*K // Then modify layout to LayoutQ = ((seq_leq_q, group_head_q), // head_size_qk, batch* num_heads_q / group_head_q), which can be merged // into one gemm for (int i = 0; i < q_group_size; ++i) { - collective_mma.mmaQK(tSr, gQ, gK_, tSr, - ceil_div(head_size_qk, QK_BLK_K), mainloop_params, - is_KV_cache, q_scale_val, k_scale_val); + collective_mma.mmaQK(tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, true, q_scale_val, k_scale_val); if constexpr (LocalMask) { // Sliding windows // mask the elements of each tile where j - left > i || j + right < i const int item_id = thread_idx % SubgroupSize; - int col_idx; - if (split < kv_splits_cache) { - col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache) ; - } else { - col_idx = item_id + seq_len_kv_cache + (split - kv_splits_cache) * cute::min(QK_BLK_N, seq_len_kv); - } + int col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache); CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < FragsN; - n++, col_idx += get<1>(MmaAtomShape())) { // 4 + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + int col_ref = seq_len_kv_cache - seq_len_qo; + // int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < FragsM; m++) { // 2 - int row_idx = m * Vec + seq_coord; - int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < Vec; row++) { // 8 - bool left_mask = col_idx < cute::max(0, row + row_idx + col_ref - mainloop_params.window_left); - bool right_mask = col_idx > cute::min(seq_len_kv_cache + seq_len_kv, row + row_idx + col_ref + mainloop_params.window_right); - if (left_mask || right_mask) { - tSr(row, m, n) = ElementAccumulator{-INFINITY}; - } + for (int row = 0; row < Vec; row++) { // 8 + bool left_mask = col_idx < cute::max(0, row + row_idx + col_ref - mainloop_params.window_left); + bool right_mask = + col_idx > cute::min(seq_len_kv_cache, row + row_idx + col_ref + mainloop_params.window_right); + if (left_mask || right_mask) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; } + } } } } @@ -534,54 +492,64 @@ class FMHAPrefillChunk { if (col_end >= seq_len_kv_cache) { int col_idx = col_start; CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < FragsM; m++) { // 2 - int row_idx = m * Vec + seq_coord; + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 + if (col_idx >= seq_len_kv_cache) { CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < Vec; row++) { // 8 - if (col_idx >= seq_len_kv_cache + seq_len_kv || row_idx + row >= seq_len_qo) { + for (int m = 0; m < FragsM; m++) { // 2 + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + if constexpr (CausalMask) { + int row_start = q_start_coord + sub_group_id * QK_SG_M; + if (row_start + seq_diff < col_end) { + int col_idx = col_start; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 + if (col_idx > row_start + seq_diff) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + int row_idx = row_start + m * Vec + row; + if (row_idx + seq_diff < col_idx) tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } } } } } } - auto &tiled_prefetch_v_ = - is_KV_cache ? tiled_prefetch_v_cache - : tiled_prefetch_v; - auto &pVgV_ = is_KV_cache ? pVgV_cache : pVgV; - int v_prefetch_idx = is_KV_cache ? PagedKV ? cached_nblock : split - : split - kv_splits_cache; + auto& tiled_prefetch_v_ = tiled_prefetch_v_cache; + auto& pVgV_ = pVgV_cache; + int v_prefetch_idx = cached_nblock; for (int i = 0; i < size<1>(pVgV_); i++) { prefetch(tiled_prefetch_v_, pVgV_(_, i, _, v_prefetch_idx)); } int next_cached_nblock = split + 1; - bool is_next_KV_cache = next_cached_nblock < kv_splits_cache; if constexpr (PagedKV) { - if (is_next_KV_cache) { - int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); - int next_page_logical_idx = - next_cached_nblock * QK_BLK_N / params.mainloop.page_size; - int batch_offset = - is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] - : batch_coord * curr_batch_pages; - bool valid_page = next_page_logical_idx < curr_batch_pages; - // get physical page idx from page table - if (valid_page) { - next_cached_nblock = - params.mainloop.ptr_page_table - [batch_offset + // page table for this batch - next_page_logical_idx // split (tile idx) to logical - // page idx - ] * tiles_per_page + // base block idx of physical page - next_cached_nblock % tiles_per_page; // offset within page - } else { - next_cached_nblock = - curr_batch_pages * - tiles_per_page; // push idx out of bounds to respect the - // boundary between batches - } + // int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); + // int batch_offset = + // is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages; + int curr_batch_pages = mainloop_params.max_num_pages_per_seq; // max_page_size_per_seq + int batch_offset = batch_coord * curr_batch_pages; + int next_page_logical_idx = next_cached_nblock * QK_BLK_N / params.mainloop.page_size; + bool valid_page = next_page_logical_idx < curr_batch_pages; + // get physical page idx from page table + if (valid_page) { + next_cached_nblock = params.mainloop.ptr_page_table + [batch_offset + // page table for this batch + next_page_logical_idx // split (tile idx) to logical + // page idx + ] * tiles_per_page + // base block idx of physical page + next_cached_nblock % tiles_per_page; // offset within page + } else { + next_cached_nblock = curr_batch_pages * tiles_per_page; // push idx out of bounds to respect the + // boundary between batches } } @@ -590,8 +558,7 @@ class FMHAPrefillChunk { softmax(split == 0, tSr, max_reg, sum_reg, out_reg); // 5) Perform GEMM O = S*V - collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, - mainloop_params, is_KV_cache, v_scale_val); + collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, mainloop_params, true, v_scale_val); // ... prefetch next tile ... // Prefetch the next Q tile CUTLASS_PRAGMA_UNROLL @@ -599,82 +566,20 @@ class FMHAPrefillChunk { prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); } - is_KV_cache = is_next_KV_cache; cached_nblock = next_cached_nblock; // Prefetch the next K tile - // there is no need to gaurd it with if statememt as prefetch will + // there is no need to guard it with if statement as prefetch will // ignore out of bound reading - if constexpr (PagedKV) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<4>(pKgK_cache); j++) { - prefetch(tiled_prefetch_k_cache, pKgK_cache(_, _, _, cached_nblock, j)); - } - } else { - bool sel_prefetch_k = - (split + DispatchPolicy::Stages) < kv_splits_cache; - auto &prefetch_k_selector = - sel_prefetch_k ? tiled_prefetch_k_cache : tiled_prefetch_k; - auto &pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK; - int k_prefetch_idx = - sel_prefetch_k - ? PagedKV ? cached_nblock : split + DispatchPolicy::Stages - : split + DispatchPolicy::Stages - kv_splits_cache; - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<4>(pKgK_); j++) { - prefetch(prefetch_k_selector, pKgK_(_, _, _, k_prefetch_idx, j)); - } - } - barrier_wait(barrier_scope); - } - - if constexpr (CausalMask) { - // BAND Matrix - // 1) Load K (performed inside mmaQK) - // 2) Create Tensor S - Tensor tSr = make_tensor( - Shape, Int, Int>{}); - clear(tSr); - // 3) Perform GEMM S = Q*K - collective_mma.mmaQK(tSr, gQ, gK(_, _, kv_splits_new - 1, _), tSr, - ceil_div(head_size_qk, QK_BLK_K), mainloop_params, - false, q_scale_val, k_scale_val); - // we only need one block ahead, there is enough gap to prefetch it - // while doing softmax. because the gap between the two MMA is big, - // prefetching it the same way as cutlass K matrix does not make sense - for (int i = 0; i < size<1>(pVgV); i++) { - prefetch(tiled_prefetch_v, pVgV(_, i, _, kv_splits_new - 1)); - } - // mask the elements of each tile where j > i - const int item_id = thread_idx % SubgroupSize; - int col_idx = item_id + (kv_splits_new - 1) * QK_BLK_N; CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < FragsN; - n++, col_idx += get<1>(MmaAtomShape())) { // 4 - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < FragsM; m++) { // 2 - int row_idx = m * Vec + seq_coord; - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < Vec; row++, row_idx++) { // 8 - if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { - tSr(row, m, n) = ElementAccumulator{-INFINITY}; - } - } - } + for (int j = 0; j < size<4>(pKgK_cache); j++) { + prefetch(tiled_prefetch_k_cache, pKgK_cache(_, _, _, cached_nblock, j)); } - - CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax((kv_splits - 1) == 0, tSr, max_reg, sum_reg, out_reg); - collective_mma.template mmaPV(out_reg, tSr, - gV(_, _, kv_splits_new - 1), - out_reg, mainloop_params, false, v_scale_val); + barrier_wait(barrier_scope); } - // Epilogue - auto epilogue_params = - CollectiveEpilogue::template get_updated_copies( - params.epilogue, params.problem_shape, sequence_length_shape, - batch_coord, q_head_coord); + auto epilogue_params = CollectiveEpilogue::template get_updated_copies( + params.epilogue, params.problem_shape, sequence_length_shape, batch_coord, q_head_coord); CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; auto blk_coord_mnkl = make_coord(blk_m_coord, blk_n_coord, _, 0); if constexpr (Sink) { @@ -696,4 +601,4 @@ class FMHAPrefillChunk { /////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::flash_attention::kernel +} // namespace cutlass::flash_attention::kernel diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index 2002dfc..97e5e57 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -39,9 +39,7 @@ #include "fp8_descale.h" //////////////////////////////////////////////////////////// -namespace { - -} +namespace {} ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -51,29 +49,69 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template < + class DispatchPolicy, + class ProblemShapeType_, + class ElementQ_, + class StrideQ_, + class ElementK_, + class StrideK_, + class ElementV_, + class StrideV_, + class MMAOperation_, + class TileShapeQK_, + class TileShapePV_, + class SubgroupLayout_, + class GmemTiledCopyQ_, + class GmemTiledCopyK_, + class GmemTiledCopyV_, + bool CausalMask_, + bool LocalMask_, + bool PagedKV_> struct FlashChunkPrefillMma { - static_assert(cutlass::detail::dependent_false, - "Could not find a mainloop specialization."); + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); }; ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template < + int Stages, + class ProblemShapeType_, + class ElementQ_, + class StrideQ_, + class ElementK_, + class StrideK_, + class ElementV_, + class StrideV_, + class MMAOperation_, + class TileShapeQK_, + class TileShapePV_, + class SubgroupLayout_, + class GmemTiledCopyQ_, + class GmemTiledCopyK_, + class GmemTiledCopyV_, + bool CausalMask_, + bool LocalMask_, + bool PagedKV_> struct FlashChunkPrefillMma< - gemm::MainloopIntelXeXMX16, ProblemShapeType_, ElementQ_, StrideQ_, - ElementK_, StrideK_, ElementV_, StrideV_, MMAOperation_, TileShapeQK_, - TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, - GmemTiledCopyV_, CausalMask_, LocalMask_, PagedKV_> { + gemm::MainloopIntelXeXMX16, + ProblemShapeType_, + ElementQ_, + StrideQ_, + ElementK_, + StrideK_, + ElementV_, + StrideV_, + MMAOperation_, + TileShapeQK_, + TileShapePV_, + SubgroupLayout_, + GmemTiledCopyQ_, + GmemTiledCopyK_, + GmemTiledCopyV_, + CausalMask_, + LocalMask_, + PagedKV_> { // // Type Aliases // @@ -94,11 +132,9 @@ struct FlashChunkPrefillMma< using ArchTag = typename DispatchPolicy::ArchTag; using MmaAtom = MMA_Atom; - using TiledMmaQK = typename TiledMMAHelper, - SubgroupLayout>::TiledMMA; + using TiledMmaQK = typename TiledMMAHelper, SubgroupLayout>::TiledMMA; - using TiledMmaPV = typename TiledMMAHelper, - SubgroupLayout>::TiledMMA; + using TiledMmaPV = typename TiledMMAHelper, SubgroupLayout>::TiledMMA; using ElementAccumulator = typename TiledMmaQK::ValTypeC; static constexpr bool CausalMask = CausalMask_; static constexpr bool LocalMask = LocalMask_; @@ -108,15 +144,11 @@ struct FlashChunkPrefillMma< using MmaAtomShape = typename MmaAtom::Shape_MNK; - static constexpr auto PV_ATOM_M = - decltype(get<0>(SubgroupLayout{}.shape()))::value; - static constexpr auto PV_ATOM_N = - decltype(get<1>(SubgroupLayout{}.shape()))::value; - static constexpr auto PV_ATOM_K = - decltype(get<2>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_M = decltype(get<0>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_N = decltype(get<1>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_K = decltype(get<2>(SubgroupLayout{}.shape()))::value; - using SubgroupTileShapePV = - decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape()))); + using SubgroupTileShapePV = decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape()))); static constexpr auto QK_BLK_M = get<0>(TileShapeQK{}); static constexpr auto QK_BLK_N = get<1>(TileShapeQK{}); static constexpr auto QK_BLK_K = get<2>(TileShapeQK{}); @@ -124,94 +156,76 @@ struct FlashChunkPrefillMma< // This TiledMma is only required to serve the specific tiling requirements // for matrix K. This is due to the consumption of matrix K by all subgroups // within a workgroup. - static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8 - static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1 - static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1 + static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8 + static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1 + static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1 - using SubgroupTileShapeQK = decltype(cute::shape_div( - TileShapeQK{}, - SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 ) + using SubgroupTileShapeQK = + decltype(cute::shape_div(TileShapeQK{}, SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 ) static constexpr auto QK_SG_M = get<0>(SubgroupTileShapeQK{}); static constexpr auto QK_SG_N = get<1>(SubgroupTileShapeQK{}); static constexpr auto QK_SG_K = get<2>(SubgroupTileShapeQK{}); static constexpr bool is_var_len = - cutlass::fmha::collective::is_variable_length_v< - tuple_element_t<3, ProblemShapeType>>; + cutlass::fmha::collective::is_variable_length_v>; using FragsShapeS = decltype(cute::shape_div( - take<0, 2>(SubgroupTileShapeQK{}), - take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4) - static constexpr int Vec = - (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8 + take<0, 2>(SubgroupTileShapeQK{}), take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4) + static constexpr int Vec = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8 static constexpr int FragsM = get<0>(FragsShapeS{}); - static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4 + static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4 - static constexpr uint32_t MaxThreadsPerBlock = - size(SubgroupLayout{}) * SubgroupSize; + static constexpr uint32_t MaxThreadsPerBlock = size(SubgroupLayout{}) * SubgroupSize; using CopyThreadShape = Shape<_1, Int>; using traits_load_Q = Copy_Traits; using atom_load_Q = Copy_Atom; - using val_layout_load_Q = decltype(make_layout( - shape_div(typename traits_load_Q::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_Q = decltype(make_tiled_copy( - atom_load_Q{}, Layout{}, val_layout_load_Q{})); + using val_layout_load_Q = decltype(make_layout(shape_div(typename traits_load_Q::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_Q = decltype(make_tiled_copy(atom_load_Q{}, Layout{}, val_layout_load_Q{})); using traits_load_K = Copy_Traits; using atom_load_K = Copy_Atom; - using val_layout_load_K = decltype(make_layout( - shape_div(typename traits_load_K::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_K = decltype(make_tiled_copy( - atom_load_K{}, Layout{}, val_layout_load_K{})); + using val_layout_load_K = decltype(make_layout(shape_div(typename traits_load_K::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_K = decltype(make_tiled_copy(atom_load_K{}, Layout{}, val_layout_load_K{})); using traits_load_V = Copy_Traits; using atom_load_V = Copy_Atom; - using val_layout_load_V = decltype(make_layout( - shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); - using XE_Copy_V = decltype(make_tiled_copy( - atom_load_V{}, Layout{}, val_layout_load_V{})); + using val_layout_load_V = decltype(make_layout(shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{}, Layout{}, val_layout_load_V{})); template static constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; // Host side kernel arguments struct Arguments { - ElementQ const *ptr_Q; + ElementQ const* ptr_Q; StrideQ dQ; - ElementK const *ptr_K; - StrideK dK; - ElementV const *ptr_V; - StrideV dV; + ElementK const* ptr_K_cache; + StrideK dK_cache; + ElementV const* ptr_V_cache; + StrideV dV_cache; float const *ptr_q_scale; float const *ptr_k_scale; float const *ptr_v_scale; - ElementK const *ptr_K_cache; - StrideK dK_cache; - ElementV const *ptr_V_cache; - StrideV dV_cache; // Paged KV Cache - int const *ptr_page_table; + int const* ptr_page_table; int page_size; - int const *num_pages_per_seq; + int max_num_pages_per_seq; int window_left; int window_right; }; struct Params { XE_Copy_Q gmem_tiled_copy_q; - XE_Copy_K gmem_tiled_copy_k; - XE_Copy_V gmem_tiled_copy_v; + XE_Copy_K gmem_tiled_copy_k_cache; + XE_Copy_V gmem_tiled_copy_v_cache; float const *ptr_q_scale; float const *ptr_k_scale; float const *ptr_v_scale; - XE_Copy_K gmem_tiled_copy_k_cache; - XE_Copy_V gmem_tiled_copy_v_cache; - // Paged KV Cache - int const *ptr_page_table; + int const* ptr_page_table; int page_size; - int const *num_pages_per_seq; + int max_num_pages_per_seq; int window_left; int window_right; }; @@ -223,67 +237,50 @@ struct FlashChunkPrefillMma< FlashChunkPrefillMma() = default; static constexpr Params - to_underlying_arguments(ProblemShapeType const &problem_shape, - Arguments const &args, void *workspace) { + to_underlying_arguments(ProblemShapeType const& problem_shape, Arguments const& args, void* workspace) { (void)workspace; - auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, - seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = + problem_shape; auto tensorQ = make_tensor( - make_gmem_ptr(args.ptr_Q), - make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), - args.dQ)); - auto tensorK = make_tensor( - make_gmem_ptr(args.ptr_K), - make_layout(make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch), - args.dK)); - auto tensorV = make_tensor( - make_gmem_ptr(args.ptr_V), - make_layout(make_shape(num_heads_kv * head_size_vo, seq_len_kv, batch), - args.dV)); - auto tensorK_cache = - make_tensor(make_gmem_ptr(args.ptr_K_cache), - make_layout(make_shape(seq_len_kv_cache, - num_heads_kv * head_size_qk, batch), - args.dK_cache)); + make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), args.dQ)); + auto tensorK_cache = make_tensor( + make_gmem_ptr(args.ptr_K_cache), + make_layout(make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch), args.dK_cache)); auto tensorV_cache = make_tensor( make_gmem_ptr(args.ptr_V_cache), - make_layout( - make_shape(num_heads_kv * head_size_vo, seq_len_kv_cache, batch), - args.dV_cache)); + make_layout(make_shape(num_heads_kv * head_size_vo, seq_len_kv_cache, batch), args.dV_cache)); XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; - XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; - XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{copyQ, - copyK, - copyV, - args.ptr_q_scale, - args.ptr_k_scale, - args.ptr_v_scale, - copyK_cache, - copyV_cache, - args.ptr_page_table, - args.page_size, - args.num_pages_per_seq, - args.window_left, - args.window_right}; + return Params{ + copyQ, + copyK_cache, + copyV_cache, + args.ptr_q_scale, + args.ptr_k_scale, + args.ptr_v_scale, + args.ptr_page_table, + args.page_size, + args.max_num_pages_per_seq, + args.window_left, + args.window_right}; } // FP8 Q and FP8 K tensors are converted to BF16 tensors using descale factors // GEMM is computed in BF16 precision (FP8 not supported in BMG) template - CUTLASS_DEVICE void mmaQK(FragQccum &accum, TensorQ gQ, TensorK gK, - FragSrc const &frag_src, int const &k_tile_count, - Params const ¶ms, bool is_KV_cache, - float q_scale, float k_scale) { - - auto &gmem_tiled_copy_k = - is_KV_cache ? params.gmem_tiled_copy_k_cache : params.gmem_tiled_copy_k; + CUTLASS_DEVICE void mmaQK( + FragQccum& accum, + TensorQ gQ, + TensorK gK, + FragSrc const& frag_src, + int const& k_tile_count, + Params const& params, bool is_KV_cache, float q_scale, float k_scale) { + auto& gmem_tiled_copy_k = params.gmem_tiled_copy_k_cache; int thread_idx = static_cast(ThreadIdxX()); auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx); @@ -293,8 +290,7 @@ struct FlashChunkPrefillMma< // To make all threads in a warp have the same global tensors pass in the // index of thread 0 in each warp auto sg = compat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = - sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx); auto thread_mma_k = tiled_mma.get_slice(0); @@ -319,10 +315,10 @@ struct FlashChunkPrefillMma< // // Mainloop // + for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ); copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); - // FP8 path: Convert FP8 fragments to BF16 if constexpr (is_fp8_v || is_fp8_v) { auto tCrQ_fp16 = make_fragment_like(tCrQ); @@ -347,11 +343,10 @@ struct FlashChunkPrefillMma< // BF16 path cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); } - #if 0 -#define PRINT(x) \ - print(#x ": "); \ - print(x); \ +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ print("\n"); if (cute::thread(0, 0)) { print("======================= Q: \n"); @@ -378,38 +373,30 @@ struct FlashChunkPrefillMma< } template - CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - auto frag = - convert_op(*reinterpret_cast *>( - tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + CUTLASS_DEVICE auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } // FP8 V tensor is converted to BF16 tensor using descale factor // P tensor (softmax output) is in FP32 precision (converted to BF16) // GEMM is computed in BF16 precision (FP8 not supported in BMG) - template - CUTLASS_DEVICE void mmaPV(FragQccum &accum, FragS const &tSr, TensorV gV, - FragSrc const &frag_src, Params const ¶ms, - bool is_KV_cache, float v_scale) { - - auto &gmem_tiled_copy_v = - is_KV_cache ? params.gmem_tiled_copy_v_cache : params.gmem_tiled_copy_v; + template + CUTLASS_DEVICE void + mmaPV(FragQccum& accum, FragS const& tSr, TensorV gV, FragSrc const& frag_src, Params const& params, bool is_KV_cache, float v_scale) { + auto& gmem_tiled_copy_v = params.gmem_tiled_copy_v_cache; int thread_idx = static_cast(ThreadIdxX()); // Instantiate the MMA object TiledMmaPV tiled_mma; // Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid // Register spill - Tensor gV_ = take<0, 3>( - local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _))); + Tensor gV_ = take<0, 3>(local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _))); auto sg = compat::get_nd_item<1>().get_sub_group(); - auto first_thread_in_sg_idx = - sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); Tensor tCgV = thread_mma.partition_B(gV_); using TCrV_Type = cute::conditional_t, uint8_t, ElementV>; @@ -421,9 +408,9 @@ struct FlashChunkPrefillMma< Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV); #if CUTLASS_ENABLE_DEBUG_PRINTS -#define PRINT(x) \ - print(#x ": "); \ - print(x); \ +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ print("\n"); if (cute::thread(LOG_THREAD, LOG_GROUP)) { print("===================== V :\n"); @@ -448,7 +435,6 @@ struct FlashChunkPrefillMma< CUTLASS_PRAGMA_UNROLL for (int i = 0; i < tile_count; i++) { copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); - if constexpr (is_fp8_v) { auto tCrV_fp16 = make_fragment_like(tCrV); convert_and_descale(tCrV, tCrV_fp16, v_scale); @@ -465,135 +451,70 @@ struct FlashChunkPrefillMma< // int, int, int> For Variable Sequence Length, ProblemShape = Shape template - CUTLASS_DEVICE static constexpr Params - get_updated_copies(Params const ¶ms, ProblemShape const &problem_shape, - SequenceLengthShape const &sequence_length_shape, - int const &l_coord, int const &q_head_coord = 0) { - auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = - select<0, 1, 2, 6, 7>(problem_shape); - auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = sequence_length_shape; + CUTLASS_DEVICE static constexpr Params get_updated_copies( + Params const& params, + ProblemShape const& problem_shape, + SequenceLengthShape const& sequence_length_shape, + int const& l_coord, + int const& q_head_coord = 0) { + auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = select<0, 1, 2, 6, 7>(problem_shape); + auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape; auto q_group_size = num_heads_q / num_heads_kv; auto kv_head_coord = q_head_coord / q_group_size; - int offset_q = 0, offset_k = 0, offset_v = 0, offset_k_cache = 0, - offset_v_cache = 0; + int offset_q = 0, offset_k = 0, offset_v = 0, offset_k_cache = 0, offset_v_cache = 0; int total_seq_len_kv_cache = 0; if constexpr (is_var_len) { auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; - auto kv_cumulative_length = get<4>(problem_shape).cumulative_length; - auto kv_cached_cumulative_length = - get<5>(problem_shape).cumulative_length; - - offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord] + - q_head_coord * head_size_qk; - - offset_k = num_heads_kv * head_size_qk * kv_cumulative_length[l_coord] + - kv_head_coord * head_size_qk; - offset_v = num_heads_kv * head_size_vo * kv_cumulative_length[l_coord] + - kv_head_coord * head_size_vo; - offset_k_cache = seq_len_kv_cache == 0 - ? 0 - : PagedKV? // For page_kv, there is no batch dimension. - kv_head_coord * head_size_qk - : num_heads_kv * head_size_qk * kv_cached_cumulative_length[l_coord] + kv_head_coord * head_size_qk; - offset_v_cache = seq_len_kv_cache == 0 - ? 0 - : PagedKV? // For page_kv, there is no batch dimension. - kv_head_coord * head_size_vo - : num_heads_kv * head_size_vo * kv_cached_cumulative_length[l_coord] + kv_head_coord * head_size_vo; + auto kv_cached_cumulative_length = get<5>(problem_shape).cumulative_length; + + offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord] + q_head_coord * head_size_qk; + + offset_k_cache = kv_head_coord * head_size_qk; + offset_v_cache = kv_head_coord * head_size_vo; total_seq_len_kv_cache = get<5>(problem_shape).total_length; } else { - offset_q = num_heads_q * head_size_qk * seq_len_qo * l_coord + - q_head_coord * head_size_qk; - - offset_k = num_heads_kv * head_size_qk * seq_len_kv * l_coord + - kv_head_coord * head_size_qk; - offset_v = num_heads_kv * head_size_vo * seq_len_kv * l_coord + - kv_head_coord * head_size_vo; - offset_k_cache = - seq_len_kv_cache == 0 - ? 0 : - PagedKV? - kv_head_coord * head_size_qk - : num_heads_kv * head_size_qk * seq_len_kv_cache * l_coord + kv_head_coord * head_size_qk; - offset_v_cache = - seq_len_kv_cache == 0 - ? 0 : - PagedKV? - kv_head_coord * head_size_vo - : num_heads_kv * head_size_vo * seq_len_kv_cache * l_coord + kv_head_coord * head_size_vo; - total_seq_len_kv_cache = batch * seq_len_kv_cache; } - auto q_traits = - static_cast(params.gmem_tiled_copy_q); - const ElementQ *q_ptr = (const ElementQ *)q_traits.base_ptr; - auto k_traits = - static_cast(params.gmem_tiled_copy_k); - const ElementK *k_ptr = (const ElementK *)k_traits.base_ptr; - auto v_traits = - static_cast(params.gmem_tiled_copy_v); - const ElementV *v_ptr = (const ElementV *)v_traits.base_ptr; - auto k_traits_cache = - static_cast(params.gmem_tiled_copy_k_cache); - const ElementK *k_cache_ptr = (const ElementK *)k_traits_cache.base_ptr; - auto v_traits_cache = - static_cast(params.gmem_tiled_copy_v_cache); - const ElementV *v_cache_ptr = (const ElementV *)v_traits_cache.base_ptr; + auto q_traits = static_cast(params.gmem_tiled_copy_q); + const ElementQ* q_ptr = (const ElementQ*)q_traits.base_ptr; + auto k_traits_cache = static_cast(params.gmem_tiled_copy_k_cache); + const ElementK* k_cache_ptr = (const ElementK*)k_traits_cache.base_ptr; + auto v_traits_cache = static_cast(params.gmem_tiled_copy_v_cache); + const ElementV* v_cache_ptr = (const ElementV*)v_traits_cache.base_ptr; // NHD format{batch, seq_len, head, dim_head} // stride {seq_len*head*dim_head, head*dim_head, dim_head, 1} - auto shape_q = - make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); + auto shape_q = make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); - auto shape_k = make_shape(static_cast(seq_len_kv), - num_heads_kv * head_size_qk, 1); - StrideK stride_k = cutlass::make_cute_packed_stride(StrideK{}, shape_k); - - auto shape_v = make_shape(head_size_vo * num_heads_kv, - static_cast(seq_len_kv), 1); - StrideV stride_v = cutlass::make_cute_packed_stride(StrideV{}, shape_v); - - auto shape_k_cache = make_shape(static_cast(PagedKV? total_seq_len_kv_cache : seq_len_kv_cache), - head_size_qk * num_heads_kv, 1); - StrideK stride_k_cache = - cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); - auto shape_v_cache = make_shape(head_size_vo * num_heads_kv, - static_cast(PagedKV? total_seq_len_kv_cache : seq_len_kv_cache), 1); - StrideV stride_v_cache = - cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache); - auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), - make_layout(shape_q, stride_q)); - auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), - make_layout(shape_k, stride_k)); - auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), - make_layout(shape_v, stride_v)); + + auto shape_k_cache = make_shape( + static_cast(PagedKV ? total_seq_len_kv_cache : seq_len_kv_cache), head_size_qk * num_heads_kv, 1); + StrideK stride_k_cache = cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); + auto shape_v_cache = make_shape( + head_size_vo * num_heads_kv, static_cast(PagedKV ? total_seq_len_kv_cache : seq_len_kv_cache), 1); + StrideV stride_v_cache = cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache); + auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), make_layout(shape_q, stride_q)); auto tensorK_cache = - make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), - make_layout(shape_k_cache, stride_k_cache)); + make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), make_layout(shape_k_cache, stride_k_cache)); auto tensorV_cache = - make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), - make_layout(shape_v_cache, stride_v_cache)); + make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), make_layout(shape_v_cache, stride_v_cache)); XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; - XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; - XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - - return Params{copyQ, - copyK, - copyV, - params.ptr_q_scale, - params.ptr_k_scale, - params.ptr_v_scale, - copyK_cache, - copyV_cache, - params.ptr_page_table, - params.page_size, - params.num_pages_per_seq, - params.window_left, - params.window_right}; + return Params{ + copyQ, + copyK_cache, + copyV_cache, + params.ptr_q_scale, + params.ptr_k_scale, + params.ptr_v_scale, + params.ptr_page_table, + params.page_size, + params.max_num_pages_per_seq, + params.window_left, + params.window_right}; } }; -} // namespace cutlass::flash_attention::collective +} // namespace cutlass::flash_attention::collective ///////////////////////////////////////////////////////////////////////////////////////////////// From 4efefc1d7d6d0f48b8c3037f4da12d80c014e771 Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Tue, 28 Oct 2025 03:04:55 +0000 Subject: [PATCH 03/11] initial fix Signed-off-by: Aditya Chatterjee --- src/sycl/chunked_prefill.cpp | 328 +++++++++++------- .../chunk_prefill/xe_chunk_prefill.hpp | 2 +- .../xe_flash_attn_chunk_prefill_mma.hpp | 25 +- tests/test_flash_attention.py | 43 ++- 4 files changed, 255 insertions(+), 143 deletions(-) diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index ac5bee0..ccc1ae8 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -337,7 +337,7 @@ struct KernelRunner { params.max_num_pages_per_seq, params.window_size_left, params.window_size_right}, - {(ElementQ)params.scale_softmax}, + {(ElementAccumulator)params.scale_softmax}, {static_cast(params.o_ptr), stride_O, static_cast(params.sink_softmax)}, @@ -761,122 +761,216 @@ std::vector mha_fwd( auto outaccum_type = at::ScalarType::Float; constexpr int PipelineStages = 2; - switch (params.d) { - case 64: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _64, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _64, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - case 96: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_128, _64, _32>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _96, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_128, _64, _32>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _96, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - case 128: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _128, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _128, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - case 192: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_256, _64, _64>, - cute::Shape<_256, _32, _64>, - cute::Shape<_256, _192, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_256, _64, _64>, - cute::Shape<_256, _32, _64>, - cute::Shape<_256, _192, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - default: - TORCH_CHECK(false, "Unsupported head size for causal attention"); + if (params.is_fp8) { + switch (params.d) { + case 64: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink, + float_e4m3_t, + float_e4m3_t, + XE_8x16x16_F32BF16BF16F32_TT, + XE_2D_U8x8x32_LD_N, + XE_2D_U8x16x16_LD_T, + XE_2D_U8x16x32_LD_T, + float, + float>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink, + float_e4m3_t, + float_e4m3_t, + XE_8x16x16_F32BF16BF16F32_TT, + XE_2D_U8x8x32_LD_N, + XE_2D_U8x16x16_LD_T, + XE_2D_U8x16x32_LD_T, + float, + float>::run(params)) + } + }) + break; + case 128: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink, + float_e4m3_t, + float_e4m3_t, + XE_8x16x16_F32BF16BF16F32_TT, + XE_2D_U8x8x32_LD_N, + XE_2D_U8x16x16_LD_T, + XE_2D_U8x16x32_LD_T, + float, + float>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink, + float_e4m3_t, + float_e4m3_t, + XE_8x16x16_F32BF16BF16F32_TT, + XE_2D_U8x8x32_LD_N, + XE_2D_U8x16x16_LD_T, + XE_2D_U8x16x32_LD_T, + float, + float>::run(params)) + } + }) + break; + default: TORCH_CHECK(false, "Unsupported head size for FP8"); + } + } else { // BF16 + switch (params.d) { + case 64: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink>::run(params)) + } + }) + break; + case 96: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _32>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _96, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink>::run(params); + + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _32>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _96, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink>::run(params)) + } + }) + break; + case 128: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink>::run(params)) + } + }) + break; + case 192: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_256, _64, _64>, + cute::Shape<_256, _32, _64>, + cute::Shape<_256, _192, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_256, _64, _64>, + cute::Shape<_256, _32, _64>, + cute::Shape<_256, _192, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink>::run(params)) + } + }) + break; + default: + TORCH_CHECK(false, "Unsupported head size for causal attention"); + } } return {out, softmax_lse, out_accum, softmax_lse_accum}; } diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp index b50e2be..d3b8f32 100644 --- a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -338,7 +338,7 @@ class FMHAPrefillChunk { // Q, K, V tensors have seperate scaling factors const float q_scale_val = params.mainloop.ptr_q_scale == nullptr ? 1.f - : params.mainloop.ptr_q_scale[batch_coord * num_heads_q + q_head_coord]; + : params.mainloop.ptr_q_scale[batch_coord * num_heads_kv + q_head_coord]; const float k_scale_val = params.mainloop.ptr_k_scale == nullptr ? 1.f : params.mainloop.ptr_k_scale[batch_coord * num_heads_kv + kv_head_coord]; diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index 97e5e57..3620d55 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -298,9 +298,8 @@ struct FlashChunkPrefillMma< Tensor tCgK = thread_mma_k.partition_B(gK); // Create fragments - // TODO(Codeplay): fix this, this is probably not general - using TCrQ_Type = cute::conditional_t, uint8_t, ElementQ>; - using TCrK_Type = cute::conditional_t, uint8_t, ElementK>; + using TCrQ_Type = cute::conditional_t, uint8_t, bfloat16_t>; + using TCrK_Type = cute::conditional_t, uint8_t, bfloat16_t>; Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape()))); Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape()))); @@ -321,24 +320,24 @@ struct FlashChunkPrefillMma< copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); // FP8 path: Convert FP8 fragments to BF16 if constexpr (is_fp8_v || is_fp8_v) { - auto tCrQ_fp16 = make_fragment_like(tCrQ); - auto tCrK_fp16 = make_fragment_like(tCrK); + auto tCrQ_bf16 = make_fragment_like(tCrQ); + auto tCrK_bf16 = make_fragment_like(tCrK); if constexpr (is_fp8_v) { - convert_and_descale(tCrQ, tCrQ_fp16, q_scale); + convert_and_descale(tCrQ, tCrQ_bf16, q_scale); } else { // If Q is already FP16, copy it. - copy(tCrQ, tCrQ_fp16); + copy(tCrQ, tCrQ_bf16); } if constexpr (is_fp8_v) { - convert_and_descale(tCrK, tCrK_fp16, k_scale); + convert_and_descale(tCrK, tCrK_bf16, k_scale); } else { - copy(tCrK, tCrK_fp16); + copy(tCrK, tCrK_bf16); } // GEMM is computed on the BF16 tensors - cute::gemm(tiled_mma, accum, tCrQ_fp16, tCrK_fp16, frag_src); + cute::gemm(tiled_mma, accum, tCrQ_bf16, tCrK_bf16, frag_src); } else { // BF16 path cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); @@ -399,7 +398,7 @@ struct FlashChunkPrefillMma< auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); Tensor tCgV = thread_mma.partition_B(gV_); - using TCrV_Type = cute::conditional_t, uint8_t, ElementV>; + using TCrV_Type = cute::conditional_t, uint8_t, bfloat16_t>; Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0,3>(tCgV.shape()))); // Partition the copying of A and B tiles across the threads @@ -428,7 +427,7 @@ struct FlashChunkPrefillMma< #endif // 7) Convert S to P (FP32 -> BF16) - Tensor tPr = convert_type(tSr); + Tensor tPr = convert_type(tSr); // // Mainloop // @@ -516,5 +515,3 @@ struct FlashChunkPrefillMma< }; } // namespace cutlass::flash_attention::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index cfed227..22329b2 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -1,4 +1,3 @@ -# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py import itertools import math import os @@ -63,7 +62,7 @@ def is_fa3_supported(device=None) -> bool: DISABLE_SOFTCAP = True DISABLE_PACKGQA = True DISABLE_FP16 = True -DISABLE_FP8 = True +DISABLE_FP8 = False # Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/padding.py @@ -595,13 +594,13 @@ def test_flash_attn_kvcache( has_qv = d == 64 and dv >= 256 softmax_scale = 1.0 / math.sqrt(d if has_qv is None else d + dv) q = ( - torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + torch.ones(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) .to(dtype) .to(dtype_ref) ) if has_qv: qv = ( - torch.randn( + torch.ones( batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref ) .to(dtype) @@ -639,14 +638,14 @@ def test_flash_attn_kvcache( key_new_padding_mask = None if new_kv: k = ( - torch.randn( + torch.ones( batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) ) v = ( - torch.randn( + torch.ones( batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref ) .to(dtype) @@ -666,7 +665,7 @@ def test_flash_attn_kvcache( k, v, k_unpad, v_unpad = None, None, None, None if page_size is None: k_cache = ( - torch.randn( + torch.ones( batch_size_cache, seqlen_k, nheads_k, @@ -678,7 +677,7 @@ def test_flash_attn_kvcache( .to(dtype_ref) ) v_cache = ( - torch.randn( + torch.ones( batch_size_cache, seqlen_k, nheads_k, @@ -708,6 +707,7 @@ def test_flash_attn_kvcache( device, dtype, dtype_ref, + use_ones=True, ) cache_seqlens = torch.randint( seqlen_q, @@ -826,6 +826,13 @@ def test_flash_attn_kvcache( v_cache_rep = repeat( v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k ) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.ones(batch_size, nheads_k, device=device, dtype=torch.float32) + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None out_ref, _ = attention_ref( q_ro, k_cache_rep, @@ -836,6 +843,9 @@ def test_flash_attn_kvcache( key_padding_mask, causal=causal, qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, key_leftpad=cache_leftpad, ) @@ -849,6 +859,9 @@ def test_flash_attn_kvcache( key_padding_mask, causal=causal, qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, upcast=False, reorder_ops=True, @@ -901,6 +914,9 @@ def test_flash_attn_kvcache( cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, max_seqlen_q=max_seqlen_q, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, rotary_seqlens=rotary_seqlens, causal=causal, window_size=window_size, @@ -924,9 +940,13 @@ def test_flash_attn_kvcache( # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) torch.xpu.synchronize() + out = out.to(dtype_ref) out = out.flatten() out_ref = out_ref.flatten() out_pt = out_pt.flatten() + print(f"out = {out}") + print("-----------------------------") + print(f"out_pt = {out_pt}") print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") @@ -1004,16 +1024,17 @@ def test_flash_attn_kvcache( def _generate_block_kvcache( - seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref, use_ones=False ): num_blocks = math.ceil(seqlen_k / page_size) * batch_size + create_fn = torch.ones if use_ones else torch.randn k_cache_paged = ( - torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + create_fn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) .to(dtype) .to(dtype_ref) ) v_cache_paged = ( - torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + create_fn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) .to(dtype) .to(dtype_ref) ) From 6a54149521fc2b965008817452cc0ef2b29aa90d Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Tue, 28 Oct 2025 05:07:41 +0000 Subject: [PATCH 04/11] fixed fp8 accuracy Signed-off-by: Aditya Chatterjee --- src/sycl/chunked_prefill.cpp | 65 +++++++++++++------ .../chunk_prefill/xe_chunk_prefill.hpp | 4 +- .../xe_flash_attn_chunk_prefill_mma.hpp | 15 +++-- tests/test_flash_attention.py | 32 +++------ 4 files changed, 66 insertions(+), 50 deletions(-) diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index ccc1ae8..4e9e448 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -1,3 +1,33 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ #include #include #include @@ -325,7 +355,7 @@ struct KernelRunner { // stride_K, // static_cast(params.vnew_ptr), // stride_V, - static_cast(params.k_ptr), + static_cast(params.k_ptr), stride_K_cache, static_cast(params.v_ptr), stride_V_cache, @@ -337,7 +367,7 @@ struct KernelRunner { params.max_num_pages_per_seq, params.window_size_left, params.window_size_right}, - {(ElementAccumulator)params.scale_softmax}, + {(ElementQ)params.scale_softmax}, {static_cast(params.o_ptr), stride_O, static_cast(params.sink_softmax)}, @@ -597,7 +627,13 @@ std::vector mha_fwd( auto opts = q.options(); at::Tensor out; + // out = torch::empty({total_q, num_heads, head_size_v}, opts); + if (q.dtype() == at::ScalarType::Float8_e4m3fn || q.dtype() == at::ScalarType::Float8_e5m2) { + // Internal math & epilogue producing BF16 + out = torch::empty({total_q, num_heads, head_size_v}, opts.dtype(torch::kBFloat16)); +} else { out = torch::empty({total_q, num_heads, head_size_v}, opts); +} auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); @@ -639,15 +675,6 @@ std::vector mha_fwd( params.v_scale_ptr = static_cast(v_descale_.value().data_ptr()); } - /*if (!is_varlen_q) { - params.q_batch_stride = q.stride(0); - params.o_batch_stride = out.stride(0); - } - if (!is_varlen_k) { - params.k_batch_stride = k.stride(0); - params.v_batch_stride = v.stride(0); - }*/ - params.cu_seqlens_q = cu_seqlens_q.data_ptr(); params.cu_seqlens_k = cu_seqlens_k.data_ptr(); @@ -780,9 +807,9 @@ std::vector mha_fwd( XE_8x16x16_F32BF16BF16F32_TT, XE_2D_U8x8x32_LD_N, XE_2D_U8x16x16_LD_T, - XE_2D_U8x16x32_LD_T, + XE_2D_U8x32x32_LD_V, float, - float>::run(params); + float, bfloat16_t, bfloat16_t, XE_2D_U16x8x16_ST_N>::run(params); } else { AT_DISPATCH_BOOL_NO_RETURN( params.is_local, @@ -801,9 +828,9 @@ std::vector mha_fwd( XE_8x16x16_F32BF16BF16F32_TT, XE_2D_U8x8x32_LD_N, XE_2D_U8x16x16_LD_T, - XE_2D_U8x16x32_LD_T, + XE_2D_U8x32x32_LD_V, float, - float>::run(params)) + float, bfloat16_t, bfloat16_t, XE_2D_U16x8x16_ST_N>::run(params)) } }) break; @@ -824,9 +851,9 @@ std::vector mha_fwd( XE_8x16x16_F32BF16BF16F32_TT, XE_2D_U8x8x32_LD_N, XE_2D_U8x16x16_LD_T, - XE_2D_U8x16x32_LD_T, + XE_2D_U8x32x32_LD_V, float, - float>::run(params); + float, bfloat16_t, bfloat16_t, XE_2D_U16x8x16_ST_N>::run(params); } else { AT_DISPATCH_BOOL_NO_RETURN( params.is_local, @@ -845,9 +872,9 @@ std::vector mha_fwd( XE_8x16x16_F32BF16BF16F32_TT, XE_2D_U8x8x32_LD_N, XE_2D_U8x16x16_LD_T, - XE_2D_U8x16x32_LD_T, + XE_2D_U8x32x32_LD_V, float, - float>::run(params)) + float, bfloat16_t, bfloat16_t, XE_2D_U16x8x16_ST_N>::run(params)) } }) break; diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp index d3b8f32..381d062 100644 --- a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -459,7 +459,7 @@ class FMHAPrefillChunk { // Then modify layout to LayoutQ = ((seq_leq_q, group_head_q), // head_size_qk, batch* num_heads_q / group_head_q), which can be merged // into one gemm for (int i = 0; i < q_group_size; ++i) { - collective_mma.mmaQK(tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, true, q_scale_val, k_scale_val); + collective_mma.mmaQK(tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, q_scale_val, k_scale_val); if constexpr (LocalMask) { // Sliding windows @@ -558,7 +558,7 @@ class FMHAPrefillChunk { softmax(split == 0, tSr, max_reg, sum_reg, out_reg); // 5) Perform GEMM O = S*V - collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, mainloop_params, true, v_scale_val); + collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, mainloop_params, v_scale_val); // ... prefetch next tile ... // Prefetch the next Q tile CUTLASS_PRAGMA_UNROLL diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index 3620d55..ba1cf5f 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -279,7 +279,7 @@ struct FlashChunkPrefillMma< TensorK gK, FragSrc const& frag_src, int const& k_tile_count, - Params const& params, bool is_KV_cache, float q_scale, float k_scale) { + Params const& params, float q_scale, float k_scale) { auto& gmem_tiled_copy_k = params.gmem_tiled_copy_k_cache; int thread_idx = static_cast(ThreadIdxX()); @@ -298,8 +298,8 @@ struct FlashChunkPrefillMma< Tensor tCgK = thread_mma_k.partition_B(gK); // Create fragments - using TCrQ_Type = cute::conditional_t, uint8_t, bfloat16_t>; - using TCrK_Type = cute::conditional_t, uint8_t, bfloat16_t>; + using TCrQ_Type = cute::conditional_t, uint8_t, ElementQ>; + using TCrK_Type = cute::conditional_t, uint8_t, ElementK>; Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape()))); Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape()))); @@ -385,9 +385,8 @@ struct FlashChunkPrefillMma< // GEMM is computed in BF16 precision (FP8 not supported in BMG) template CUTLASS_DEVICE void - mmaPV(FragQccum& accum, FragS const& tSr, TensorV gV, FragSrc const& frag_src, Params const& params, bool is_KV_cache, float v_scale) { + mmaPV(FragQccum& accum, FragS const& tSr, TensorV gV, FragSrc const& frag_src, Params const& params, float v_scale) { auto& gmem_tiled_copy_v = params.gmem_tiled_copy_v_cache; - int thread_idx = static_cast(ThreadIdxX()); // Instantiate the MMA object TiledMmaPV tiled_mma; @@ -398,7 +397,7 @@ struct FlashChunkPrefillMma< auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); Tensor tCgV = thread_mma.partition_B(gV_); - using TCrV_Type = cute::conditional_t, uint8_t, bfloat16_t>; + using TCrV_Type = cute::conditional_t, uint8_t, ElementV>; Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0,3>(tCgV.shape()))); // Partition the copying of A and B tiles across the threads @@ -427,7 +426,7 @@ struct FlashChunkPrefillMma< #endif // 7) Convert S to P (FP32 -> BF16) - Tensor tPr = convert_type(tSr); + Tensor tPr = convert_type(tSr); // // Mainloop // @@ -515,3 +514,5 @@ struct FlashChunkPrefillMma< }; } // namespace cutlass::flash_attention::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 22329b2..11c57a8 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -1,3 +1,4 @@ +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py import itertools import math import os @@ -594,13 +595,13 @@ def test_flash_attn_kvcache( has_qv = d == 64 and dv >= 256 softmax_scale = 1.0 / math.sqrt(d if has_qv is None else d + dv) q = ( - torch.ones(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) .to(dtype) .to(dtype_ref) ) if has_qv: qv = ( - torch.ones( + torch.randn( batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref ) .to(dtype) @@ -638,14 +639,14 @@ def test_flash_attn_kvcache( key_new_padding_mask = None if new_kv: k = ( - torch.ones( + torch.randn( batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) ) v = ( - torch.ones( + torch.randn( batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref ) .to(dtype) @@ -665,7 +666,7 @@ def test_flash_attn_kvcache( k, v, k_unpad, v_unpad = None, None, None, None if page_size is None: k_cache = ( - torch.ones( + torch.randn( batch_size_cache, seqlen_k, nheads_k, @@ -677,7 +678,7 @@ def test_flash_attn_kvcache( .to(dtype_ref) ) v_cache = ( - torch.ones( + torch.randn( batch_size_cache, seqlen_k, nheads_k, @@ -707,7 +708,6 @@ def test_flash_attn_kvcache( device, dtype, dtype_ref, - use_ones=True, ) cache_seqlens = torch.randint( seqlen_q, @@ -828,7 +828,7 @@ def test_flash_attn_kvcache( ) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [ - torch.ones(batch_size, nheads_k, device=device, dtype=torch.float32) + torch.randn(batch_size, nheads_k, device=device, dtype=torch.float32) for _ in range(3) ] else: @@ -944,13 +944,6 @@ def test_flash_attn_kvcache( out = out.flatten() out_ref = out_ref.flatten() out_pt = out_pt.flatten() - print(f"out = {out}") - print("-----------------------------") - print(f"out_pt = {out_pt}") - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error @@ -1024,10 +1017,10 @@ def test_flash_attn_kvcache( def _generate_block_kvcache( - seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref, use_ones=False + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref ): num_blocks = math.ceil(seqlen_k / page_size) * batch_size - create_fn = torch.ones if use_ones else torch.randn + create_fn = torch.randn k_cache_paged = ( create_fn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) .to(dtype) @@ -1274,9 +1267,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - if query_unused_mask is not None: q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") @@ -1308,8 +1298,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): out = output_pad_fn(out_unpad) if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. From 49c2212137eddfcf12e752392b3d70c4a71a6480 Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Wed, 29 Oct 2025 08:48:02 +0000 Subject: [PATCH 05/11] update test code Signed-off-by: Aditya Chatterjee --- tests/test_flash_attention.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 11c57a8..06ee1d1 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -944,6 +944,10 @@ def test_flash_attn_kvcache( out = out.flatten() out_ref = out_ref.flatten() out_pt = out_pt.flatten() + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error @@ -1267,6 +1271,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if query_unused_mask is not None: q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") @@ -1299,6 +1306,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * ( From 4cfd0403d9888831d4e23b2bbefed9a01aae63c2 Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Wed, 5 Nov 2025 05:13:20 +0000 Subject: [PATCH 06/11] trigger CI From 7a123cda2bc18682b047fe63b73eb5b69ba5d4d6 Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Wed, 5 Nov 2025 07:18:22 +0000 Subject: [PATCH 07/11] Fix format Signed-off-by: Aditya Chatterjee --- src/sycl/chunked_prefill.cpp | 37 ++-- src/sycl/kernels/chunk_prefill/fp8_descale.h | 160 ++++++++---------- .../chunk_prefill/xe_chunk_prefill.hpp | 3 +- .../xe_flash_attn_chunk_prefill_mma.hpp | 38 +++-- 4 files changed, 121 insertions(+), 117 deletions(-) diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 4e9e448..e1aeb73 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -359,7 +359,7 @@ struct KernelRunner { stride_K_cache, static_cast(params.v_ptr), stride_V_cache, - params.q_scale_ptr, + params.q_scale_ptr, params.k_scale_ptr, params.v_scale_ptr, params.page_table, @@ -629,11 +629,11 @@ std::vector mha_fwd( at::Tensor out; // out = torch::empty({total_q, num_heads, head_size_v}, opts); if (q.dtype() == at::ScalarType::Float8_e4m3fn || q.dtype() == at::ScalarType::Float8_e5m2) { - // Internal math & epilogue producing BF16 - out = torch::empty({total_q, num_heads, head_size_v}, opts.dtype(torch::kBFloat16)); -} else { - out = torch::empty({total_q, num_heads, head_size_v}, opts); -} + // Internal math & epilogue producing BF16 + out = torch::empty({total_q, num_heads, head_size_v}, opts.dtype(torch::kBFloat16)); + } else { + out = torch::empty({total_q, num_heads, head_size_v}, opts); + } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); @@ -809,7 +809,10 @@ std::vector mha_fwd( XE_2D_U8x16x16_LD_T, XE_2D_U8x32x32_LD_V, float, - float, bfloat16_t, bfloat16_t, XE_2D_U16x8x16_ST_N>::run(params); + float, + bfloat16_t, + bfloat16_t, + XE_2D_U16x8x16_ST_N>::run(params); } else { AT_DISPATCH_BOOL_NO_RETURN( params.is_local, @@ -830,7 +833,10 @@ std::vector mha_fwd( XE_2D_U8x16x16_LD_T, XE_2D_U8x32x32_LD_V, float, - float, bfloat16_t, bfloat16_t, XE_2D_U16x8x16_ST_N>::run(params)) + float, + bfloat16_t, + bfloat16_t, + XE_2D_U16x8x16_ST_N>::run(params)) } }) break; @@ -853,7 +859,10 @@ std::vector mha_fwd( XE_2D_U8x16x16_LD_T, XE_2D_U8x32x32_LD_V, float, - float, bfloat16_t, bfloat16_t, XE_2D_U16x8x16_ST_N>::run(params); + float, + bfloat16_t, + bfloat16_t, + XE_2D_U16x8x16_ST_N>::run(params); } else { AT_DISPATCH_BOOL_NO_RETURN( params.is_local, @@ -874,13 +883,17 @@ std::vector mha_fwd( XE_2D_U8x16x16_LD_T, XE_2D_U8x32x32_LD_V, float, - float, bfloat16_t, bfloat16_t, XE_2D_U16x8x16_ST_N>::run(params)) + float, + bfloat16_t, + bfloat16_t, + XE_2D_U16x8x16_ST_N>::run(params)) } }) break; - default: TORCH_CHECK(false, "Unsupported head size for FP8"); + default: + TORCH_CHECK(false, "Unsupported head size for FP8"); } - } else { // BF16 + } else { // BF16 switch (params.d) { case 64: AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { diff --git a/src/sycl/kernels/chunk_prefill/fp8_descale.h b/src/sycl/kernels/chunk_prefill/fp8_descale.h index b53ff2f..645df98 100644 --- a/src/sycl/kernels/chunk_prefill/fp8_descale.h +++ b/src/sycl/kernels/chunk_prefill/fp8_descale.h @@ -31,6 +31,9 @@ #pragma once +#include +#include + #include #include #include @@ -38,103 +41,88 @@ #include #include #include -#include -#include // Helper device function for E4M3 -> BFLOAT16 bitwise conversion -CUTLASS_DEVICE uint16_t -fp8_e4m3_to_fp16_bitwise(uint8_t const& src) { - // E4M3 (1-4-3) constants - constexpr uint32_t e4m3_exp_bias = 7; - // BFLOAT16 (1-8-7) constants - constexpr uint32_t bf16_exp_bias = 127; - - // Unpack FP8 bits - uint16_t sign = static_cast(src & 0x80); - uint16_t exponent = static_cast(src & 0x78) >> 3; - uint16_t mantissa = static_cast(src & 0x07); - - // Reconstruct BFLOAT16 bits - uint16_t bf16_sign = sign << 8; - // Re-bias exponent and shift to BFLOAT16 position - uint16_t bf16_exponent = (exponent - e4m3_exp_bias + bf16_exp_bias) << 7; - // Shift mantissa to BFLOAT16 position - uint16_t bf16_mantissa = mantissa << 4; - - return bf16_sign | bf16_exponent | bf16_mantissa; +CUTLASS_DEVICE uint16_t fp8_e4m3_to_fp16_bitwise(uint8_t const& src) { + // E4M3 (1-4-3) constants + constexpr uint32_t e4m3_exp_bias = 7; + // BFLOAT16 (1-8-7) constants + constexpr uint32_t bf16_exp_bias = 127; + + // Unpack FP8 bits + uint16_t sign = static_cast(src & 0x80); + uint16_t exponent = static_cast(src & 0x78) >> 3; + uint16_t mantissa = static_cast(src & 0x07); + + // Reconstruct BFLOAT16 bits + uint16_t bf16_sign = sign << 8; + // Re-bias exponent and shift to BFLOAT16 position + uint16_t bf16_exponent = (exponent - e4m3_exp_bias + bf16_exp_bias) << 7; + // Shift mantissa to BFLOAT16 position + uint16_t bf16_mantissa = mantissa << 4; + + return bf16_sign | bf16_exponent | bf16_mantissa; } // Helper device function for E5M2 -> BFLOAT16 bitwise conversion -CUTLASS_DEVICE uint16_t -fp8_e5m2_to_fp16_bitwise(uint8_t const& src) { - // E5M2 (1-5-2) constants - constexpr uint32_t e5m2_exp_bias = 15; - // BFLOAT16 (1-8-7) constants - constexpr uint32_t bf16_exp_bias = 127; - - // Unpack FP8 bits - uint16_t sign = static_cast(src & 0x80); - uint16_t exponent = static_cast(src & 0x7C) >> 2; - uint16_t mantissa = static_cast(src & 0x03); - - // Reconstruct BFLOAT16 bits - uint16_t bf16_sign = sign << 8; - // Re-bias exponent and shift to BFLOAT16 position - uint16_t bf16_exponent = (exponent - e5m2_exp_bias + bf16_exp_bias) << 7; - // Shift mantissa to BFLOAT16 position - uint16_t bf16_mantissa = mantissa << 5; - - return bf16_sign | bf16_exponent | bf16_mantissa; +CUTLASS_DEVICE uint16_t fp8_e5m2_to_fp16_bitwise(uint8_t const& src) { + // E5M2 (1-5-2) constants + constexpr uint32_t e5m2_exp_bias = 15; + // BFLOAT16 (1-8-7) constants + constexpr uint32_t bf16_exp_bias = 127; + + // Unpack FP8 bits + uint16_t sign = static_cast(src & 0x80); + uint16_t exponent = static_cast(src & 0x7C) >> 2; + uint16_t mantissa = static_cast(src & 0x03); + + // Reconstruct BFLOAT16 bits + uint16_t bf16_sign = sign << 8; + // Re-bias exponent and shift to BFLOAT16 position + uint16_t bf16_exponent = (exponent - e5m2_exp_bias + bf16_exp_bias) << 7; + // Shift mantissa to BFLOAT16 position + uint16_t bf16_mantissa = mantissa << 5; + + return bf16_sign | bf16_exponent | bf16_mantissa; } +template +CUTLASS_DEVICE void convert_and_descale(SrcTensor const& src, DstTensor& dst, float scale) { + using SrcVec_u8 = sycl::vec; + using DstVec_u16 = sycl::vec; -template < - typename Encoding, - int VectorizeSize = 8, - typename SrcTensor, - typename DstTensor -> -CUTLASS_DEVICE void -convert_and_descale( - SrcTensor const& src, - DstTensor& dst, - float scale) { - - using SrcVec_u8 = sycl::vec; - using DstVec_u16 = sycl::vec; + auto src_ptr = reinterpret_cast(src.data()); + auto dst_ptr = reinterpret_cast(dst.data()); - auto src_ptr = reinterpret_cast(src.data()); - auto dst_ptr = reinterpret_cast(dst.data()); + // Create a SCALAR bfloat16_t for scaling + const cutlass::bfloat16_t scale_bf16 = static_cast(scale); - // Create a SCALAR bfloat16_t for scaling - const cutlass::bfloat16_t scale_bf16 = static_cast(scale); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cute::size(src) / VectorizeSize; ++i) { + SrcVec_u8 const src_vec_u8 = src_ptr[i]; + DstVec_u16 result_vec_u16; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < cute::size(src) / VectorizeSize; ++i) { - SrcVec_u8 const src_vec_u8 = src_ptr[i]; - DstVec_u16 result_vec_u16; - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < VectorizeSize; ++j) { - // 1. Convert FP8 bits to BFLOAT16 bits - uint16_t val_bf16_bits; - if constexpr (std::is_same_v) { - val_bf16_bits = fp8_e4m3_to_fp16_bitwise(src_vec_u8[j]); - } else { - val_bf16_bits = fp8_e5m2_to_fp16_bitwise(src_vec_u8[j]); - } - - // 2. Reinterpret bits as bfloat16_t to perform math - cutlass::bfloat16_t val_bf16 = reinterpret_cast(val_bf16_bits); - - // 3. Apply scaling - val_bf16 *= scale_bf16; - - // 4. Reinterpret back to bits for storage - result_vec_u16[j] = reinterpret_cast(val_bf16); - } - - // 5. Store the final vector of bits - dst_ptr[i] = result_vec_u16; + for (int j = 0; j < VectorizeSize; ++j) { + // 1. Convert FP8 bits to BFLOAT16 bits + uint16_t val_bf16_bits; + if constexpr (std::is_same_v) { + val_bf16_bits = fp8_e4m3_to_fp16_bitwise(src_vec_u8[j]); + } else { + val_bf16_bits = fp8_e5m2_to_fp16_bitwise(src_vec_u8[j]); + } + + // 2. Reinterpret bits as bfloat16_t to perform math + cutlass::bfloat16_t val_bf16 = reinterpret_cast(val_bf16_bits); + + // 3. Apply scaling + val_bf16 *= scale_bf16; + + // 4. Reinterpret back to bits for storage + result_vec_u16[j] = reinterpret_cast(val_bf16); } + + // 5. Store the final vector of bits + dst_ptr[i] = result_vec_u16; + } } diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp index 381d062..5f1167e 100644 --- a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -459,7 +459,8 @@ class FMHAPrefillChunk { // Then modify layout to LayoutQ = ((seq_leq_q, group_head_q), // head_size_qk, batch* num_heads_q / group_head_q), which can be merged // into one gemm for (int i = 0; i < q_group_size; ++i) { - collective_mma.mmaQK(tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, q_scale_val, k_scale_val); + collective_mma.mmaQK( + tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, q_scale_val, k_scale_val); if constexpr (LocalMask) { // Sliding windows diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index ba1cf5f..73e155e 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -195,7 +195,7 @@ struct FlashChunkPrefillMma< using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{}, Layout{}, val_layout_load_V{})); template - static constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; + static constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; // Host side kernel arguments struct Arguments { @@ -205,9 +205,9 @@ struct FlashChunkPrefillMma< StrideK dK_cache; ElementV const* ptr_V_cache; StrideV dV_cache; - float const *ptr_q_scale; - float const *ptr_k_scale; - float const *ptr_v_scale; + float const* ptr_q_scale; + float const* ptr_k_scale; + float const* ptr_v_scale; // Paged KV Cache int const* ptr_page_table; int page_size; @@ -220,9 +220,9 @@ struct FlashChunkPrefillMma< XE_Copy_Q gmem_tiled_copy_q; XE_Copy_K gmem_tiled_copy_k_cache; XE_Copy_V gmem_tiled_copy_v_cache; - float const *ptr_q_scale; - float const *ptr_k_scale; - float const *ptr_v_scale; + float const* ptr_q_scale; + float const* ptr_k_scale; + float const* ptr_v_scale; int const* ptr_page_table; int page_size; int max_num_pages_per_seq; @@ -260,7 +260,7 @@ struct FlashChunkPrefillMma< copyQ, copyK_cache, copyV_cache, - args.ptr_q_scale, + args.ptr_q_scale, args.ptr_k_scale, args.ptr_v_scale, args.ptr_page_table, @@ -279,7 +279,9 @@ struct FlashChunkPrefillMma< TensorK gK, FragSrc const& frag_src, int const& k_tile_count, - Params const& params, float q_scale, float k_scale) { + Params const& params, + float q_scale, + float k_scale) { auto& gmem_tiled_copy_k = params.gmem_tiled_copy_k_cache; int thread_idx = static_cast(ThreadIdxX()); @@ -300,8 +302,8 @@ struct FlashChunkPrefillMma< // Create fragments using TCrQ_Type = cute::conditional_t, uint8_t, ElementQ>; using TCrK_Type = cute::conditional_t, uint8_t, ElementK>; - Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape()))); - Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape()))); + Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0, 3>(tCgQ.shape()))); + Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0, 3>(tCgK.shape()))); // Retile registers for copies Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); @@ -320,8 +322,8 @@ struct FlashChunkPrefillMma< copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); // FP8 path: Convert FP8 fragments to BF16 if constexpr (is_fp8_v || is_fp8_v) { - auto tCrQ_bf16 = make_fragment_like(tCrQ); - auto tCrK_bf16 = make_fragment_like(tCrK); + auto tCrQ_bf16 = make_fragment_like(tCrQ); + auto tCrK_bf16 = make_fragment_like(tCrK); if constexpr (is_fp8_v) { convert_and_descale(tCrQ, tCrQ_bf16, q_scale); @@ -398,7 +400,7 @@ struct FlashChunkPrefillMma< auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); Tensor tCgV = thread_mma.partition_B(gV_); using TCrV_Type = cute::conditional_t, uint8_t, ElementV>; - Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0,3>(tCgV.shape()))); + Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0, 3>(tCgV.shape()))); // Partition the copying of A and B tiles across the threads auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx); @@ -434,12 +436,12 @@ struct FlashChunkPrefillMma< for (int i = 0; i < tile_count; i++) { copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); if constexpr (is_fp8_v) { - auto tCrV_fp16 = make_fragment_like(tCrV); + auto tCrV_fp16 = make_fragment_like(tCrV); convert_and_descale(tCrV, tCrV_fp16, v_scale); - cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV_fp16, frag_src(_,_,_,i)); + cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV_fp16, frag_src(_, _, _, i)); } else { - cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV, frag_src(_,_,_,i)); + cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i)); } } } @@ -502,7 +504,7 @@ struct FlashChunkPrefillMma< copyQ, copyK_cache, copyV_cache, - params.ptr_q_scale, + params.ptr_q_scale, params.ptr_k_scale, params.ptr_v_scale, params.ptr_page_table, From 7a6c59b5b02ec23d6c7b9d3cb94fbcbe6544c8b2 Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Thu, 6 Nov 2025 02:30:44 +0000 Subject: [PATCH 08/11] fixed typo Signed-off-by: Aditya Chatterjee --- src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp index 5f1167e..79b14e4 100644 --- a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -334,8 +334,8 @@ class FMHAPrefillChunk { auto kv_head_coord = q_head_coord / q_group_size; // Descale tensors are shaped (batch size * # heads) - // Each head has a seperate scale factor - // Q, K, V tensors have seperate scaling factors + // Each head has a separate scale factor + // Q, K, V tensors have separate scaling factors const float q_scale_val = params.mainloop.ptr_q_scale == nullptr ? 1.f : params.mainloop.ptr_q_scale[batch_coord * num_heads_kv + q_head_coord]; From 73449a0e7c1391ffc4be0d2c7af7d0d67c47bed9 Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Fri, 7 Nov 2025 03:38:39 +0000 Subject: [PATCH 09/11] Minor changes Signed-off-by: Aditya Chatterjee --- .../chunk_prefill => comm}/fp8_descale.h | 8 ++--- .../xe_flash_attn_chunk_prefill_mma.hpp | 2 +- tests/test_flash_attention.py | 33 +++++++++++-------- 3 files changed, 24 insertions(+), 19 deletions(-) rename src/sycl/{kernels/chunk_prefill => comm}/fp8_descale.h (95%) diff --git a/src/sycl/kernels/chunk_prefill/fp8_descale.h b/src/sycl/comm/fp8_descale.h similarity index 95% rename from src/sycl/kernels/chunk_prefill/fp8_descale.h rename to src/sycl/comm/fp8_descale.h index 645df98..37ca11d 100644 --- a/src/sycl/kernels/chunk_prefill/fp8_descale.h +++ b/src/sycl/comm/fp8_descale.h @@ -43,7 +43,7 @@ #include // Helper device function for E4M3 -> BFLOAT16 bitwise conversion -CUTLASS_DEVICE uint16_t fp8_e4m3_to_fp16_bitwise(uint8_t const& src) { +CUTLASS_DEVICE uint16_t fp8_e4m3_to_bf16_bitwise(uint8_t const& src) { // E4M3 (1-4-3) constants constexpr uint32_t e4m3_exp_bias = 7; // BFLOAT16 (1-8-7) constants @@ -65,7 +65,7 @@ CUTLASS_DEVICE uint16_t fp8_e4m3_to_fp16_bitwise(uint8_t const& src) { } // Helper device function for E5M2 -> BFLOAT16 bitwise conversion -CUTLASS_DEVICE uint16_t fp8_e5m2_to_fp16_bitwise(uint8_t const& src) { +CUTLASS_DEVICE uint16_t fp8_e5m2_to_bf16_bitwise(uint8_t const& src) { // E5M2 (1-5-2) constants constexpr uint32_t e5m2_exp_bias = 15; // BFLOAT16 (1-8-7) constants @@ -107,9 +107,9 @@ CUTLASS_DEVICE void convert_and_descale(SrcTensor const& src, DstTensor& dst, fl // 1. Convert FP8 bits to BFLOAT16 bits uint16_t val_bf16_bits; if constexpr (std::is_same_v) { - val_bf16_bits = fp8_e4m3_to_fp16_bitwise(src_vec_u8[j]); + val_bf16_bits = fp8_e4m3_to_bf16_bitwise(src_vec_u8[j]); } else { - val_bf16_bits = fp8_e5m2_to_fp16_bitwise(src_vec_u8[j]); + val_bf16_bits = fp8_e5m2_to_bf16_bitwise(src_vec_u8[j]); } // 2. Reinterpret bits as bfloat16_t to perform math diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index 73e155e..a89bcbf 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -30,13 +30,13 @@ **************************************************************************************************/ #pragma once +#include "../../comm/fp8_descale.h" #include "cute/algorithm/functional.hpp" #include "cute/algorithm/gemm.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/cutlass.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "fmha_fusion.hpp" -#include "fp8_descale.h" //////////////////////////////////////////////////////////// namespace {} diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 06ee1d1..dfb6bb3 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -475,7 +475,7 @@ def generate_qkv( ) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( - "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn, torch.float8_e5m2] if not DISABLE_FP8 else []) ) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @@ -584,11 +584,11 @@ def test_flash_attn_kvcache( rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 - dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dtype_ref = torch.bfloat16 if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if use_softmax_sink: softmax_sink = torch.randn(nheads, device=device, dtype=dtype_ref) - if dtype == torch.float8_e4m3fn or not is_hopper(): + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2 or not is_hopper(): # for fp8 and ampere arch, we not support v head dim != qk head dim dv_vals = [d] for dv in dv_vals: @@ -826,7 +826,7 @@ def test_flash_attn_kvcache( v_cache_rep = repeat( v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k ) - if dtype == torch.float8_e4m3fn: + if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2): q_descale, k_descale, v_descale = [ torch.randn(batch_size, nheads_k, device=device, dtype=torch.float32) for _ in range(3) @@ -866,7 +866,7 @@ def test_flash_attn_kvcache( upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + intermediate_dtype=dtype if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) else None, ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None @@ -989,7 +989,7 @@ def test_flash_attn_kvcache( )[:, :seqlen_k].to(dtype_ref) k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: + if dtype is not torch.float8_e4m3fn and dtype is not torch.float8_e5m2: assert torch.equal(v_cache_select, v_cache_ref) else: assert torch.allclose( @@ -1002,7 +1002,7 @@ def test_flash_attn_kvcache( else: # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): # breakpoint() - if dtype is not torch.float8_e4m3fn: + if dtype is not torch.float8_e4m3fn and dtype is not torch.float8_e5m2: assert torch.allclose( k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 ) @@ -1010,11 +1010,16 @@ def test_flash_attn_kvcache( assert torch.allclose( k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 ) + # E5M2 has large dynamic range and low precision so error range is large (standard) mult = 4 if dtype == torch.float8_e4m3fn else 2 + if dtype == torch.float8_e5m2: + mult = 90 assert (out - out_ref).abs().max().item() <= mult * ( out_pt - out_ref ).abs().max().item() + 1e-5 mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + if dtype == torch.float8_e5m2: + mult_mean = 40 assert (out - out_ref).abs().mean().item() <= mult_mean * ( out_pt - out_ref ).abs().mean().item() @@ -1059,7 +1064,7 @@ def _generate_block_kvcache( reason="flash_attn at sgl-kernel-xpu only supports paged cache", ) @pytest.mark.parametrize( - "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn, torch.float8_e5m2] if not DISABLE_FP8 else []) ) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @@ -1134,9 +1139,9 @@ def test_flash_attn_varlen_output( # batch_size = 2 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dtype_ref = torch.bfloat16 if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: dv_vals = [d] for dv in dv_vals: q_ref = torch.randn( @@ -1174,7 +1179,7 @@ def test_flash_attn_varlen_output( qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - if dtype == torch.float8_e4m3fn: + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: q_descale, k_descale, v_descale = [ torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 @@ -1268,7 +1273,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): softcap=softcap, upcast=False, reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + intermediate_dtype=dtype if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) else None, ) print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") @@ -1315,7 +1320,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): out_pt - out_ref ).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + if not DISABLE_BACKWARD and (dtype != torch.float8_e4m3fn and dtype != torch.float8_e5m2) and not has_qv: g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( @@ -1351,7 +1356,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + if not DISABLE_BACKWARD and (dtype != torch.float8_e4m3fn and dtype != torch.float8_e5m2) and not has_qv: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) From be8ae5ed1a86e9cbd0f5b3eb043bb0700a817678 Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Mon, 10 Nov 2025 06:01:10 +0000 Subject: [PATCH 10/11] fix format Signed-off-by: Aditya Chatterjee --- tests/test_flash_attention.py | 56 ++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 2df720b..f1810d7 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -475,7 +475,9 @@ def generate_qkv( ) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( - "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn, torch.float8_e5m2] if not DISABLE_FP8 else []) + "dtype", + [torch.bfloat16] + + ([torch.float8_e4m3fn, torch.float8_e5m2] if not DISABLE_FP8 else []), ) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @@ -584,7 +586,11 @@ def test_flash_attn_kvcache( rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 - dtype_ref = torch.bfloat16 if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) else dtype + dtype_ref = ( + torch.bfloat16 + if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) + else dtype + ) dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if use_sinks: sinks = torch.randn(nheads, device=device, dtype=dtype_ref) @@ -826,7 +832,7 @@ def test_flash_attn_kvcache( v_cache_rep = repeat( v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k ) - if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2): + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: q_descale, k_descale, v_descale = [ torch.randn(batch_size, nheads_k, device=device, dtype=torch.float32) for _ in range(3) @@ -866,7 +872,11 @@ def test_flash_attn_kvcache( upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, - intermediate_dtype=dtype if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) else None, + intermediate_dtype=( + dtype + if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) + else None + ), ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None @@ -989,7 +999,10 @@ def test_flash_attn_kvcache( )[:, :seqlen_k].to(dtype_ref) k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn and dtype is not torch.float8_e5m2: + if ( + dtype is not torch.float8_e4m3fn + and dtype is not torch.float8_e5m2 + ): assert torch.equal(v_cache_select, v_cache_ref) else: assert torch.allclose( @@ -1002,7 +1015,10 @@ def test_flash_attn_kvcache( else: # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): # breakpoint() - if dtype is not torch.float8_e4m3fn and dtype is not torch.float8_e5m2: + if ( + dtype is not torch.float8_e4m3fn + and dtype is not torch.float8_e5m2 + ): assert torch.allclose( k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 ) @@ -1064,7 +1080,9 @@ def _generate_block_kvcache( reason="flash_attn at sgl-kernel-xpu only supports paged cache", ) @pytest.mark.parametrize( - "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn, torch.float8_e5m2] if not DISABLE_FP8 else []) + "dtype", + [torch.bfloat16] + + ([torch.float8_e4m3fn, torch.float8_e5m2] if not DISABLE_FP8 else []), ) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @@ -1139,7 +1157,11 @@ def test_flash_attn_varlen_output( # batch_size = 2 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - dtype_ref = torch.bfloat16 if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) else dtype + dtype_ref = ( + torch.bfloat16 + if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) + else dtype + ) dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: dv_vals = [d] @@ -1273,7 +1295,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): softcap=softcap, upcast=False, reorder_ops=True, - intermediate_dtype=dtype if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) else None, + intermediate_dtype=( + dtype + if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) + else None + ), ) print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") @@ -1320,7 +1346,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): out_pt - out_ref ).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and (dtype != torch.float8_e4m3fn and dtype != torch.float8_e5m2) and not has_qv: + if ( + not DISABLE_BACKWARD + and (dtype != torch.float8_e4m3fn and dtype != torch.float8_e5m2) + and not has_qv + ): g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( @@ -1356,7 +1386,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - if not DISABLE_BACKWARD and (dtype != torch.float8_e4m3fn and dtype != torch.float8_e5m2) and not has_qv: + if ( + not DISABLE_BACKWARD + and (dtype != torch.float8_e4m3fn and dtype != torch.float8_e5m2) + and not has_qv + ): dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) From 05ee4f4b6cd512527a81cf66c086ba5f469d3c62 Mon Sep 17 00:00:00 2001 From: Aditya Chatterjee Date: Wed, 12 Nov 2025 07:27:32 +0000 Subject: [PATCH 11/11] address review comments Signed-off-by: Aditya Chatterjee --- src/sycl/comm/fp8_descale.h | 16 ++++++++++------ .../xe_flash_attn_chunk_prefill_mma.hpp | 6 +++--- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/sycl/comm/fp8_descale.h b/src/sycl/comm/fp8_descale.h index 37ca11d..c2402c4 100644 --- a/src/sycl/comm/fp8_descale.h +++ b/src/sycl/comm/fp8_descale.h @@ -94,8 +94,8 @@ CUTLASS_DEVICE void convert_and_descale(SrcTensor const& src, DstTensor& dst, fl auto src_ptr = reinterpret_cast(src.data()); auto dst_ptr = reinterpret_cast(dst.data()); - // Create a SCALAR bfloat16_t for scaling - const cutlass::bfloat16_t scale_bf16 = static_cast(scale); + // Keep scale as FLOAT to maintain precision for small values + const float scale_f32 = scale; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < cute::size(src) / VectorizeSize; ++i) { @@ -115,11 +115,15 @@ CUTLASS_DEVICE void convert_and_descale(SrcTensor const& src, DstTensor& dst, fl // 2. Reinterpret bits as bfloat16_t to perform math cutlass::bfloat16_t val_bf16 = reinterpret_cast(val_bf16_bits); - // 3. Apply scaling - val_bf16 *= scale_bf16; + // 3. Apply scaling in FLOAT precision (not bfloat16) + float val_f32 = static_cast(val_bf16); + val_f32 *= scale_f32; - // 4. Reinterpret back to bits for storage - result_vec_u16[j] = reinterpret_cast(val_bf16); + // 4. Convert back to bfloat16 + cutlass::bfloat16_t scaled_bf16 = static_cast(val_f32); + + // 5. Store as bits + result_vec_u16[j] = reinterpret_cast(scaled_bf16); } // 5. Store the final vector of bits diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index a89bcbf..05ae1ac 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -436,10 +436,10 @@ struct FlashChunkPrefillMma< for (int i = 0; i < tile_count; i++) { copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); if constexpr (is_fp8_v) { - auto tCrV_fp16 = make_fragment_like(tCrV); - convert_and_descale(tCrV, tCrV_fp16, v_scale); + auto tCrV_bf16 = make_fragment_like(tCrV); + convert_and_descale(tCrV, tCrV_bf16, v_scale); - cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV_fp16, frag_src(_, _, _, i)); + cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV_bf16, frag_src(_, _, _, i)); } else { cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i)); }