-
Notifications
You must be signed in to change notification settings - Fork 57
Separate output and accumulator type for Flash Attention Prefill Cached #448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate output and accumulator type for Flash Attention Prefill Cached #448
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - some suggestions.
using TiledMmaOutput = typename TiledMMAHelper<MMA_Atom<MMAOperation_>, Layout<TileShapeOutput>, SubgroupLayout>::TiledMMA; | ||
using GmemTiledCopyO = CopyOpO; | ||
using ElementOutput = ElementO_; | ||
using ElementCompute = ElementO_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using ElementCompute = ElementO_; | |
using ElementCompute = ElementCompute_; |
Tensor tOgO = thread_xe_store_o.partition_D(gO); | ||
|
||
copy(params.xe_store_o, out_reg, tOgO); | ||
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as in #443, I think a comment explaining this if/else would be useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the comment after this line.
bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or | ||
(args.mode == gemm::GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); | ||
return mode_implementable; | ||
bool valid_page_size = !PagedKV ? true : args.mainloop.page_size >= QK_BLK_N && args.mainloop.page_size % QK_BLK_N == 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool valid_page_size = !PagedKV ? true : args.mainloop.page_size >= QK_BLK_N && args.mainloop.page_size % QK_BLK_N == 0; | |
bool valid_page_size = !PagedKV || (args.mainloop.page_size >= QK_BLK_N && args.mainloop.page_size % QK_BLK_N == 0); |
int seq_len_cache = isVarLen ? cumulative_seqlen_kv_cache[b + 1] - cumulative_seqlen_kv_cache[b] : seq_len_kv_cache; | ||
int pages_per_seq = ceil_div(seq_len_cache, paged_kv_cache.page_size); | ||
num_pages_per_seq.push_back(num_pages_per_seq.back() + pages_per_seq); | ||
num_pages += pages_per_seq; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I understanding correctly that num_pages_per_seq
actually stores the offset of the first page of each sequence? If so, I'd say that it is misnamed. It is seq_page_offsets
or something like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Variable length sequences, the seq_len_cache
is not constant, so each batch has it's own set of pages. num_pages_per_seq
is storing all the indices of the pages in all of the batches. It's size would be batch * num_pages_for_all_batches
.
std::vector<int> physical_pages(num_pages_per_seq[b + 1] - num_pages_per_seq[b]); | ||
std::iota(physical_pages.begin(), physical_pages.end(), 0); | ||
// shuffle physical pages | ||
std::shuffle(physical_pages.begin(), physical_pages.end(), std::mt19937{ std::random_device{}() }); | ||
for (int blk = 0; blk < physical_pages.size(); ++blk) { | ||
int logical_idx = num_pages_per_seq[b] + blk; | ||
page_mapping[logical_idx] = physical_pages[blk]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is correct, but I think the page_mapping
for each seq will contain the same indices (though shuffled differently).
In other words, the value e.g. 0
will appear cute::get<0>(problem_shape)
times in the final page_mapping
.
Is this expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is true. We may get repeated indices. This is needed for Variable length sequences when each batch can have a different number of pages, and the physical page mapping for each batch could vary. So we need this bigger vector to hold the mapping information.
int max_idx = row; | ||
for (int col = 0; col < seq_len_kv_total; col++, idx++) { | ||
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) * softmax_scale); | ||
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementAccumulator>((head_size_qk)))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change leaves the softmax_scale
argument to verify
unused. Also extra ()
brackets around head_size_qk
.
This PR separates the output type and accumulator type for Flash Attention Prefill Cached. Combinations supported are:
It also fixes the PagedKV cache support when used with Variable Length sequences.