-
-
Notifications
You must be signed in to change notification settings - Fork 9.7k
[Kernel] Add FP8 support with FlashMLA backend #22668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Add FP8 support with FlashMLA backend #22668
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds FP8 support for the FlashMLA attention backend. The changes include updating the FlashMLA sources in CMake, modifying the FlashMLA Python ops to handle FP8 scaling factors, and adding FP8 data types to the FlashMLA tests. The logic for enabling FP8 in the attention backend seems correct and is consistently applied across the codebase. I have one suggestion to improve code clarity in the test suite.
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
2c339b7
to
b548b10
Compare
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
79bfbd5
to
8dfbf29
Compare
Signed-off-by: Matthew Bonanni <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Matthew Bonanni <[email protected]>
CI build fails because FlashMLA git tag hasn't been updated yet, will update prior to merge EDIT: Done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MatthewBonanni this is great work, FP8 MLA is an important feature to have. In general, the PR LGTM, left some minor comments. Some questions about perf:
- The benchmark you did used prompt/decode 512/512, would be interesting to see perf numbers for 8192/1024 prompt/decode. It should exercise the prefill chunking more aggressively.
- Would be interesting to profile the MLA code to see how much improvement FP8 MLA related code is actually making vs the other MLA code. Maybe there is some low hanging fruit to improve (the code that surrounds the new FP8 MLA)
- Did you had a chance to compare perf of F8 MLA vs plain CUTLASS MLA on B200 (fp16)? Just to know that FP8 helps vs CUTLASS MLA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really great work! Overall seems very close to landable; left a couple comments
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
@alexm-redhat Thanks for your review!
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
…supported (#96) Fixes vllm-project/vllm#22668 - we need to take one more arg. Signed-off-by: Marcin Swiniarski <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]> Signed-off-by: root <[email protected]>
…supported (#96) Fixes vllm-project/vllm#22668 - we need to take one more arg. Signed-off-by: Marcin Swiniarski <[email protected]> Signed-off-by: Marcin Swiniarski <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
Enable FP8 KV cache with MLA
Test Plan
Correctness
Accuracy
With
kv_cache_type = "auto"
:VLLM_ATTENTION_BACKEND=FLASHMLA lm_eval --model vllm --model_args '{"pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat", "trust_remote_code": true, "kv_cache_dtype": "auto"}' --tasks gsm8k --batch_size auto
With
kv_cache_type = "fp8"
:VLLM_ATTENTION_BACKEND=FLASHMLA lm_eval --model vllm --model_args '{"pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat", "trust_remote_code": true, "kv_cache_dtype": "fp8"}' --tasks gsm8k --batch_size auto
Performance
With
kv_cache_type = "auto"
:VLLM_ATTENTION_BACKEND=FLASHMLA chg run --gpus 1 -- vllm bench throughput --model=deepseek-ai/DeepSeek-V2-Lite-Chat --dataset-name=random --input-len=512 --output-len=512 --num-prompts=10000 --kv-cache-dtype=auto
With
kv_cache_type = "fp8"
:VLLM_ATTENTION_BACKEND=FLASHMLA chg run --gpus 1 -- vllm bench throughput --model=deepseek-ai/DeepSeek-V2-Lite-Chat --dataset-name=random --input-len=512 --output-len=512 --num-prompts=10000 --kv-cache-dtype=fp8
Test Result
Correctness
Tests pass
Accuracy
With
kv_cache_type = "auto"
:With
kv_cache_type = "fp8"
:Performance
On 1x H100:
Here are the results for 512/512:
--kv-cache-dtype=auto: Throughput: 26.37 requests/s, 26975.81 total tokens/s, 13499.72 output tokens/s
--kv-cache-dtype=fp8: Throughput: 27.99 requests/s, 28635.39 total tokens/s, 14330.23 output tokens/s
Here are the results for 8192/1024:
--kv-cache-dtype=auto: Throughput: 2.40 requests/s, 22143.47 total tokens/s, 2460.56 output tokens/s
--kv-cache-dtype=fp8: Throughput: 3.25 requests/s, 29971.81 total tokens/s, 3330.44 output tokens/s
(Optional) Documentation Update