Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds support for FP8 query tensors in the flash attention implementation. Currently, FP8 Q/K/V tensors are dequantized to the output dtype before matrix multiplication operations.
Changes:
- Modified flash attention interface to accept optional FP8 query scaling parameters
- Updated kernel implementations to handle FP8 query dequantization
- Enhanced test coverage for FP8 query scenarios
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| vllm_xpu_kernels/flash_attn_interface.py | Added type hints for q/k/v_descale parameters and validation logic |
| tests/flash_attn/test_flash_attn_varlen_func.py | Extended tests to include FP8 query dtypes and updated reference implementation |
| csrc/xpu/attn/xe_2/fmha_xe2.h | Changed k_scale and v_scale from float to optional Tensor |
| csrc/xpu/attn/xe_2/fmha_xe2.cpp | Updated to handle optional Tensor scaling parameters and FP8 query detection |
| csrc/xpu/attn/xe_2/fmha_utils.hpp | Extended CutlassQKType to CutlassQKOType to include output dtype |
| csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp | Implemented FP8 query dequantization logic in mainloop |
| csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp | Updated sink element type to use output dtype |
| csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp | Updated type signatures for new scaling parameter types |
| csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in | Updated template parameter from CutlassQKType to CutlassQKOType |
| csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp | Updated extern template declarations for new type |
| csrc/xpu/attn/xe_2/chunk_prefill.hpp | Added FP8 query kernel configurations and updated type references |
| csrc/xpu/attn/attn_interface.h | Updated interface signature for optional Tensor scaling |
| csrc/xpu/attn/attn_interface.cpp | Forwarded new q_scale parameter to implementation |
| csrc/flash_attn/flash_api.cpp | Modified dtype validation and added FP8 query support |
| .github/workflows/ut.yaml | Reduced MAX_JOB from 128 to 72 |
Comments suppressed due to low confidence (1)
tests/flash_attn/test_flash_attn_varlen_func.py:1
- Line 267 checks
q_descale is not Nonebut should checkv_descale is not None. This is a copy-paste error that will cause incorrect behavior when v_descale is provided without q_descale.
# SPDX-License-Identifier: Apache-2.0
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if is_fp8_query: | ||
| q_descale = (torch.abs(query).max() / 200).to(torch.float32) | ||
| maybe_quantized_query = (query / q_descale).to(q_dtype) | ||
| if is_fp8kv: | ||
| k_descale = (torch.abs(key_cache).max() / 200).to(torch.float32) | ||
| v_descale = (torch.abs(value_cache).max() / 200).to(torch.float32) | ||
| maybe_quantized_key_cache = (key_cache / k_descale).to(fp8_dtype) | ||
| maybe_quantized_value_cache = (value_cache / v_descale).to(fp8_dtype) |
There was a problem hiding this comment.
The magic number 200 appears three times without explanation. This should be extracted as a named constant (e.g., FP8_QUANTIZATION_DIVISOR) with a comment explaining its purpose in the FP8 scaling calculation.
| is_fp8_q ? q_scale.value().data_ptr() : nullptr, | ||
| is_fp8_kv ? k_scale.value().data_ptr() : nullptr, | ||
| is_fp8_kv ? v_scale.value().data_ptr() : nullptr, |
There was a problem hiding this comment.
Direct use of .value() without checking .has_value() first can lead to undefined behavior if the optional is empty. Add explicit checks or document the assumption that when is_fp8_q or is_fp8_kv is true, the corresponding scale tensors must have values.
| q_descale=q_descale.expand(scale_shape) | ||
| if q_descale is not None else None, | ||
| k_descale=k_descale.expand(scale_shape) | ||
| if k_descale is not None else None, | ||
| v_descale=v_descale.expand(scale_shape) | ||
| if v_descale is not None else None, |
There was a problem hiding this comment.
Line 287 checks v_descale is not None but should check v_descale is not None. This appears to be a duplicate of the issue on line 267 - the condition is checking the wrong variable.
2c8a646 to
30af0cf
Compare
30af0cf to
c3df9d2
Compare
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
c8bf0cf to
c0f0d16
Compare
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
added fp8 query support. Currently, fp8 q/k/v will be dequantized to the output dtype before mma.
depends on #150
vllm change: xinyu-intel/vllm@97c2151
example:
Test Plan
Test Result
(Optional) Documentation Update
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)