[Kernel] Fuse FP8 output quantization into merge_attn_states#36518
[Kernel] Fuse FP8 output quantization into merge_attn_states#36518ProExpertProg merged 15 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable optimization by fusing FP8 output quantization into the merge_attn_states kernel. However, security vulnerabilities were identified, including a potential buffer overflow in the CUDA kernel due to missing data type and shape validation when the output is not quantized, and a potential denial-of-service vulnerability in the Triton implementation caused by a zero-valued output scale leading to a division-by-zero exception. These issues highlight a lack of robust input tensor validation. Additionally, an area of code duplication in the CUDA kernel could be refactored to improve maintainability.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ba229809ff
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
- Rename out_t -> output_t, FP8_OUTPUT -> USE_FP8_OUTPUT for clarity - Pre-invert FP8 scale in kernel (1.0f / *output_scale) and use scaled_fp8_conversion<true> (multiply) instead of <false> (divide), matching the established pattern in layernorm_quant_kernels.cu, common.cu, and activation_kernels.cu - Restore the original -inf edge case comment explaining MLA chunked prefill behavior - Parametrize FP8 tests on output_scale (0.5, 0.05) with scale-dependent tolerances for better coverage - Parametrize FP8 tests on use_output_lse (True, False) to test combined FP8 output + LSE output path Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
- Add merge_attn_states benchmark script (fused FP8 vs unfused) - Parametrize FP8 test over float32/float16/bfloat16 inputs to cover all output_pack_t paths (uint/uint2) - Tighten FP8 test tolerances to match FP8 e4m3 representable spacing - Fix output_lse shape docstring: [h,d] -> [h,n] Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Benchmark: - Use triton.testing.perf_report with do_bench_cudagraph - Use torch.compiled QuantFP8 for unfused baseline - Add --tp axis (1,2,4,8) dividing num_heads - Add --dtype CLI arg for input dtype selection - Remove non-FP8 benchmark (not relevant to this PR) - Fix DeepSeek-V3 MLA head_size: 512 -> 128 (v_head_dim) Triton kernel: - Remove .item() call on output_scale (avoids GPU-CPU sync) - Pre-compute scale inverse as tensor, load via tl.load() Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Avoid launching a separate PyTorch kernel for 1.0/output_scale by computing the reciprocal inside the Triton kernel via tl.load. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
b91d7b8 to
57a8cba
Compare
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com> Signed-off-by: Song Kai <songkai05@baidu.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com> Signed-off-by: rishitdholakia13 <rishit+github@cohere.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com> Signed-off-by: Rishi Puri <riship@nvidia.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com> Signed-off-by: jackcfwang <jackcfwang@tencent.com>
…oject#36518) Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Purpose
Closes #33097
output_scaleparameter tomerge_attn_states(CUDA kernel, Triton kernel, Python bindings, dispatcher) for fused FP8 static per-tensor quantizationoutput_scaleis provided, the kernel quantizes merged attention output directly to FP8 during the final store, eliminating a separate quantization kernel launch and BF16 memory round-tripoutput_scaleget the original BF16/FP16/FP32 behaviour unchangedCUDA kernel changes
output_tandUSE_FP8_OUTPUTconstexpr bool; usesscaled_fp8_conversionfor FP8 storesTriton kernel changes
1.0 / scale) so the kernel does fast multiplication instead of divisionUSE_FP8: tl.constexprflag gatesclamp+ cast before the finaltl.storeTest Plan
test_merge_attn_states.pypassesTest Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.