Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "fmha_fusion.hpp"
#include "xe_rotary.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand All @@ -61,7 +62,7 @@ CUTLASS_DEVICE auto convert_type(Tensor<Engine, Layout> const &tensor) {

template <class DispatchPolicy, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_ = false>
struct FlashPrefillMma {
static_assert(cutlass::detail::dependent_false<ElementQ_>, "Could not find a mainloop specialization.");
};
Expand All @@ -70,9 +71,9 @@ struct FlashPrefillMma {

template <int Stages, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_>
struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, ElementQ_, StrideQ_, ElementK_, StrideK_, ElementV_,
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_> {
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_, RopeMask_> {
//
// Type Aliases
//
Expand All @@ -96,6 +97,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
using TiledMmaPV = typename TiledMMAHelper<MmaAtom, Layout<TileShapePV>, SubgroupLayout>::TiledMMA;
using ElementAccumulator = typename TiledMmaQK::ValTypeC;
static constexpr bool CausalMask = CausalMask_;
static constexpr bool rope_enabled = RopeMask_;
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

using MmaAtomShape = typename MmaAtom::Shape_MNK;
Expand Down Expand Up @@ -157,12 +159,19 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
StrideK dK;
ElementV const *ptr_V;
StrideV dV;
// for RoPE case
ElementQ const *ptr_cos = nullptr;
ElementQ const *ptr_sin = nullptr;
};

struct Params {
XE_Copy_Q gmem_tiled_copy_q;
XE_Copy_K gmem_tiled_copy_k;
XE_Copy_V gmem_tiled_copy_v;
XE_Copy_Q gmem_tiled_copy_q_cos;
XE_Copy_Q gmem_tiled_copy_q_sin;
XE_Copy_K gmem_tiled_copy_k_cos;
XE_Copy_K gmem_tiled_copy_k_sin;
};

//
Expand All @@ -180,18 +189,87 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
auto tensorQ = make_tensor(make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorK = make_tensor(make_gmem_ptr(args.ptr_K), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK));
auto tensorV = make_tensor(make_gmem_ptr(args.ptr_V), make_layout(make_shape(head_size_vo, seq_len_kv, batch * num_heads_kv), args.dV));

auto tensorCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dK));

XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};

return Params{copyQ, copyK, copyV};

XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorCos)};
XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorSin)};

XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorCos)};
XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorSin)};

return Params{copyQ, copyK, copyV, copyQCos, copyQSin, copyKCos, copyKSin};
}

template <class FragQccum, class TensorQ, class TensorK, class FragSrc>
template <class FragQccum, class TensorQ, class TensorK, class FragSrc, class TensorCosQ, class TensorSinQ, class TensorCosK, class TensorSinK>
CUTLASS_DEVICE void mmaQK(FragQccum &accum, TensorQ gQ, TensorK gK, FragSrc const &frag_src,
int const &k_tile_count, Params const &params) {
int const &k_tile_count, TensorCosQ gQCos, TensorSinQ gQSin, TensorCosK gKCos, TensorSinK gKSin, Params const &params, ProblemShapeType const &problem_shape) {


int thread_idx = static_cast<int>(ThreadIdxX());
if constexpr (rope_enabled) {
// calculate the base_ptr and offset for Q, K.
// also calculate the layout for Q, K.
// then apply RoPE on Q, K accordingly
auto [coord_q_x, coord_q_y, coord_q_z] = *gQ.data();
auto [coord_k_x, coord_k_y, coord_k_z] = *gK.data();

auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape;

int offset_q = seq_len_qo*head_size_qk*coord_q_z + head_size_qk*coord_q_x + coord_q_y; // row major
// int offset_k = seq_len_kv*head_size_qk*coord_k_z + head_size_qk*coord_k_y + coord_k_x; // col major
int offset_k = seq_len_kv*head_size_qk*coord_k_z + head_size_qk*coord_k_x + coord_k_y; // row major

auto q_traits = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q);
ElementQ* base_ptr_q = (ElementQ*)q_traits.base_ptr;

auto q_traits_cos = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_cos);
ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr;

auto q_traits_sin = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_sin);
ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr;

// auto layout_q = gQ.layout();
constexpr auto static_shape_q = make_shape(size<0>(gQ), size<1>(gQ));
// constexpr auto layout_q = LayoutQ::packed({size<0>(gQ), size<1>(gQ)});
constexpr auto layout_q = make_layout(static_shape_q, LayoutRight{});

auto k_traits = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k);
ElementK* base_ptr_k = (ElementK*)k_traits.base_ptr;

auto k_traits_cos = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_cos);
ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr;

auto k_traits_sin = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_sin);
ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr;

// auto layout_k = gK.layout();
constexpr auto static_shape_k = make_shape(size<0>(gK), size<1>(gK));
constexpr auto layout_k = make_layout(static_shape_k, LayoutRight{});

for (int i =0 ;i< size<2>(gQ) && thread_idx< size<0>(gQ); i++){
auto tensorQ = make_tensor(make_gmem_ptr(base_ptr_q+offset_q), layout_q);
auto tensorCosQ = make_tensor(make_gmem_ptr(base_ptr_q_cos+offset_q), layout_q);
auto tensorSinQ = make_tensor(make_gmem_ptr(base_ptr_q_sin+offset_q), layout_q);
cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorQ, tensorCosQ, tensorSinQ, tensorQ);
offset_q += QK_BLK_M*QK_BLK_K;
}

for (int i =0 ;i< size<2>(gK) && thread_idx< size<0>(gK); i++){
auto tensorK = make_tensor(make_gmem_ptr(base_ptr_k+offset_k), layout_k);
auto tensorCosK = make_tensor(make_gmem_ptr(base_ptr_k_cos+offset_k), layout_k);
auto tensorSinK = make_tensor(make_gmem_ptr(base_ptr_k_sin+offset_k), layout_k);
cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorK, tensorCosK, tensorSinK, tensorK);
offset_k += QK_BLK_N*QK_BLK_K;
}
}


auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx);
auto thr_copy_K = params.gmem_tiled_copy_k.get_slice(thread_idx);
// Instantiate the MMA object
Expand Down
89 changes: 89 additions & 0 deletions applications/flash_attention_v2/collective/xe_rotary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/***************************************************************************************************
* Copyright (c) 2025 Intel Corporation. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once

#include "cutlass/cutlass.h"

namespace cutlass::flash_attention::collective {
using namespace cute;


template <typename Tensor,
typename TensorCos, typename TensorSin, typename TensorOut>
CUTLASS_DEVICE void apply_rope_interleaved_gmem(
int thread_idx,
Tensor const &srcTensor,
TensorCos const &gCos,
TensorSin const &gSin, TensorOut &destTensor) {

// here based on the thread_idx, we will apply RoPE on the srcTensor and store the result in destTensor
// we will access row by thread_idx and then access col by loop
// we assume the input tensor is in row major format
if (cute::thread(1, 0)){
print("before apply_rope_interleaved_gmem\n");
cute::print_tensor(srcTensor);
}
syncthreads();
if(thread_idx < size<0>(srcTensor)){
for (int j = 0; j < size<1>(gCos); j+=2) {
float real = static_cast<float>(srcTensor[make_coord(thread_idx, j)]);
float imag = static_cast<float>(srcTensor[make_coord(thread_idx, j + 1)]);


float cos_val = static_cast<float>(gCos[make_coord(thread_idx, j)]);
float sin_val = static_cast<float>(gSin[make_coord(thread_idx, j)]);
// syncthreads();
float new_real = real * cos_val - imag * sin_val;
float new_imag = real * sin_val + imag * cos_val;
// syncthreads();
destTensor[make_coord(thread_idx,j)] = static_cast<typename Tensor::value_type>(new_real);
destTensor[make_coord(thread_idx,j + 1)] = static_cast<typename Tensor::value_type>(new_imag);

if (cute::thread(1, 0)){
#define PRINT(x) print(#x ": "); print(x); print("\n");
PRINT(thread_idx);
PRINT(j);
PRINT(real);
PRINT(imag);
PRINT(cos_val);
PRINT(sin_val);
PRINT(new_real);
PRINT(new_imag);
}
}
}
if (cute::thread(1, 0)){
print("after apply_rope_interleaved_gmem\n");
cute::print_tensor(destTensor);
}
}

} // namespace cutlass::flash_attention::collective
43 changes: 41 additions & 2 deletions applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class FMHAPrefill {
using AccumeShape = decltype(make_shape(Int<Vec>{}, Int<FragsM>{}, get<1>(TileShapePV{})/get<1>(MmaAtomShape()), Int<VSlicer>{}));

static constexpr bool is_var_len = CollectiveMainloop::is_var_len;
static constexpr bool rope_enabled = CollectiveMainloop::rope_enabled;

// Kernel level shared memory storage
struct SharedStorage {
Expand Down Expand Up @@ -272,10 +273,24 @@ class FMHAPrefill {
Tensor mK_nk = mK_nkl(_, _, blk_l_coord/group_heads_q); // (n,k)
Tensor mV_nk = mV_nkl(_, _, blk_l_coord/group_heads_q); // (n,k)

Tensor mCosQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads_q)); // (m, k, l)
Tensor mSinQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads_q)); // (m, k, l)
Tensor mCosK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_head_kv)); // (n, k, l)
Tensor mSinK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_head_kv)); // (n, k, l)
Tensor mCosQ_mk = mCosQ_mkl(_, _, blk_l_coord); // (m,k)
Tensor mSinQ_mk = mSinQ_mkl(_, _, blk_l_coord); // (m,k)
Tensor mCosK_nk = mCosK_nkl(_, _, blk_l_coord/group_heads_q); // (n,k)
Tensor mSinK_nk = mSinK_nkl(_, _, blk_l_coord/group_heads_q);

auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});
auto gV = local_tile(mV_nk, TileShapeOutput{}, make_coord(_, blk_n_coord, _), Step<X, _1, _1>{});

auto gCosQ = local_tile(mCosQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gSinQ = local_tile(mSinQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gCosK = local_tile(mCosK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});
auto gSinK = local_tile(mSinK_nk, TileShapeQK{}, make_coord(_, _ , _), Step<X, _1, _1>{});

auto mainloop_params = CollectiveMainloop::get_updated_copies(params.mainloop, params.problem_shape, sequence_length_shape, batch_coord);
// we limit the horisontal size to two subgroup, the empirical resutls show that reading the two cacheline side by side in gives better performance and
// anything after that does not have an effect on performance. // (64 here for float b float when possible and loop over to cover all the data needed)
Expand All @@ -289,6 +304,12 @@ class FMHAPrefill {
auto pKgK = thr_prefetch_K.partition_S(gK);
auto pVgV = thr_prefetch_V.partition_S(gV);

// RoPE coordinate tensor partitions
auto pCosQgCosQ = thr_prefetch_Q.partition_S(gCosQ);
auto pSinQgSinQ = thr_prefetch_Q.partition_S(gSinQ);
auto pCosKgCosK = thr_prefetch_K.partition_S(gCosK);
auto pSinKgSinK = thr_prefetch_K.partition_S(gSinK);

for (int i = 0; i < size<3>(pQgQ); i++) {
prefetch(tiled_prefetch_q, pQgQ(_, _, _, i));
}
Expand All @@ -299,6 +320,18 @@ class FMHAPrefill {
}
}

for (int i = 0; i < size<3>(pQgQ); i++) {
prefetch(tiled_prefetch_q, pCosQgCosQ(_, _, _, i));
prefetch(tiled_prefetch_q, pSinQgSinQ(_, _, _, i));
}
for (int j = 0; j < size<4>(pKgK); j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < DispatchPolicy::Stages; i++) {
prefetch(tiled_prefetch_k, pCosKgCosK(_, _, _ , i, j));
prefetch(tiled_prefetch_k, pSinKgSinK(_, _, _ , i, j));
}
}

// Allocate the tiled_mma and the accumulators for the (M,N) workgroup_shape
Tensor out_reg = make_tensor<ElementAccumulator>(AccumeShape{});

Expand All @@ -325,7 +358,7 @@ class FMHAPrefill {
clear(tSr);

// 3) Perform GEMM S = Q*K
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params);
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock, _), tSr, ceil_div(head_size_qk, QK_BLK_K), gCosQ, gSinQ, gCosK(_, _, nblock, _), gSinK(_, _, nblock, _), mainloop_params, params.problem_shape);

// we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big,
// prefetching it the same way as cutlass K matrix does not make sense
Expand All @@ -343,6 +376,12 @@ class FMHAPrefill {
for (int j = 0; j < size<4>(pKgK); j++) {
prefetch(tiled_prefetch_k, pKgK(_, _, _, nblock + DispatchPolicy::Stages, j));
}

for (int j = 0; j < size<4>(pKgK); j++) {
prefetch(tiled_prefetch_k, pCosKgCosK(_, _, _, nblock + DispatchPolicy::Stages, j));
prefetch(tiled_prefetch_k, pSinKgSinK(_, _, _, nblock + DispatchPolicy::Stages, j));
}

barrier_wait(barrier_scope);
}

Expand All @@ -351,7 +390,7 @@ class FMHAPrefill {
Tensor tSr = make_tensor<ElementAccumulator>(Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});
clear(tSr);
// 3) Perform GEMM S = Q*K
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock_limit - 1, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params);
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock_limit - 1, _), tSr, ceil_div(head_size_qk, QK_BLK_K), gCosQ, gSinQ, gCosK(_, _, nblock_limit - 1, _), gSinK(_, _, nblock_limit - 1, _), mainloop_params, params.problem_shape);
// we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big,
// prefetching it the same way as cutlass K matrix does not make sense
for(int i=0; i< size<1>(pVgV); i++) {
Expand Down
Loading
Loading