-
Notifications
You must be signed in to change notification settings - Fork 56
Description
Which component has the problem?
CUTLASS C++
Bug Report
Describe the bug
With BF16 A
, B
matrices, try computing Group GEMM. Use output dtype as BFloat16 (convert in epilogue).
Did not work with epilogue created directly with cutlass::epilogue::collective::CollectiveEpilogue
.
Steps/Code to reproduce bug
Please apply this small diff & compile Group GEMM example (the same file)
diff --git a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp
index bdda0536..860cbf0c 100644
--- a/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp
+++ b/examples/sycl/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp
@@ -92,7 +92,7 @@ using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementA = bfloat16_t; // <- data type of elements in input matrix A
using ElementB = bfloat16_t; // <- data type of elements in input matrix B
-using ElementOutput = float; // <- data type of elements in output matrix D
+using ElementOutput = bfloat16_t; // <- data type of elements in output matrix D
///////////////////////////////////////////////////////////////////////////////////////////////////
@@ -198,8 +198,8 @@ struct ExampleRunner {
using LayoutD = typename Gemm::LayoutD;
using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
- using ElementOutput = typename CollectiveEpilogue::ElementOutput;
- using ElementAccumulator = ElementOutput;
-
using ElementOutput = bfloat16_t;
-
using ElementAccumulator = float_t;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
@@ -361,7 +361,7 @@ void initialize(const Options &options) {
std::vector<ElementA *> ptr_A_host(options.groups);
std::vector<ElementB *> ptr_B_host(options.groups);
std::vector<ElementC *> ptr_C_host(options.groups);
- std::vector<ElementC *> ptr_D_host(options.groups);
- std::vector<ElementOutput *> ptr_D_host(options.groups);
std::vector<ElementAccumulator *> ptr_alpha_host(options.groups);
std::vector<ElementAccumulator *> ptr_beta_host(options.groups);
@@ -599,7 +599,7 @@ int main(int argc, const char** argv)
FusionCallBacks,
XE_2D_U32x8x16_LD_N,
void, void,
-
XE_2D_U32x8x16_ST_N,
-
XE_2D_U16x8x16_ST_N, void, void>;
Expected behavior
Dtype conversion of MMA output should be supported in epilogue
Environment details (please complete the following information):
PVC GPU
Additional context
Main branch