From f258437c6400d988d7549628c438abfac94c25b2 Mon Sep 17 00:00:00 2001 From: taozha2 Date: Wed, 27 Aug 2025 08:20:23 +0800 Subject: [PATCH 1/2] improve mixed data type performance --- benchmarks/gemm/benchmark_runner.hpp | 19 +- benchmarks/gemm/benchmarks_sycl.hpp | 24 +- .../02_bmg_gemm_bf16_s8_bf16.cpp | 108 +++----- .../02_bmg_gemm_f16_u4_f16.cpp | 108 +++----- .../02_bmg_gemm_f16_u4_s8.cpp | 108 +++----- include/cute/arch/copy_xe_U8.hpp | 36 ++- include/cute/atom/copy_traits_xe.hpp | 46 +++- .../gemm/collective/xe_mma_mixed_input.hpp | 245 ++++++++++-------- test/unit/cute/intel_xe/copy_block.cpp | 1 + 9 files changed, 330 insertions(+), 365 deletions(-) diff --git a/benchmarks/gemm/benchmark_runner.hpp b/benchmarks/gemm/benchmark_runner.hpp index 96ebaabbe7..03839013f6 100644 --- a/benchmarks/gemm/benchmark_runner.hpp +++ b/benchmarks/gemm/benchmark_runner.hpp @@ -61,18 +61,16 @@ namespace cutlass::benchmark { /////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(SYCL_INTEL_TARGET) -template +template static constexpr auto is_mixed_dtype = false; +#if defined(SYCL_INTEL_TARGET) template static constexpr auto is_mixed_dtype> = true; -#else -template -static constexpr auto is_mixed_dtype = false; #endif -template +// ScaleType +template struct ScaleType { using type = int; }; @@ -81,7 +79,8 @@ struct ScaleType> { using type = typename T::ElementScale; }; -template +// ZeroType +template struct ZeroType { using type = int; }; @@ -90,7 +89,8 @@ struct ZeroType> { using type = typename T::ElementZero; }; -template +// ScaleStride +template struct ScaleStride { using type = int; }; @@ -99,7 +99,8 @@ struct ScaleStride> { using type = typename T::StrideScale; }; -template +// ZeroStride +template struct ZeroStride { using type = int; }; diff --git a/benchmarks/gemm/benchmarks_sycl.hpp b/benchmarks/gemm/benchmarks_sycl.hpp index ee28c1cfce..89f68ce2e7 100644 --- a/benchmarks/gemm/benchmarks_sycl.hpp +++ b/benchmarks/gemm/benchmarks_sycl.hpp @@ -357,8 +357,8 @@ using PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1 = cutlass::gemm::device::Mix typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, - cutlass::epilogue::fusion::LinearCombination, + cutlass::epilogue::fusion::LinearCombination, 2 >; @@ -373,8 +373,8 @@ using PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1 = cutlass::gemm::device::Mixed typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N, - cutlass::epilogue::fusion::LinearCombination, + cutlass::epilogue::fusion::LinearCombination, 2 >; @@ -389,8 +389,8 @@ using PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1 = cutlass::gemm::device::Mix typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N, - cutlass::epilogue::fusion::LinearCombination, + cutlass::epilogue::fusion::LinearCombination, 2 >; @@ -405,8 +405,8 @@ using PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1 = cutlass::gemm::device::Mixed typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N, - cutlass::epilogue::fusion::LinearCombination, + cutlass::epilogue::fusion::LinearCombination, 2 >; @@ -421,8 +421,8 @@ using PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 = cutlass::gemm::device::Mix typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N, - cutlass::epilogue::fusion::LinearCombination, + cutlass::epilogue::fusion::LinearCombination, 2 >; @@ -437,8 +437,8 @@ using PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 = cutlass::gemm::device::Mix typename TiledMMAHelper, Layout>, Layout, Stride<_4, _1, _0>>>::TiledMMA, XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N, - cutlass::epilogue::fusion::LinearCombination, + cutlass::epilogue::fusion::LinearCombination, 2 >; diff --git a/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_bf16_s8_bf16.cpp b/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_bf16_s8_bf16.cpp index 195d44409a..0ad6a3e1d2 100755 --- a/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_bf16_s8_bf16.cpp +++ b/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_bf16_s8_bf16.cpp @@ -253,86 +253,38 @@ struct ExampleRunner { // Methods // - bool verify(const Options &options) { - - // - // Compute reference output (default gemm kernel w/ ElementA == ElementB) - // - - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; - - // Workgroup-level tile - using TileShape = Shape<_256, _256, _32>; - - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; - - constexpr int PipelineStages = 3; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - - using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U32x8x16_ST_N, - void, void>; - - // Mainloop - using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma< - GEMMDispatchPolicy, - TileShape, - ElementMMA, - cutlass::gemm::TagToStrideA_t, - ElementMMA, - cutlass::gemm::TagToStrideB_t, - TiledMma, - GmemTiledCopyA, void, void, cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; - - using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloopRef, - CollectiveEpilogueRef - >; - - using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; - - typename GemmRef::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {options.m, options.n, options.k, options.l}, - {block_A_dq.get(), stride_A, block_B_dq.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_ref_D.get(), stride_D} - }; - - // Run the gemm where the scaling is performed outside of the kernel. - GemmRef gemm_ref; - size_t workspace_size = GemmRef::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - CUTLASS_CHECK(gemm_ref.can_implement(arguments)); - CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); - CUTLASS_CHECK(gemm_ref.run()); + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A_dq.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B_dq.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // CUTLASS on SYCL uses the compatibility library syclcompat for e.g. default in-order queue + syclcompat::wait(); - // compare_reference ElementOutput const epsilon(1e-2f); ElementOutput const non_zero_floor(1e-4f); - bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); - return passed; + return cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); } template @@ -462,7 +414,7 @@ struct ExampleRunner { syclcompat::wait(); // Verify that the result is correct - bool passed = verify(options); + bool passed = verify(problem_size, options.alpha, options.beta); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if(!passed) return cutlass::Status::kErrorInternal; diff --git a/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp b/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp index 5aa90672ae..cb22bb6b23 100755 --- a/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp +++ b/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp @@ -242,86 +242,38 @@ struct ExampleRunner { // Methods // - bool verify(const Options &options) { - - // - // Compute reference output (default gemm kernel w/ ElementA == ElementB) - // - - using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; - - // Workgroup-level tile - using TileShape = Shape<_256, _256, _32>; - - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; - - constexpr int PipelineStages = 3; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - - using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U16x8x16_ST_N, - void, void>; - - // Mainloop - using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma< - GEMMDispatchPolicy, - TileShape, - ElementMMA, - cutlass::gemm::TagToStrideA_t, - ElementMMA, - cutlass::gemm::TagToStrideB_t, - TiledMma, - GmemTiledCopyA, void, void, cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; - - using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloopRef, - CollectiveEpilogueRef - >; - - using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; - - typename GemmRef::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {options.m, options.n, options.k, options.l}, - {block_A_dq.get(), stride_A, block_B_dq.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_ref_D.get(), stride_D} - }; - - // Run the gemm where the scaling is performed outside of the kernel. - GemmRef gemm_ref; - size_t workspace_size = GemmRef::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - CUTLASS_CHECK(gemm_ref.can_implement(arguments)); - CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); - CUTLASS_CHECK(gemm_ref.run()); + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A_dq.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B_dq.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // CUTLASS on SYCL uses the compatibility library syclcompat for e.g. default in-order queue + syclcompat::wait(); - // compare_reference ElementOutput const epsilon(1e-2f); ElementOutput const non_zero_floor(1e-4f); - bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); - return passed; + return cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); } template @@ -553,7 +505,7 @@ struct ExampleRunner { syclcompat::wait(); // Verify that the result is correct - bool passed = verify(options); + bool passed = verify(problem_size, options.alpha, options.beta); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if(!passed) return cutlass::Status::kErrorInternal; diff --git a/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp b/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp index be18a9b170..249464e96e 100755 --- a/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp +++ b/examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_s8.cpp @@ -237,86 +237,38 @@ struct ExampleRunner { // Methods // - bool verify(const Options &options) { - - // - // Compute reference output (default gemm kernel w/ ElementA == ElementB) - // - - using GmemTiledCopyA = XE_2D_Packed_U8x32x32_LD_N; - using GmemTiledCopyB = XE_2D_U8x16x32_LD_T; - - // Workgroup-level tile - using TileShape = Shape<_32, _64, _32>; - - using TiledMma = - typename TiledMMAHelper, Layout, - Layout, Stride<_2, _1, _0>>>::TiledMMA; - - constexpr int PipelineStages = 3; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - - using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16_LD_N, - void, void, - XE_2D_U16x8x16_ST_N, - void, void>; - - // Mainloop - using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma< - GEMMDispatchPolicy, - TileShape, - ElementMMA, - cutlass::gemm::TagToStrideA_t, - ElementMMA, - cutlass::gemm::TagToStrideB_t, - TiledMma, - GmemTiledCopyA, void, void, cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; - - using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloopRef, - CollectiveEpilogueRef - >; - - using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; - - typename GemmRef::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {options.m, options.n, options.k, options.l}, - {block_A_dq.get(), stride_A, block_B_dq.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_ref_D.get(), stride_D} - }; - - // Run the gemm where the scaling is performed outside of the kernel. - GemmRef gemm_ref; - size_t workspace_size = GemmRef::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - CUTLASS_CHECK(gemm_ref.can_implement(arguments)); - CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); - CUTLASS_CHECK(gemm_ref.run()); + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A_dq.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B_dq.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + // CUTLASS on SYCL uses the compatibility library syclcompat for e.g. default in-order queue + syclcompat::wait(); - // compare_reference ElementOutput const epsilon(1e-2f); ElementOutput const non_zero_floor(1e-4f); - bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); - return passed; + return cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); } template @@ -644,7 +596,7 @@ struct ExampleRunner { syclcompat::wait(); // Verify that the result is correct - bool passed = verify(options); + bool passed = verify(problem_size, options.alpha, options.beta); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if(!passed) return cutlass::Status::kErrorInternal; diff --git a/include/cute/arch/copy_xe_U8.hpp b/include/cute/arch/copy_xe_U8.hpp index 411e6cb5d8..ceb5b69bb8 100644 --- a/include/cute/arch/copy_xe_U8.hpp +++ b/include/cute/arch/copy_xe_U8.hpp @@ -141,6 +141,23 @@ struct XE_2D_Packed_U8x2x32_LD_N { }; }; +struct XE_2D_U8x1x32_ST_N { + using BlockShape = Shape<_1, _32>; + using inst_dtype = uint16_t; + + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *src) { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + detail::XeSubgroup2DBlockStore<2, 16, 1, 1>{}(baseoffset, width, height, pitch, coord, src); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + struct XE_2D_U8x2x32_ST_N { using BlockShape = Shape<_2, _32>; @@ -157,6 +174,23 @@ struct XE_2D_U8x2x32_ST_N { } }; +struct XE_2D_U8x1x64_ST_N { + using BlockShape = Shape<_1, _64>; + using inst_dtype = uint32_t; + + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *src) { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + detail::XeSubgroup2DBlockStore<4, 16, 1, 1>{}(baseoffset, width, height, pitch, coord, src); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + struct XE_2D_Packed_U8x4x32_LD_N { using BlockShape = Shape<_4, _32>; @@ -311,7 +345,7 @@ struct XE_2D_U8x8x32_LD_N { struct XE_2D_Packed_U8x1x64_LD_N { using BlockShape = Shape<_1, _64>; - + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 5697d3e9a1..a78d0acd18 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -851,11 +851,11 @@ struct Copy_Traits_ : XE_2D_LD_Unpack { using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0,_1>>; + using SrcLayout = Layout>, + Stride< _0, Stride<_1, _8, _512>>>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; + using DstLayout = Layout>, + Stride<_32, Stride<_1, _8, _512>>>; // Reference map from (thr,val) to bit using RefLayout = DstLayout; @@ -2206,6 +2206,24 @@ struct Copy_Traits_ : XE_2D_LD_Unpack(args...) {} }; +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_128,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0,Stride< _1,_128,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + template struct Copy_Traits_ : XE_2D_ST_Unpack { @@ -2224,6 +2242,24 @@ struct Copy_Traits_ : XE_2D_ST_Unpack(args...) {} }; +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1, _8, _512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0,Stride< _1, _8, _512>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + template struct Copy_Traits_ : XE_2D_ST_Unpack { @@ -2667,7 +2703,9 @@ COPY_TRAIT_LD_DEF(XE_2D_U16x32x16_LD_N::PREFETCH) COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_N::PREFETCH) COPY_TRAIT_LD_DEF(XE_2D_U16x32x32_LD_N::PREFETCH) +COPY_TRAIT_ST_DEF(XE_2D_U8x1x32_ST_N) COPY_TRAIT_ST_DEF(XE_2D_U8x2x32_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U8x1x64_ST_N) COPY_TRAIT_ST_DEF(XE_2D_U8x1x16_ST_N) COPY_TRAIT_ST_DEF(XE_2D_U8x2x16_ST_N) COPY_TRAIT_ST_DEF(XE_2D_U8x4x16_ST_N) diff --git a/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp b/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp index aaaf74fb56..b10903762d 100644 --- a/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp +++ b/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp @@ -46,6 +46,26 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct general_same_bits { + using type = T; +}; + +template +struct general_same_bits == 8>> { + using type = int8_t; +}; + +template +struct general_same_bits == 16>> { + using type = int16_t; +}; + +template +struct general_same_bits == 32>> { + using type = int32_t; +}; + template , class = void> struct scale_zero_copy_traits { static_assert(cute::dependent_false, Stride>>, "scale_zero_copy_traits not defined"); @@ -67,7 +87,7 @@ struct scale_zero_copy_traits struct scale_zero_copy_traits == 8 && N >= 32>> { - using type = XE_2D_U8x1x16_LD_N; // XE_2D_U8x1x32_LD_N not work, use this instead + using type = XE_2D_U8x1x32_LD_N; // XE_2D_U8x1x32_LD_N not work, use this instead }; // 16 bits @@ -191,8 +211,8 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); - static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); + static_assert(cute::is_same_v, "Transformation for A is not currently supported on Intel PVC"); + static_assert(cute::is_same_v, "Transformation for B is not currently supported on Intel PVC"); private: @@ -413,20 +433,20 @@ struct CollectiveMma< transform_quant( Tensor const& in, Tensor& out, - Tensor& tCrS_input, - Tensor& tCrZ_input + Tensor const& tCrS_input, + Tensor const& tCrZ_input ) { static_assert(is_rmem::value, "Input tensor for conversion must come from registers"); static_assert(size_v == cosize_v); - static_assert(std::is_same_v); + static_assert(cute::is_same_v); using SrcType = typename EngineIn::value_type; using DstType = typename EngineOut::value_type; using ZeroType = typename EngineZeros::value_type; using ScaleType = typename EngineScales::value_type; - static constexpr bool is_quantization = !((cutlass::platform::numeric_limits::is_integer && cutlass::platform::numeric_limits::is_integer) - || (cutlass::platform::is_floating_point::value && cutlass::platform::is_floating_point::value)); + static constexpr bool is_quantization = cutlass::platform::numeric_limits::is_integer + ^ cutlass::platform::numeric_limits::is_integer; static constexpr auto DPAS = decltype(size<0>(in))::value; static constexpr auto N = decltype(size<1>(in))::value; @@ -443,9 +463,7 @@ struct CollectiveMma< static constexpr auto splits = loop_cnt / vec_size; static_assert(vec_size <= scalar); - if (std::is_same_v) { - return; - } + static_assert(!cute::is_same_v); // reshape tensors for easy access auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape, Int>{}); @@ -453,27 +471,30 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int n = 0; n < N; n++) { - const auto ts = tCrS_input(n); - const auto tz = [&](){ + const auto scale = tCrS_input(n); + const auto zero = [&]() { if constexpr (sizeof_bits_v >= 8) { return tCrZ_input(n); } else { - return tCrZ_input(n).get(); + return (int8_t)(tCrZ_input(n).get()); } }(); - auto& src = *(cute::array*)(s_tensor(_, n).data()); + auto& src = *(cute::intel::vector_t*)(s_tensor(_, n).data()); CUTLASS_PRAGMA_UNROLL for (int s = 0; s < splits; s++) { auto idx = vec_size * s / scalar; auto format_data = src[idx]; - auto& dst = *(cute::array*)(d_tensor(_, s, n).data()); + // for performance, _Float16 have better performance than half_t here + using vector_type = typename general_same_bits::type; + + auto& dst = *(cute::intel::vector_t*)(d_tensor(_, s, n).data()); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < vec_size; i++) { - auto data = [&](){ + auto data = [&]() { if constexpr (cutlass::platform::numeric_limits::is_signed) { return static_cast((format_data >> (src_bits * i)) & 0xf); } else { @@ -485,15 +506,14 @@ struct CollectiveMma< if constexpr (IsATransformed) { static_assert(dependent_false && "ATransform not support now"); } else { - using ret_type = cute::conditional_t >= 8, ZeroType, int8_t>; - ret_type minus(data); - if constexpr (ModeScaleZero) { - minus = static_cast(data) - static_cast(tz); + if constexpr (ModeScale) { + dst[i] = cutlass::platform::bit_cast(static_cast(data * scale)); + } else if constexpr (ModeScaleZero) { + dst[i] = cutlass::platform::bit_cast(static_cast((static_cast(data) - zero) * scale)); } - dst[i] = (static_cast(minus)) * ts; } } else { - dst[i] = static_cast(data); + dst[i] = cutlass::platform::bit_cast(static_cast(data)); } } } @@ -514,25 +534,40 @@ struct CollectiveMma< transform_quant( Tensor const& in, Tensor& out, - Tensor& tCrS_input, - Tensor& tCrZ_input + Tensor const& tCrS_input, + Tensor const& tCrZ_input ) { static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); static_assert(size_v == cosize_v); - static_assert(std::is_same_v); + static_assert(cute::is_same_v); using SrcType = typename EngineIn::value_type; using DstType = typename EngineOut::value_type; using ZeroType = typename EngineZeros::value_type; using ScaleType = typename EngineScales::value_type; - using MmaType = DstType; - if constexpr (!std::is_same_v) { - if constexpr(cute::is_any_of_v - && cute::is_any_of_v) { - convert_FP8_to_FP16(make_tensor(reinterpret_cast(in.data()), in.layout()), out); - } else { + static_assert(!cute::is_same_v); + + if constexpr (cute::is_any_of_v + && cute::is_any_of_v) { + convert_FP8_to_FP16(make_tensor(reinterpret_cast(in.data()), in.layout()), out); + } else if constexpr (!ModeHasScales) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < decltype(size(in))::value; ++i) { + out[i] = static_cast(in[i]); + } + } else if constexpr (IsATransformed) { + // The current scale load atom (1x32) gives 2 scale values to + // each thread. All threads need access to all other threads + // scale values, and each scale value is reused twice (unrolled) + + static constexpr auto M = decltype(size<1>(in))::value; + static constexpr auto K = decltype(size(in))::value / 8 / M; + + static constexpr auto is_dst_int = cutlass::platform::numeric_limits::is_integer; + + if constexpr (!is_dst_int) { auto const& src = in(_, _, _); auto const& dst = out(_, _, _); auto pSrc = const_cast(raw_pointer_cast(src.data())); @@ -546,79 +581,62 @@ struct CollectiveMma< using DstArray = cutlass::Array; constexpr int iters = num_elements / pack; - if constexpr (!cutlass::platform::numeric_limits::is_integer) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < iters; ++i) { - SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; - DstArray* pDstArr = reinterpret_cast(pDst) + i; - *pDstArr = Converter::convert(*pSrcArr); - } - } else if constexpr (!ModeHasScales) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < decltype(size(in))::value; ++i) { - out[i] = static_cast(in[i]); - } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + *pDstArr = Converter::convert(*pSrcArr); } + } - if constexpr (ModeHasScales) { - if constexpr(IsATransformed){ - // The current scale load atom (1x32) gives 2 scale values to - // each thread. All threads need access to all other threads - // scale values, and each scale value is reused twice (unrolled) - - static constexpr auto M = decltype(size<1>(in))::value; - static constexpr auto K = decltype(size(in))::value / 8 / M; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 16 ; ++i) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < M / 2; ++m) { + const auto scale = shfl_sync(0xFFFFFFFF, tCrS_input(m), i); + const auto zero = [&]() { + if constexpr (sizeof_bits_v >= 8) { + return shfl_sync(0xFFFFFFFF, tCrZ_input(m), i); + } else { + return shfl_sync(0xFFFFFFFF, tCrZ_input(m).get(), i); + } + }(); + if constexpr (is_dst_int) { // quantization CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 16 ; ++i) { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < M / 2; ++m) { - auto scale = shfl_sync(0xFFFFFFFF, tCrS_input(m), i); - auto zero = - [&]() { - if constexpr (sizeof_bits_v >= 8) { - return shfl_sync(0xFFFFFFFF, tCrZ_input(m), i); - } else { - return shfl_sync(0xFFFFFFFF, tCrZ_input(m).get(), i); - } - }(); - - if constexpr (cutlass::platform::numeric_limits::is_integer) { // quantization - for (int k = 0; k < K; k++) { - out[2 * (m * 16 + i) + k] = in[2 * (m * 16 + i) + k] / scale; - if constexpr (ModeScaleZero) { - out[2 * (m * 16 + i) + k] += zero; - } - } - } else { // dequantization - for (int k = 0; k < K; k++) { - if constexpr (ModeScaleZero) { - out(_, _, k)[m * 16 + i] -= zero; - } - out(_, _, k)[m * 16 + i] *= scale; - } - } + for (int k = 0; k < K; k++) { + out[2 * (m * 16 + i) + k] = in[2 * (m * 16 + i) + k] / scale; + if constexpr (ModeScaleZero) { + out[2 * (m * 16 + i) + k] += zero; } } - } else { - static constexpr auto N = decltype(size<1>(in))::value; - + } else { // dequantization CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < N; ++n) { - auto [zero, scale] = (is_groupwise) ? cute::make_tuple(tCrZ_input(n), tCrS_input(n)) : cute::make_tuple(tCrZ_input(0), tCrS_input(0)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < decltype(size(in))::value / N; ++i) { - ZeroType minus_zp = static_cast(in(_, n, _)[i]); - if constexpr (ModeScaleZero) { - minus_zp -= zero; - } - out(_, n, _)[i] = static_cast(minus_zp) * scale; + for (int k = 0; k < K; k++) { + if constexpr (ModeScaleZero) { + out(_, _, k)[m * 16 + i] -= zero; } + out(_, _, k)[m * 16 + i] *= scale; } } } } + } else { + static constexpr auto N = decltype(size<1>(in))::value; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < N; ++n) { + auto [zero, scale] = (is_groupwise) ? cute::make_tuple(tCrZ_input(n), tCrS_input(n)) : cute::make_tuple(tCrZ_input(0), tCrS_input(0)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < decltype(size(in))::value / N; ++i) { + ZeroType minus_zp = static_cast(in(_, n, _)[i]); + if constexpr (ModeScaleZero) { + minus_zp -= zero; + } + out(_, n, _)[i] = static_cast(minus_zp) * scale; + } + } } } @@ -721,7 +739,7 @@ struct CollectiveMma< Tensor quant_frag_B = make_tensor(mma_B.layout()); auto frag_copy_A = [&]() -> decltype(auto) { - if constexpr (std::is_same_v) { + if constexpr (cute::is_same_v) { return thr_copy_A.retile_D(mma_A); } else { return thr_copy_A.retile_D(quant_frag_A); @@ -729,7 +747,7 @@ struct CollectiveMma< }(); auto frag_copy_B = [&]() -> decltype(auto) { - if constexpr (std::is_same_v) { + if constexpr (cute::is_same_v) { return thr_copy_B.retile_D(mma_B); } else { return thr_copy_B.retile_D(quant_frag_B); @@ -822,16 +840,24 @@ struct CollectiveMma< const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); int prefetch_k = k_start_idx; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { - prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); - prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); + // Prefetch does not always bring benefits in all scenarios, + // Use "DispatchPolicy::Stages" to control whether prefetching is needed. + static constexpr auto prefetch_enabled = (DispatchPolicy::Stages > 0); + + if constexpr (prefetch_enabled) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { + prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); + } } - for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++) { constexpr int barrier_scope = 2; - barrier_arrive(barrier_scope); + if constexpr (prefetch_enabled) { + barrier_arrive(barrier_scope); + } // Copy gmem to rmem for the first k_tile copy(mainloop.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A); @@ -848,9 +874,10 @@ struct CollectiveMma< } } - if(prefetch_k < k_tile_count) { + if constexpr (prefetch_enabled) { prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); + prefetch_k++; } auto quant_zero = [&]() -> decltype(auto) { @@ -861,11 +888,19 @@ struct CollectiveMma< } }(); - transform_quant(quant_frag_A, mma_A, fragment_scale, quant_zero); - transform_quant(quant_frag_B, mma_B, fragment_scale, quant_zero); + if constexpr (!cute::is_same_v) { + transform_quant(quant_frag_A, mma_A, fragment_scale, quant_zero); + } + + if constexpr (!cute::is_same_v) { + transform_quant(quant_frag_B, mma_B, fragment_scale, quant_zero); + } cute::gemm(tiled_mma, mma_A, mma_B, accum); - barrier_wait(barrier_scope); + + if constexpr (prefetch_enabled) { + barrier_wait(barrier_scope); + } } } }; diff --git a/test/unit/cute/intel_xe/copy_block.cpp b/test/unit/cute/intel_xe/copy_block.cpp index 4f287617ce..59ef352e46 100644 --- a/test/unit/cute/intel_xe/copy_block.cpp +++ b/test/unit/cute/intel_xe/copy_block.cpp @@ -340,6 +340,7 @@ TEST(PVC_CuTe_Xe, block_2d_32bits_n) { TEST(PVC_CuTe_Xe, block_2d_8bits_n) { copy_op{}(); + copy_op{}(); copy_op{}(); copy_op{}(); copy_op{}(); From 949d3e79b351b98e2c6347db578c39f82181d51b Mon Sep 17 00:00:00 2001 From: taozha2 Date: Wed, 27 Aug 2025 12:28:16 +0800 Subject: [PATCH 2/2] fix issues --- include/cutlass/gemm/collective/xe_mma_mixed_input.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp b/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp index b10903762d..0505e94a96 100644 --- a/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp +++ b/include/cutlass/gemm/collective/xe_mma_mixed_input.hpp @@ -87,7 +87,7 @@ struct scale_zero_copy_traits struct scale_zero_copy_traits == 8 && N >= 32>> { - using type = XE_2D_U8x1x32_LD_N; // XE_2D_U8x1x32_LD_N not work, use this instead + using type = XE_2D_U8x1x16_LD_N; // XE_2D_U8x1x32_LD_N not work, use this instead }; // 16 bits