-
-
Notifications
You must be signed in to change notification settings - Fork 9.8k
[Performance] Introduce Marlin-based GEMM kernels for the calibration-free RTN-based quantization #23197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Alex Kogan <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Could you explain the need for adding marlin kernels specifically for this case? It seems you could get the same results by providing dummy scales to the existing marlin impl, is that right? |
Yes, I think you are right. In theory, there are a number of Marlin kernels in the vLLM code base that could be used for RTN, e.g., https://github.com/vllm-project/vllm/blob/main/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu. But I have not seen a version that could be used as-is, e.g., the one linked above seems to support only FP16. So I thought it would be better to create a separate version for RTN, which could also be tuned in the future without impacting other quantization schemes. I did try to reuse as many helper functions as possible, hence all the includes for |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Alex Kogan <[email protected]>
This PR enhances the work started in #18768 and #20766 by introducing Marlin-based kernels for the calibration-free RTN-based quantization.
These kernels substantially improve the performance of dense models quantized with RTN.
We ran
benchmark_latency
with several Llama models on a machine equipped with H100 GPUs. The exact command was[RTN_NUM_BITS=4] python benchmark_latency.py --model <model> --n 1 --num-iters-warmup 3 --num-iters 10 --input-len 256 --output-len 32 -tp <#GPUs> --batch-size <batch> [--quantization rtn]
Each data point is an average of 5 runs, the units are seconds (measuring generation latency, the lower the better).
Here are the results for Llama3.1-8B (ran on 1 GPU), for various batch sizes:
Here are the results for Llama3.3-70B (ran on 4 GPUs), for various batch sizes: