Skip to content

Conversation

Wangzheee
Copy link

@Wangzheee Wangzheee commented Jun 19, 2025

support group gemm offset type: group_gemm_offset, and group_gemm_offset_swapAB

  • Performance:
    • Remove random: group_ms = [int(expected_m_per_group * random.uniform(1, 1)) for _ in range(num_groups)]:
    • m<64 about 30%~50% kernel speedup
      • Testing grouped masked GEMM:
        Perf (num_groups=2, expected_m_per_group= 16, n=4096, k=7168): 36 us | throughput: 53 TFLOPS, 1665 GB/s
        Perf (num_groups=4, expected_m_per_group= 16, n=4096, k=7168): 65 us | throughput: 58 TFLOPS, 1813 GB/s
        Perf (num_groups=2, expected_m_per_group= 32, n=4096, k=7168): 35 us | throughput: 106 TFLOPS, 1685 GB/s
        Perf (num_groups=9, expected_m_per_group= 32, n=4096, k=7168): 141 us | throughput: 120 TFLOPS, 1900 GB/s
        Perf (num_groups=2, expected_m_per_group= 32, n=4096, k=7168): 35 us | throughput: 106 TFLOPS, 1689 GB/s
        Perf (num_groups=4, expected_m_per_group= 32, n=4096, k=7168): 66 us | throughput: 115 TFLOPS, 1822 GB/s
        Perf (num_groups=32, expected_m_per_group= 64, n=4096, k=7168): 485 us | throughput: 248 TFLOPS, 2002 GB/s
      • Testing grouped offset GEMM:
        Perf (num_groups= 2, expected_m_per_group= 16, n=4096, k=7168): 27 us | throughput: 71 TFLOPS, 2226 GB/s
        Perf (num_groups= 4, expected_m_per_group= 16, n=4096, k=7168): 46 us | throughput: 82 TFLOPS, 2587 GB/s
        Perf (num_groups= 2, expected_m_per_group= 32, n=4096, k=7168): 28 us | throughput: 134 TFLOPS, 2136 GB/s
        Perf (num_groups= 9, expected_m_per_group= 32, n=4096, k=7168): 93 us | throughput: 183 TFLOPS, 2902 GB/s
        Perf (num_groups= 2, expected_m_per_group= 32, n=4096, k=7168): 28 us | throughput: 135 TFLOPS, 2143 GB/s
        Perf (num_groups= 4, expected_m_per_group= 32, n=4096, k=7168): 49 us | throughput: 152 TFLOPS, 2414 GB/s
        Perf (num_groups=32, expected_m_per_group= 64, n=4096, k=7168): 479 us | throughput: 251 TFLOPS, 2029 GB/s

@LyricZhao
Copy link
Collaborator

Thanks for your contribution! We will merge it after the refactor #112.

@Wangzheee
Copy link
Author

Thanks for your contribution! We will merge it after the refactor #112.

Thank you for your reply.
We are still working on W4Afp8 for NormalGEMM and GroupedGEMM. Does DeepGEMM have any plans to develop project W4Afp8?

@Chtholly-Boss
Copy link

I reproduce the benchmark results on H20:

Testing grouped masked GEMM:
 > Perf (num_groups=2, expected_m_per_group=  16, n=4096, k=7168):   36 us | throughput:   52 TFLOPS, 1650 GB/s
 > Perf (num_groups=4, expected_m_per_group=  16, n=4096, k=7168):   66 us | throughput:   57 TFLOPS, 1791 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   36 us | throughput:  105 TFLOPS, 1674 GB/s
 > Perf (num_groups=9, expected_m_per_group=  32, n=4096, k=7168):  142 us | throughput:  119 TFLOPS, 1894 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   36 us | throughput:  105 TFLOPS, 1670 GB/s
 > Perf (num_groups=4, expected_m_per_group=  32, n=4096, k=7168):   66 us | throughput:  114 TFLOPS, 1806 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  485 us | throughput:  248 TFLOPS, 2000 GB/s

Testing grouped offset GEMM:
 > Perf (num_groups= 2, expected_m_per_group=  16, n=4096, k=7168):   27 us | throughput:   70 TFLOPS, 2211 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  16, n=4096, k=7168):   46 us | throughput:   82 TFLOPS, 2571 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   28 us | throughput:  134 TFLOPS, 2125 GB/s
 > Perf (num_groups= 9, expected_m_per_group=  32, n=4096, k=7168):   93 us | throughput:  182 TFLOPS, 2896 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   28 us | throughput:  133 TFLOPS, 2113 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  32, n=4096, k=7168):   50 us | throughput:  150 TFLOPS, 2382 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  481 us | throughput:  250 TFLOPS, 2018 GB/s

It should be noted that the above results' TFLOPS is measured by setting all elements as 1. Refer to strangely-matrix-multiplications, the initialization method will influence the benchmark result.
So I also run the original benchmark with random(0.7,1.3), following show the results:

Testing grouped masked GEMM:
 > Perf (num_groups=2, expected_m_per_group=  16, n=4096, k=7168):   35 us | throughput:   61 TFLOPS, 1670 GB/s
 > Perf (num_groups=4, expected_m_per_group=  16, n=4096, k=7168):   66 us | throughput:   54 TFLOPS, 1790 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   36 us | throughput:  128 TFLOPS, 1681 GB/s
 > Perf (num_groups=9, expected_m_per_group=  32, n=4096, k=7168):  142 us | throughput:  112 TFLOPS, 1884 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   36 us | throughput:   89 TFLOPS, 1663 GB/s
 > Perf (num_groups=4, expected_m_per_group=  32, n=4096, k=7168):   66 us | throughput:  105 TFLOPS, 1810 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  690 us | throughput:  169 TFLOPS, 1406 GB/s

Testing grouped offset GEMM:
 > Perf (num_groups= 2, expected_m_per_group=  16, n=4096, k=7168):   27 us | throughput:   61 TFLOPS, 2208 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  16, n=4096, k=7168):   53 us | throughput:   70 TFLOPS, 2218 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   40 us | throughput:   93 TFLOPS, 1476 GB/s
 > Perf (num_groups= 9, expected_m_per_group=  32, n=4096, k=7168):  107 us | throughput:  162 TFLOPS, 2506 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   43 us | throughput:   88 TFLOPS, 1403 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  32, n=4096, k=7168):   60 us | throughput:  138 TFLOPS, 2006 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  687 us | throughput:  176 TFLOPS, 1413 GB/s

I also test on H100:

Testing grouped masked GEMM:
 > Perf (num_groups=2, expected_m_per_group=  16, n=4096, k=7168):   48 us | throughput:   45 TFLOPS, 1227 GB/s
 > Perf (num_groups=4, expected_m_per_group=  16, n=4096, k=7168):   82 us | throughput:   44 TFLOPS, 1438 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   48 us | throughput:   95 TFLOPS, 1238 GB/s
 > Perf (num_groups=9, expected_m_per_group=  32, n=4096, k=7168):  166 us | throughput:   96 TFLOPS, 1618 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   48 us | throughput:   66 TFLOPS, 1231 GB/s
 > Perf (num_groups=4, expected_m_per_group=  32, n=4096, k=7168):   82 us | throughput:   84 TFLOPS, 1447 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  723 us | throughput:  161 TFLOPS, 1341 GB/s

Testing grouped offset GEMM:
 > Perf (num_groups= 2, expected_m_per_group=  16, n=4096, k=7168):   47 us | throughput:   35 TFLOPS, 1256 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  16, n=4096, k=7168):   81 us | throughput:   47 TFLOPS, 1468 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   56 us | throughput:   67 TFLOPS, 1065 GB/s
 > Perf (num_groups= 9, expected_m_per_group=  32, n=4096, k=7168):  183 us | throughput:   95 TFLOPS, 1465 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   63 us | throughput:   60 TFLOPS,  955 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  32, n=4096, k=7168):   86 us | throughput:   96 TFLOPS, 1392 GB/s
Traceback (most recent call last):
  File "/home/s_sunqianqi/mayibin/test/DeepGEMM/tests/test_core.py", line 414, in <module>
    test_m_grouped_gemm_offset()
  File "/home/s_sunqianqi/mayibin/test/DeepGEMM/tests/test_core.py", line 385, in test_m_grouped_gemm_offset
    deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group)
  File "/home/s_sunqianqi/mayibin/test/DeepGEMM/deep_gemm/jit_kernels/m_grouped_gemm.py", line 264, in m_grouped_gemm_fp8_fp8_bf16_nt_offset
    assert m % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
AssertionError: For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: 128)

The above results shows that on H100, this approach couldn't bring any benefits. And there might be a bug related to the block assumption.

@Wangzheee
Copy link
Author

Wangzheee commented Jul 28, 2025

I reproduce the benchmark results on H20:
It should be noted that the above results' TFLOPS is measured by setting all elements as 1. Refer to strangely-matrix-multiplications, the initialization method will influence the benchmark result. So I also run the original benchmark with random(0.7,1.3), following show the results:

Testing grouped masked GEMM:
 > Perf (num_groups=2, expected_m_per_group=  16, n=4096, k=7168):   35 us | throughput:   61 TFLOPS, 1670 GB/s
 > Perf (num_groups=4, expected_m_per_group=  16, n=4096, k=7168):   66 us | throughput:   54 TFLOPS, 1790 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   36 us | throughput:  128 TFLOPS, 1681 GB/s
 > Perf (num_groups=9, expected_m_per_group=  32, n=4096, k=7168):  142 us | throughput:  112 TFLOPS, 1884 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   36 us | throughput:   89 TFLOPS, 1663 GB/s
 > Perf (num_groups=4, expected_m_per_group=  32, n=4096, k=7168):   66 us | throughput:  105 TFLOPS, 1810 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  690 us | throughput:  169 TFLOPS, 1406 GB/s

Testing grouped offset GEMM:
 > Perf (num_groups= 2, expected_m_per_group=  16, n=4096, k=7168):   27 us | throughput:   61 TFLOPS, 2208 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  16, n=4096, k=7168):   53 us | throughput:   70 TFLOPS, 2218 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   40 us | throughput:   93 TFLOPS, 1476 GB/s
 > Perf (num_groups= 9, expected_m_per_group=  32, n=4096, k=7168):  107 us | throughput:  162 TFLOPS, 2506 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   43 us | throughput:   88 TFLOPS, 1403 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  32, n=4096, k=7168):   60 us | throughput:  138 TFLOPS, 2006 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  687 us | throughput:  176 TFLOPS, 1413 GB/s

I also test on H100:

Testing grouped masked GEMM:
 > Perf (num_groups=2, expected_m_per_group=  16, n=4096, k=7168):   48 us | throughput:   45 TFLOPS, 1227 GB/s
 > Perf (num_groups=4, expected_m_per_group=  16, n=4096, k=7168):   82 us | throughput:   44 TFLOPS, 1438 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   48 us | throughput:   95 TFLOPS, 1238 GB/s
 > Perf (num_groups=9, expected_m_per_group=  32, n=4096, k=7168):  166 us | throughput:   96 TFLOPS, 1618 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   48 us | throughput:   66 TFLOPS, 1231 GB/s
 > Perf (num_groups=4, expected_m_per_group=  32, n=4096, k=7168):   82 us | throughput:   84 TFLOPS, 1447 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  723 us | throughput:  161 TFLOPS, 1341 GB/s

Testing grouped offset GEMM:
 > Perf (num_groups= 2, expected_m_per_group=  16, n=4096, k=7168):   47 us | throughput:   35 TFLOPS, 1256 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  16, n=4096, k=7168):   81 us | throughput:   47 TFLOPS, 1468 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   56 us | throughput:   67 TFLOPS, 1065 GB/s
 > Perf (num_groups= 9, expected_m_per_group=  32, n=4096, k=7168):  183 us | throughput:   95 TFLOPS, 1465 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   63 us | throughput:   60 TFLOPS,  955 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  32, n=4096, k=7168):   86 us | throughput:   96 TFLOPS, 1392 GB/s

This test has already set random(0.7,1.3)
https://github.com/deepseek-ai/DeepGEMM/pull/116/files#diff-938de46e643ab091cff6a7c23e4e752e8907e2266f47889d988923352f7a1058R218
May I ask how you run the individual test? What specific code was modified?

@Chtholly-Boss
Copy link

Chtholly-Boss commented Jul 28, 2025

I reproduce the benchmark results on H20:
It should be noted that the above results' TFLOPS is measured by setting all elements as 1. Refer to strangely-matrix-multiplications, the initialization method will influence the benchmark result. So I also run the original benchmark with random(0.7,1.3), following show the results:

Testing grouped masked GEMM:
 > Perf (num_groups=2, expected_m_per_group=  16, n=4096, k=7168):   35 us | throughput:   61 TFLOPS, 1670 GB/s
 > Perf (num_groups=4, expected_m_per_group=  16, n=4096, k=7168):   66 us | throughput:   54 TFLOPS, 1790 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   36 us | throughput:  128 TFLOPS, 1681 GB/s
 > Perf (num_groups=9, expected_m_per_group=  32, n=4096, k=7168):  142 us | throughput:  112 TFLOPS, 1884 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   36 us | throughput:   89 TFLOPS, 1663 GB/s
 > Perf (num_groups=4, expected_m_per_group=  32, n=4096, k=7168):   66 us | throughput:  105 TFLOPS, 1810 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  690 us | throughput:  169 TFLOPS, 1406 GB/s

Testing grouped offset GEMM:
 > Perf (num_groups= 2, expected_m_per_group=  16, n=4096, k=7168):   27 us | throughput:   61 TFLOPS, 2208 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  16, n=4096, k=7168):   53 us | throughput:   70 TFLOPS, 2218 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   40 us | throughput:   93 TFLOPS, 1476 GB/s
 > Perf (num_groups= 9, expected_m_per_group=  32, n=4096, k=7168):  107 us | throughput:  162 TFLOPS, 2506 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   43 us | throughput:   88 TFLOPS, 1403 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  32, n=4096, k=7168):   60 us | throughput:  138 TFLOPS, 2006 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  687 us | throughput:  176 TFLOPS, 1413 GB/s

I also test on H100:

Testing grouped masked GEMM:
 > Perf (num_groups=2, expected_m_per_group=  16, n=4096, k=7168):   48 us | throughput:   45 TFLOPS, 1227 GB/s
 > Perf (num_groups=4, expected_m_per_group=  16, n=4096, k=7168):   82 us | throughput:   44 TFLOPS, 1438 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   48 us | throughput:   95 TFLOPS, 1238 GB/s
 > Perf (num_groups=9, expected_m_per_group=  32, n=4096, k=7168):  166 us | throughput:   96 TFLOPS, 1618 GB/s
 > Perf (num_groups=2, expected_m_per_group=  32, n=4096, k=7168):   48 us | throughput:   66 TFLOPS, 1231 GB/s
 > Perf (num_groups=4, expected_m_per_group=  32, n=4096, k=7168):   82 us | throughput:   84 TFLOPS, 1447 GB/s
 > Perf (num_groups=32, expected_m_per_group=  64, n=4096, k=7168):  723 us | throughput:  161 TFLOPS, 1341 GB/s

Testing grouped offset GEMM:
 > Perf (num_groups= 2, expected_m_per_group=  16, n=4096, k=7168):   47 us | throughput:   35 TFLOPS, 1256 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  16, n=4096, k=7168):   81 us | throughput:   47 TFLOPS, 1468 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   56 us | throughput:   67 TFLOPS, 1065 GB/s
 > Perf (num_groups= 9, expected_m_per_group=  32, n=4096, k=7168):  183 us | throughput:   95 TFLOPS, 1465 GB/s
 > Perf (num_groups= 2, expected_m_per_group=  32, n=4096, k=7168):   63 us | throughput:   60 TFLOPS,  955 GB/s
 > Perf (num_groups= 4, expected_m_per_group=  32, n=4096, k=7168):   86 us | throughput:   96 TFLOPS, 1392 GB/s

This test has already set random(0.7,1.3) https://github.com/deepseek-ai/DeepGEMM/pull/116/files#diff-938de46e643ab091cff6a7c23e4e752e8907e2266f47889d988923352f7a1058R218 May I ask how you run the individual test? What specific code was modified?

When I run the original code with random(0.7,1.3), I observe that the throughputs mismatch the results in this pr. So I suspected the data come from fixed value of input. Then I apply random(1,1) to the input construction function, and all results matched with those in the pr.

As for how I run the test, just run python3 test/test_core.py, and only run the follows:

if __name__ == '__main__':
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.manual_seed(0)
    random.seed(0)

    print('Library path:')
    print(f' > {deep_gemm.__path__}\n')

    # test_gemm()
    # test_m_grouped_gemm_contiguous()
    test_m_grouped_gemm_masked()
    test_m_grouped_gemm_offset()

    # test_wgrad_gemm()
    # test_k_grouped_wgrad_gemm()

@chengmengli06
Copy link

chengmengli06 commented Aug 13, 2025

how to integrate offset gemm with deepep low latency kernels? there seems to be a huge gap. @Wangzheee

@LyricZhao
Copy link
Collaborator

Closing this as duplicated with #192. Thanks!

@LyricZhao LyricZhao closed this Sep 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants