-
Notifications
You must be signed in to change notification settings - Fork 711
support group_gemm_offset, group_gemm_offset_swapAB #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support group_gemm_offset, group_gemm_offset_swapAB #116
Conversation
Thanks for your contribution! We will merge it after the refactor #112. |
Thank you for your reply. |
I reproduce the benchmark results on H20:
It should be noted that the above results' TFLOPS is measured by setting all elements as 1. Refer to strangely-matrix-multiplications, the initialization method will influence the benchmark result.
I also test on H100:
The above results shows that on H100, this approach couldn't bring any benefits. And there might be a bug related to the block assumption. |
This test has already set random(0.7,1.3) |
When I run the original code with As for how I run the test, just run if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
# test_gemm()
# test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
test_m_grouped_gemm_offset()
# test_wgrad_gemm()
# test_k_grouped_wgrad_gemm() |
how to integrate offset gemm with deepep low latency kernels? there seems to be a huge gap. @Wangzheee |
Closing this as duplicated with #192. Thanks! |
support group gemm offset type: group_gemm_offset, and group_gemm_offset_swapAB
Perf (num_groups=2, expected_m_per_group= 16, n=4096, k=7168): 36 us | throughput: 53 TFLOPS, 1665 GB/s
Perf (num_groups=4, expected_m_per_group= 16, n=4096, k=7168): 65 us | throughput: 58 TFLOPS, 1813 GB/s
Perf (num_groups=2, expected_m_per_group= 32, n=4096, k=7168): 35 us | throughput: 106 TFLOPS, 1685 GB/s
Perf (num_groups=9, expected_m_per_group= 32, n=4096, k=7168): 141 us | throughput: 120 TFLOPS, 1900 GB/s
Perf (num_groups=2, expected_m_per_group= 32, n=4096, k=7168): 35 us | throughput: 106 TFLOPS, 1689 GB/s
Perf (num_groups=4, expected_m_per_group= 32, n=4096, k=7168): 66 us | throughput: 115 TFLOPS, 1822 GB/s
Perf (num_groups=32, expected_m_per_group= 64, n=4096, k=7168): 485 us | throughput: 248 TFLOPS, 2002 GB/s
Perf (num_groups= 2, expected_m_per_group= 16, n=4096, k=7168): 27 us | throughput: 71 TFLOPS, 2226 GB/s
Perf (num_groups= 4, expected_m_per_group= 16, n=4096, k=7168): 46 us | throughput: 82 TFLOPS, 2587 GB/s
Perf (num_groups= 2, expected_m_per_group= 32, n=4096, k=7168): 28 us | throughput: 134 TFLOPS, 2136 GB/s
Perf (num_groups= 9, expected_m_per_group= 32, n=4096, k=7168): 93 us | throughput: 183 TFLOPS, 2902 GB/s
Perf (num_groups= 2, expected_m_per_group= 32, n=4096, k=7168): 28 us | throughput: 135 TFLOPS, 2143 GB/s
Perf (num_groups= 4, expected_m_per_group= 32, n=4096, k=7168): 49 us | throughput: 152 TFLOPS, 2414 GB/s
Perf (num_groups=32, expected_m_per_group= 64, n=4096, k=7168): 479 us | throughput: 251 TFLOPS, 2029 GB/s