-
Notifications
You must be signed in to change notification settings - Fork 58
Separate output and accumulator type for Flash Attention Prefill #443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate output and accumulator type for Flash Attention Prefill #443
Conversation
e948f6d
to
f64c7b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - some nits.
|
||
std::vector<ElementAccumulator> vec_acc(block_acc.size()); | ||
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size()); | ||
syclcompat::wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syclcompat::memcpy
is synchronous/blocking so the wait
s here are unneeded.
|
||
std::vector<ElementAccumulator> vec_acc(block_acc.size()); | ||
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size()); | ||
syclcompat::wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syclcompat::wait(); |
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]); | ||
} | ||
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size()); | ||
syclcompat::wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syclcompat::wait(); |
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size()); | ||
std::vector<ElementAccumulator> host_S(block_S.size()); | ||
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size()); | ||
syclcompat::wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syclcompat::wait(); |
|
||
std::vector<ElementAccumulator> vec_acc(block_acc.size()); | ||
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size()); | ||
syclcompat::wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syclcompat::wait(); |
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]); | ||
} | ||
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size()); | ||
syclcompat::wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syclcompat::wait(); |
|
||
copy(params.xe_store_o, out_reg, tOgO); | ||
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg); | ||
if constexpr (cute::is_any_of_v<ElementOutput, cute::float_e5m2_t, cute::float_e4m3_t> || cute::is_same_v<ElementOutput, ElementCompute>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment here to explain this if condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this comment, not sure why it doesn't show up here.
* Add comment on final output type conversion
This PR adds tests for all the different data types supported with Flash Attention Prefill. It is a continuation of PR #443
This PR separates the output type and accumulator type for Flash Attention Prefill. Combinations supported are:
Tests added in: #446
Benchmarks added in: #447