Skip to content

Conversation

muhammad-tanvir-1211
Copy link

@muhammad-tanvir-1211 muhammad-tanvir-1211 commented Jun 25, 2025

This PR separates the output type and accumulator type for Flash Attention Prefill. Combinations supported are:

  • bf16 inputs, fp32 accumulator, bf16 | fp32 output
  • fp16 inputs, fp32 accumulator, fp16 | fp32 output
  • fp8 inputs, fp32 accumulator, fp8 | fp32 output

Tests added in: #446
Benchmarks added in: #447

@muhammad-tanvir-1211 muhammad-tanvir-1211 force-pushed the flash_prefill_separate_out_type branch from e948f6d to f64c7b7 Compare June 25, 2025 22:25
Copy link

@joeatodd joeatodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - some nits.


std::vector<ElementAccumulator> vec_acc(block_acc.size());
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
syclcompat::wait();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syclcompat::memcpy is synchronous/blocking so the waits here are unneeded.


std::vector<ElementAccumulator> vec_acc(block_acc.size());
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
syclcompat::wait();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
syclcompat::wait();

vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
}
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());
syclcompat::wait();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
syclcompat::wait();

syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
std::vector<ElementAccumulator> host_S(block_S.size());
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());
syclcompat::wait();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
syclcompat::wait();


std::vector<ElementAccumulator> vec_acc(block_acc.size());
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
syclcompat::wait();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
syclcompat::wait();

vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
}
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());
syclcompat::wait();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
syclcompat::wait();


copy(params.xe_store_o, out_reg, tOgO);
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg);
if constexpr (cute::is_any_of_v<ElementOutput, cute::float_e5m2_t, cute::float_e4m3_t> || cute::is_same_v<ElementOutput, ElementCompute>) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here to explain this if condition?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this comment, not sure why it doesn't show up here.

@aacostadiaz aacostadiaz merged commit d5f1886 into intel:sycl-develop Jun 28, 2025
26 of 48 checks passed
aacostadiaz pushed a commit that referenced this pull request Jun 30, 2025
This PR adds tests for all the different data types supported with Flash
Attention Prefill. It is a continuation of PR #443
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants