Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/fp8_to_fp16.h"
#include "cutlass/gemm/dispatch_policy.hpp"

#include "cute/algorithm/functional.hpp"
Expand Down Expand Up @@ -147,6 +148,8 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
using atom_load_V = Copy_Atom<traits_load_V, ElementV>;
using val_layout_load_V = decltype(make_layout(shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{})));
using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{}, Layout<CopyThreadShape>{}, val_layout_load_V{}));
template <typename T>
static constexpr bool is_fp8_v = cute::is_any_of_v<T, float_e4m3_t, float_e5m2_t>;

// Host side kernel arguments
struct Arguments {
Expand Down Expand Up @@ -227,10 +230,11 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
Tensor tCgK = thread_mma_k.partition_B(gK);

// Create fragments
// TODO(Codeplay): fix this, this is probably not general
Tensor tCrQ = make_tensor<ElementQ>(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape())));
Tensor tCrK = make_tensor<ElementK>(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape())));

using TCrQ_Type = cute::conditional_t<is_fp8_v<ElementQ>, uint8_t, ElementQ>;
using TCrK_Type = cute::conditional_t<is_fp8_v<ElementK>, uint8_t, ElementK>;
Tensor tCrQ = make_tensor<TCrQ_Type>(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape())));
Tensor tCrK = make_tensor<TCrK_Type>(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape())));

// Retile registers for copies
Tensor tQrQ = thr_copy_Q.retile_D(tCrQ);
Tensor tKrK = thr_copy_K.retile_D(tCrK);
Expand Down Expand Up @@ -270,7 +274,23 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) {
copy(params.gmem_tiled_copy_q, tQgQ(_,_,_,k_tile), tQrQ);
copy(gmem_tiled_copy_k, tKgK(_,_,_,k_tile), tKrK);
cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src);
if constexpr (is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
auto tCrQ_fp16 = make_fragment_like<half_t>(tCrQ);
convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_fp16);
auto tCrK_fp16 = make_fragment_like<half_t>(tCrK);
convert_FP8_to_FP16<ElementK>(tCrK, tCrK_fp16);
cute::gemm(tiled_mma, accum, tCrQ_fp16, tCrK_fp16, frag_src);
} else if constexpr (is_fp8_v<ElementQ> && !is_fp8_v<ElementK>) {
auto tCrQ_fp16 = make_fragment_like<half_t>(tCrQ);
convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_fp16);
cute::gemm(tiled_mma, accum, tCrQ_fp16 , tCrK, frag_src);
} else if constexpr (!is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
auto tCrK_fp16 = make_fragment_like<half_t>(tCrK);
convert_FP8_to_FP16<ElementK>(tCrK, tCrK_fp16);
cute::gemm(tiled_mma, accum, tCrQ , tCrK_fp16, frag_src);
} else {
cute::gemm(tiled_mma, accum, tCrQ , tCrK, frag_src);
}
}
}

Expand All @@ -289,7 +309,8 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
Tensor tCgV = thread_mma.partition_B(gV_);
Tensor tCrV = make_tensor<ElementV>(make_fragment_layout(gmem_tiled_copy_v, take<0,3>(tCgV.shape())));
using TCrV_Type = cute::conditional_t<is_fp8_v<ElementV>, uint8_t, ElementV>;
Tensor tCrV = make_tensor<TCrV_Type>(make_fragment_layout(gmem_tiled_copy_v, take<0,3>(tCgV.shape())));

// Partition the copying of A and B tiles across the threads
auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx);
Expand Down Expand Up @@ -321,7 +342,13 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i< tile_count; i++) {
copy(gmem_tiled_copy_v, tVgV(_,_,_,i), tVrV);
cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV, frag_src(_,_,_,i));
if constexpr (is_fp8_v<ElementV>) {
auto tCrV_fp16 = make_fragment_like<half_t>(tCrV);
convert_FP8_to_FP16<ElementV>(tCrV, tCrV_fp16);
cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV_fp16, frag_src(_,_,_,i));
} else {
cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV, frag_src(_,_,_,i));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Flash Attention V2 Prefill for Intel BMG

This example constructs and executes a Flash Attention Prefill with KV cache on Intel BMG. The
definition of the GEMM, options etc for this example are defined in the associated
bmg_flash_attn_cachedKV_runner.hpp header file.

See https://arxiv.org/pdf/2307.08691 for details of Flash Attention V2 algorithm

To run this example:
$ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV --seq_len_qo=512
--seq_len_kv=512 --seq_len_kv_cache=512 --head_size_vo=128 --head_size_qk=128

Causal masking of the first matrix multiplication is supported (`--is_causal`)

To build & run this example (from your build dir):

$ ninja 06_bmg_prefill_attention_cachedKV
$ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV

Call with `--help` for information about available options
*/

#include "bmg_flash_attn_prefill_cachedKV_runner.hpp"

int main(int argc, const char **argv) {
//
// Parse options
//

Options options;

options.parse(argc, argv);

if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}

if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}

// Define the work-group tile shape depending on the head-size of the second matmul
// Shape<_SequenceLenthOutputBLOCK, _HeadSizeout(NV), SequenceLengthKVBLOCK_KN/KV, HeadSizeQKBLOCK_KQK, HEADSIZEOutSlicerBlock>
//
#if !defined(HEAD_DIM)
std::cerr << "HEAD_DIM must be defined" << std::endl;
return -1;
#endif
if (options.head_size_vo != HEAD_DIM) {
std::cerr << "head_size_vo must be " << HEAD_DIM << ", but got " << options.head_size_vo << std::endl;
return -1;
}

using ElementInputQ = cutlass::float_e5m2_t; // <- data type of elements in input matrix A
using ElementInputKV = cutlass::float_e5m2_t; // <- data type of elements in input matrix B
using MMAOperation = XE_8x16x16_F32F16F16F32_TT;
using GmemTiledCopyQ = XE_2D_U8x8x32_LD_N;
using GmemTiledCopyK = XE_2D_U8x16x16_LD_T; // _T designates a transposed block load operation
using GmemTiledCopyV = XE_2D_U8x32x32_LD_V;
constexpr int PipelineStages = 2;
#if HEAD_DIM == 64
using ShapeQK = Shape<_128, _64, _64>;
using ShapePV = Shape<_128, _32, _64>;
using ShapeOutPut = Shape<_128, _64, _64>;
using SubgroupLayout = Layout<Shape<_8, _1, _1>, Stride<_1, _1, _1>>;
#elif HEAD_DIM == 96
using ShapeQK = Shape<_128, _64, _32>;
using ShapePV = Shape<_128, _32, _64>;
using ShapeOutPut = Shape<_128, _96, _64>;
using SubgroupLayout = Layout<Shape<_8, _1, _1>, Stride<_1, _1, _1>>;
#elif HEAD_DIM == 128
using ShapeQK = Shape<_128, _64, _64>;
using ShapePV = Shape<_128, _32, _64>;
using ShapeOutPut = Shape<_128, _128, _64>;
using SubgroupLayout = Layout<Shape<_16, _1, _1>, Stride<_1, _1, _1>>;
#elif HEAD_DIM == 192
using ShapeQK = Shape<_256, _64, _64>;
using ShapePV = Shape<_256, _32, _64>;
using ShapeOutPut = Shape<_256, _192, _64>;
using SubgroupLayout = Layout<Shape<_32, _1, _1>, Stride<_1, _1, _1>>;
#endif
return options.is_causal ? FMHAConfig<true, ShapeQK, ShapePV, ShapeOutPut, SubgroupLayout, PipelineStages,
ElementInputQ, ElementInputKV, MMAOperation,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV>::run(options)
: FMHAConfig<false, ShapeQK, ShapePV, ShapeOutPut, SubgroupLayout, PipelineStages,
ElementInputQ, ElementInputKV, MMAOperation,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV>::run(options);

}
10 changes: 10 additions & 0 deletions examples/06_bmg_flash_attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,21 @@ foreach(HEAD_DIM 64 96 128 192)
TEST_NO_PAGED
TEST_PAGED
)

cutlass_example_add_executable(
06_bmg_prefill_attention_fp8_hdim${HEAD_DIM}
06_bmg_prefill_attention_fp8.cpp
TEST_COMMAND_OPTIONS
)

cutlass_example_add_executable(
06_bmg_prefill_attention_cachedKV_fp8_hdim${HEAD_DIM}
06_bmg_prefill_attention_prefill_cachedKV_fp8.cpp
TEST_COMMAND_OPTIONS
TEST_NO_PAGED
TEST_PAGED
)

cutlass_example_add_executable(
06_bmg_decode_attention_fp8_hdim${HEAD_DIM}
06_bmg_decode_attention_fp8.cpp
Expand All @@ -71,5 +80,6 @@ foreach(HEAD_DIM 64 96 128 192)
target_compile_definitions(06_bmg_prefill_attention_cachedKV_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM})
target_compile_definitions(06_bmg_decode_attention_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM})
target_compile_definitions(06_bmg_prefill_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM})
target_compile_definitions(06_bmg_prefill_attention_cachedKV_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM})
target_compile_definitions(06_bmg_decode_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM})
endforeach()
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,27 @@ template <class FMHAPrefillCachedKernel, bool isVarLen> struct ExampleRunner {
};
PagedKVParams paged_kv_cache;

template <typename SrcT, typename DstT>
void convert_fp8_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) {
syclcompat::get_default_queue().parallel_for(size, [=](auto indx) {
d_dst[indx] = static_cast<DstT>(d_src[indx]);
}).wait();
}

template <typename T>
static constexpr bool is_fp8_v = cute::is_any_of_v<T, cute::float_e5m2_t, cute::float_e4m3_t>;

template <typename Tin> inline auto in_memory(cutlass::DeviceAllocation<Tin>& in) {
using outType = cutlass::DeviceAllocation<cute::conditional_t<is_fp8_v<Tin>, half_t, Tin>>;
if constexpr(is_fp8_v<Tin>) {
cutlass::DeviceAllocation<half_t> out(in.size());
convert_fp8_to_fp16<Tin, half_t>(in.get(), out.get(), in.size());
return out;
} else {
return in;
};
}

//
// Methods
//
Expand All @@ -221,6 +242,11 @@ template <class FMHAPrefillCachedKernel, bool isVarLen> struct ExampleRunner {
auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size);
int seq_len_qo, seq_len_kv, seq_len_kv_cache;

auto block_Q_ = in_memory(block_Q);
auto block_K_ = in_memory(block_K);
auto block_V_ = in_memory(block_V);
using ElementK_ = cute::conditional_t<is_fp8_v<ElementK>, half_t, ElementK>;
using ElementV_ = cute::conditional_t<is_fp8_v<ElementV>, half_t, ElementV>;
int offset_q = 0;
int offset_k = 0;
int offset_v = 0;
Expand Down Expand Up @@ -248,45 +274,47 @@ template <class FMHAPrefillCachedKernel, bool isVarLen> struct ExampleRunner {
cutlass::DeviceAllocation<ElementAccumulator> block_S;
block_S.reset(seq_len_qo * seq_len_kv_total);

ElementK* k_ptr;
ElementV* v_ptr;
ElementK_* k_ptr;
ElementV_* v_ptr;

if (use_kv_cache) {
cutlass::DeviceAllocation<ElementK> block_K_concat(head_size_qk * seq_len_kv_total);
cutlass::DeviceAllocation<ElementV> block_V_concat(seq_len_kv_total * head_size_vo);
auto block_K_cache_ = in_memory(block_K_cache);
auto block_V_cache_ = in_memory(block_V_cache);
cutlass::DeviceAllocation<ElementK_> block_K_concat(head_size_qk * seq_len_kv_total);
cutlass::DeviceAllocation<ElementV_> block_V_concat(seq_len_kv_total * head_size_vo);

// Concatenate K_cache and K
syclcompat::memcpy<ElementK>(
syclcompat::memcpy<ElementK_>(
block_K_concat.get(),
block_K_cache.get() + offset_k_cache,
block_K_cache_.get() + offset_k_cache,
seq_len_kv_cache * head_size_qk
);
syclcompat::memcpy<ElementK>(
syclcompat::memcpy<ElementK_>(
block_K_concat.get() + seq_len_kv_cache * head_size_qk,
block_K.get() + offset_k,
block_K_.get() + offset_k,
seq_len_kv * head_size_qk
);

// Concatenate V_cache and V
syclcompat::memcpy<ElementV>(
syclcompat::memcpy<ElementV_>(
block_V_concat.get(),
block_V_cache.get() + offset_v_cache,
block_V_cache_.get() + offset_v_cache,
seq_len_kv_cache * head_size_vo
);
syclcompat::memcpy<ElementV>(
syclcompat::memcpy<ElementV_>(
block_V_concat.get() + seq_len_kv_cache * head_size_vo,
block_V.get() + offset_v,
block_V_.get() + offset_v,
seq_len_kv * head_size_vo
);

k_ptr = block_K_concat.get();
v_ptr = block_V_concat.get();
} else {
k_ptr = block_K.get() + offset_k;
v_ptr = block_V.get() + offset_v;
k_ptr = block_K_.get() + offset_k;
v_ptr = block_V_.get() + offset_v;
}

cutlass::TensorRef ref_Q(block_Q.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk}));
cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk}));
cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total}));
cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({seq_len_kv_total, head_size_vo}));
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total}));
Expand Down Expand Up @@ -364,14 +392,14 @@ template <class FMHAPrefillCachedKernel, bool isVarLen> struct ExampleRunner {
}
}

std::vector<ElementV> host_P(host_S.size());
std::vector<ElementV_> host_P(host_S.size());
for (int p = 0; p < host_P.size(); p++)
host_P[p] = static_cast<ElementV>(host_S[p]);
host_P[p] = static_cast<ElementV_>(host_S[p]);

cutlass::DeviceAllocation<ElementV> block_P;
cutlass::DeviceAllocation<ElementV_> block_P;
block_P.reset(host_P.size());

syclcompat::memcpy<ElementV>(block_P.get(), host_P.data(), host_P.size());
syclcompat::memcpy<ElementV_>(block_P.get(), host_P.data(), host_P.size());

cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total}));

Expand Down
Loading