Skip to content

CUDA: fix overflow in FA, tune performance #14840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

Due to numerical overflows the CUDA FlashAttention code on master does not work correctly for very long contexts (something like several million tokens across all sequences). This PR uses 64 bit math for those parts of the code susceptible to such problems: the K/V offsets between sequences and the calculation of K/V offsets within a sequence. For the vector kernel there was a performance regression on Pascal when simply casting the offsets to 64 bit, for this reason I'm adding a 32 bit offset after each iteration (turns out to be faster for Pascal/AMD anyways). I am not seeing any performance differences for the other kernels so I'm just casting the offsets to 64 bit. While working on this I noticed that at some point the tile FA kernels seem to have gotten faster than the vector kernels on my RX 6800 so I'm enabling them for AMD.

Performance changes
GPU Model Microbatch size Test t/s master t/s ec05b08 Speedup
P40 llama 8B Q4_0 1 pp8192 45.19 46.21 1.02
P40 llama 8B Q4_0 2 pp8192 79.03 81.57 1.03
P40 llama 8B Q4_0 4 pp8192 100.38 103.55 1.03
P40 llama 8B Q4_0 8 pp8192 111.95 115.22 1.03
P40 llama 8B Q4_0 16 pp8192 326.43 327.61 1.00
P40 llama 8B Q4_0 32 pp8192 458.59 459.75 1.00
P40 llama 8B Q4_0 64 pp8192 519.87 521.12 1.00
P40 llama 8B Q4_0 128 pp8192 560.72 564.46 1.01
P40 llama 8B Q4_0 256 pp8192 595.92 598.70 1.00
P40 llama 8B Q4_0 512 pp8192 608.42 610.60 1.00
P40 llama 8B Q4_0 1024 pp8192 599.67 604.16 1.01
P40 llama 8B Q4_0 2048 pp8192 581.61 580.96 1.00
P40 llama 8B Q4_0 4096 pp8192 578.04 583.86 1.01
P40 llama 8B Q4_0 8192 pp8192 578.77 582.35 1.01
RTX 3090 llama 8B Q4_0 1 pp8192 140.13 139.89 1.00
RTX 3090 llama 8B Q4_0 2 pp8192 249.16 250.09 1.00
RTX 3090 llama 8B Q4_0 4 pp8192 425.18 426.44 1.00
RTX 3090 llama 8B Q4_0 8 pp8192 529.60 528.76 1.00
RTX 3090 llama 8B Q4_0 16 pp8192 1132.49 1131.78 1.00
RTX 3090 llama 8B Q4_0 32 pp8192 1842.52 1841.17 1.00
RTX 3090 llama 8B Q4_0 64 pp8192 2807.42 2798.67 1.00
RTX 3090 llama 8B Q4_0 128 pp8192 3534.20 3516.30 0.99
RTX 3090 llama 8B Q4_0 256 pp8192 4192.59 4172.36 1.00
RTX 3090 llama 8B Q4_0 512 pp8192 4426.59 4397.17 0.99
RTX 3090 llama 8B Q4_0 1024 pp8192 4529.86 4496.75 0.99
RTX 3090 llama 8B Q4_0 2048 pp8192 4494.54 4483.05 1.00
RTX 3090 llama 8B Q4_0 4096 pp8192 4509.37 4485.62 0.99
RTX 3090 llama 8B Q4_0 8192 pp8192 4490.27 4481.76 1.00
RTX 4090 llama 8B Q4_0 1 pp8192 168.07 167.63 1.00
RTX 4090 llama 8B Q4_0 2 pp8192 302.88 303.42 1.00
RTX 4090 llama 8B Q4_0 4 pp8192 591.02 592.69 1.00
RTX 4090 llama 8B Q4_0 8 pp8192 1005.46 1006.44 1.00
RTX 4090 llama 8B Q4_0 16 pp8192 1692.36 1690.25 1.00
RTX 4090 llama 8B Q4_0 32 pp8192 3107.26 3105.39 1.00
RTX 4090 llama 8B Q4_0 64 pp8192 5421.35 5433.31 1.00
RTX 4090 llama 8B Q4_0 128 pp8192 7964.31 7995.78 1.00
RTX 4090 llama 8B Q4_0 256 pp8192 10339.18 10333.54 1.00
RTX 4090 llama 8B Q4_0 512 pp8192 11580.34 11574.64 1.00
RTX 4090 llama 8B Q4_0 1024 pp8192 11811.83 11801.82 1.00
RTX 4090 llama 8B Q4_0 2048 pp8192 11432.54 11399.22 1.00
RTX 4090 llama 8B Q4_0 4096 pp8192 11411.55 11403.61 1.00
RTX 4090 llama 8B Q4_0 8192 pp8192 11417.33 11400.58 1.00
RX 6800 llama 8B Q4_0 1 pp8192 38.91 46.38 1.19
RX 6800 llama 8B Q4_0 2 pp8192 68.44 67.47 0.99
RX 6800 llama 8B Q4_0 4 pp8192 73.64 73.36 1.00
RX 6800 llama 8B Q4_0 8 pp8192 75.06 75.16 1.00
RX 6800 llama 8B Q4_0 16 pp8192 86.76 93.01 1.07
RX 6800 llama 8B Q4_0 32 pp8192 94.86 115.82 1.22
RX 6800 llama 8B Q4_0 64 pp8192 101.16 127.03 1.26
RX 6800 llama 8B Q4_0 128 pp8192 113.84 153.90 1.35
RX 6800 llama 8B Q4_0 256 pp8192 118.68 161.18 1.36
RX 6800 llama 8B Q4_0 512 pp8192 118.96 159.27 1.34
RX 6800 llama 8B Q4_0 1024 pp8192 116.08 148.57 1.28
RX 6800 llama 8B Q4_0 2048 pp8192 106.82 134.90 1.26
RX 6800 llama 8B Q4_0 4096 pp8192 106.94 134.76 1.26
RX 6800 llama 8B Q4_0 8192 pp8192 106.93 135.25 1.26

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jul 23, 2025
Comment on lines 21 to 27
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int32_t nb33) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the logic for choosing between int32_t and int64_t here? For example, why is int64_t nb23, but int32_t nb33?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask is being broadcast across all attention heads so it's simply smaller than K/V. I suppose you could also use 64 bit for nb33, it should still be fine in terms of register pressure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More generally, the only offsets that are going to be really large are those that scale with the number of tokens, so the offsets between sequences.

@JohannesGaessler JohannesGaessler force-pushed the cuda-fa-fix-overflow-2 branch from ec05b08 to d4209ee Compare July 23, 2025 18:45
@JohannesGaessler JohannesGaessler merged commit a86f52b into ggml-org:master Jul 23, 2025
47 checks passed
taronaeo pushed a commit to taronaeo/llama.cpp-s390x that referenced this pull request Jul 25, 2025
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jul 25, 2025
* origin/master:
docs : update HOWTO‑add‑model.md for ModelBase and new model classes (ggml-org#14874)
ggml : remove invalid portPos specifiers from dot files (ggml-org#14838)
context : restore preemptive sched reset when LLAMA_SET_ROWS=0 (ggml-org#14870)
mtmd : fix 32-bit narrowing issue in export-lora and mtmd clip (ggml-org#14503)
rpc : check for null buffers in get/set/copy tensor endpoints (ggml-org#14868)
sched : fix multiple evaluations of the same graph with pipeline parallelism (ggml-org#14855)
musa: upgrade musa sdk to rc4.2.0 (ggml-org#14498)
sync : ggml
cmake : fix usage issues (ggml/1257)
ggml-cpu : remove stdlib include from repack.cpp (ggml/1276)
context : perform output reorder lazily upon access after sync (ggml-org#14853)
chat : fix kimi-k2 chat template (ggml-org#14852)
sycl: fixed semantics of block offset calculation (ggml-org#14814)
llama : fix MiniCPM inference after Granite Four changes (ggml-org#14850)
docs: add libcurl-dev install hint for Linux distros (ggml-org#14801)
metal : fix fusion across different encoders (ggml-org#14849)
sycl: fix undefined variable in work group size check (ggml-org#14843)
convert : text-only support for GLM-4.1V-9B-Thinking (ggml-org#14823)
CUDA: fix overflow in FA, tune performance (ggml-org#14840)
CUDA: fix compilation with GGML_CUDA_F16 (ggml-org#14837)
@he29-net
Copy link

he29-net commented Aug 1, 2025

Hi,

not sure if this is expected, but I'm seeing a regression in PP performance with FA enabled on my RX 6800:

  Device 0: AMD Radeon RX 6800, gfx1030 (0x1030), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_batch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  0 |           pp512 |        892.16 ± 2.68 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  0 |           tg128 |         59.76 ± 0.05 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  1 |           pp512 |        595.24 ± 2.19 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  1 |           tg128 |         56.65 ± 0.03 |

build: a86f52b2 (5973)

vs. previous commit:

  Device 0: AMD Radeon RX 6800, gfx1030 (0x1030), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_batch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  0 |           pp512 |        884.87 ± 3.40 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  0 |           tg128 |         59.86 ± 0.02 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  1 |           pp512 |        820.55 ± 4.46 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  1 |           tg128 |         53.97 ± 0.02 |

build: b284197d (5972)

I also found other unrelated regression, so keep in mind it may also influence benchmark results at the same time (see #14624).

@JohannesGaessler
Copy link
Collaborator Author

In my testing the new build is consistently faster:

Model FlashAttention Test t/s b5972 t/s b5973 Speedup
llama 8B Q4_0 No pp512 784.26 785.00 1.00
llama 8B Q4_0 No tg128 62.24 61.93 0.99
llama 8B Q4_0 Yes pp512 530.82 604.01 1.14
llama 8B Q4_0 Yes tg128 56.86 59.19 1.04
qwen3moe 30B.A3B Q3_K_S No pp512 515.26 520.13 1.01
qwen3moe 30B.A3B Q3_K_S No tg128 40.12 39.88 0.99
qwen3moe 30B.A3B Q3_K_S Yes pp512 350.03 391.21 1.12
qwen3moe 30B.A3B Q3_K_S Yes tg128 36.17 37.59 1.04

@he29-net
Copy link

he29-net commented Aug 1, 2025

Huh, interesting. I wonder if ROCm version could explain the difference; I haven't updated in a while (the install directory says 6.0.0, though I'm not sure if that's the specific release or just 6.x.x.). I'll try to update and see if it changes anything.

Great job on the other regression btw.!

@he29-net
Copy link

he29-net commented Aug 1, 2025

After upgrading to ROCm 6.4.2, the PP speed with FA is now comparable between both commits. It's still lower than ROCm 6.0.0 + b5972 (~620 t/s vs. ~820), but I suppose it means the regression isn't directly related to this PR, but rather to some version-specific "something" that just happened to be triggered by this PR in the previous version.

  Device 0: AMD Radeon RX 6800, gfx1030 (0x1030), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_batch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  0 |           pp512 |        909.08 ± 3.47 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  0 |           tg128 |         61.90 ± 0.06 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  1 |           pp512 |        621.87 ± 2.23 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  1 |           tg128 |         55.63 ± 0.03 |

build: b284197d (5972)

vs.

  Device 0: AMD Radeon RX 6800, gfx1030 (0x1030), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_batch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | --------------: | -------------------: |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  0 |           pp512 |        903.51 ± 2.56 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  0 |           tg128 |         61.69 ± 0.07 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  1 |           pp512 |        618.86 ± 1.96 |
| qwen3moe 30B.A3B Q3_K - Small  |  12.37 GiB |    30.53 B | ROCm       |  99 |     512 |  1 |           tg128 |         58.37 ± 0.05 |

build: a86f52b2 (5973)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants