[Perf] Optimize FP8 gemm of sm120.#34424
[Perf] Optimize FP8 gemm of sm120.#34424wenshuai-xiaomi wants to merge 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: wenshuai <wenshuai@xiaomi.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance optimizations for FP8 GEMM on sm120 by adding specialized kernels for smaller M dimensions, which is well-supported by the provided benchmarks. The changes are clear and well-commented. I have one suggestion to improve the robustness of the new custom kernel definition by adding an architecture guard, consistent with other parts of the codebase.
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh
Outdated
Show resolved
Hide resolved
Signed-off-by: wenshuai <wenshuai@xiaomi.com>
Purpose
Optimize FP8 gemm of sm120 at any input shape.
Test Plan
benchmarks/kernels# python3 bench_fp8_gemm.py
run bench_fp8_gemm.py and compare the results with the patch or not.
Test Result
the current restut without the patch
root@de-22309-vllm-0-15-1-5090-0211172246-79ccd9bbf9-6jfzz:/vllm/benchmarks/kernels# python3 bench_fp8_gemm.py
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
batch_size torch-bf16 (TFLOP/s (larger is better)) fp8-channel-w-token-a-noquant (TFLOP/s (larger is better))
0 1.0 3.344568 1.024770
1 16.0 40.995152 16.455449
2 64.0 45.736745 66.282656
3 128.0 166.934308 131.791062
4 256.0 95.930135 262.376881
5 512.0 130.277803 260.550933
6 1024.0 135.348890 346.748240
7 2048.0 181.197632 423.735653
8 4096.0 217.213004 421.863460
9 8192.0 217.972362 448.194999
10 16384.0 230.377563 455.878270
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
batch_size torch-bf16 (TFLOP/s (larger is better)) fp8-channel-w-token-a-noquant (TFLOP/s (larger is better))
0 1.0 2.263270 0.687744
1 16.0 29.630899 11.015585
2 64.0 86.362940 44.033500
3 128.0 168.295722 87.967286
4 256.0 160.279088 175.910714
5 512.0 180.310192 346.710305
6 1024.0 176.772548 349.077112
7 2048.0 180.936688 345.857536
8 4096.0 180.823892 398.767823
9 8192.0 207.076365 433.350413
10 16384.0 221.810967 451.221271
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
batch_size torch-bf16 (TFLOP/s (larger is better)) fp8-channel-w-token-a-noquant (TFLOP/s (larger is better))
0 1.0 1.672496 1.669070
1 16.0 24.180134 28.221204
2 64.0 52.315158 136.948860
3 128.0 103.835064 269.587963
4 256.0 155.564739 400.063653
5 512.0 209.379791 400.541531
6 1024.0 210.884048 446.680304
7 2048.0 230.708248 442.389577
8 4096.0 232.133293 456.806899
9 8192.0 234.854875 462.058209
10 16384.0 237.644796 460.937606
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
batch_size torch-bf16 (TFLOP/s (larger is better)) fp8-channel-w-token-a-noquant (TFLOP/s (larger is better))
0 1.0 1.518173 0.720043
1 16.0 19.784419 11.522944
2 64.0 68.747089 46.091139
3 128.0 151.251493 92.150711
4 256.0 193.489921 184.144740
5 512.0 214.527360 364.379841
6 1024.0 181.161535 362.360832
7 2048.0 184.422913 359.474919
8 4096.0 184.317055 415.080517
9 8192.0 211.196123 444.822998
10 16384.0 227.567523 456.602628
Benchmark finished!
with the patch
root@de-22309-vllm-0-15-1-5090-0211172246-79ccd9bbf9-6jfzz:/vllm/benchmarks/kernels# python3 bench_fp8_gemm.py
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
batch_size torch-bf16 (TFLOP/s (larger is better)) fp8-channel-w-token-a-noquant (TFLOP/s (larger is better))
0 1.0 3.342850 4.128756
1 16.0 40.915532 89.481096
2 64.0 45.726899 218.686183
3 128.0 166.963724 227.250401
4 256.0 95.772855 319.665674
5 512.0 130.486043 260.787669
6 1024.0 135.389482 346.598435
7 2048.0 181.199159 423.918463
8 4096.0 217.099188 422.501862
9 8192.0 217.915771 445.575639
10 16384.0 230.359433 455.356509
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
batch_size torch-bf16 (TFLOP/s (larger is better)) fp8-channel-w-token-a-noquant (TFLOP/s (larger is better))
0 1.0 2.274765 3.716156
1 16.0 29.568187 63.239882
2 64.0 86.566854 149.832803
3 128.0 168.023253 286.921626
4 256.0 160.296810 301.421506
5 512.0 179.728678 345.488658
6 1024.0 176.916451 348.885494
7 2048.0 181.024093 345.692769
8 4096.0 180.814668 398.704495
9 8192.0 207.152592 433.540198
10 16384.0 224.049211 450.394314
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
batch_size torch-bf16 (TFLOP/s (larger is better)) fp8-channel-w-token-a-noquant (TFLOP/s (larger is better))
0 1.0 1.674452 3.214388
1 16.0 24.146398 50.854301
2 64.0 52.276680 197.029327
3 128.0 103.893205 324.303896
4 256.0 155.527548 372.306472
5 512.0 209.104991 399.643764
6 1024.0 210.812200 446.255617
7 2048.0 230.768658 442.181018
8 4096.0 231.991775 457.899588
9 8192.0 234.925537 461.821327
10 16384.0 237.397110 463.450713
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
batch_size torch-bf16 (TFLOP/s (larger is better)) fp8-channel-w-token-a-noquant (TFLOP/s (larger is better))
0 1.0 1.517410 4.878684
1 16.0 19.801098 80.358029
2 64.0 68.683802 171.915276
3 128.0 151.241149 337.289947
4 256.0 193.546784 322.995322
5 512.0 214.325228 363.278695
6 1024.0 181.201461 362.528802
7 2048.0 184.573575 359.762664
8 4096.0 184.299785 415.034690
9 8192.0 211.104738 444.967840
10 16384.0 227.424222 454.878541
Benchmark finished!
batch size from 1 to 256 be optimized