Skip to content

Fix to allow all gemm tile shapes#738

Open
vidyasiv wants to merge 2 commits intointel:mainfrom
vidyasiv:fix_python_gemm_tile_shapes
Open

Fix to allow all gemm tile shapes#738
vidyasiv wants to merge 2 commits intointel:mainfrom
vidyasiv:fix_python_gemm_tile_shapes

Conversation

@vidyasiv
Copy link

@vidyasiv vidyasiv commented Mar 6, 2026

Description

Previously GEMM operations were supported for only one (default) tile size 256,256,32 . This PR removes that limitation

Supported OPs list

Previous: Click to expand
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x256x32_1x1x1_0_nnt_align8
With this PR: Click to expand
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f32_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f32_df32_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_f32_f16_df16_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_f16_f16_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_f16_f16_f16_void_f16_df16_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_f32_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_f32_df32_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_f32_bf16_dbf16_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_f32_void_bf16_dbf16_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_bf16_bf16_64x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_128x256x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_128x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_64x128x32_1x1x1_0_ttt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_128x256x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_128x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_64x128x32_1x1x1_0_tnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_128x256x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_128x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_64x128x32_1x1x1_0_ntt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_128x256x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_256x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_128x128x32_1x1x1_0_nnt_align8
cutlass3x_xe12_tensorop_gemm_bf16_bf16_bf16_void_bf16_dbf16_64x128x32_1x1x1_0_nnt_align8

torch mm operation output

Click to view test code
import time
import torch
import logging
def gemm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
	return a @ b

def main() -> None:

	torch._inductor.config.max_autotune_gemm_backends = "CUTLASS" # sycl-tla kernel use 38.112 ms
	# torch._inductor.config.max_autotune_gemm_backends = "TRITON" # triton kernel use 2.336 ms
	# torch._inductor.config.max_autotune_gemm_backends = "ATen" # onednn kernel use 1.957 ms

	torch.manual_seed(0)
	device = "xpu"
	dtype = torch.float16
	m, n, k = 4096, 4096, 4096

	a = torch.randn((m, k), device=device, dtype=dtype)
	b = torch.randn((k, n), device=device, dtype=dtype)

	compiled_gemm = torch.compile(gemm, mode="max-autotune")

	# Warmup
	for _ in range(5):
		compiled_gemm(a, b)
	torch.xpu.synchronize()

	# Timing
	iters = 20
	start = time.time()
	for _ in range(iters):
		compiled_gemm(a, b)
	torch.xpu.synchronize()
	elapsed = (time.time() - start) / iters

	# Correctness check
	ref = gemm(a, b)
	out = compiled_gemm(a, b)
	max_err = (out - ref).abs().max().item()

	print(f"avg latency: {elapsed * 1e3:.3f} ms")
	print(f"max error: {max_err:.6f}")

if __name__ == "__main__":
	main()

Previous output:

Autotune Choices Stats:
{"num_choices": 4, "num_triton_choices": 0, "best_kernel": "cutlass_50216987", "best_kernel_desc": "cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=2", "best_time": 0.53504}
AUTOTUNE mm(4096x4096, 4096x4096)
strides: [4096, 1], [4096, 1]
dtypes: torch.float16, torch.float16
  cutlass_50216987 0.5350 ms 100.0% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=2
  cutlass_50216987 0.5380 ms 99.4% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=4
  cutlass_50216987 0.5605 ms 95.5% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=1
  cutlass_50216987 0.6414 ms 83.4% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=8
SingleProcess AUTOTUNE benchmarking takes 0.6772 seconds and 37.4367 seconds precompiling for 4 choices
avg latency: 0.506 ms
max error: 0.000000

Output with this PR:

Autotune Choices Stats:
{"num_choices": 20, "num_triton_choices": 0, "best_kernel": "cutlass_50216987", "best_kernel_desc": "cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=8", "best_time": 0.61476}
AUTOTUNE mm(4096x4096, 4096x4096)
strides: [4096, 1], [4096, 1]
dtypes: torch.float16, torch.float16
  cutlass_50216987 0.6148 ms 100.0% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=8
  cutlass_50216987 0.6154 ms 99.9% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=1
  cutlass_50216987 0.6154 ms 99.9% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=4
  cutlass_50216987 0.6155 ms 99.9% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x256x32_1x1x1_0_ttt_align8 swizzle=2
  cutlass_6514dbfa 0.7586 ms 81.0% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x128x32_1x1x1_0_ttt_align8 swizzle=1
  cutlass_6514dbfa 0.7971 ms 77.1% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x128x32_1x1x1_0_ttt_align8 swizzle=4
  cutlass_6514dbfa 0.8026 ms 76.6% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_128x128x32_1x1x1_0_ttt_align8 swizzle=2
  cutlass_23b288b8 0.8098 ms 75.9% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x128x32_1x1x1_0_ttt_align8 swizzle=8
  cutlass_23b288b8 0.8102 ms 75.9% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x128x32_1x1x1_0_ttt_align8 swizzle=1
  cutlass_23b288b8 0.8107 ms 75.8% cutlass3x_xe12_tensorop_gemm_f16_f16_f32_void_f16_df16_256x128x32_1x1x1_0_ttt_align8 swizzle=4
SingleProcess AUTOTUNE benchmarking takes 3.2547 seconds and 39.3417 seconds precompiling for 20 choices
avg latency: 0.507 ms
max error: 0.000000

Type

  • Bug - [ ] Feature - [ ] Performance - [ ] Refactor

Testing

  • Tests pass - [ ] Xe12 - [ ] Xe20

Performance

Metric Before After

References

Fixes #

Checklist

  • Copyright - [ ] Co-pilot Review - [ ] Deprecated APIs not used

@Antonyvance
Copy link

I am wondering if we need to add different tile description optons for PVC and BMG.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants