From 4736337e45443e05db6d1c62673aab9cd2110d5c Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 27 Aug 2025 17:29:47 +0800 Subject: [PATCH 1/6] support fp32 accumulation for bf16 gemm and grouped gemm Previously bf16 in bf16 out gemm will use bf16 accmulator, after this patch fp32 accumulator will be used for good accuracy. There is an assumption that C and D have same dtypes. Signed-off-by: Wuxun Zhang --- examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp | 4 +-- .../04_bmg_grouped_gemm.cpp | 8 ++--- .../epilogue/collective/xe_array_epilogue.hpp | 31 ++++++++++++++++--- .../epilogue/collective/xe_epilogue.hpp | 31 ++++++++++++++++--- 4 files changed, 58 insertions(+), 16 deletions(-) diff --git a/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp b/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp index 251a4d1f10..1dcb278a9d 100644 --- a/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp +++ b/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp @@ -156,7 +156,7 @@ struct ExampleRunner { using ElementAcc = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; - using ElementC = typename Gemm::ElementC; + using ElementC = typename CollectiveEpilogue::ElementOutput; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; @@ -386,7 +386,7 @@ int main(int argc, const char** argv) using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, - ElementAccumulator, + ElementOutput, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation ElementOutput, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation diff --git a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp index bdda0536d2..a77b6be0de 100644 --- a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp +++ b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp @@ -190,7 +190,6 @@ struct ExampleRunner { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; - using ElementC = typename Gemm::ElementC; using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; @@ -199,7 +198,8 @@ struct ExampleRunner { using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementOutput = typename CollectiveEpilogue::ElementOutput; - using ElementAccumulator = ElementOutput; + using ElementAccumulator = ElementAccumulator; + using ElementC = typename CollectiveEpilogue::ElementOutput; using StrideA = typename Gemm::GemmKernel::InternalStrideA; using StrideB = typename Gemm::GemmKernel::InternalStrideB; @@ -585,14 +585,14 @@ int main(int argc, const char** argv) using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + ElementOutput, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, - ElementAccumulator, + ElementOutput, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index e8b1709aad..1139d4a643 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -90,8 +90,10 @@ class CollectiveEpilogue< using DispatchPolicy = IntelXeXMX16Group; using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementC_; - using ElementAccumulator = ElementC_; + using ElementC = ElementD_; + // simple heristic to determine accumulator dtype + using ElementAccumulator = conditional_t || + std::is_same_v || std::is_same_v, float, float>; using StrideC = StrideC_; using InternalStrideC = cute::remove_pointer_t; using ElementD = ElementD_; @@ -115,7 +117,7 @@ class CollectiveEpilogue< static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; static_assert(cute::is_same_v>, + fusion::LinearCombination>, "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -421,6 +423,8 @@ class CollectiveEpilogue< constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + constexpr bool is_same_dtype_accum_and_output = std::is_same_v; + auto synchronize = [&] () {}; CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { @@ -428,8 +432,19 @@ class CollectiveEpilogue< for (int epi_m = 0; epi_m < FragsM; epi_m++) { if (is_C_load_needed) { - //cordinates for C and D are the same + if constexpr (is_same_dtype_accum_and_output) { + //cordinates for C and D are the same copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC); + } else { + Tensor trC_ori = make_tensor(Shape>{}); + copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC_ori); + // convert trC from original dtype to accumulator dtype + CUTLASS_PRAGMA_UNROLL + for (int copy_n = 0; copy_n < FragmentSize; copy_n++) { + trC(copy_n) = (typename TiledMma::ValTypeC)trC_ori(copy_n); + } + } + } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); @@ -438,7 +453,13 @@ class CollectiveEpilogue< CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { - trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + if constexpr (is_same_dtype_accum_and_output) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } else { + // align dtypes firstly + auto tmp = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + trD_compute_frag(epi_v) = cutlass::NumericArrayConverter{}(tmp); + } } cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 003f5de776..8b0869b19f 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -89,8 +89,11 @@ class CollectiveEpilogue< using DispatchPolicy = IntelXeXMX16; using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementC_; - using ElementAccumulator = ElementC_; + // assume C and D have same dtype + using ElementC = ElementD_; + // simple heristic to determine accumulator dtype + using ElementAccumulator = conditional_t || + std::is_same_v || std::is_same_v, float, float>; using StrideC = StrideC_; using ElementD = ElementD_; using StrideD = StrideD_; @@ -398,7 +401,9 @@ class CollectiveEpilogue< FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); - + + constexpr bool is_same_dtype_accum_and_output = std::is_same_v; + auto synchronize = [&] () {}; CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { @@ -407,7 +412,17 @@ class CollectiveEpilogue< cst_callbacks.begin_loop(epi_m, epi_n); if (is_C_load_needed) { - copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); + if constexpr (is_same_dtype_accum_and_output) { + copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC); + } else { + Tensor trC_ori = make_tensor(Shape>{}); + copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC_ori); + // convert trC from original dtype to accumulator dtype + CUTLASS_PRAGMA_UNROLL + for (int copy_n = 0; copy_n < FragmentSize; copy_n++) { + trC(copy_n) = (typename TiledMma::ValTypeC)trC_ori(copy_n); + } + } } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); @@ -416,7 +431,13 @@ class CollectiveEpilogue< CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { - trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + if constexpr (is_same_dtype_accum_and_output) { + trD_compute_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } else { + // align dtypes firstly + auto tmp = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + trD_compute_frag(epi_v) = cutlass::NumericArrayConverter{}(tmp); + } } cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD_compute_frag); From a8aefe093a045b29238f294541ee7a503f7fe204 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Fri, 29 Aug 2025 15:06:20 +0800 Subject: [PATCH 2/6] updates --- .../epilogue/collective/xe_array_epilogue.hpp | 12 ++++++------ include/cutlass/epilogue/collective/xe_epilogue.hpp | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index 1139d4a643..c60d8e756e 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -91,9 +91,8 @@ class CollectiveEpilogue< using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; using ElementC = ElementD_; - // simple heristic to determine accumulator dtype - using ElementAccumulator = conditional_t || - std::is_same_v || std::is_same_v, float, float>; + // simple use fp32 as accumulator dtype + using ElementAccumulator = float; using StrideC = StrideC_; using InternalStrideC = cute::remove_pointer_t; using ElementD = ElementD_; @@ -374,6 +373,7 @@ class CollectiveEpilogue< Tensor tCgD = thread_xe_store_d.partition_D(gD); Tensor trC = make_tensor(Shape>{}); + auto trC_frag = recast>(trC); Tensor trD_compute = make_tensor(Shape>{}); // Because Sm90 uses shared memory, they are not tied to using the same accumulator values @@ -438,10 +438,10 @@ class CollectiveEpilogue< } else { Tensor trC_ori = make_tensor(Shape>{}); copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC_ori); - // convert trC from original dtype to accumulator dtype + auto trC_ori_frag = recast>(trC_ori); CUTLASS_PRAGMA_UNROLL - for (int copy_n = 0; copy_n < FragmentSize; copy_n++) { - trC(copy_n) = (typename TiledMma::ValTypeC)trC_ori(copy_n); + for (int i = 0; i < size(trC_frag); ++i) { + trC_frag(i) = cutlass::NumericArrayConverter{}(trC_ori_frag(i)); } } diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 8b0869b19f..3647b6fad3 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -91,9 +91,8 @@ class CollectiveEpilogue< using FusionCallbacks = FusionCallbacks_; // assume C and D have same dtype using ElementC = ElementD_; - // simple heristic to determine accumulator dtype - using ElementAccumulator = conditional_t || - std::is_same_v || std::is_same_v, float, float>; + // simple use fp32 as accumulator dtype + using ElementAccumulator = float; using StrideC = StrideC_; using ElementD = ElementD_; using StrideD = StrideD_; @@ -353,6 +352,7 @@ class CollectiveEpilogue< Tensor tCgD = thread_xe_store_d.partition_D(gD); Tensor trC = make_tensor(Shape>{}); + auto trC_frag = recast>(trC); Tensor trD_compute = make_tensor(Shape>{}); // Because Sm90 uses shared memory, they are not tied to using the same accumulator values @@ -417,10 +417,10 @@ class CollectiveEpilogue< } else { Tensor trC_ori = make_tensor(Shape>{}); copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC_ori); - // convert trC from original dtype to accumulator dtype + auto trC_ori_frag = recast>(trC_ori); CUTLASS_PRAGMA_UNROLL - for (int copy_n = 0; copy_n < FragmentSize; copy_n++) { - trC(copy_n) = (typename TiledMma::ValTypeC)trC_ori(copy_n); + for (int i = 0; i < size(trC_frag); ++i) { + trC_frag(i) = cutlass::NumericArrayConverter{}(trC_ori_frag(i)); } } } From 8b36b71d1eb01ff8d55a99bdfd4a413b3152b40d Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Tue, 2 Sep 2025 16:40:22 +0800 Subject: [PATCH 3/6] address comment --- include/cutlass/epilogue/collective/xe_array_epilogue.hpp | 6 +----- include/cutlass/epilogue/collective/xe_epilogue.hpp | 5 +---- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index c60d8e756e..ee630adec8 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -439,12 +439,8 @@ class CollectiveEpilogue< Tensor trC_ori = make_tensor(Shape>{}); copy(params.xe_load_c.with(get<0>(load_store_tensors)), tCgD(_, epi_m, epi_n), trC_ori); auto trC_ori_frag = recast>(trC_ori); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(trC_frag); ++i) { - trC_frag(i) = cutlass::NumericArrayConverter{}(trC_ori_frag(i)); - } + *(trC_frag.data()) = cutlass::NumericArrayConverter{}(*(trC_ori_frag.data())); } - } cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 3647b6fad3..0be0c2759e 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -418,10 +418,7 @@ class CollectiveEpilogue< Tensor trC_ori = make_tensor(Shape>{}); copy(params.xe_load_c, tCgC(_, epi_m, epi_n), trC_ori); auto trC_ori_frag = recast>(trC_ori); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(trC_frag); ++i) { - trC_frag(i) = cutlass::NumericArrayConverter{}(trC_ori_frag(i)); - } + *(trC_frag.data()) = cutlass::NumericArrayConverter{}(*(trC_ori_frag.data())); } } From 0d733e8a794fae5f86e1be6ee5ffa447cc965a48 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 3 Sep 2025 10:06:20 +0800 Subject: [PATCH 4/6] support different accumulator dtype combinations --- include/cutlass/epilogue/collective/xe_array_epilogue.hpp | 5 ++--- include/cutlass/epilogue/collective/xe_epilogue.hpp | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index ee630adec8..0c2b2fd51a 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -90,9 +90,8 @@ class CollectiveEpilogue< using DispatchPolicy = IntelXeXMX16Group; using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; - using ElementC = ElementD_; - // simple use fp32 as accumulator dtype - using ElementAccumulator = float; + using ElementC = typename FusionCallbacks::ElementSource; + using ElementAccumulator = ElementC_; using StrideC = StrideC_; using InternalStrideC = cute::remove_pointer_t; using ElementD = ElementD_; diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 0be0c2759e..779a32b84f 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -89,10 +89,8 @@ class CollectiveEpilogue< using DispatchPolicy = IntelXeXMX16; using CtaTileMNK = CtaTileMNK_; using FusionCallbacks = FusionCallbacks_; - // assume C and D have same dtype - using ElementC = ElementD_; - // simple use fp32 as accumulator dtype - using ElementAccumulator = float; + using ElementC = typename FusionCallbacks::ElementSource;; + using ElementAccumulator = ElementC_; using StrideC = StrideC_; using ElementD = ElementD_; using StrideD = StrideD_; From 04e3da07d3432713a4570041bb238fef7ea1f5e1 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 3 Sep 2025 10:16:31 +0800 Subject: [PATCH 5/6] fix example --- examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp | 4 ++-- examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp b/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp index 1dcb278a9d..b7bb1d2753 100644 --- a/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp +++ b/examples/sycl/00_bmg_gemm/00_bmg_gemm.cpp @@ -375,7 +375,7 @@ int main(int argc, const char** argv) // aside from the (A*B), which is handled by the GEMM. See 05_bmg_gemm_with_epilogues for more // complex epilogue examples. using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + ElementOutput, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>; // FusionCallbacks ties the EpilogueOp to an implementation (based on the dispatch // policy/architecture) and defines the epilogue arguments. @@ -386,7 +386,7 @@ int main(int argc, const char** argv) using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, - ElementOutput, + ElementAccumulator, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation ElementOutput, cutlass::gemm::TagToStrideC_t, // Converts CUTLASS 2.x to CUTLASS 3.x representation diff --git a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp index a77b6be0de..1dbe943c7d 100644 --- a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp +++ b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp @@ -592,7 +592,7 @@ int main(int argc, const char** argv) using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< EpilogueDispatchPolicy, TileShape, - ElementOutput, + ElementAccumulator, cutlass::gemm::TagToStrideC_t, ElementOutput, cutlass::gemm::TagToStrideC_t, From a1a612172fb649e27fcaec44897cb233064b0a91 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 3 Sep 2025 11:31:43 +0800 Subject: [PATCH 6/6] fix tests --- include/cutlass/epilogue/fusion/xe_callbacks.hpp | 11 ++++++----- .../gemm/device/default_gemm_group_configuration.hpp | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 5173d77000..f78f7b2862 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -343,7 +343,7 @@ template < class ElementOutput_, class ElementCompute_, class ElementAux, - class ElementSource, + class ElementSource_, class ElementScalar, int AlignmentAux, FloatRoundStyle RoundStyle, @@ -355,28 +355,29 @@ struct FusionCallbacks< epilogue::IntelXeXMX16, fusion::LinCombDeEltAct< GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, - ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + ElementAux, ElementSource_, ElementScalar, AlignmentAux, RoundStyle >, CtaTileShapeMNK, EpilogueTile, CopyOpG2R > : XeLinCombDeEltAct< cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, ElementOutput_, - ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle + ElementCompute_, ElementAux, ElementSource_, ElementScalar, RoundStyle > { using ElementOutput = ElementOutput_; using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; using Impl = XeLinCombDeEltAct< cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, ElementOutput, - ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle + ElementCompute, ElementAux, ElementSource_, ElementScalar, RoundStyle >; using Operation = fusion::LinCombDeEltAct< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + ElementAux, ElementSource_, ElementScalar, AlignmentAux, RoundStyle >; struct Arguments { diff --git a/test/unit/gemm/device/default_gemm_group_configuration.hpp b/test/unit/gemm/device/default_gemm_group_configuration.hpp index 33003e2863..ee98a07ddf 100644 --- a/test/unit/gemm/device/default_gemm_group_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_group_configuration.hpp @@ -87,7 +87,7 @@ struct DefaultGemmGroupConfiguration< using TiledMma = typename CollectiveMainloop::TiledMma; - using EpilogueOp = epilogue::fusion::LinearCombination; + using EpilogueOp = epilogue::fusion::LinearCombination; using FusionCallBacks = epilogue::fusion::FusionCallbacks< epilogue::IntelXeXMX16Group, @@ -101,7 +101,7 @@ struct DefaultGemmGroupConfiguration< TileShape, Shape<_1, _1, _1>, epilogue::collective::EpilogueTileAuto, float, float, - float, LayoutC, 1, + ElementOutput, LayoutC, 1, ElementOutput, LayoutC, 1, epilogue::IntelXeXMX16Group, EpilogueOp