diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma.hpp index 6dcfe4bfa8..eaf5f3652c 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma.hpp @@ -39,6 +39,7 @@ #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" #include "fmha_fusion.hpp" +#include "xe_rotary.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -62,7 +63,7 @@ CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { template + class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_ = false> struct FlashPrefillMma { static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); }; @@ -71,9 +72,9 @@ struct FlashPrefillMma { template + class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_> struct FlashPrefillMma, ProblemShapeType_, ElementQ_, StrideQ_, ElementK_, StrideK_, ElementV_, - StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_> { + StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_, RopeMask_> { // // Type Aliases // @@ -97,6 +98,7 @@ struct FlashPrefillMma, ProblemShapeType_, El using TiledMmaPV = typename TiledMMAHelper, SubgroupLayout>::TiledMMA; using ElementAccumulator = typename TiledMmaQK::ValTypeC; static constexpr bool CausalMask = CausalMask_; + static constexpr bool rope_enabled = RopeMask_; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; using MmaAtomShape = typename MmaAtom::Shape_MNK; @@ -158,12 +160,19 @@ struct FlashPrefillMma, ProblemShapeType_, El StrideK dK; ElementV const *ptr_V; StrideV dV; + // for RoPE case + ElementQ const *ptr_cos = nullptr; + ElementQ const *ptr_sin = nullptr; }; struct Params { XE_Copy_Q gmem_tiled_copy_q; XE_Copy_K gmem_tiled_copy_k; XE_Copy_V gmem_tiled_copy_v; + XE_Copy_Q gmem_tiled_copy_q_cos; + XE_Copy_Q gmem_tiled_copy_q_sin; + XE_Copy_K gmem_tiled_copy_k_cos; + XE_Copy_K gmem_tiled_copy_k_sin; }; // @@ -181,11 +190,21 @@ struct FlashPrefillMma, ProblemShapeType_, El auto tensorQ = make_tensor(make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ)); auto tensorK = make_tensor(make_gmem_ptr(args.ptr_K), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK)); auto tensorV = make_tensor(make_gmem_ptr(args.ptr_V), make_layout(make_shape(head_size_vo, seq_len_kv, batch * num_heads_kv), args.dV)); + + auto tensorQCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ)); + auto tensorQSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ)); + auto tensorKCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK)); + auto tensorKSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK)); + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; - - return Params{copyQ, copyK, copyV}; + XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)}; + XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)}; + XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)}; + XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)}; + + return Params{copyQ, copyK, copyV, copyQCos, copyQSin, copyKCos, copyKSin}; } template @@ -372,11 +391,32 @@ struct FlashPrefillMma, ProblemShapeType_, El auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), make_layout(shape_k, stride_k)); auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), make_layout(shape_v, stride_v)); + auto q_traits_cos = static_cast(params.gmem_tiled_copy_q_cos); + ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr; + + auto q_traits_sin = static_cast(params.gmem_tiled_copy_q_sin); + ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr; + + auto k_traits_cos = static_cast(params.gmem_tiled_copy_k_cos); + ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr; + + auto k_traits_sin = static_cast(params.gmem_tiled_copy_k_sin); + ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr; + + auto tensorQCos = make_tensor(make_gmem_ptr(base_ptr_q_cos + offset_q), make_layout(shape_q, stride_q)); + auto tensorQSin = make_tensor(make_gmem_ptr(base_ptr_q_sin + offset_q), make_layout(shape_q, stride_q)); + auto tensorKCos = make_tensor(make_gmem_ptr(base_ptr_k_cos + offset_k), make_layout(shape_k, stride_k)); + auto tensorKSin = make_tensor(make_gmem_ptr(base_ptr_k_sin + offset_k), make_layout(shape_k, stride_k)); + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; + XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)}; + XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)}; + XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)}; + XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)}; - return Params{copyQ, copyK, copyV}; + return Params{copyQ, copyK, copyV, copyQCos, copyQSin, copyKCos, copyKSin}; } } }; diff --git a/applications/flash_attention_v2/collective/xe_rotary.h b/applications/flash_attention_v2/collective/xe_rotary.h new file mode 100644 index 0000000000..4f65e14d2e --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_rotary.h @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" + +namespace cutlass::flash_attention::collective { +using namespace cute; + +template +CUTLASS_DEVICE void apply_rope_interleaved_gmem( + int thread_idx, + Tensor const &srcTensor, + TensorCos const &gCos, + TensorSin const &gSin, TensorOut &destTensor) { + if(thread_idx < size<0>(srcTensor)){ + for (int j = 0; j < size<1>(gCos); j+=2) { + auto real = static_cast(srcTensor[make_coord(thread_idx, j)]); + auto imag = static_cast(srcTensor[make_coord(thread_idx, j + 1)]); + auto cos_val = static_cast(gCos[make_coord(thread_idx, j)]); + auto sin_val = static_cast(gSin[make_coord(thread_idx, j)]); + + auto new_real = real * cos_val - imag * sin_val; + auto new_imag = real * sin_val + imag * cos_val; + + destTensor[make_coord(thread_idx,j)] = static_cast(new_real); + destTensor[make_coord(thread_idx,j + 1)] = static_cast(new_imag); + } + } + syncthreads(); +} +} // namespace cutlass::flash_attention::collective \ No newline at end of file diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp index 88c1b2042c..461e1f0a0d 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp @@ -72,6 +72,8 @@ class FMHAPrefill { using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; + using traits_load_Q = typename CollectiveMainloop::traits_load_Q; + using traits_load_K = typename CollectiveMainloop::traits_load_K; using CollectiveSoftmaxEpilogue = CollectiveSoftmaxEpilogue_; using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; @@ -132,6 +134,10 @@ class FMHAPrefill { using AccumeShape = decltype(make_shape(Int{}, Int{}, get<1>(TileShapePV{})/get<1>(MmaAtomShape()), Int{})); static constexpr bool is_var_len = CollectiveMainloop::is_var_len; + static constexpr bool rope_enabled = CollectiveMainloop::rope_enabled; + + template + static constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; // Kernel level shared memory storage struct SharedStorage { @@ -272,10 +278,24 @@ class FMHAPrefill { Tensor mK_nk = mK_nkl(_, _, blk_l_coord/group_heads_q); // (n,k) Tensor mV_nk = mV_nkl(_, _, blk_l_coord/group_heads_q); // (n,k) + Tensor mCosQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads_q)); // (m, k, l) + Tensor mSinQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads_q)); // (m, k, l) + Tensor mCosK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_head_kv)); // (n, k, l) + Tensor mSinK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_head_kv)); // (n, k, l) + Tensor mCosQ_mk = mCosQ_mkl(_, _, blk_l_coord); // (m,k) + Tensor mSinQ_mk = mSinQ_mkl(_, _, blk_l_coord); // (m,k) + Tensor mCosK_nk = mCosK_nkl(_, _, blk_l_coord/group_heads_q); // (n,k) + Tensor mSinK_nk = mSinK_nkl(_, _, blk_l_coord/group_heads_q); + auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _ , _), Step{}); auto gV = local_tile(mV_nk, TileShapeOutput{}, make_coord(_, blk_n_coord, _), Step{}); + auto gCosQ = local_tile(mCosQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); + auto gSinQ = local_tile(mSinQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); + auto gCosK = local_tile(mCosK_nk, TileShapeQK{}, make_coord(_, _ , _), Step{}); + auto gSinK = local_tile(mSinK_nk, TileShapeQK{}, make_coord(_, _ , _), Step{}); + auto mainloop_params = CollectiveMainloop::get_updated_copies(params.mainloop, params.problem_shape, sequence_length_shape, batch_coord); // we limit the horisontal size to two subgroup, the empirical resutls show that reading the two cacheline side by side in gives better performance and // anything after that does not have an effect on performance. // (64 here for float b float when possible and loop over to cover all the data needed) @@ -289,6 +309,109 @@ class FMHAPrefill { auto pKgK = thr_prefetch_K.partition_S(gK); auto pVgV = thr_prefetch_V.partition_S(gV); + // currently RoPE is not supported for fp8. + if constexpr (rope_enabled && !is_fp8_v) { + int block_idx = static_cast(BlockIdxX()); + int block_idy = static_cast(BlockIdxY()); + int block_idz = static_cast(BlockIdxZ()); + int block_dimx = static_cast(BlockDimX()); + int block_dimy = static_cast(BlockDimY()); + int block_dimz = static_cast(BlockDimZ()); + int thread_idx = static_cast(ThreadIdxX()); + int thread_idy = static_cast(ThreadIdxY()); + int thread_idz = static_cast(ThreadIdxZ()); + int grid_dimx = static_cast(GridDimX()); + int grid_dimy = static_cast(GridDimY()); + int grid_dimz = static_cast(GridDimZ()); + int block_id = block_idx + block_idy * grid_dimx + block_idz * grid_dimx * grid_dimy; + int thread_id = block_id * block_dimx * block_dimy * block_dimz + thread_idz * block_dimx * block_dimy + thread_idy * block_dimx + thread_idx; + + + // calculate the base_ptr and offset for Q, K. + // also calculate the layout for Q, K. + // then apply RoPE on Q, K accordingly + auto [coord_q_x, coord_q_y, coord_q_z] = *gQ.data(); + auto [coord_k_x, coord_k_y, coord_k_z] = *gK.data(); + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = params.problem_shape; + + int offset_q = seq_len_qo*head_size_qk*coord_q_z + head_size_qk*coord_q_x + coord_q_y; // row major + int offset_k = seq_len_kv*head_size_qk*coord_k_z + head_size_qk*coord_k_x + coord_k_y; // row major + + // calculate Q/cosQ/sinQ ptr + auto q_traits = static_cast(mainloop_params.gmem_tiled_copy_q); + ElementQ* base_ptr_q = (ElementQ*)q_traits.base_ptr; + + auto q_traits_cos = static_cast(mainloop_params.gmem_tiled_copy_q_cos); + ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr; + + auto q_traits_sin = static_cast(mainloop_params.gmem_tiled_copy_q_sin); + ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr; + + auto static_shape_q = make_shape(size<0>(gQ), size<1>(gQ)*size<2>(gQ)); + auto layout_q = make_layout(static_shape_q, LayoutRight{}); + + // calculate K/cosK/sinK ptr + auto k_traits = static_cast(mainloop_params.gmem_tiled_copy_k); + ElementK* base_ptr_k = (ElementK*)k_traits.base_ptr; + + auto k_traits_cos = static_cast(mainloop_params.gmem_tiled_copy_k_cos); + ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr; + + auto k_traits_sin = static_cast(mainloop_params.gmem_tiled_copy_k_sin); + ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr; + + auto static_shape_k = make_shape(size<0>(gK), size<1>(gK)*size<3>(gK)); + auto layout_k = make_layout(static_shape_k, LayoutRight{}); + auto gK_dim3 = size<3>(gK); + + // calculating rope for Q + auto tensorQ = make_tensor(make_gmem_ptr(base_ptr_q+offset_q), layout_q); + auto tensorCosQ = make_tensor(make_gmem_ptr(base_ptr_q_cos+offset_q), layout_q); + auto tensorSinQ = make_tensor(make_gmem_ptr(base_ptr_q_sin+offset_q), layout_q); + cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorQ, tensorCosQ, tensorSinQ, tensorQ); + + //calculating rope for K + // need to consider the case when there are multiple blocks in y direction + // each block in y direction will handle a different set of K + // so need to adjust the base pointer of K accordingly. + if(grid_dimy == 4){ + if (block_id%4==1){ + offset_k += QK_BLK_N*QK_BLK_K*gK_dim3; + } else if (block_id%4==2){ + offset_k += 2*QK_BLK_N*QK_BLK_K*gK_dim3; + } else if (block_id%4==3){ + offset_k += 3*QK_BLK_N*QK_BLK_K*gK_dim3; + } + + auto new_offset_k = offset_k; + for (int i =0 ;i< size<2>(gK); i+=4){ + auto tensorK = make_tensor(make_gmem_ptr(base_ptr_k+new_offset_k), layout_k); + auto tensorCosK = make_tensor(make_gmem_ptr(base_ptr_k_cos+new_offset_k), layout_k); + auto tensorSinK = make_tensor(make_gmem_ptr(base_ptr_k_sin+new_offset_k), layout_k); + cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorK, tensorCosK, tensorSinK, tensorK); + new_offset_k += 4*QK_BLK_N*QK_BLK_K*gK_dim3; + } + } else if (grid_dimy ==2){ + if (block_id%2==1){ + offset_k += QK_BLK_N*QK_BLK_K*gK_dim3; + } + auto new_offset_k = offset_k; + for (int i =0 ;i< size<2>(gK); i+=2){ + auto tensorK = make_tensor(make_gmem_ptr(base_ptr_k+new_offset_k), layout_k); + auto tensorCosK = make_tensor(make_gmem_ptr(base_ptr_k_cos+new_offset_k), layout_k); + auto tensorSinK = make_tensor(make_gmem_ptr(base_ptr_k_sin+new_offset_k), layout_k); + cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorK, tensorCosK, tensorSinK, tensorK); + new_offset_k += 2*QK_BLK_N*QK_BLK_K*gK_dim3; + } + } + + barrier_arrive(2); + for(int i=0;i< 10000;i++){ + + } + barrier_wait(2); + } + for (int i = 0; i < size<3>(pQgQ); i++) { prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); } diff --git a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp index a60cc7c385..28fd954993 100644 --- a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp @@ -49,6 +49,7 @@ #include #include +#include #include "cutlass/util/command_line.h" #include "cutlass/util/device_memory.h" @@ -130,7 +131,7 @@ struct Shape_h192 { template + typename MMAOperation, bool HasCausalMask, bool isVarLen, int PipelineStages, bool rope_enabled=false> struct XE_Flash_Attention_Prefill { using LayoutQ = cutlass::layout::RowMajor; using LayoutK = cutlass::layout::ColumnMajor; @@ -171,7 +172,7 @@ struct XE_Flash_Attention_Prefill { GmemTiledCopyQ, // Q GmemTiledCopyK, // K GmemTiledCopyV, // V, - HasCausalMask>; + HasCausalMask, rope_enabled>; using Kernel = cutlass::flash_attention::kernel::FMHAPrefill; @@ -222,6 +223,13 @@ struct TestbedImpl { cutlass::DeviceAllocation block_V; cutlass::DeviceAllocation block_O; cutlass::DeviceAllocation block_ref_O; + cutlass::DeviceAllocation block_ref_Q; + cutlass::DeviceAllocation block_ref_K; + + // RoPE support + cutlass::DeviceAllocation block_cos; + cutlass::DeviceAllocation block_sin; + static constexpr bool rope_enabled = CollectiveMainloop::rope_enabled; // // Methods @@ -248,6 +256,57 @@ struct TestbedImpl { }; } + /// Initialize RoPE cos/sin tensors + void initialize_rope_tensors(int max_seq_len, int head_dim, int num_heads_q, int batch) { + std::vector cos_vals(max_seq_len * head_dim * num_heads_q * batch); + std::vector sin_vals(max_seq_len * head_dim * num_heads_q * batch); + + // fill data row-major wise + for(int b = 0; b< num_heads_q*batch; b++){ + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int i = 0; i < head_dim/2 ; ++i) { + int idx = b*max_seq_len*head_dim + pos*head_dim + 2*i; + int idx1 = b*max_seq_len*head_dim + pos*head_dim + 2*i + 1; + float theta = static_cast(pos / std::pow(10000.0f, (2.0f * i) / head_dim)); + // float theta = i; + cos_vals[idx] = static_cast(std::cos(theta)); + cos_vals[idx1] = static_cast(std::cos(theta)); + sin_vals[idx] = static_cast(std::sin(theta)); + sin_vals[idx1] = static_cast(std::sin(theta)); + } + } + } + compat::memcpy(block_cos.get(), cos_vals.data(), cos_vals.size() * sizeof(ElementQ)); + compat::memcpy(block_sin.get(), sin_vals.data(), sin_vals.size() * sizeof(ElementQ)); + compat::wait(); + } + + /// Apply RoPE transformation to a tensor + template + void apply_rope_on_host(std::vector& tensor, int seq_len, int head_dim, int batch, int head, + const std::vector& cos_vals, const std::vector& sin_vals) { + for (int seq_pos = 0; seq_pos < seq_len; ++seq_pos) { + for (int dim_pair = 0; dim_pair < head_dim/2; ++dim_pair) { + int cos_sin_idx = seq_pos * head_dim + dim_pair * 2; + auto cos_val = static_cast(cos_vals[cos_sin_idx]); + auto sin_val = static_cast(sin_vals[cos_sin_idx]); + + int x_idx = seq_pos * head_dim + dim_pair * 2; + int y_idx = seq_pos * head_dim + dim_pair * 2 + 1; + + auto x = static_cast(tensor[x_idx]); + auto y = static_cast(tensor[y_idx]); + + auto new_x = x * cos_val - y * sin_val; + auto new_y = x * sin_val + y * cos_val; + + tensor[x_idx] = static_cast(new_x); + tensor[y_idx] = static_cast(new_y); + } + + } + } + /// Initializes data structures template ProblemShapeType initialize(ProblemShape problem_shape_in) { @@ -279,10 +338,25 @@ struct TestbedImpl { block_V.reset(batch * num_heads_kv * seq_len_kv * head_size_vo); block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + block_ref_Q.reset(batch * num_heads_q * seq_len_qo * head_size_qk); + block_ref_K.reset(batch * num_heads_kv * seq_len_kv * head_size_qk); + + // Initialize RoPE tensors if enabled + if constexpr (rope_enabled) { + int max_seq_len = std::max(seq_len_qo, seq_len_kv); + block_cos.reset(max_seq_len * head_size_qk * num_heads_q * batch); + block_sin.reset(max_seq_len * head_size_qk * num_heads_q * batch); + initialize_rope_tensors(max_seq_len, head_size_qk, num_heads_q, batch); + } initialize_block(block_Q, seed + 2023); initialize_block(block_K, seed + 2022); initialize_block(block_V, seed + 2021); + compat::wait(); + // reference copy of Q and K for verification + compat::memcpy(block_ref_Q.get(), block_Q.get(), batch * num_heads_q * seq_len_qo * head_size_qk); + compat::memcpy(block_ref_K.get(), block_K.get(), batch * num_heads_kv * seq_len_kv * head_size_qk); + compat::wait(); if (!cumulative_seqlen_q.empty()) { device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); @@ -381,9 +455,11 @@ struct TestbedImpl { auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,5,6>(problem_size); int seq_len_qo, seq_len_kv; - auto block_Q_ = in_memory(block_Q); - auto block_K_ = in_memory(block_K); + auto block_Q_ = in_memory(block_ref_Q); + auto block_K_ = in_memory(block_ref_K); auto block_V_ = in_memory(block_V); + auto block_cos_ = in_memory(block_cos); + auto block_sin_ = in_memory(block_sin); using ElementV_ = cute::conditional_t, half_t, ElementV>; int offset_q = 0; @@ -411,6 +487,43 @@ struct TestbedImpl { cutlass::TensorRef ref_K(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); cutlass::TensorRef ref_V(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo})); cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + + // Apply RoPE to Q and K if enabled on host + // Currently RoPE is not supported for fp8. + if constexpr (rope_enabled && !is_fp8_v) { + cutlass::TensorRef ref_Q_cos(block_cos_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); + cutlass::TensorRef ref_Q_sin(block_sin_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); + cutlass::TensorRef ref_K_cos(block_cos_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); + cutlass::TensorRef ref_K_sin(block_sin_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); + + std::vector host_Q(seq_len_qo* head_size_qk); + std::vector host_K(head_size_qk* seq_len_kv); + std::vector host_Q_cos(seq_len_qo* head_size_qk); + std::vector host_Q_sin(seq_len_qo* head_size_qk); + std::vector host_K_cos(head_size_qk* seq_len_kv); + std::vector host_K_sin(head_size_qk* seq_len_kv); + + compat::wait(); + + compat::memcpy(host_Q.data(), ref_Q.data(), seq_len_qo* head_size_qk); + compat::memcpy(host_K.data(), ref_K.data(), head_size_qk* seq_len_kv); + compat::memcpy(host_Q_cos.data(), ref_Q_cos.data(), seq_len_qo* head_size_qk); + compat::memcpy(host_Q_sin.data(), ref_Q_sin.data(), seq_len_qo* head_size_qk); + compat::memcpy(host_K_cos.data(), ref_K_cos.data(), head_size_qk* seq_len_kv); + compat::memcpy(host_K_sin.data(), ref_K_sin.data(), head_size_qk* seq_len_kv); + compat::wait(); + + apply_rope_on_host(host_Q, seq_len_qo, head_size_qk, b, h, host_Q_cos, host_Q_sin); + apply_rope_on_host(host_K, seq_len_kv, head_size_qk, b, h, host_K_cos, host_K_sin); + compat::wait(); + + // Update tensor references to use RoPE-transformed tensors + ref_Q.reset(ref_Q.data(), LayoutQ::packed({seq_len_qo, head_size_qk})); + ref_K.reset(ref_K.data(), LayoutK::packed({head_size_qk, seq_len_kv})); + compat::memcpy(ref_Q.data(), host_Q.data(), seq_len_qo * head_size_qk * sizeof(ElementQ)); + compat::memcpy(ref_K.data(), host_K.data(), seq_len_kv * head_size_qk * sizeof(ElementK)); + compat::wait(); + } cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, ElementAccumulator{1}, ref_Q, cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, @@ -535,6 +648,15 @@ struct TestbedImpl { } compat::wait(); + std::vector host_ref_o(block_O.size()); + std::vector host_o(block_O.size()); + compat::wait(); + compat::memcpy(host_ref_o.data(), block_ref_O.get(), batch * num_heads_q * seq_len_qo * head_size_vo); + compat::memcpy(host_o.data(), block_O.get(), batch * num_heads_q * seq_len_qo * head_size_vo); + compat::wait(); + // for(int i = 0; i < host_o.size(); i++) { + // std::cout << "O[" << i << "] = " << host_o[i] << ", ref_O[" << i << "] = " << host_ref_o[i] << ", diff : " << (host_o[i] - host_ref_o[i]) << std::endl; + // } // Check if output from CUTLASS kernel and reference kernel are equal or not bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), @@ -576,10 +698,33 @@ struct TestbedImpl { // Initialize the Flash attention operator // cutlass::KernelHardwareInfo hw_info; + + // Prepare mainloop arguments with RoPE tensors if enabled + auto mainloop_args = [&]() { + if constexpr (rope_enabled) { + return typename FlashAttention::CollectiveMainloop::Arguments{ + block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_cos.get(), + block_sin.get(), + // stride_Q_cs, + // stride_K_cs, + }; + } else { + return typename FlashAttention::CollectiveMainloop::Arguments{ + block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V + }; + } + }(); + + typename FlashAttention::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {block_Q.get(), stride_Q, block_K.get(), stride_K, block_V.get(), stride_V}, + mainloop_args, {softmax_scale}, {block_O.get(), stride_O}, hw_info}; diff --git a/test/unit/flash_attention/flash_attention_prefill/xe_flash_prefill.cpp b/test/unit/flash_attention/flash_attention_prefill/xe_flash_prefill.cpp index 4a8005a948..f4c400d338 100644 --- a/test/unit/flash_attention/flash_attention_prefill/xe_flash_prefill.cpp +++ b/test/unit/flash_attention/flash_attention_prefill/xe_flash_prefill.cpp @@ -42,7 +42,13 @@ using Shape_h = test::flash_attention::SHAPE_H; TEST(TEST_NAME, causal) { using Kernel = test::flash_attention::XE_Flash_Attention_Prefill::Kernel; + typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout, MMAOperation, true, false, 2, false>::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashPrefillAll(HEAD_DIM)); +} + +TEST(TEST_NAME, causal_rope) { + using Kernel = test::flash_attention::XE_Flash_Attention_Prefill::Kernel; EXPECT_TRUE(test::flash_attention::TestFlashPrefillAll(HEAD_DIM)); } @@ -52,6 +58,12 @@ TEST(TEST_NAME, noncausal) { EXPECT_TRUE(test::flash_attention::TestFlashPrefillAll(HEAD_DIM)); } +TEST(TEST_NAME, noncausal_rope) { + using Kernel = test::flash_attention::XE_Flash_Attention_Prefill::Kernel; + EXPECT_TRUE(test::flash_attention::TestFlashPrefillAll(HEAD_DIM)); +} + TEST(TEST_NAME, varlen_causal) { using Kernel = test::flash_attention::XE_Flash_Attention_Prefill::Kernel;