diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index 05ae343..e4a9572 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -1,3 +1,33 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ #include #include #include @@ -67,6 +97,10 @@ struct Flash_fwd_params { void* sink_softmax; float softcap; + float* __restrict__ q_scale_ptr; + float* __restrict__ k_scale_ptr; + float* __restrict__ v_scale_ptr; + // array of length b+1 holding starting offset of each sequence. int* __restrict__ cu_seqlens_q; int* __restrict__ cu_seqlens_k; @@ -138,7 +172,7 @@ struct Flash_fwd_params { bool is_bf16; bool is_fp32; - bool is_e4m3; + bool is_fp8; bool is_causal; bool is_local; @@ -321,10 +355,13 @@ struct KernelRunner { // stride_K, // static_cast(params.vnew_ptr), // stride_V, - static_cast(params.k_ptr), + static_cast(params.k_ptr), stride_K_cache, static_cast(params.v_ptr), stride_V_cache, + params.q_scale_ptr, + params.k_scale_ptr, + params.v_scale_ptr, params.page_table, params.page_size, params.max_num_pages_per_seq, @@ -505,8 +542,9 @@ std::vector mha_fwd( auto q_type = q.scalar_type(); TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "SGL Kernel XPU only supports fp16 and bf16 type"); + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn || + q_type == at::ScalarType::Float8_e5m2, + "SGL Kernel XPU only supports fp16, bf16 and fp8 types"); TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); @@ -589,7 +627,13 @@ std::vector mha_fwd( auto opts = q.options(); at::Tensor out; - out = torch::empty({total_q, num_heads, head_size_v}, opts); + // out = torch::empty({total_q, num_heads, head_size_v}, opts); + if (q.dtype() == at::ScalarType::Float8_e4m3fn || q.dtype() == at::ScalarType::Float8_e5m2) { + // Internal math & epilogue producing BF16 + out = torch::empty({total_q, num_heads, head_size_v}, opts.dtype(torch::kBFloat16)); + } else { + out = torch::empty({total_q, num_heads, head_size_v}, opts); + } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); @@ -607,6 +651,7 @@ std::vector mha_fwd( // align with FA3 Flash_fwd_params params; params.is_bf16 = q.dtype() == torch::kBFloat16; + params.is_fp8 = q.dtype() == at::ScalarType::Float8_e4m3fn || q.dtype() == at::ScalarType::Float8_e5m2; // Set the pointers and strides. params.q_ptr = q.data_ptr(); @@ -624,6 +669,12 @@ std::vector mha_fwd( params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); + if (params.is_fp8) { + params.q_scale_ptr = static_cast(q_descale_.value().data_ptr()); + params.k_scale_ptr = static_cast(k_descale_.value().data_ptr()); + params.v_scale_ptr = static_cast(v_descale_.value().data_ptr()); + } + params.cu_seqlens_q = cu_seqlens_q.data_ptr(); params.cu_seqlens_k = cu_seqlens_k.data_ptr(); @@ -737,122 +788,229 @@ std::vector mha_fwd( auto outaccum_type = at::ScalarType::Float; constexpr int PipelineStages = 2; - switch (params.d) { - case 64: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _64, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _64, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - case 96: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_128, _64, _32>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _96, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_128, _64, _32>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _96, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - case 128: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _128, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _128, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - case 192: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_256, _64, _64>, - cute::Shape<_256, _32, _64>, - cute::Shape<_256, _192, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_256, _64, _64>, - cute::Shape<_256, _32, _64>, - cute::Shape<_256, _192, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - default: - TORCH_CHECK(false, "Unsupported head size for causal attention"); + if (params.is_fp8) { + switch (params.d) { + case 64: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink, + float_e4m3_t, + float_e4m3_t, + XE_8x16x16_F32BF16BF16F32_TT, + XE_2D_U8x8x32_LD_N, + XE_2D_U8x16x16_LD_T, + XE_2D_U8x32x32_LD_V, + float, + float, + bfloat16_t, + bfloat16_t, + XE_2D_U16x8x16_ST_N>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink, + float_e4m3_t, + float_e4m3_t, + XE_8x16x16_F32BF16BF16F32_TT, + XE_2D_U8x8x32_LD_N, + XE_2D_U8x16x16_LD_T, + XE_2D_U8x32x32_LD_V, + float, + float, + bfloat16_t, + bfloat16_t, + XE_2D_U16x8x16_ST_N>::run(params)) + } + }) + break; + case 128: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink, + float_e4m3_t, + float_e4m3_t, + XE_8x16x16_F32BF16BF16F32_TT, + XE_2D_U8x8x32_LD_N, + XE_2D_U8x16x16_LD_T, + XE_2D_U8x32x32_LD_V, + float, + float, + bfloat16_t, + bfloat16_t, + XE_2D_U16x8x16_ST_N>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink, + float_e4m3_t, + float_e4m3_t, + XE_8x16x16_F32BF16BF16F32_TT, + XE_2D_U8x8x32_LD_N, + XE_2D_U8x16x16_LD_T, + XE_2D_U8x32x32_LD_V, + float, + float, + bfloat16_t, + bfloat16_t, + XE_2D_U16x8x16_ST_N>::run(params)) + } + }) + break; + default: + TORCH_CHECK(false, "Unsupported head size for FP8"); + } + } else { // BF16 + switch (params.d) { + case 64: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _64, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink>::run(params)) + } + }) + break; + case 96: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _32>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _96, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink>::run(params); + + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _32>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _96, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink>::run(params)) + } + }) + break; + case 128: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_128, _64, _64>, + cute::Shape<_128, _32, _64>, + cute::Shape<_128, _128, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink>::run(params)) + } + }) + break; + case 192: + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + FMHAConfig< + cute::Shape<_256, _64, _64>, + cute::Shape<_256, _32, _64>, + cute::Shape<_256, _192, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + true, + false, + Sink>::run(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN( + params.is_local, + LocalMask, + FMHAConfig< + cute::Shape<_256, _64, _64>, + cute::Shape<_256, _32, _64>, + cute::Shape<_256, _192, _64>, + cute::Layout, cute::Stride<_1, _1, _1>>, + PipelineStages, + false, + LocalMask, + Sink>::run(params)) + } + }) + break; + default: + TORCH_CHECK(false, "Unsupported head size for causal attention"); + } } return {out, softmax_lse, out_accum, softmax_lse_accum}; } diff --git a/src/sycl/comm/fp8_descale.h b/src/sycl/comm/fp8_descale.h new file mode 100644 index 0000000..c2402c4 --- /dev/null +++ b/src/sycl/comm/fp8_descale.h @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// Helper device function for E4M3 -> BFLOAT16 bitwise conversion +CUTLASS_DEVICE uint16_t fp8_e4m3_to_bf16_bitwise(uint8_t const& src) { + // E4M3 (1-4-3) constants + constexpr uint32_t e4m3_exp_bias = 7; + // BFLOAT16 (1-8-7) constants + constexpr uint32_t bf16_exp_bias = 127; + + // Unpack FP8 bits + uint16_t sign = static_cast(src & 0x80); + uint16_t exponent = static_cast(src & 0x78) >> 3; + uint16_t mantissa = static_cast(src & 0x07); + + // Reconstruct BFLOAT16 bits + uint16_t bf16_sign = sign << 8; + // Re-bias exponent and shift to BFLOAT16 position + uint16_t bf16_exponent = (exponent - e4m3_exp_bias + bf16_exp_bias) << 7; + // Shift mantissa to BFLOAT16 position + uint16_t bf16_mantissa = mantissa << 4; + + return bf16_sign | bf16_exponent | bf16_mantissa; +} + +// Helper device function for E5M2 -> BFLOAT16 bitwise conversion +CUTLASS_DEVICE uint16_t fp8_e5m2_to_bf16_bitwise(uint8_t const& src) { + // E5M2 (1-5-2) constants + constexpr uint32_t e5m2_exp_bias = 15; + // BFLOAT16 (1-8-7) constants + constexpr uint32_t bf16_exp_bias = 127; + + // Unpack FP8 bits + uint16_t sign = static_cast(src & 0x80); + uint16_t exponent = static_cast(src & 0x7C) >> 2; + uint16_t mantissa = static_cast(src & 0x03); + + // Reconstruct BFLOAT16 bits + uint16_t bf16_sign = sign << 8; + // Re-bias exponent and shift to BFLOAT16 position + uint16_t bf16_exponent = (exponent - e5m2_exp_bias + bf16_exp_bias) << 7; + // Shift mantissa to BFLOAT16 position + uint16_t bf16_mantissa = mantissa << 5; + + return bf16_sign | bf16_exponent | bf16_mantissa; +} + +template +CUTLASS_DEVICE void convert_and_descale(SrcTensor const& src, DstTensor& dst, float scale) { + using SrcVec_u8 = sycl::vec; + using DstVec_u16 = sycl::vec; + + auto src_ptr = reinterpret_cast(src.data()); + auto dst_ptr = reinterpret_cast(dst.data()); + + // Keep scale as FLOAT to maintain precision for small values + const float scale_f32 = scale; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cute::size(src) / VectorizeSize; ++i) { + SrcVec_u8 const src_vec_u8 = src_ptr[i]; + DstVec_u16 result_vec_u16; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < VectorizeSize; ++j) { + // 1. Convert FP8 bits to BFLOAT16 bits + uint16_t val_bf16_bits; + if constexpr (std::is_same_v) { + val_bf16_bits = fp8_e4m3_to_bf16_bitwise(src_vec_u8[j]); + } else { + val_bf16_bits = fp8_e5m2_to_bf16_bitwise(src_vec_u8[j]); + } + + // 2. Reinterpret bits as bfloat16_t to perform math + cutlass::bfloat16_t val_bf16 = reinterpret_cast(val_bf16_bits); + + // 3. Apply scaling in FLOAT precision (not bfloat16) + float val_f32 = static_cast(val_bf16); + val_f32 *= scale_f32; + + // 4. Convert back to bfloat16 + cutlass::bfloat16_t scaled_bf16 = static_cast(val_f32); + + // 5. Store as bits + result_vec_u16[j] = reinterpret_cast(scaled_bf16); + } + + // 5. Store the final vector of bits + dst_ptr[i] = result_vec_u16; + } +} diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp index 18d525b..79b14e4 100644 --- a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -330,6 +330,22 @@ class FMHAPrefillChunk { int tiles_per_page = params.mainloop.page_size / QK_BLK_N; + auto q_group_size = num_heads_q / num_heads_kv; + auto kv_head_coord = q_head_coord / q_group_size; + + // Descale tensors are shaped (batch size * # heads) + // Each head has a separate scale factor + // Q, K, V tensors have separate scaling factors + const float q_scale_val = params.mainloop.ptr_q_scale == nullptr + ? 1.f + : params.mainloop.ptr_q_scale[batch_coord * num_heads_kv + q_head_coord]; + const float k_scale_val = params.mainloop.ptr_k_scale == nullptr + ? 1.f + : params.mainloop.ptr_k_scale[batch_coord * num_heads_kv + kv_head_coord]; + const float v_scale_val = params.mainloop.ptr_v_scale == nullptr + ? 1.f + : params.mainloop.ptr_v_scale[batch_coord * num_heads_kv + kv_head_coord]; + Tensor mQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) Tensor mK_cache_nkl = cute::get_xe_tensor(make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) @@ -443,7 +459,8 @@ class FMHAPrefillChunk { // Then modify layout to LayoutQ = ((seq_leq_q, group_head_q), // head_size_qk, batch* num_heads_q / group_head_q), which can be merged // into one gemm for (int i = 0; i < q_group_size; ++i) { - collective_mma.mmaQK(tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params); + collective_mma.mmaQK( + tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, q_scale_val, k_scale_val); if constexpr (LocalMask) { // Sliding windows @@ -542,7 +559,7 @@ class FMHAPrefillChunk { softmax(split == 0, tSr, max_reg, sum_reg, out_reg); // 5) Perform GEMM O = S*V - collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, mainloop_params); + collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, mainloop_params, v_scale_val); // ... prefetch next tile ... // Prefetch the next Q tile CUTLASS_PRAGMA_UNROLL diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index 4c21c3b..05ae1ac 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -30,6 +30,7 @@ **************************************************************************************************/ #pragma once +#include "../../comm/fp8_descale.h" #include "cute/algorithm/functional.hpp" #include "cute/algorithm/gemm.hpp" #include "cute/atom/mma_atom.hpp" @@ -193,6 +194,9 @@ struct FlashChunkPrefillMma< using val_layout_load_V = decltype(make_layout(shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{}, Layout{}, val_layout_load_V{})); + template + static constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; + // Host side kernel arguments struct Arguments { ElementQ const* ptr_Q; @@ -201,6 +205,9 @@ struct FlashChunkPrefillMma< StrideK dK_cache; ElementV const* ptr_V_cache; StrideV dV_cache; + float const* ptr_q_scale; + float const* ptr_k_scale; + float const* ptr_v_scale; // Paged KV Cache int const* ptr_page_table; int page_size; @@ -213,6 +220,9 @@ struct FlashChunkPrefillMma< XE_Copy_Q gmem_tiled_copy_q; XE_Copy_K gmem_tiled_copy_k_cache; XE_Copy_V gmem_tiled_copy_v_cache; + float const* ptr_q_scale; + float const* ptr_k_scale; + float const* ptr_v_scale; int const* ptr_page_table; int page_size; int max_num_pages_per_seq; @@ -250,6 +260,9 @@ struct FlashChunkPrefillMma< copyQ, copyK_cache, copyV_cache, + args.ptr_q_scale, + args.ptr_k_scale, + args.ptr_v_scale, args.ptr_page_table, args.page_size, args.max_num_pages_per_seq, @@ -257,6 +270,8 @@ struct FlashChunkPrefillMma< args.window_right}; } + // FP8 Q and FP8 K tensors are converted to BF16 tensors using descale factors + // GEMM is computed in BF16 precision (FP8 not supported in BMG) template CUTLASS_DEVICE void mmaQK( FragQccum& accum, @@ -264,7 +279,9 @@ struct FlashChunkPrefillMma< TensorK gK, FragSrc const& frag_src, int const& k_tile_count, - Params const& params) { + Params const& params, + float q_scale, + float k_scale) { auto& gmem_tiled_copy_k = params.gmem_tiled_copy_k_cache; int thread_idx = static_cast(ThreadIdxX()); @@ -283,9 +300,10 @@ struct FlashChunkPrefillMma< Tensor tCgK = thread_mma_k.partition_B(gK); // Create fragments - // TODO(Codeplay): fix this, this is probably not general - Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0, 3>(tCgQ.shape()))); - Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0, 3>(tCgK.shape()))); + using TCrQ_Type = cute::conditional_t, uint8_t, ElementQ>; + using TCrK_Type = cute::conditional_t, uint8_t, ElementK>; + Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0, 3>(tCgQ.shape()))); + Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0, 3>(tCgK.shape()))); // Retile registers for copies Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); @@ -302,7 +320,30 @@ struct FlashChunkPrefillMma< for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ); copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); - cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); + // FP8 path: Convert FP8 fragments to BF16 + if constexpr (is_fp8_v || is_fp8_v) { + auto tCrQ_bf16 = make_fragment_like(tCrQ); + auto tCrK_bf16 = make_fragment_like(tCrK); + + if constexpr (is_fp8_v) { + convert_and_descale(tCrQ, tCrQ_bf16, q_scale); + } else { + // If Q is already FP16, copy it. + copy(tCrQ, tCrQ_bf16); + } + + if constexpr (is_fp8_v) { + convert_and_descale(tCrK, tCrK_bf16, k_scale); + } else { + copy(tCrK, tCrK_bf16); + } + + // GEMM is computed on the BF16 tensors + cute::gemm(tiled_mma, accum, tCrQ_bf16, tCrK_bf16, frag_src); + } else { + // BF16 path + cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); + } #if 0 #define PRINT(x) \ print(#x ": "); \ @@ -341,11 +382,13 @@ struct FlashChunkPrefillMma< return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } + // FP8 V tensor is converted to BF16 tensor using descale factor + // P tensor (softmax output) is in FP32 precision (converted to BF16) + // GEMM is computed in BF16 precision (FP8 not supported in BMG) template CUTLASS_DEVICE void - mmaPV(FragQccum& accum, FragS const& tSr, TensorV gV, FragSrc const& frag_src, Params const& params) { + mmaPV(FragQccum& accum, FragS const& tSr, TensorV gV, FragSrc const& frag_src, Params const& params, float v_scale) { auto& gmem_tiled_copy_v = params.gmem_tiled_copy_v_cache; - int thread_idx = static_cast(ThreadIdxX()); // Instantiate the MMA object TiledMmaPV tiled_mma; @@ -356,7 +399,8 @@ struct FlashChunkPrefillMma< auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); Tensor tCgV = thread_mma.partition_B(gV_); - Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0, 3>(tCgV.shape()))); + using TCrV_Type = cute::conditional_t, uint8_t, ElementV>; + Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, take<0, 3>(tCgV.shape()))); // Partition the copying of A and B tiles across the threads auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx); @@ -391,7 +435,14 @@ struct FlashChunkPrefillMma< CUTLASS_PRAGMA_UNROLL for (int i = 0; i < tile_count; i++) { copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); - cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i)); + if constexpr (is_fp8_v) { + auto tCrV_bf16 = make_fragment_like(tCrV); + convert_and_descale(tCrV, tCrV_bf16, v_scale); + + cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV_bf16, frag_src(_, _, _, i)); + } else { + cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i)); + } } } @@ -453,6 +504,9 @@ struct FlashChunkPrefillMma< copyQ, copyK_cache, copyV_cache, + params.ptr_q_scale, + params.ptr_k_scale, + params.ptr_v_scale, params.ptr_page_table, params.page_size, params.max_num_pages_per_seq, diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 3b03b53..f1810d7 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -63,7 +63,7 @@ def is_fa3_supported(device=None) -> bool: DISABLE_SOFTCAP = True DISABLE_PACKGQA = True DISABLE_FP16 = True -DISABLE_FP8 = True +DISABLE_FP8 = False # Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/padding.py @@ -475,7 +475,9 @@ def generate_qkv( ) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( - "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) + "dtype", + [torch.bfloat16] + + ([torch.float8_e4m3fn, torch.float8_e5m2] if not DISABLE_FP8 else []), ) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @@ -584,11 +586,15 @@ def test_flash_attn_kvcache( rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 - dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dtype_ref = ( + torch.bfloat16 + if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) + else dtype + ) dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if use_sinks: sinks = torch.randn(nheads, device=device, dtype=dtype_ref) - if dtype == torch.float8_e4m3fn or not is_hopper(): + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2 or not is_hopper(): # for fp8 and ampere arch, we not support v head dim != qk head dim dv_vals = [d] for dv in dv_vals: @@ -826,6 +832,13 @@ def test_flash_attn_kvcache( v_cache_rep = repeat( v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k ) + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: + q_descale, k_descale, v_descale = [ + torch.randn(batch_size, nheads_k, device=device, dtype=torch.float32) + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None out_ref, _ = attention_ref( q_ro, k_cache_rep, @@ -836,6 +849,9 @@ def test_flash_attn_kvcache( key_padding_mask, causal=causal, qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, key_leftpad=cache_leftpad, ) @@ -849,11 +865,18 @@ def test_flash_attn_kvcache( key_padding_mask, causal=causal, qv=qv, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, window_size=window_size, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + intermediate_dtype=( + dtype + if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) + else None + ), ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None @@ -901,6 +924,9 @@ def test_flash_attn_kvcache( cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, max_seqlen_q=max_seqlen_q, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, rotary_seqlens=rotary_seqlens, causal=causal, window_size=window_size, @@ -924,6 +950,7 @@ def test_flash_attn_kvcache( # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) torch.xpu.synchronize() + out = out.to(dtype_ref) out = out.flatten() out_ref = out_ref.flatten() out_pt = out_pt.flatten() @@ -972,7 +999,10 @@ def test_flash_attn_kvcache( )[:, :seqlen_k].to(dtype_ref) k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: + if ( + dtype is not torch.float8_e4m3fn + and dtype is not torch.float8_e5m2 + ): assert torch.equal(v_cache_select, v_cache_ref) else: assert torch.allclose( @@ -985,7 +1015,10 @@ def test_flash_attn_kvcache( else: # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): # breakpoint() - if dtype is not torch.float8_e4m3fn: + if ( + dtype is not torch.float8_e4m3fn + and dtype is not torch.float8_e5m2 + ): assert torch.allclose( k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 ) @@ -993,11 +1026,16 @@ def test_flash_attn_kvcache( assert torch.allclose( k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 ) + # E5M2 has large dynamic range and low precision so error range is large (standard) mult = 4 if dtype == torch.float8_e4m3fn else 2 + if dtype == torch.float8_e5m2: + mult = 90 assert (out - out_ref).abs().max().item() <= mult * ( out_pt - out_ref ).abs().max().item() + 1e-5 mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + if dtype == torch.float8_e5m2: + mult_mean = 40 assert (out - out_ref).abs().mean().item() <= mult_mean * ( out_pt - out_ref ).abs().mean().item() @@ -1007,13 +1045,14 @@ def _generate_block_kvcache( seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref ): num_blocks = math.ceil(seqlen_k / page_size) * batch_size + create_fn = torch.randn k_cache_paged = ( - torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + create_fn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) .to(dtype) .to(dtype_ref) ) v_cache_paged = ( - torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + create_fn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) .to(dtype) .to(dtype_ref) ) @@ -1041,7 +1080,9 @@ def _generate_block_kvcache( reason="flash_attn at sgl-kernel-xpu only supports paged cache", ) @pytest.mark.parametrize( - "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) + "dtype", + [torch.bfloat16] + + ([torch.float8_e4m3fn, torch.float8_e5m2] if not DISABLE_FP8 else []), ) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @@ -1116,9 +1157,13 @@ def test_flash_attn_varlen_output( # batch_size = 2 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dtype_ref = ( + torch.bfloat16 + if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) + else dtype + ) dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: dv_vals = [d] for dv in dv_vals: q_ref = torch.randn( @@ -1156,7 +1201,7 @@ def test_flash_attn_varlen_output( qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - if dtype == torch.float8_e4m3fn: + if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: q_descale, k_descale, v_descale = [ torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 @@ -1250,7 +1295,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): softcap=softcap, upcast=False, reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + intermediate_dtype=( + dtype + if (dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2) + else None + ), ) print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") @@ -1287,6 +1336,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): out = output_pad_fn(out_unpad) if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -1296,7 +1346,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): out_pt - out_ref ).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + if ( + not DISABLE_BACKWARD + and (dtype != torch.float8_e4m3fn and dtype != torch.float8_e5m2) + and not has_qv + ): g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( @@ -1332,7 +1386,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + if ( + not DISABLE_BACKWARD + and (dtype != torch.float8_e4m3fn and dtype != torch.float8_e5m2) + and not has_qv + ): dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 )