-
Notifications
You must be signed in to change notification settings - Fork 64
Separate output and accumulator type for Flash Attention Prefill #443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -189,18 +189,17 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA { | |||
| } | ||||
| int kv_group_update=1; | ||||
| for (int h = 0; h < num_heads_q; h++) { | ||||
| cutlass::DeviceAllocation<ElementOutput> block_S; | ||||
| cutlass::DeviceAllocation<ElementAccumulator> block_S; | ||||
| block_S.reset(seq_len_qo * seq_len_kv); | ||||
|
|
||||
| cutlass::TensorRef ref_Q(block_Q[0].get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); | ||||
| cutlass::TensorRef ref_K(block_K[0].get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); | ||||
| cutlass::TensorRef ref_V(block_V[0].get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo})); | ||||
| cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); | ||||
| cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo})); | ||||
|
|
||||
| cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q, | ||||
| cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, ElementAccumulator{1.f}, ref_Q, | ||||
| cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, | ||||
| 0.f, ref_S, ref_S, ElementAccumulator(0), | ||||
| ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0}, | ||||
| 1, // batch_count | ||||
| seq_len_qo * head_size_qk, // batch_stride_Q | ||||
| seq_len_kv * head_size_qk, // batch_stride_K | ||||
|
|
@@ -210,8 +209,8 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA { | |||
|
|
||||
| syclcompat::wait(); | ||||
|
|
||||
| std::vector<ElementOutput> host_S(block_S.size()); | ||||
| syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size()); | ||||
| std::vector<ElementAccumulator> host_S(block_S.size()); | ||||
| syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size()); | ||||
| syclcompat::wait(); | ||||
|
||||
| syclcompat::wait(); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| syclcompat::wait(); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| syclcompat::wait(); |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -230,14 +230,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner { | |||
| } | ||||
| int kv_group_update=1; | ||||
| for (int h = 0; h < num_heads_q; h++) { | ||||
| cutlass::DeviceAllocation<ElementOutput> block_S; | ||||
| cutlass::DeviceAllocation<ElementAccumulator> block_S; | ||||
| block_S.reset(seq_len_qo * seq_len_kv); | ||||
|
|
||||
| cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); | ||||
| cutlass::TensorRef ref_K(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); | ||||
| cutlass::TensorRef ref_V(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo})); | ||||
| cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); | ||||
| cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo})); | ||||
|
|
||||
| cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q, | ||||
| cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, | ||||
|
|
@@ -251,8 +250,8 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner { | |||
|
|
||||
| syclcompat::wait(); | ||||
|
|
||||
| std::vector<ElementOutput> host_S(block_S.size()); | ||||
| syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size()); | ||||
| std::vector<ElementAccumulator> host_S(block_S.size()); | ||||
| syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size()); | ||||
| syclcompat::wait(); | ||||
|
|
||||
| // delete this memory as it is no longer needed | ||||
|
|
@@ -265,13 +264,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner { | |||
| for (int row = 0; row < seq_len_qo; row++) { | ||||
| for (int col = 0; col < seq_len_kv; col++) { | ||||
| if ((col - full_tile_offset) > (row - discard_seq_coord)) | ||||
| host_S[col + row * seq_len_kv] = -INFINITY; | ||||
| host_S[col + row * seq_len_kv] = ElementAccumulator{-INFINITY}; | ||||
| } | ||||
| } | ||||
| } | ||||
|
|
||||
| // compute max element per row of S | ||||
| std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY); | ||||
| std::vector<ElementAccumulator> max_vec(seq_len_qo, ElementAccumulator{-INFINITY}); | ||||
| for (int row = 0; row < seq_len_qo; row++) { | ||||
| int idx = row * seq_len_kv; | ||||
| int max_idx = row; | ||||
|
|
@@ -287,12 +286,12 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner { | |||
| int idx = row * seq_len_kv; | ||||
| int max_idx = row; | ||||
| for (int col = 0; col < seq_len_kv; col++, idx++) { | ||||
| host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementOutput>((head_size_qk)))); | ||||
| host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementAccumulator>((head_size_qk)))); | ||||
| } | ||||
| } | ||||
|
|
||||
| // compute sum per row of S | ||||
| std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0}); | ||||
| std::vector<ElementAccumulator> sum_vec(seq_len_qo, ElementAccumulator{0}); | ||||
| for (int row = 0; row < seq_len_qo; row++) { | ||||
| int idx = row * seq_len_kv; | ||||
| int sum_idx = row; | ||||
|
|
@@ -324,9 +323,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner { | |||
|
|
||||
| cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); | ||||
|
|
||||
| cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, 1.f, ref_P, | ||||
| cutlass::DeviceAllocation<ElementAccumulator> block_acc; | ||||
| block_acc.reset(seq_len_qo * head_size_vo); | ||||
| cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); | ||||
|
|
||||
| cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementAccumulator{1}, ref_P, | ||||
| cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, | ||||
| 0.f, ref_O, ref_O, ElementAccumulator(0), | ||||
| ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0}, | ||||
| 1, // batch_count | ||||
| seq_len_qo * seq_len_kv, // batch_stride_P | ||||
| seq_len_kv * head_size_vo, // batch_stride_V | ||||
|
|
@@ -338,6 +341,19 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner { | |||
| // delete this memory as it is no longer needed | ||||
| block_P.reset(); | ||||
|
|
||||
| std::vector<ElementAccumulator> vec_acc(block_acc.size()); | ||||
| syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size()); | ||||
| syclcompat::wait(); | ||||
|
||||
| syclcompat::wait(); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| syclcompat::wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment here to explain this if condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this comment, not sure why it doesn't show up here.