Skip to content

Commit 0d733e8

Browse files
committed
support different accumulator dtype combinations
1 parent 8b36b71 commit 0d733e8

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

include/cutlass/epilogue/collective/xe_array_epilogue.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,8 @@ class CollectiveEpilogue<
9090
using DispatchPolicy = IntelXeXMX16Group;
9191
using CtaTileMNK = CtaTileMNK_;
9292
using FusionCallbacks = FusionCallbacks_;
93-
using ElementC = ElementD_;
94-
// simple use fp32 as accumulator dtype
95-
using ElementAccumulator = float;
93+
using ElementC = typename FusionCallbacks::ElementSource;
94+
using ElementAccumulator = ElementC_;
9695
using StrideC = StrideC_;
9796
using InternalStrideC = cute::remove_pointer_t<StrideC>;
9897
using ElementD = ElementD_;

include/cutlass/epilogue/collective/xe_epilogue.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,8 @@ class CollectiveEpilogue<
8989
using DispatchPolicy = IntelXeXMX16;
9090
using CtaTileMNK = CtaTileMNK_;
9191
using FusionCallbacks = FusionCallbacks_;
92-
// assume C and D have same dtype
93-
using ElementC = ElementD_;
94-
// simple use fp32 as accumulator dtype
95-
using ElementAccumulator = float;
92+
using ElementC = typename FusionCallbacks::ElementSource;;
93+
using ElementAccumulator = ElementC_;
9694
using StrideC = StrideC_;
9795
using ElementD = ElementD_;
9896
using StrideD = StrideD_;

0 commit comments

Comments
 (0)