Skip to content

Conversation

daijh
Copy link
Contributor

@daijh daijh commented Jul 15, 2025

Description

This PR enhances unidirectional FlashAttention by applying causal masking inside the main loop. This optimization eliminates unnecessary memory loads by avoiding future entries in the KV cache.

Testing on Lunar Lake shows up to a 20% performance improvement for phi-4-mini-accuracy4 (with a prompt of 4096). Similar performance gains were also observed for other models, including Qwen3-0.6B-accuracy4.

This PR now uses the more readable unidirectional attribute instead of is_gqa, to control causal masking.

Motivation and Context

See above.

This PR enhances unidirectional `FlashAttention` by applying causal
masking inside the main loop. This optimization eliminates unnecessary
memory loads by avoiding future entries in the KV cache.

Testing on Lunar Lake shows up to a 20% performance improvement for
`phi-4-mini-accuracy4` (with a prompt of 4096). Similar performance
gains were also observed for other models, including
`Qwen3-0.6B-accuracy4`.

This PR now uses the more readable `unidirectional` attribute instead of
`is_gpa`, to control causal masking.
@daijh
Copy link
Contributor Author

daijh commented Jul 15, 2025

Lunar Lake, Phi-4-mini-accuracy4:

Prompt Default Prefill Speed (tps) Opt Prefill Speed (tps) Improvement
Prompt-1024 561.40 593.76 5.76%
Prompt-2048 498.60 549.59 10.23%
Prompt-3072 465.98 537.64 15.38%
Prompt-4096 430.61 513.40 19.23%

@daijh
Copy link
Contributor Author

daijh commented Jul 15, 2025

@sushraja-msft @qjia7 pls take a look.

cc @jchen10 @xhcao

qjia7
qjia7 previously approved these changes Jul 15, 2025
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with nits.

@daijh
Copy link
Contributor Author

daijh commented Jul 16, 2025

@guschmue @fs-eire
Please take a look.

@fs-eire
Copy link
Contributor

fs-eire commented Jul 22, 2025

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Jul 29, 2025
@guschmue guschmue merged commit 2bd00ec into microsoft:main Jul 29, 2025
89 of 92 checks passed
@daijh daijh deleted the optimize-flash-attention-for-prefill branch July 30, 2025 01:05
sanketkaleoss pushed a commit to sanketkaleoss/onnxruntime that referenced this pull request Aug 11, 2025
### Description
This PR enhances unidirectional `FlashAttention` by applying causal
masking inside the main loop. This optimization eliminates unnecessary
memory loads by avoiding future entries in the KV cache.

Testing on Lunar Lake shows up to a 20% performance improvement for
`phi-4-mini-accuracy4` (with a prompt of 4096). Similar performance
gains were also observed for other models, including
`Qwen3-0.6B-accuracy4`.

This PR now uses the more readable `unidirectional` attribute instead of
`is_gqa`, to control causal masking.

### Motivation and Context
See above.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants