Skip to content

Commit 1c0248e

Browse files
committed
improve mixed data type performance
1 parent ce061da commit 1c0248e

File tree

9 files changed

+333
-366
lines changed

9 files changed

+333
-366
lines changed

benchmarks/gemm/benchmark_runner.hpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,16 @@ namespace cutlass::benchmark {
6161

6262
///////////////////////////////////////////////////////////////////////////////////////////////////
6363

64-
#if defined(SYCL_INTEL_TARGET)
65-
template <class T, int Stages = 0>
64+
template <class T>
6665
static constexpr auto is_mixed_dtype = false;
6766

67+
#if defined(SYCL_INTEL_TARGET)
6868
template <int Stages>
6969
static constexpr auto is_mixed_dtype<cutlass::gemm::MainloopIntelXeXMX16MixedPrecision<Stages>> = true;
70-
#else
71-
template <class T, int Stages = 0>
72-
static constexpr auto is_mixed_dtype = false;
7370
#endif
7471

75-
template <class T, class = void>
72+
// ScaleType
73+
template <class, class = void>
7674
struct ScaleType {
7775
using type = int;
7876
};
@@ -81,7 +79,8 @@ struct ScaleType<T, cute::void_t<typename T::ElementScale>> {
8179
using type = typename T::ElementScale;
8280
};
8381

84-
template <class T, class = void>
82+
// ZeroType
83+
template <class, class = void>
8584
struct ZeroType {
8685
using type = int;
8786
};
@@ -90,7 +89,8 @@ struct ZeroType<T, cute::void_t<typename T::ElementZero>> {
9089
using type = typename T::ElementZero;
9190
};
9291

93-
template <class T, class = void>
92+
// ScaleStride
93+
template <class, class = void>
9494
struct ScaleStride {
9595
using type = int;
9696
};
@@ -99,7 +99,8 @@ struct ScaleStride<T, cute::void_t<typename T::StrideScale>> {
9999
using type = typename T::StrideScale;
100100
};
101101

102-
template <class T, class = void>
102+
// ZeroStride
103+
template <class, class = void>
103104
struct ZeroStride {
104105
using type = int;
105106
};

benchmarks/gemm/benchmarks_sycl.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ using PvcMixedPrecisionGemmFP16U4FP16S8FP16S4_RCR_1 = cutlass::gemm::device::Mix
357357
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
358358
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
359359
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N,
360-
cutlass::epilogue::fusion::LinearCombination<int, int,
361-
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
360+
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
361+
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
362362
2
363363
>;
364364

@@ -373,8 +373,8 @@ using PvcMixedPrecisionGemmFP16U4S8S8FP16S4_RCR_1 = cutlass::gemm::device::Mixed
373373
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
374374
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
375375
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N,
376-
cutlass::epilogue::fusion::LinearCombination<int, int,
377-
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
376+
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
377+
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
378378
2
379379
>;
380380

@@ -389,8 +389,8 @@ using PvcMixedPrecisionGemmBF16U4BF16S8BF16S4_RCR_1 = cutlass::gemm::device::Mix
389389
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
390390
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
391391
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U16x8x16_ST_N,
392-
cutlass::epilogue::fusion::LinearCombination<int, int,
393-
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
392+
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
393+
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
394394
2
395395
>;
396396

@@ -405,8 +405,8 @@ using PvcMixedPrecisionGemmBF16U4S8S8BF16S4_RCR_1 = cutlass::gemm::device::Mixed
405405
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
406406
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
407407
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U4x32x16_LD_T, XE_2D_U8x8x16_ST_N,
408-
cutlass::epilogue::fusion::LinearCombination<int, int,
409-
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
408+
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
409+
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
410410
2
411411
>;
412412

@@ -421,8 +421,8 @@ using PvcMixedPrecisionGemmBF16S8BF16S8BF16S8_RCR_1 = cutlass::gemm::device::Mix
421421
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
422422
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
423423
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N,
424-
cutlass::epilogue::fusion::LinearCombination<int, int,
425-
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
424+
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
425+
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
426426
2
427427
>;
428428

@@ -437,8 +437,8 @@ using PvcMixedPrecisionGemmFP16S8FP16S8FP16S8_RCR_1 = cutlass::gemm::device::Mix
437437
typename TiledMMAHelper<MMA_Atom<XE_8x16x32_S32S8S8S32_TT>, Layout<Shape<_32, _128, _32>>,
438438
Layout<Shape<_1, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA,
439439
XE_2D_Packed_U16x32x32_LD_N, XE_2D_U8x16x32_LD_T, XE_2D_U16x8x16_ST_N,
440-
cutlass::epilogue::fusion::LinearCombination<int, int,
441-
int, int, cutlass::FloatRoundStyle::round_to_nearest>,
440+
cutlass::epilogue::fusion::LinearCombination<int32_t, int32_t,
441+
int32_t, int32_t, cutlass::FloatRoundStyle::round_to_nearest>,
442442
2
443443
>;
444444

examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_bf16_s8_bf16.cpp

Lines changed: 30 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -253,86 +253,38 @@ struct ExampleRunner {
253253
// Methods
254254
//
255255

256-
bool verify(const Options &options) {
257-
258-
//
259-
// Compute reference output (default gemm kernel w/ ElementA == ElementB)
260-
//
261-
262-
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
263-
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
264-
265-
// Workgroup-level tile
266-
using TileShape = Shape<_256, _256, _32>;
267-
268-
using TiledMma =
269-
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
270-
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
271-
272-
constexpr int PipelineStages = 3;
273-
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
274-
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
275-
276-
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementCompute,
277-
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
278-
279-
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
280-
decltype(tile_shape(TiledMma()))>;
281-
282-
using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue<
283-
EpilogueDispatchPolicy,
284-
TileShape,
285-
ElementAccumulator,
286-
cutlass::gemm::TagToStrideC_t<LayoutC>,
287-
ElementOutput,
288-
cutlass::gemm::TagToStrideC_t<LayoutD>,
289-
FusionCallBacks,
290-
XE_2D_U32x8x16_LD_N,
291-
void, void,
292-
XE_2D_U32x8x16_ST_N,
293-
void, void>;
294-
295-
// Mainloop
296-
using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma<
297-
GEMMDispatchPolicy,
298-
TileShape,
299-
ElementMMA,
300-
cutlass::gemm::TagToStrideA_t<LayoutA>,
301-
ElementMMA,
302-
cutlass::gemm::TagToStrideB_t<LayoutB>,
303-
TiledMma,
304-
GmemTiledCopyA, void, void, cute::identity, // A
305-
GmemTiledCopyB, void, void, cute::identity // B
306-
>;
307-
308-
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
309-
Shape<int, int, int, int>,
310-
CollectiveMainloopRef,
311-
CollectiveEpilogueRef
312-
>;
313-
314-
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
315-
316-
typename GemmRef::Arguments arguments{
317-
cutlass::gemm::GemmUniversalMode::kGemm,
318-
{options.m, options.n, options.k, options.l},
319-
{block_A_dq.get(), stride_A, block_B_dq.get(), stride_B},
320-
{{options.alpha, options.beta}, block_C.get(), stride_C, block_ref_D.get(), stride_D}
321-
};
322-
323-
// Run the gemm where the scaling is performed outside of the kernel.
324-
GemmRef gemm_ref;
325-
size_t workspace_size = GemmRef::get_workspace_size(arguments);
326-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
327-
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
328-
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
329-
CUTLASS_CHECK(gemm_ref.run());
256+
bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) {
257+
auto [M, N, K, L] = problem_size;
258+
259+
cutlass::TensorRef ref_A(block_A_dq.get(), LayoutA::packed({M, K}));
260+
cutlass::TensorRef ref_B(block_B_dq.get(), LayoutB::packed({K, N}));
261+
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N}));
262+
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N}));
263+
264+
cutlass::reference::device::GemmComplex(
265+
{M, N, K},
266+
alpha,
267+
ref_A,
268+
cutlass::ComplexTransform::kNone,
269+
ref_B,
270+
cutlass::ComplexTransform::kNone,
271+
beta,
272+
ref_C,
273+
ref_D,
274+
ElementAccumulator(0),
275+
L, // batch_count
276+
M * K, // batch_stride_A
277+
K * N, // batch_stride_B
278+
M * N, // batch_stride_C
279+
M * N // batch_stride_D
280+
);
281+
282+
// CUTLASS on SYCL uses the compatibility library syclcompat for e.g. default in-order queue
283+
syclcompat::wait();
330284

331-
// compare_reference
332285
ElementOutput const epsilon(1e-2f);
333286
ElementOutput const non_zero_floor(1e-4f);
334-
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
335-
return passed;
287+
return cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
336288
}
337289

338290
template <class Element>
@@ -462,7 +414,7 @@ struct ExampleRunner {
462414
syclcompat::wait();
463415

464416
// Verify that the result is correct
465-
bool passed = verify(options);
417+
bool passed = verify(problem_size, options.alpha, options.beta);
466418
std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
467419

468420
if(!passed) return cutlass::Status::kErrorInternal;

examples/sycl/02_bmg_gemm_mixed_dtype/02_bmg_gemm_f16_u4_f16.cpp

Lines changed: 30 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -242,86 +242,38 @@ struct ExampleRunner {
242242
// Methods
243243
//
244244

245-
bool verify(const Options &options) {
246-
247-
//
248-
// Compute reference output (default gemm kernel w/ ElementA == ElementB)
249-
//
250-
251-
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
252-
using GmemTiledCopyB = XE_2D_U16x16x16_LD_T;
253-
254-
// Workgroup-level tile
255-
using TileShape = Shape<_256, _256, _32>;
256-
257-
using TiledMma =
258-
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
259-
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
260-
261-
constexpr int PipelineStages = 3;
262-
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
263-
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
264-
265-
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementAccumulator, ElementCompute,
266-
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
267-
268-
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
269-
decltype(tile_shape(TiledMma()))>;
270-
271-
using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue<
272-
EpilogueDispatchPolicy,
273-
TileShape,
274-
ElementAccumulator,
275-
cutlass::gemm::TagToStrideC_t<LayoutC>,
276-
ElementOutput,
277-
cutlass::gemm::TagToStrideC_t<LayoutD>,
278-
FusionCallBacks,
279-
XE_2D_U32x8x16_LD_N,
280-
void, void,
281-
XE_2D_U16x8x16_ST_N,
282-
void, void>;
283-
284-
// Mainloop
285-
using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma<
286-
GEMMDispatchPolicy,
287-
TileShape,
288-
ElementMMA,
289-
cutlass::gemm::TagToStrideA_t<LayoutA>,
290-
ElementMMA,
291-
cutlass::gemm::TagToStrideB_t<LayoutB>,
292-
TiledMma,
293-
GmemTiledCopyA, void, void, cute::identity, // A
294-
GmemTiledCopyB, void, void, cute::identity // B
295-
>;
296-
297-
using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal<
298-
Shape<int, int, int, int>,
299-
CollectiveMainloopRef,
300-
CollectiveEpilogueRef
301-
>;
302-
303-
using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelRef>;
304-
305-
typename GemmRef::Arguments arguments{
306-
cutlass::gemm::GemmUniversalMode::kGemm,
307-
{options.m, options.n, options.k, options.l},
308-
{block_A_dq.get(), stride_A, block_B_dq.get(), stride_B},
309-
{{options.alpha, options.beta}, block_C.get(), stride_C, block_ref_D.get(), stride_D}
310-
};
311-
312-
// Run the gemm where the scaling is performed outside of the kernel.
313-
GemmRef gemm_ref;
314-
size_t workspace_size = GemmRef::get_workspace_size(arguments);
315-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
316-
CUTLASS_CHECK(gemm_ref.can_implement(arguments));
317-
CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get()));
318-
CUTLASS_CHECK(gemm_ref.run());
245+
bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) {
246+
auto [M, N, K, L] = problem_size;
247+
248+
cutlass::TensorRef ref_A(block_A_dq.get(), LayoutA::packed({M, K}));
249+
cutlass::TensorRef ref_B(block_B_dq.get(), LayoutB::packed({K, N}));
250+
cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N}));
251+
cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N}));
252+
253+
cutlass::reference::device::GemmComplex(
254+
{M, N, K},
255+
alpha,
256+
ref_A,
257+
cutlass::ComplexTransform::kNone,
258+
ref_B,
259+
cutlass::ComplexTransform::kNone,
260+
beta,
261+
ref_C,
262+
ref_D,
263+
ElementAccumulator(0),
264+
L, // batch_count
265+
M * K, // batch_stride_A
266+
K * N, // batch_stride_B
267+
M * N, // batch_stride_C
268+
M * N // batch_stride_D
269+
);
270+
271+
// CUTLASS on SYCL uses the compatibility library syclcompat for e.g. default in-order queue
272+
syclcompat::wait();
319273

320-
// compare_reference
321274
ElementOutput const epsilon(1e-2f);
322275
ElementOutput const non_zero_floor(1e-4f);
323-
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
324-
return passed;
276+
return cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor);
325277
}
326278

327279
template <class Element>
@@ -553,7 +505,7 @@ struct ExampleRunner {
553505
syclcompat::wait();
554506

555507
// Verify that the result is correct
556-
bool passed = verify(options);
508+
bool passed = verify(problem_size, options.alpha, options.beta);
557509
std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl;
558510

559511
if(!passed) return cutlass::Status::kErrorInternal;

0 commit comments

Comments
 (0)