Skip to content

Commit 275ec3d

Browse files
leafs1facebook-github-bot
authored andcommitted
Flash Attention Buffer Compute Shader for Vulkan Backend Delegate (#12654)
Summary: Built flash attention compute shader for Vulkan backend delegate. The current implementation only supports buffer storage and is not fully optimized, but is functional. This shader should speed up the SDPA process in the attention block of transformer inferencing as the previous implementation used many i/o operations. The implementation includes proper multi-query attention support for models like LLaMA, uses tiled block processing to reduce memory usage, and replaces multiple separate operations (matmul, softmax, masking) with a single efficient compute shader. Reviewed By: SS-JIA Differential Revision: D78586517
1 parent 9236a68 commit 275ec3d

File tree

4 files changed

+835
-12
lines changed

4 files changed

+835
-12
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define T ${buffer_scalar_type(DTYPE)}
13+
${define_required_extensions(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
// Flash Attention inputs: Query, Key, Value tensors
20+
${layout_declare_tensor(B, "rw", "t_O", DTYPE, "buffer")}
21+
${layout_declare_tensor(B, "rw", "t_l", "float", "buffer")}
22+
${layout_declare_tensor(B, "rw", "t_m", "float", "buffer")}
23+
${layout_declare_tensor(B, "r", "t_Q", DTYPE, "buffer")}
24+
${layout_declare_tensor(B, "r", "t_K", DTYPE, "buffer")}
25+
${layout_declare_tensor(B, "r", "t_V", DTYPE, "buffer")}
26+
27+
${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D]
28+
${layout_declare_ubo(B, "ivec4", "K_sizes")}
29+
${layout_declare_ubo(B, "ivec4", "V_sizes")}
30+
${layout_declare_ubo(B, "ivec4", "O_sizes")}
31+
32+
${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N]
33+
${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N]
34+
35+
${layout_declare_ubo(B, "float", "scale")}
36+
${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block)
37+
${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block)
38+
${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking
39+
${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads
40+
${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
// Maximum block sizes to prevent array overflow
44+
#define MAX_BR 64
45+
#define MAX_BC 128
46+
47+
void main() {
48+
// Each thread processes one row block
49+
const int thread_id = int(gl_GlobalInvocationID.x);
50+
51+
// Tensor dimensions: Q_sizes = [D, H, N, B] from graph.sizes_ubo()
52+
// The UBO layout is different from the PyTorch tensor layout
53+
const int head_dim = Q_sizes.x; // D (head dim)
54+
const int num_heads = Q_sizes.y; // H (num heads)
55+
const int seq_len = Q_sizes.z; // N (sequence length)
56+
const int batch_size = Q_sizes.w; // B (batch)
57+
58+
// Block sizes
59+
const int Br = block_size_r;
60+
const int Bc = block_size_c;
61+
62+
const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks
63+
const int total_row_blocks = batch_size * num_heads * Tr;
64+
65+
if (thread_id >= total_row_blocks) {
66+
return;
67+
}
68+
69+
// Decode thread_id to (batch, head, row_block)
70+
const int batch = thread_id / (num_heads * Tr);
71+
const int remaining = thread_id % (num_heads * Tr);
72+
const int head = remaining / Tr;
73+
const int row_block = remaining % Tr;
74+
75+
// Calculate row range for this block
76+
const int row_start = row_block * Br;
77+
const int row_end = min(row_start + Br, seq_len);
78+
const int actual_Br = row_end - row_start;
79+
80+
// Base indices for this batch
81+
const int q_base = batch * (seq_len * num_heads * head_dim);
82+
const int k_base = batch * (seq_len * num_heads * head_dim);
83+
const int v_base = batch * (seq_len * num_heads * head_dim);
84+
const int o_base = batch * (seq_len * num_heads * head_dim);
85+
const int lm_base = batch * (seq_len * num_heads);
86+
87+
// STEP 2: Initialize O = 0, l = 0, m = -inf for this row block
88+
for (int r = 0; r < actual_Br; r++) {
89+
const int seq_pos = row_start + r;
90+
const int lm_idx = lm_base + head * seq_len + seq_pos;
91+
92+
t_l[lm_idx] = 0.0;
93+
t_m[lm_idx] = -1.0 / 0.0; // -infinity
94+
95+
for (int dim = 0; dim < head_dim; dim++) {
96+
const int o_idx = o_base + seq_pos * (num_heads * head_dim) + head * head_dim + dim;
97+
t_O[o_idx] = T(0.0);
98+
}
99+
}
100+
101+
// STEP 5: Outer loop over column blocks (For K, V tensors)
102+
const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks
103+
for (int j = 0; j < Tc; j++) {
104+
const int col_start = j * Bc;
105+
const int col_end = min(col_start + Bc, seq_len);
106+
const int actual_Bc = col_end - col_start;
107+
108+
// STEP 6-8 done implicitly below
109+
110+
// Load current statistics for all rows in this block
111+
float m_i[MAX_BR];
112+
float l_i[MAX_BR];
113+
for (int r = 0; r < actual_Br; r++) {
114+
const int seq_pos = row_start + r;
115+
const int lm_idx = lm_base + head * seq_len + seq_pos;
116+
m_i[r] = t_m[lm_idx];
117+
l_i[r] = t_l[lm_idx];
118+
}
119+
120+
// STEP 9: Compute Sij = Qi * Kj^T
121+
T S_block[MAX_BR][MAX_BC]; // Use MAX_BR and MAX_BC constants
122+
float m_tilde_ij[MAX_BR]; // Row maxes (float to match l/m)
123+
float l_tilde_ij[MAX_BR]; // Row sums (float to match l/m)
124+
125+
// Initialize row statistics
126+
for (int r = 0; r < actual_Br; r++) {
127+
m_tilde_ij[r] = -1.0 / 0.0; // -infinity
128+
l_tilde_ij[r] = 0.0;
129+
}
130+
131+
// Compute attention scores Sij = Qi @ Kj^T
132+
for (int r = 0; r < actual_Br; r++) {
133+
const int global_row = row_start + r;
134+
for (int c = 0; c < actual_Bc; c++) {
135+
const int global_col = col_start + c;
136+
137+
// For multi-query attention: map query head to KV head
138+
const int kv_head = (head * num_kv_heads) / num_heads;
139+
140+
// Dot product: Q[seq_pos, :] · K[col_pos, :]
141+
T score = T(0.0);
142+
for (int dim = 0; dim < head_dim; dim++) {
143+
const int q_idx = q_base + global_row * (num_heads * head_dim) + head * head_dim + dim;
144+
const int k_idx = k_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim;
145+
score += t_Q[q_idx] * t_K[k_idx];
146+
}
147+
score *= scale;
148+
149+
// Apply causal masking: mask if global_col > global_row + input_pos
150+
if (global_col > global_row + input_pos) {
151+
score = T(-1.0 / 0.0); // Set to negative infinity
152+
}
153+
154+
S_block[r][c] = score;
155+
156+
// Track row maximum (after masking)
157+
m_tilde_ij[r] = max(m_tilde_ij[r], float(score));
158+
}
159+
}
160+
161+
// STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij)
162+
for (int r = 0; r < actual_Br; r++) {
163+
// Handle the case where all scores are -inf (fully masked row)
164+
if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) {
165+
// All scores are -inf, so all probabilities are 0
166+
for (int c = 0; c < actual_Bc; c++) {
167+
S_block[r][c] = T(0.0);
168+
}
169+
l_tilde_ij[r] = 0.0;
170+
} else {
171+
// Normal case: compute softmax
172+
for (int c = 0; c < actual_Bc; c++) {
173+
S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r]));
174+
l_tilde_ij[r] += float(S_block[r][c]);
175+
}
176+
}
177+
}
178+
179+
// STEP 11: Softmax update
180+
float m_new_i[MAX_BR];
181+
float l_new_i[MAX_BR];
182+
for (int r = 0; r < actual_Br; r++) {
183+
m_new_i[r] = max(m_i[r], m_tilde_ij[r]);
184+
185+
l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r];
186+
}
187+
188+
// STEP 12: Update Oi
189+
for (int r = 0; r < actual_Br; r++) {
190+
const int global_row = row_start + r;
191+
float alpha = exp(m_i[r] - m_new_i[r]);
192+
float beta = exp(m_tilde_ij[r] - m_new_i[r]);
193+
194+
// For multi-query attention: map query head to KV head
195+
const int kv_head = (head * num_kv_heads) / num_heads;
196+
197+
for (int dim = 0; dim < head_dim; dim++) {
198+
const int o_idx = o_base + global_row * (num_heads * head_dim) + head * head_dim + dim;
199+
200+
// Compute P'ij @ Vj for this dimension
201+
T pv_sum = T(0.0);
202+
for (int c = 0; c < actual_Bc; c++) {
203+
const int global_col = col_start + c;
204+
const int v_idx = v_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim;
205+
pv_sum += S_block[r][c] * t_V[v_idx];
206+
}
207+
208+
// Check for division by zero before updating output
209+
if (l_new_i[r] <= 0.0) {
210+
t_O[o_idx] = T(0.0); // Set to zero to avoid NaN
211+
} else {
212+
// Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i
213+
t_O[o_idx] = (T(alpha) * T(l_i[r]) * t_O[o_idx] + T(beta) * pv_sum) / T(l_new_i[r]);
214+
}
215+
}
216+
}
217+
218+
// STEP 13: Update li, mi
219+
for (int r = 0; r < actual_Br; r++) {
220+
const int seq_pos = row_start + r;
221+
const int lm_idx = lm_base + head * seq_len + seq_pos;
222+
t_l[lm_idx] = l_new_i[r];
223+
t_m[lm_idx] = m_new_i[r];
224+
}
225+
}
226+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
flash_attention:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: buffer
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: float
8+
shader_variants:
9+
- NAME: flash_attention_buffer
10+
STORAGE: buffer

0 commit comments

Comments
 (0)