Skip to content

[fmha] support fp8 query#153

Open
xinyu-intel wants to merge 3 commits intovllm-project:mainfrom
xinyu-intel:dev/fp8-query
Open

[fmha] support fp8 query#153
xinyu-intel wants to merge 3 commits intovllm-project:mainfrom
xinyu-intel:dev/fp8-query

Conversation

@xinyu-intel
Copy link
Collaborator

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.

Purpose

added fp8 query support. Currently, fp8 q/k/v will be dequantized to the output dtype before mma.

depends on #150

vllm change: xinyu-intel/vllm@97c2151

example:

VLLM_WORKER_MULTIPROC_METHOD=spawn python examples/offline_inference/data_parallel.py --model /workspace/Qwen3-0.6B/ --no-enable-expert-parallel --enforce-eager --kv-cache-dtype fp8 --calculate-kv-scales

Test Plan

Test Result

(Optional) Documentation Update

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)

Copilot AI review requested due to automatic review settings February 11, 2026 13:16
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for FP8 query tensors in the flash attention implementation. Currently, FP8 Q/K/V tensors are dequantized to the output dtype before matrix multiplication operations.

Changes:

  • Modified flash attention interface to accept optional FP8 query scaling parameters
  • Updated kernel implementations to handle FP8 query dequantization
  • Enhanced test coverage for FP8 query scenarios

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
vllm_xpu_kernels/flash_attn_interface.py Added type hints for q/k/v_descale parameters and validation logic
tests/flash_attn/test_flash_attn_varlen_func.py Extended tests to include FP8 query dtypes and updated reference implementation
csrc/xpu/attn/xe_2/fmha_xe2.h Changed k_scale and v_scale from float to optional Tensor
csrc/xpu/attn/xe_2/fmha_xe2.cpp Updated to handle optional Tensor scaling parameters and FP8 query detection
csrc/xpu/attn/xe_2/fmha_utils.hpp Extended CutlassQKType to CutlassQKOType to include output dtype
csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp Implemented FP8 query dequantization logic in mainloop
csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp Updated sink element type to use output dtype
csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp Updated type signatures for new scaling parameter types
csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in Updated template parameter from CutlassQKType to CutlassQKOType
csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp Updated extern template declarations for new type
csrc/xpu/attn/xe_2/chunk_prefill.hpp Added FP8 query kernel configurations and updated type references
csrc/xpu/attn/attn_interface.h Updated interface signature for optional Tensor scaling
csrc/xpu/attn/attn_interface.cpp Forwarded new q_scale parameter to implementation
csrc/flash_attn/flash_api.cpp Modified dtype validation and added FP8 query support
.github/workflows/ut.yaml Reduced MAX_JOB from 128 to 72
Comments suppressed due to low confidence (1)

tests/flash_attn/test_flash_attn_varlen_func.py:1

  • Line 267 checks q_descale is not None but should check v_descale is not None. This is a copy-paste error that will cause incorrect behavior when v_descale is provided without q_descale.
# SPDX-License-Identifier: Apache-2.0

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +245 to +252
if is_fp8_query:
q_descale = (torch.abs(query).max() / 200).to(torch.float32)
maybe_quantized_query = (query / q_descale).to(q_dtype)
if is_fp8kv:
k_descale = (torch.abs(key_cache).max() / 200).to(torch.float32)
v_descale = (torch.abs(value_cache).max() / 200).to(torch.float32)
maybe_quantized_key_cache = (key_cache / k_descale).to(fp8_dtype)
maybe_quantized_value_cache = (value_cache / v_descale).to(fp8_dtype)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number 200 appears three times without explanation. This should be extracted as a named constant (e.g., FP8_QUANTIZATION_DIVISOR) with a comment explaining its purpose in the FP8 scaling calculation.

Copilot uses AI. Check for mistakes.
Comment on lines +136 to +138
is_fp8_q ? q_scale.value().data_ptr() : nullptr,
is_fp8_kv ? k_scale.value().data_ptr() : nullptr,
is_fp8_kv ? v_scale.value().data_ptr() : nullptr,
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Direct use of .value() without checking .has_value() first can lead to undefined behavior if the optional is empty. Add explicit checks or document the assumption that when is_fp8_q or is_fp8_kv is true, the corresponding scale tensors must have values.

Copilot uses AI. Check for mistakes.
Comment on lines +282 to +287
q_descale=q_descale.expand(scale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(scale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(scale_shape)
if v_descale is not None else None,
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 287 checks v_descale is not None but should check v_descale is not None. This appears to be a duplicate of the issue on line 267 - the condition is checking the wrong variable.

Copilot uses AI. Check for mistakes.
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants