Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "fmha_fusion.hpp"
#include "xe_rotary.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand All @@ -62,7 +63,7 @@ CUTLASS_DEVICE auto convert_type(Tensor<Engine, Layout> const &tensor) {

template <class DispatchPolicy, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_ = false>
struct FlashPrefillMma {
static_assert(cutlass::detail::dependent_false<ElementQ_>, "Could not find a mainloop specialization.");
};
Expand All @@ -71,9 +72,9 @@ struct FlashPrefillMma {

template <int Stages, class ProblemShapeType_, class ElementQ_, class StrideQ_, class ElementK_, class StrideK_,
class ElementV_, class StrideV_, class MMAOperation_, class TileShapeQK_, class TileShapePV_, class SubgroupLayout_, class GmemTiledCopyQ_, class GmemTiledCopyK_,
class GmemTiledCopyV_, bool CausalMask_>
class GmemTiledCopyV_, bool CausalMask_, bool RopeMask_>
struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, ElementQ_, StrideQ_, ElementK_, StrideK_, ElementV_,
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_> {
StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_, RopeMask_> {
//
// Type Aliases
//
Expand All @@ -97,6 +98,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
using TiledMmaPV = typename TiledMMAHelper<MmaAtom, Layout<TileShapePV>, SubgroupLayout>::TiledMMA;
using ElementAccumulator = typename TiledMmaQK::ValTypeC;
static constexpr bool CausalMask = CausalMask_;
static constexpr bool rope_enabled = RopeMask_;
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

using MmaAtomShape = typename MmaAtom::Shape_MNK;
Expand Down Expand Up @@ -158,12 +160,19 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
StrideK dK;
ElementV const *ptr_V;
StrideV dV;
// for RoPE case
ElementQ const *ptr_cos = nullptr;
ElementQ const *ptr_sin = nullptr;
};

struct Params {
XE_Copy_Q gmem_tiled_copy_q;
XE_Copy_K gmem_tiled_copy_k;
XE_Copy_V gmem_tiled_copy_v;
XE_Copy_Q gmem_tiled_copy_q_cos;
XE_Copy_Q gmem_tiled_copy_q_sin;
XE_Copy_K gmem_tiled_copy_k_cos;
XE_Copy_K gmem_tiled_copy_k_sin;
};

//
Expand All @@ -181,11 +190,21 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
auto tensorQ = make_tensor(make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorK = make_tensor(make_gmem_ptr(args.ptr_K), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK));
auto tensorV = make_tensor(make_gmem_ptr(args.ptr_V), make_layout(make_shape(head_size_vo, seq_len_kv, batch * num_heads_kv), args.dV));

auto tensorQCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorQSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorKCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK));
auto tensorKSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK));

XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};

return Params{copyQ, copyK, copyV};
XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)};
XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)};
XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)};
XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)};

return Params{copyQ, copyK, copyV, copyQCos, copyQSin, copyKCos, copyKSin};
}

template <class FragQccum, class TensorQ, class TensorK, class FragSrc>
Expand Down Expand Up @@ -372,11 +391,32 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), make_layout(shape_k, stride_k));
auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), make_layout(shape_v, stride_v));

auto q_traits_cos = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_cos);
ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr;

auto q_traits_sin = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_sin);
ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr;

auto k_traits_cos = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_cos);
ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr;

auto k_traits_sin = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_sin);
ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr;

auto tensorQCos = make_tensor(make_gmem_ptr(base_ptr_q_cos + offset_q), make_layout(shape_q, stride_q));
auto tensorQSin = make_tensor(make_gmem_ptr(base_ptr_q_sin + offset_q), make_layout(shape_q, stride_q));
auto tensorKCos = make_tensor(make_gmem_ptr(base_ptr_k_cos + offset_k), make_layout(shape_k, stride_k));
auto tensorKSin = make_tensor(make_gmem_ptr(base_ptr_k_sin + offset_k), make_layout(shape_k, stride_k));

XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};
XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)};
XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)};
XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)};
XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)};

return Params{copyQ, copyK, copyV};
return Params{copyQ, copyK, copyV, copyQCos, copyQSin, copyKCos, copyKSin};
}
}
};
Expand Down
63 changes: 63 additions & 0 deletions applications/flash_attention_v2/collective/xe_rotary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/***************************************************************************************************
* Copyright (c) 2025 Intel Corporation. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once

#include "cutlass/cutlass.h"

namespace cutlass::flash_attention::collective {
using namespace cute;

template <typename Tensor,
typename TensorCos, typename TensorSin, typename TensorOut>
CUTLASS_DEVICE void apply_rope_interleaved_gmem(
int thread_idx,
Tensor const &srcTensor,
TensorCos const &gCos,
TensorSin const &gSin, TensorOut &destTensor) {
if(thread_idx < size<0>(srcTensor)){
for (int j = 0; j < size<1>(gCos); j+=2) {
auto real = static_cast<float>(srcTensor[make_coord(thread_idx, j)]);
auto imag = static_cast<float>(srcTensor[make_coord(thread_idx, j + 1)]);
auto cos_val = static_cast<float>(gCos[make_coord(thread_idx, j)]);
auto sin_val = static_cast<float>(gSin[make_coord(thread_idx, j)]);

auto new_real = real * cos_val - imag * sin_val;
auto new_imag = real * sin_val + imag * cos_val;

destTensor[make_coord(thread_idx,j)] = static_cast<typename Tensor::value_type>(new_real);
destTensor[make_coord(thread_idx,j + 1)] = static_cast<typename Tensor::value_type>(new_imag);
}
}
syncthreads();
}


} // namespace cutlass::flash_attention::collective
Loading