File tree Expand file tree Collapse file tree 2 files changed +4
-7
lines changed
include/cutlass/epilogue/collective Expand file tree Collapse file tree 2 files changed +4
-7
lines changed Original file line number Diff line number Diff line change @@ -90,9 +90,8 @@ class CollectiveEpilogue<
90
90
using DispatchPolicy = IntelXeXMX16Group;
91
91
using CtaTileMNK = CtaTileMNK_;
92
92
using FusionCallbacks = FusionCallbacks_;
93
- using ElementC = ElementD_;
94
- // simple use fp32 as accumulator dtype
95
- using ElementAccumulator = float ;
93
+ using ElementC = typename FusionCallbacks::ElementSource;
94
+ using ElementAccumulator = ElementC_;
96
95
using StrideC = StrideC_;
97
96
using InternalStrideC = cute::remove_pointer_t <StrideC>;
98
97
using ElementD = ElementD_;
Original file line number Diff line number Diff line change @@ -89,10 +89,8 @@ class CollectiveEpilogue<
89
89
using DispatchPolicy = IntelXeXMX16;
90
90
using CtaTileMNK = CtaTileMNK_;
91
91
using FusionCallbacks = FusionCallbacks_;
92
- // assume C and D have same dtype
93
- using ElementC = ElementD_;
94
- // simple use fp32 as accumulator dtype
95
- using ElementAccumulator = float ;
92
+ using ElementC = typename FusionCallbacks::ElementSource;;
93
+ using ElementAccumulator = ElementC_;
96
94
using StrideC = StrideC_;
97
95
using ElementD = ElementD_;
98
96
using StrideD = StrideD_;
You can’t perform that action at this time.
0 commit comments