Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@ template <class DispatchPolicy, class MMAOperation_, class TileShapeOutput_, cla
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
};

template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementCompute_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementCompute_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
public:
//
// Type Aliases
//
using DispatchPolicy = epilogue::IntelXeXMX16;
using ElementO = ElementO_;
using ElementAccumulator = ElementO_;
using StrideO = StrideO_;
using ElementLSE = ElementLSE_;
using CopyOpO = CopyOpO_;
Expand All @@ -70,7 +69,8 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutp
using TiledMmaOutput = typename TiledMMAHelper<MMA_Atom<MMAOperation_>, Layout<TileShapeOutput>, SubgroupLayout>::TiledMMA;
using GmemTiledCopyO = CopyOpO;
using ElementOutput = ElementO_;
using ElementCompute = ElementO_;
using ElementCompute = ElementCompute_;
using ElementAccumulator = ElementCompute_;
using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape())));

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
Expand Down Expand Up @@ -196,7 +196,18 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutp
auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX());
Tensor tOgO = thread_xe_store_o.partition_D(gO);

copy(params.xe_store_o, out_reg, tOgO);
Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg);
// iff ElementOutput == ElementAccumulator, then convert_type doesn't do the right conversion
// iff ElementOutput == fp8, there is no NumericConverter specialization available
// for both the above cases, we call copy() which internally performs a static_cast op on the data.
// for ElementOutput == bf16 | fp16, convert_type calls relevant NumericConverter specialization.
if constexpr (cute::is_any_of_v<ElementOutput, cute::float_e5m2_t, cute::float_e4m3_t> || cute::is_same_v<ElementOutput, ElementCompute>) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here to explain this if condition?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this comment, not sure why it doesn't show up here.

copy(out_reg, final_out_reg);
} else {
Tensor temp = convert_type<ElementOutput>(out_reg);
copy(temp, final_out_reg);
}
copy(params.xe_store_o, final_out_reg, tOgO);
}

// SequenceLengthShapeType = Shape<int, int>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ class FMHAPrefill {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < Vec; row++, row_idx++) { // 8
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
tSr(row, m, n) = -INFINITY;
tSr(row, m, n) = ElementAccumulator{-INFINITY};
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,17 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
}
int kv_group_update=1;
for (int h = 0; h < num_heads_q; h++) {
cutlass::DeviceAllocation<ElementOutput> block_S;
cutlass::DeviceAllocation<ElementAccumulator> block_S;
block_S.reset(seq_len_qo * seq_len_kv);

cutlass::TensorRef ref_Q(block_Q[0].get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk}));
cutlass::TensorRef ref_K(block_K[0].get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv}));
cutlass::TensorRef ref_V(block_V[0].get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo}));
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo}));

cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q,
cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, ElementAccumulator{1.f}, ref_Q,
cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone,
0.f, ref_S, ref_S, ElementAccumulator(0),
ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0},
1, // batch_count
seq_len_qo * head_size_qk, // batch_stride_Q
seq_len_kv * head_size_qk, // batch_stride_K
Expand All @@ -210,9 +209,8 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {

syclcompat::wait();

std::vector<ElementOutput> host_S(block_S.size());
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
syclcompat::wait();
std::vector<ElementAccumulator> host_S(block_S.size());
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());

// delete this memory as it is no longer needed
block_S.reset();
Expand All @@ -224,13 +222,13 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
for (int row = 0; row < seq_len_qo; row++) {
for (int col = 0; col < seq_len_kv; col++) {
if ((col - full_tile_offset) > (row - discard_seq_coord))
host_S[col + row * seq_len_kv] = -INFINITY;
host_S[col + row * seq_len_kv] = ElementAccumulator{-INFINITY};
}
}
}

// compute max element per row of S
std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY);
std::vector<ElementAccumulator> max_vec(seq_len_qo, ElementAccumulator{-INFINITY});
for (int row = 0; row < seq_len_qo; row++) {
int idx = row * seq_len_kv;
int max_idx = row;
Expand All @@ -246,12 +244,12 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
int idx = row * seq_len_kv;
int max_idx = row;
for (int col = 0; col < seq_len_kv; col++, idx++) {
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementOutput>((head_size_qk))));
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / std::sqrt(static_cast<ElementAccumulator>((head_size_qk))));
}
}

// compute sum per row of S
std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0});
std::vector<ElementAccumulator> sum_vec(seq_len_qo, ElementAccumulator{0});
for (int row = 0; row < seq_len_qo; row++) {
int idx = row * seq_len_kv;
int sum_idx = row;
Expand Down Expand Up @@ -279,13 +277,16 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
block_P.reset(host_P.size());

syclcompat::memcpy<ElementV>(block_P.get(), host_P.data(), host_P.size());
syclcompat::wait();

cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));

cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, 1.f, ref_P,
cutlass::DeviceAllocation<ElementAccumulator> block_acc;
block_acc.reset(seq_len_qo * head_size_vo);
cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo}));

cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementAccumulator{1}, ref_P,
cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone,
0.f, ref_O, ref_O, ElementAccumulator(0),
ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0},
1, // batch_count
seq_len_qo * seq_len_kv, // batch_stride_P
seq_len_kv * head_size_vo, // batch_stride_V
Expand All @@ -297,6 +298,17 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {
// delete this memory as it is no longer needed
block_P.reset();

std::vector<ElementAccumulator> vec_acc(block_acc.size());
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());

// delete this memory as it is no longer needed
block_acc.reset();
std::vector<ElementOutput> vec_out(vec_acc.size());
for(int i = 0; i < vec_out.size(); i++) {
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
}
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());

offset_q += seq_len_qo * head_size_qk;
if(kv_group_update % q_group_size==0) {
offset_k += seq_len_kv * head_size_qk;
Expand All @@ -311,7 +323,7 @@ template <class FMHAPrefillConfiguration> struct BenchmarkRunnerFMHA {

// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
block_O.size(), 0.5f, 0.5f);
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});

return passed;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct FMHAPrefillConfig {
using MMAOperation = typename MMAOP<GEMMDispatchPolicy, ElementInputType,ElementAccumulator>::Type;
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillEpilogue<
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput,
SubgroupLayout, ElementAccumulator,
SubgroupLayout, ElementAccumulator, ElementOutputType,
cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
GmemTiledCopyO>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
}
int kv_group_update=1;
for (int h = 0; h < num_heads_q; h++) {
cutlass::DeviceAllocation<ElementOutput> block_S;
cutlass::DeviceAllocation<ElementAccumulator> block_S;
block_S.reset(seq_len_qo * seq_len_kv);

cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk}));
cutlass::TensorRef ref_K(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv}));
cutlass::TensorRef ref_V(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo}));
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));
cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo}));

cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q,
cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone,
Expand All @@ -251,9 +250,8 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {

syclcompat::wait();

std::vector<ElementOutput> host_S(block_S.size());
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
syclcompat::wait();
std::vector<ElementAccumulator> host_S(block_S.size());
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());

// delete this memory as it is no longer needed
block_S.reset();
Expand All @@ -265,13 +263,13 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
for (int row = 0; row < seq_len_qo; row++) {
for (int col = 0; col < seq_len_kv; col++) {
if ((col - full_tile_offset) > (row - discard_seq_coord))
host_S[col + row * seq_len_kv] = -INFINITY;
host_S[col + row * seq_len_kv] = ElementAccumulator{-INFINITY};
}
}
}

// compute max element per row of S
std::vector<ElementOutput> max_vec(seq_len_qo, -INFINITY);
std::vector<ElementAccumulator> max_vec(seq_len_qo, ElementAccumulator{-INFINITY});
for (int row = 0; row < seq_len_qo; row++) {
int idx = row * seq_len_kv;
int max_idx = row;
Expand All @@ -287,12 +285,12 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
int idx = row * seq_len_kv;
int max_idx = row;
for (int col = 0; col < seq_len_kv; col++, idx++) {
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementOutput>((head_size_qk))));
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementAccumulator>((head_size_qk))));
}
}

// compute sum per row of S
std::vector<ElementOutput> sum_vec(seq_len_qo, ElementOutput{0});
std::vector<ElementAccumulator> sum_vec(seq_len_qo, ElementAccumulator{0});
for (int row = 0; row < seq_len_qo; row++) {
int idx = row * seq_len_kv;
int sum_idx = row;
Expand Down Expand Up @@ -320,13 +318,16 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
block_P.reset(host_P.size());

syclcompat::memcpy<ElementV_>(block_P.get(), host_P.data(), host_P.size());
syclcompat::wait();

cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv}));

cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, 1.f, ref_P,
cutlass::DeviceAllocation<ElementAccumulator> block_acc;
block_acc.reset(seq_len_qo * head_size_vo);
cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo}));

cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementAccumulator{1}, ref_P,
cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone,
0.f, ref_O, ref_O, ElementAccumulator(0),
ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0},
1, // batch_count
seq_len_qo * seq_len_kv, // batch_stride_P
seq_len_kv * head_size_vo, // batch_stride_V
Expand All @@ -338,6 +339,17 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {
// delete this memory as it is no longer needed
block_P.reset();

std::vector<ElementAccumulator> vec_acc(block_acc.size());
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());

// delete this memory as it is no longer needed
block_acc.reset();
std::vector<ElementOutput> vec_out(vec_acc.size());
for(int i = 0; i < vec_out.size(); i++) {
vec_out[i] = static_cast<ElementOutput>(vec_acc[i]);
}
syclcompat::memcpy<ElementOutput>(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size());

offset_q += seq_len_qo * head_size_qk;
if(kv_group_update % q_group_size==0) {
offset_k += seq_len_kv * head_size_qk;
Expand All @@ -352,7 +364,7 @@ template <class FMHAPrefillKernel, bool isVarLen> struct ExampleRunner {

// Check if output from CUTLASS kernel and reference kernel are equal or not
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
block_O.size(), 0.5f, 0.5f);
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});

return passed;
}
Expand Down Expand Up @@ -619,7 +631,7 @@ template <bool Causal,
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
using CollectiveEpilogue = cutlass::flash_attention::collective::FlashPrefillEpilogue<
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementAccumulator, cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
GmemTiledCopyStore>;
using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue<Causal, EpilogueDispatchPolicy, ElementAccumulator>;

Expand Down
1 change: 1 addition & 0 deletions include/cute/arch/copy_xe_U8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ struct XE_2D_U8x4x16_ST_N {
};

struct XE_2D_U8x8x16_ST_N {
using BlockShape = Shape<_8, _16>;
template <class T>
CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height,
int pitch, intel::coord_t coord,
Expand Down
Loading
Loading