Skip to content

Commit 0e1f838

Browse files
authored
vulkan: Fix FA coopmat1 invalid array indexing (#16365)
When computing sinks, the cm1 shader was looping r from 0 to Br rather than to rows_per_thread. I must have copied this from the scalar path (where it is correct), and somehow it wasn't causing failures on current drivers.
1 parent ad12647 commit 0e1f838

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,8 @@ void main() {
358358
}
359359

360360
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
361-
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
362-
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
361+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
362+
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
363363

364364
float ms = 1.0f;
365365
float vs = 1.0f;

0 commit comments

Comments
 (0)