Skip to content
Open
25 changes: 14 additions & 11 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,30 +345,33 @@ int main(int argc, const char** argv)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

// The 2D block copy operations used for the A and B matrices
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
// [New Copy Atom] When left unspecified (void), make_block_2d_copy_* automatically selects
// appropriate 2D block copy operations for matrices A and B. Alternatively, you can
// explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI
// (applicable only to matrix B), or XE_LOAD_2D_TRANSPOSE.
// Refer https://github.com/intel/sycl-tla/blob/petercad/rearchitecture/media/docs/cpp/xe_rearchitecture.md
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;

// A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
// hardware (sub-groups for Intel BMG) and iterations by each sub-group.
//
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
// (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
// TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses
// the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1).
// The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// single contiguous chunk of the work-group TileShape. For this configuration, this implies that
// each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
// 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
// performance reasons.
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;

// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
// For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
Expand Down
Loading