Skip to content

Conversation

muhammad-tanvir-1211
Copy link

This PR separates the output type and accumulator type for Flash Attention Prefill Cached. Combinations supported are:

  • bf16 inputs, fp32 accumulator, bf16 | fp32 output
  • fp16 inputs, fp32 accumulator, fp16 | fp32 output

It also fixes the PagedKV cache support when used with Variable Length sequences.

Copy link

@joeatodd joeatodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - some suggestions.

using TiledMmaOutput = typename TiledMMAHelper<MMA_Atom<MMAOperation_>, Layout<TileShapeOutput>, SubgroupLayout>::TiledMMA;
using GmemTiledCopyO = CopyOpO;
using ElementOutput = ElementO_;
using ElementCompute = ElementO_;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using ElementCompute = ElementO_;
using ElementCompute = ElementCompute_;

Tensor tOgO = thread_xe_store_o.partition_D(gO);

copy(params.xe_store_o, out_reg, tOgO);
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in #443, I think a comment explaining this if/else would be useful.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comment after this line.

bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or
(args.mode == gemm::GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
return mode_implementable;
bool valid_page_size = !PagedKV ? true : args.mainloop.page_size >= QK_BLK_N && args.mainloop.page_size % QK_BLK_N == 0;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool valid_page_size = !PagedKV ? true : args.mainloop.page_size >= QK_BLK_N && args.mainloop.page_size % QK_BLK_N == 0;
bool valid_page_size = !PagedKV || (args.mainloop.page_size >= QK_BLK_N && args.mainloop.page_size % QK_BLK_N == 0);

Comment on lines +549 to +552
int seq_len_cache = isVarLen ? cumulative_seqlen_kv_cache[b + 1] - cumulative_seqlen_kv_cache[b] : seq_len_kv_cache;
int pages_per_seq = ceil_div(seq_len_cache, paged_kv_cache.page_size);
num_pages_per_seq.push_back(num_pages_per_seq.back() + pages_per_seq);
num_pages += pages_per_seq;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding correctly that num_pages_per_seq actually stores the offset of the first page of each sequence? If so, I'd say that it is misnamed. It is seq_page_offsets or something like that.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Variable length sequences, the seq_len_cache is not constant, so each batch has it's own set of pages. num_pages_per_seq is storing all the indices of the pages in all of the batches. It's size would be batch * num_pages_for_all_batches.

Comment on lines +559 to 566
std::vector<int> physical_pages(num_pages_per_seq[b + 1] - num_pages_per_seq[b]);
std::iota(physical_pages.begin(), physical_pages.end(), 0);
// shuffle physical pages
std::shuffle(physical_pages.begin(), physical_pages.end(), std::mt19937{ std::random_device{}() });
for (int blk = 0; blk < physical_pages.size(); ++blk) {
int logical_idx = num_pages_per_seq[b] + blk;
page_mapping[logical_idx] = physical_pages[blk];
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is correct, but I think the page_mapping for each seq will contain the same indices (though shuffled differently).

In other words, the value e.g. 0 will appear cute::get<0>(problem_shape) times in the final page_mapping.

Is this expected?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is true. We may get repeated indices. This is needed for Variable length sequences when each batch can have a different number of pages, and the physical page mapping for each batch could vary. So we need this bigger vector to hold the mapping information.

int max_idx = row;
for (int col = 0; col < seq_len_kv_total; col++, idx++) {
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) * softmax_scale);
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementAccumulator>((head_size_qk))));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change leaves the softmax_scale argument to verify unused. Also extra () brackets around head_size_qk.

@aacostadiaz aacostadiaz merged commit 5377d14 into intel:sycl-develop Jun 30, 2025
15 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants