Skip to content

Conversation

Insideyyy
Copy link

@Insideyyy Insideyyy commented Sep 5, 2025

This PR adds split-k optimization for sm90, reduce partitioned d through DSMEM.
Currently support fp8 & bf16 Normal, MGroupedContiguous, MGroupedMasked gemms on sm90.

fp8_gemm_1d2d on H20:

m x n x k TFLOPS w/o split-k TFLOPS w/ split-k (optional)
128 x 64 x 8192 12 21
128 x 128 x 8192 24 35
128 x 256 x 8192 47 64
128 x 1024 x 8192 137 137
128 x 1280 x 8192 137 151
256 x 64 x 8192 24 32
256 x 128 x 8192 47 64
256 x 256 x 8192 93 93
256 x 1024 x 8192 181 180
256 x 1280 x 8192 190 198

bf16_gemm on H20:

m x n x k TFLOPS w/o split-k TFLOPS w/ split-k (optional)
128 x 64 x 8192 7 15
128 x 128 x 8192 13 25
128 x 256 x 8192 26 42
128 x 1024 x 8192 76 76
128 x 1280 x 8192 76 90
256 x 64 x 8192 13 21
256 x 128 x 8192 26 41
256 x 256 x 8192 52 51
256 x 1024 x 8192 99 99
256 x 1280 x 8192 104 112

Notes:

  • Split-k is enabled automatically if possible to improve SM utilization.
  • The k_slices partitions of same (m_block_idx, n_block_idx) are assigned to k_slices SMs within a thread block cluster, so that the intermediate results could be reduced through DSMEM.

@LyricZhao
Copy link
Collaborator

Great point for some shapes, may take some time to merge. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants