From cb395d33e05e7cf1fae13121b8cdb73f6b3bbad1 Mon Sep 17 00:00:00 2001 From: Sanchit Jain Date: Mon, 29 Sep 2025 16:36:18 -0700 Subject: [PATCH 1/4] Support nullptr value of arg ptr_C for xe_array_epilogue --- .../04_bmg_grouped_gemm.cpp | 76 ++++++++++++------- .../collective/builders/xe_builder.inl | 8 +- .../epilogue/collective/xe_array_epilogue.hpp | 7 +- .../cutlass/epilogue/fusion/operations.hpp | 5 +- .../cutlass/epilogue/fusion/xe_callbacks.hpp | 14 ++-- 5 files changed, 68 insertions(+), 42 deletions(-) diff --git a/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp b/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp index ffc01d0825..dd015f30f4 100644 --- a/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp +++ b/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp @@ -421,7 +421,10 @@ void initialize(const Options &options) { } /// Populates a Gemm::Arguments structure from the given commandline options - typename Gemm::Arguments args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) + typename Gemm::Arguments args_from_options(const Options &options, + const cutlass::KernelHardwareInfo& hw_info, + bool host_problem_shapes_available = true, + bool use_nullptr_c = false) { typename Gemm::Arguments arguments; decltype(arguments.epilogue.thread) fusion_args; @@ -458,7 +461,7 @@ void initialize(const Options &options) { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + {fusion_args, use_nullptr_c ? nullptr : ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info, {1, RasterOrderOptions::AlongN} }; @@ -468,7 +471,7 @@ void initialize(const Options &options) { cutlass::gemm::GemmUniversalMode::kGrouped, {options.groups, problem_sizes.get(), nullptr}, {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + {fusion_args, use_nullptr_c ? nullptr : ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, hw_info, {1, RasterOrderOptions::AlongN} }; @@ -477,13 +480,16 @@ void initialize(const Options &options) { return arguments; } - cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info, bool host_problem_shapes_available = true) { + cutlass::Status run(const Options& options, + const cutlass::KernelHardwareInfo& hw_info, + bool host_problem_shapes_available = true, + bool use_nullptr_c = false) { allocate(options); initialize(options); Gemm gemm_op; - auto arguments = args_from_options(options, hw_info, host_problem_shapes_available); + auto arguments = args_from_options(options, hw_info, host_problem_shapes_available, use_nullptr_c); size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); @@ -530,26 +536,8 @@ void initialize(const Options &options) { }; -int main(int argc, const char** argv) -{ - // - // Parse options - // - - Options options; - - options.parse(argc, argv); - - if (options.help) { - options.print_usage(std::cout) << std::endl; - return 0; - } - - if (options.error) { - std::cerr << "Aborting execution." << std::endl; - return -1; - } - +template +void launcher(Options& options) { // // Run examples // @@ -584,8 +572,11 @@ int main(int argc, const char** argv) using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group; using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + using EpilogueOp = cute::conditional_t, + cutlass::epilogue::fusion::LinearCombination>; using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; @@ -626,7 +617,36 @@ int main(int argc, const char** argv) ExampleRunner runner; - CUTLASS_CHECK(runner.run(options, hw_info)); + CUTLASS_CHECK(runner.run(options, hw_info, true, /* use_nullptr_c = */use_nullptr_c)); +} + + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + if (options.beta == 0.f) { + std::cout << "\n\nUse a nullptr as argument ptr_C of the group GEMM epilogue colective\n\n"; + launcher(options); + std::cout << "\n\nPass actual ptr_C as an argument to the group GEMM epilogue colective\n\n"; + } + launcher(options); return 0; + } diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index 809cede6f7..6b93ba76da 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -49,10 +49,12 @@ namespace detail { template < class ElementD, class ElementCompute, - class ElementC + class ElementC, + cutlass::FloatRoundStyle RoundStyle_, + bool supportSource_ > struct FusionOpInfo> { constexpr static bool HasBuilder = true; @@ -63,7 +65,7 @@ namespace detail { class> using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< DispatchPolicy, - cutlass::epilogue::fusion::LinearCombination, + cutlass::epilogue::fusion::LinearCombination, TileShape_MNK, EpilogueTile >; diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index 9b879bd14d..746ddf8fc5 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -114,8 +114,9 @@ class CollectiveEpilogue< using ElementScalar = typename FusionCallbacks::ElementScalar; static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; - static_assert(cute::is_same_v>, + static_assert(cute::is_any_of_v, + fusion::LinearCombination>, "Only Linear Combination Epilogue is supported for Grouped GEMM at the moment."); static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -139,7 +140,7 @@ class CollectiveEpilogue< Layout{}, make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})))); private: - constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_source_supported = not cute::is_void_v && FusionCallbacks::Operation::IsSourceSupported; constexpr static bool is_destination_supported = not cute::is_void_v && not cute::is_void_v; public: diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index c7c94d18f9..37bb8882ec 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -111,12 +111,13 @@ template< class ElementCompute_, class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, - FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest, + bool supportSource_ = true > struct LinearCombination : ScaledAcc { using ElementSource = ElementSource_; - static constexpr bool IsSourceSupported = true; + static constexpr bool IsSourceSupported = supportSource_; }; // D = activation(alpha * acc + beta * C) diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index 5173d77000..447555c043 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -63,11 +63,12 @@ template < class ElementScalar_, FloatRoundStyle RoundStyle_, class CtaTileShapeMNK_, - class EpilogueTile_ + class EpilogueTile_, + bool supportSource_ > struct FusionCallbacks< epilogue::IntelXeXMX16, - fusion::LinearCombination, + fusion::LinearCombination, CtaTileShapeMNK_, EpilogueTile_ > : Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { @@ -77,7 +78,7 @@ struct FusionCallbacks< using ElementCompute = ElementCompute_; using ElementSource = ElementSource_; using ElementScalar = ElementScalar_; - using Operation = fusion::LinearCombination; + using Operation = fusion::LinearCombination; struct Arguments { ElementScalar alpha = ElementScalar(1); @@ -730,11 +731,12 @@ template < class ElementScalar_, FloatRoundStyle RoundStyle_, class CtaTileShapeMNK_, - class EpilogueTile_ + class EpilogueTile_, + bool supportSource_ > struct FusionCallbacks< epilogue::IntelXeXMX16Group, - fusion::LinearCombination, + fusion::LinearCombination, CtaTileShapeMNK_, EpilogueTile_ > : Sm90LinearCombinationPtrArray::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { @@ -744,7 +746,7 @@ struct FusionCallbacks< using ElementCompute = ElementCompute_; using ElementSource = ElementSource_; using ElementScalar = ElementScalar_; - using Operation = fusion::LinearCombination; + using Operation = fusion::LinearCombination; struct Arguments { ElementScalar alpha = ElementScalar(1); From d02b5a6a3f22d36273ee8dc680e1ebecf88d5f7f Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Tue, 30 Sep 2025 14:59:45 -0700 Subject: [PATCH 2/4] Update comment in example --- examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp b/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp index dd015f30f4..d1fbaba524 100644 --- a/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp +++ b/examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp @@ -641,6 +641,8 @@ int main(int argc, const char** argv) return -1; } if (options.beta == 0.f) { + // the reference kernel doesn't accept nullptr for C, so we only test for nullptr ptr_C epilogue arg + // when beta is 0. std::cout << "\n\nUse a nullptr as argument ptr_C of the group GEMM epilogue colective\n\n"; launcher(options); std::cout << "\n\nPass actual ptr_C as an argument to the group GEMM epilogue colective\n\n"; From 51cf86d12aa58b29685e512b35148881472f4544 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Tue, 30 Sep 2025 17:16:17 -0700 Subject: [PATCH 3/4] Set default template parameters --- include/cutlass/epilogue/collective/builders/xe_builder.inl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index 6b93ba76da..aca2f3c6ea 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -50,8 +50,8 @@ namespace detail { class ElementD, class ElementCompute, class ElementC, - cutlass::FloatRoundStyle RoundStyle_, - bool supportSource_ + cutlass::FloatRoundStyle RoundStyle_ = cutlass::FloatRoundStyle::round_to_nearest, + bool supportSource_ = true > struct FusionOpInfo Date: Tue, 30 Sep 2025 20:21:08 -0700 Subject: [PATCH 4/4] Undo change --- include/cutlass/epilogue/collective/builders/xe_builder.inl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index aca2f3c6ea..6b93ba76da 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -50,8 +50,8 @@ namespace detail { class ElementD, class ElementCompute, class ElementC, - cutlass::FloatRoundStyle RoundStyle_ = cutlass::FloatRoundStyle::round_to_nearest, - bool supportSource_ = true + cutlass::FloatRoundStyle RoundStyle_, + bool supportSource_ > struct FusionOpInfo