Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@ template <class DispatchPolicy, class MMAOperation_, class TileShapeOutput_, cla
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
};

template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementCompute_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementCompute_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
public:
//
// Type Aliases
//
using DispatchPolicy = epilogue::IntelXeXMX16;
using ElementO = ElementO_;
using ElementAccumulator = ElementO_;
using StrideO = StrideO_;
using ElementLSE = ElementLSE_;
using CopyOpO = CopyOpO_;
Expand All @@ -70,7 +69,8 @@ class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileSha
using TiledMmaOutput = typename TiledMMAHelper<MMA_Atom<MMAOperation_>, Layout<TileShapeOutput>, SubgroupLayout>::TiledMMA;
using GmemTiledCopyO = CopyOpO;
using ElementOutput = ElementO_;
using ElementCompute = ElementO_;
using ElementCompute = ElementCompute_;
using ElementAccumulator = ElementCompute_;
using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape())));

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
Expand Down Expand Up @@ -197,7 +197,17 @@ class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileSha
auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX());
Tensor tOgO = thread_xe_store_o.partition_D(gO);

copy(params.xe_store_o, out_reg, tOgO);
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in #443, I think a comment explaining this if/else would be useful.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comment after this line.

// iff ElementOutput == ElementAccumulator, then convert_type doesn't do the right conversion
// so we call copy() which internally performs a static_cast op on the data.
// for ElementOutput == bf16 | fp16, convert_type calls relevant NumericConverter specialization.
if constexpr (cute::is_same_v<ElementOutput, ElementCompute>) {
copy(out_reg, final_out_reg);
} else {
Tensor temp = convert_type<ElementOutput>(out_reg);
copy(temp, final_out_reg);
}
copy(params.xe_store_o, final_out_reg, tOgO);
}

// SequenceLengthShapeType = Shape<int, int, int>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
// Paged KV Cache
int const* ptr_page_table;
int page_size;
int num_pages_per_seq;
int const* num_pages_per_seq;
};

struct Params {
Expand All @@ -176,7 +176,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
// Paged KV Cache
int const* ptr_page_table;
int page_size;
int num_pages_per_seq;
int const* num_pages_per_seq;
};

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class FMHAPrefillCached {
static bool can_implement(Arguments const &args) {
bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or
(args.mode == gemm::GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
return mode_implementable;
bool valid_page_size = !PagedKV || (args.mainloop.page_size >= QK_BLK_N && args.mainloop.page_size % QK_BLK_N == 0);
return mode_implementable && valid_page_size;
}

static int get_workspace_size(Arguments const &args) { return 0; }
Expand Down Expand Up @@ -314,10 +315,22 @@ class FMHAPrefillCached {
}
auto& prefetch_K = (seq_len_kv_cache == 0) ? tiled_prefetch_k: tiled_prefetch_k_cache;
auto& pKgK1_ = (seq_len_kv_cache == 0) ? pKgK: pKgK_cache;

int cached_nblock = 0;
if constexpr (PagedKV) {
if (seq_len_kv_cache != 0) {
int curr_batch_pages = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord + 1] - mainloop_params.num_pages_per_seq[batch_coord]
: ceil_div(seq_len_kv_cache, mainloop_params.page_size);
int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages;
cached_nblock = mainloop_params.ptr_page_table[
batch_offset // page table for this batch
] * tiles_per_page; // base block idx of physical page
}
}
// The headsize for both cached and non-cached version is the same
for (int j = 0; j < size<4>(pKgK1_); j++) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < DispatchPolicy::Stages; i++) {
for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; i++) {
prefetch(prefetch_K, pKgK1_(_, _, _ , i, j));
}
}
Expand Down Expand Up @@ -345,18 +358,6 @@ class FMHAPrefillCached {

bool is_KV_cache = nblock < nblock_cache;

int cached_nblock = nblock;
if constexpr (PagedKV) {
if (is_KV_cache) {
// get physical page idx from page table
cached_nblock = params.mainloop.ptr_page_table[
batch_coord * params.mainloop.num_pages_per_seq + // page table for this batch
nblock * QK_BLK_N / params.mainloop.page_size // nblock (tile idx) to logical page idx
] * tiles_per_page + // base block idx of physical page
nblock % tiles_per_page; // offset within page
}
}

// 1) Load KV (performed inside mmaQK)
auto gK_ = is_KV_cache ? gK_cache(_, _, cached_nblock, _) : gK(_, _, nblock - nblock_cache, _);
auto gV_ = is_KV_cache ? gV_cache(_, _, cached_nblock) : gV(_, _, nblock - nblock_cache);
Expand All @@ -372,8 +373,32 @@ class FMHAPrefillCached {
// prefetching it the same way as cutlass K matrix does not make sense
auto& tiled_prefetch_v_ = is_KV_cache ? tiled_prefetch_v_cache : tiled_prefetch_v;
auto& pVgV_ = is_KV_cache ? pVgV_cache : pVgV;
for(int i=0; i < size<1>(pVgV); i++) {
prefetch(tiled_prefetch_v_, pVgV_cache(_, i, _ , nblock - (!is_KV_cache) * nblock_cache));
int v_prefetch_idx = is_KV_cache ? PagedKV ? cached_nblock : nblock
: nblock - nblock_cache;
for(int i = 0; i < size<1>(pVgV_); i++) {
prefetch(tiled_prefetch_v_, pVgV_(_, i, _ , v_prefetch_idx));
}

int next_cached_nblock = nblock + 1;
bool is_next_KV_cache = next_cached_nblock < nblock_cache;
if constexpr (PagedKV) {
if (is_next_KV_cache) {
int curr_batch_pages = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord + 1] - mainloop_params.num_pages_per_seq[batch_coord]
: ceil_div(seq_len_kv_cache, mainloop_params.page_size);
int next_page_logical_idx = next_cached_nblock * QK_BLK_N / params.mainloop.page_size;
int batch_offset = is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] : batch_coord * curr_batch_pages;
bool valid_page = next_page_logical_idx < curr_batch_pages;
// get physical page idx from page table
if (valid_page) {
next_cached_nblock = params.mainloop.ptr_page_table[
batch_offset + // page table for this batch
next_page_logical_idx // nblock (tile idx) to logical page idx
] * tiles_per_page + // base block idx of physical page
next_cached_nblock % tiles_per_page; // offset within page
} else {
next_cached_nblock = curr_batch_pages * tiles_per_page; // push idx out of bounds to respect the boundary between batches
}
}
}

// 4) Fused softmax
Expand All @@ -382,16 +407,26 @@ class FMHAPrefillCached {

// 5) Perform GEMM O = S*V
collective_mma.template mmaPV<VSlicer>(out_reg, tSr, gV_, out_reg, mainloop_params, is_KV_cache);


// Prefetch the next Q tile
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<3>(pQgQ); i++) {
prefetch(tiled_prefetch_q, pQgQ(_, _, _, i));
}

is_KV_cache = is_next_KV_cache;
cached_nblock = next_cached_nblock;
// Prefetch the next K tile
// there is no need to gaurd it with if statememt as prefetch will ignore out of bound reading

bool sel_prefetch_k = (nblock + DispatchPolicy::Stages) < nblock_cache;
auto& prefetch_k_selector = sel_prefetch_k ? tiled_prefetch_k_cache: tiled_prefetch_k;
auto& pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK;
int k_prefetch_idx = sel_prefetch_k ? PagedKV ? cached_nblock : nblock + DispatchPolicy::Stages
: nblock + DispatchPolicy::Stages - nblock_cache;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size<4>(pKgK_); j++) {
prefetch(prefetch_k_selector, pKgK_(_, _, _, (nblock + DispatchPolicy::Stages) - (!sel_prefetch_k) * nblock_cache , j));
prefetch(prefetch_k_selector, pKgK_(_, _, _, k_prefetch_idx , j));
}
barrier_wait(barrier_scope);
}
Expand All @@ -406,8 +441,8 @@ class FMHAPrefillCached {
collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock_new - 1, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, false);
// we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big,
// prefetching it the same way as cutlass K matrix does not make sense
for(int i=0; i< size<1>(pVgV); i++) {
prefetch(tiled_prefetch_v, pVgV(_, i, _ , nblock_new - 1));
for(int i = 0; i< size<1>(pVgV); i++) {
prefetch(tiled_prefetch_v, pVgV(_, i, _ , nblock_new - 1));
}
// mask the elements of each tile where j > i
const int item_id = thread_idx % SubgroupSize;
Expand All @@ -420,7 +455,7 @@ class FMHAPrefillCached {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < Vec; row++, row_idx++) { // 8
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
tSr(row, m, n) = -INFINITY;
tSr(row, m, n) = ElementAccumulator{-INFINITY};
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
int kv_group_update=1;

for (int h = 0; h < num_heads_q; h++) {
cutlass::DeviceAllocation<ElementOutput> block_S;
cutlass::DeviceAllocation<ElementAccumulator> block_S;
block_S.reset(seq_len_qo * seq_len_kv_total);

ElementK* k_ptr;
Expand Down Expand Up @@ -254,11 +254,10 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total}));
cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({seq_len_kv_total, head_size_vo}));
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total}));
cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo}));

cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, 1.f, ref_Q,
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, ElementAccumulator{1}, ref_Q,
cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone,
0.f, ref_S, ref_S, ElementAccumulator(0),
ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0},
1, // batch_count
seq_len_qo * head_size_qk, // batch_stride_Q
seq_len_kv_total * head_size_qk, // batch_stride_K
Expand All @@ -268,8 +267,8 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {

syclcompat::wait();

std::vector<ElementOutput> host_S(block_S.size());
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
std::vector<ElementAccumulator> host_S(block_S.size());
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());
syclcompat::wait();

// delete this memory as it is no longer needed
Expand All @@ -283,13 +282,13 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
for (int row = 0; row < seq_len_qo; row++) {
for (int col = start_col; col < seq_len_kv_total; col++) {
if (col - full_tile_offset > row + start_col - discard_seq_coord)
host_S[col + row * seq_len_kv_total] = -INFINITY;
host_S[col + row * seq_len_kv_total] = ElementAccumulator{-INFINITY};
}
}
}

// compute max element per row of S
std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY);
std::vector<ElementAccumulator> max_vec(seq_len_qo, ElementAccumulator{-INFINITY});
for (int row = 0; row < seq_len_qo; row++) {
int idx = row * seq_len_kv_total;
int max_idx = row;
Expand All @@ -305,12 +304,12 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
int idx = row * seq_len_kv_total;
int max_idx = row;
for (int col = 0; col < seq_len_kv_total; col++, idx++) {
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementOutput>((head_size_qk))));
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementAccumulator>((head_size_qk))));
}
}

// compute sum per row of S
std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0});
std::vector<ElementAccumulator> sum_vec(seq_len_qo, ElementAccumulator{0});
for (int row = 0; row < seq_len_qo; row++) {
int idx = row * seq_len_kv_total;
int sum_idx = row;
Expand Down Expand Up @@ -342,9 +341,13 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {

cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total}));

cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, 1.f, ref_P,
cutlass::DeviceAllocation<ElementAccumulator> block_acc;
block_acc.reset(seq_len_qo * head_size_vo);
cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo}));

cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, ElementAccumulator{1}, ref_P,
cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone,
0.f, ref_O, ref_O, ElementAccumulator(0),
ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0},
1, // batch_count
seq_len_qo * seq_len_kv_total, // batch_stride_P
seq_len_kv_total * head_size_vo, // batch_stride_V
Expand All @@ -356,6 +359,19 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
// delete this memory as it is no longer needed
block_P.reset();

std::vector<ElementAccumulator> vec_acc(block_acc.size());
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
syclcompat::wait();

// delete this memory as it is no longer needed
block_acc.reset();
std::vector<ElementOutput> vec_out(vec_acc.size());
for(int i = 0; i < vec_out.size(); i++) {
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
}
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());
syclcompat::wait();

offset_q += seq_len_qo * head_size_qk;
if(kv_group_update % q_group_size==0) {
offset_k += seq_len_kv * head_size_qk;
Expand All @@ -372,7 +388,7 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {

// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
block_O.size(), 0.5f, 0.5f);
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});

return passed;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct FMHAPrefillConfig {
using MMAOperation = typename MMAOP<GEMMDispatchPolicy, ElementInputType,ElementAccumulator>::Type;
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillCachedEpilogue<
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput,
SubgroupLayout, ElementAccumulator,
SubgroupLayout, ElementAccumulator, ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
GmemTiledCopyO>;

Expand Down
Loading
Loading