Skip to content

[Kernel] Fuse FP8 output quantization into merge_attn_states#36518

Merged
ProExpertProg merged 15 commits intovllm-project:mainfrom
carlyou:fuse-fp8-quant-merge-attn-states
Apr 3, 2026
Merged

[Kernel] Fuse FP8 output quantization into merge_attn_states#36518
ProExpertProg merged 15 commits intovllm-project:mainfrom
carlyou:fuse-fp8-quant-merge-attn-states

Conversation

@carlyou
Copy link
Copy Markdown
Contributor

@carlyou carlyou commented Mar 9, 2026

Purpose

Closes #33097

  • Add optional output_scale parameter to merge_attn_states (CUDA kernel, Triton kernel, Python bindings, dispatcher) for fused FP8 static per-tensor quantization
  • When output_scale is provided, the kernel quantizes merged attention output directly to FP8 during the final store, eliminating a separate quantization kernel launch and BF16 memory round-trip
  • Backward compatible — existing callers that omit output_scale get the original BF16/FP16/FP32 behaviour unchanged

CUDA kernel changes

  • Template on output_t and USE_FP8_OUTPUT constexpr bool; uses scaled_fp8_conversion for FP8 stores
  • Thread mapping stays based on input pack_size (128-bit loads); FP8 path stores smaller output packs (64-bit for BF16 input, 32-bit for float input)

Triton kernel changes

  • Pre-inverts scale on CPU (1.0 / scale) so the kernel does fast multiplication instead of division
  • USE_FP8: tl.constexpr flag gates clamp + cast before the final tl.store

Test Plan

  • Unit test (CI): test_merge_attn_states.py passes
  • Benchmark to compare performance
  • smoke test: compile and test serve on H100 / B200

Test Result

❯ python -m pytest tests/kernels/attention/test_merge_attn_states.py
=========================================================================================== test session starts ============================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /home/carlyou/projects/vllm
configfile: pyproject.toml
plugins: anyio-4.12.0
collected 972 items

tests/kernels/attention/test_merge_attn_states.py .................................................................................................................................................. [ 15%]
.................................................................................................................................................................................................... [ 35%]
.................................................................................................................................................................................................... [ 55%]
.................................................................................................................................................................................................... [ 75%]
.................................................................................................................................................................................................... [ 95%]
..........................................                                                                                                                                                           [100%]

============================================================================================= warnings summary =============================================================================================
(hidden)
=============================================================================== 972 passed, 3 warnings in 165.46s (0:02:45) ================================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute
======================================================================================
  fuse-fp8-quant-merge-attn-states  —  H100 SXM 80GB
=====================================================================================

FUSED CUDA vs UNFUSED CUDA
-------------------------------------------------------------------------------------
  tokens         heads        fused (us)      unfused (us)               speedup
-------------------------------------------------------------------------------------
      1         4-128         1.62-1.90         2.66-2.99         1.41x - 1.73x
      16         4-128         1.78-2.11         2.91-3.52         1.57x - 1.70x
      64         4-128         1.81-3.27         3.08-5.43         1.53x - 1.70x
    256         4-128        1.95-10.08        3.23-22.54         1.46x - 2.24x
    1024         4-128        2.21-51.39        3.69-95.96         1.46x - 2.20x
    4096         4-128       3.83-197.55       5.74-370.38         1.48x - 2.21x

  Overall range: 1.41x - 2.24x
  Overall mean:  1.65x


FUSED TRITON vs UNFUSED TRITON
-------------------------------------------------------------------------------------
  tokens         heads        fused (us)      unfused (us)               speedup
-------------------------------------------------------------------------------------
      1         4-128         1.22-1.41         2.35-2.68         1.90x - 2.05x
      16         4-128         1.35-2.30         2.60-3.77         1.64x - 1.97x
      64         4-128         1.42-5.99         2.78-8.08         1.31x - 1.97x
    256         4-128        1.70-21.33        3.08-27.91         1.21x - 1.81x
    1024         4-128        3.53-81.06       5.10-110.13         1.17x - 1.48x
    4096         4-128      10.92-318.44      13.34-431.39         1.17x - 1.36x

  Overall range: 1.17x - 2.05x
  Overall mean:  1.58x

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable optimization by fusing FP8 output quantization into the merge_attn_states kernel. However, security vulnerabilities were identified, including a potential buffer overflow in the CUDA kernel due to missing data type and shape validation when the output is not quantized, and a potential denial-of-service vulnerability in the Triton implementation caused by a zero-valued output scale leading to a division-by-zero exception. These issues highlight a lack of robust input tensor validation. Additionally, an area of code duplication in the CUDA kernel could be refactored to improve maintainability.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ba229809ff

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

carlyou added 13 commits April 2, 2026 15:03
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
- Rename out_t -> output_t, FP8_OUTPUT -> USE_FP8_OUTPUT for clarity
- Pre-invert FP8 scale in kernel (1.0f / *output_scale) and use
  scaled_fp8_conversion<true> (multiply) instead of <false> (divide),
  matching the established pattern in layernorm_quant_kernels.cu,
  common.cu, and activation_kernels.cu
- Restore the original -inf edge case comment explaining MLA chunked
  prefill behavior
- Parametrize FP8 tests on output_scale (0.5, 0.05) with
  scale-dependent tolerances for better coverage
- Parametrize FP8 tests on use_output_lse (True, False) to test
  combined FP8 output + LSE output path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
- Add merge_attn_states benchmark script (fused FP8 vs unfused)
- Parametrize FP8 test over float32/float16/bfloat16 inputs to cover
  all output_pack_t paths (uint/uint2)
- Tighten FP8 test tolerances to match FP8 e4m3 representable spacing
- Fix output_lse shape docstring: [h,d] -> [h,n]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Benchmark:
- Use triton.testing.perf_report with do_bench_cudagraph
- Use torch.compiled QuantFP8 for unfused baseline
- Add --tp axis (1,2,4,8) dividing num_heads
- Add --dtype CLI arg for input dtype selection
- Remove non-FP8 benchmark (not relevant to this PR)
- Fix DeepSeek-V3 MLA head_size: 512 -> 128 (v_head_dim)

Triton kernel:
- Remove .item() call on output_scale (avoids GPU-CPU sync)
- Pre-compute scale inverse as tensor, load via tl.load()

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Avoid launching a separate PyTorch kernel for 1.0/output_scale
by computing the reciprocal inside the Triton kernel via tl.load.

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
@carlyou carlyou force-pushed the fuse-fp8-quant-merge-attn-states branch from b91d7b8 to 57a8cba Compare April 2, 2026 22:09
@ProExpertProg ProExpertProg enabled auto-merge (squash) April 3, 2026 01:17
@ProExpertProg ProExpertProg merged commit 3bc2734 into vllm-project:main Apr 3, 2026
140 of 141 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Apr 3, 2026
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Apr 3, 2026
@carlyou carlyou deleted the fuse-fp8-quant-merge-attn-states branch April 4, 2026 02:35
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Apr 6, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
askliar pushed a commit to netanel-haber/vllm that referenced this pull request Apr 7, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
askliar pushed a commit to netanel-haber/vllm that referenced this pull request Apr 7, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
askliar pushed a commit to netanel-haber/vllm that referenced this pull request Apr 7, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
USTCKAY pushed a commit to USTCKAY/vllm that referenced this pull request Apr 7, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Song Kai <songkai05@baidu.com>
rishitdholakia13 pushed a commit to rishitdholakia13/vllm that referenced this pull request Apr 7, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: rishitdholakia13 <rishit+github@cohere.com>
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
big-yellow-duck pushed a commit to EmbeddedLLM/vllm that referenced this pull request Apr 8, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
jackcfwang pushed a commit to jackcfwang/vllm that referenced this pull request Apr 10, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: jackcfwang <jackcfwang@tencent.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 10, 2026
…oject#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation llama Related to Llama models nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

[Feature]: Fuse FP8 output quantization into merge_attn_states (DCP / cascade paths)

2 participants