Skip to content

Commit f2adda1

Browse files
leafs1facebook-github-bot
authored andcommitted
Flash Attention Buffer Compute Shader for Vulkan Backend Delegate
Summary: Built flash attention compute shader for Vulkan backend delegate. The current implementation only supports buffer storage and is not fully optimized, but is functional. This shader should speed up the SDPA process in the attention block of transformer inferencing as the previous implementation used many i/o operations. The implementation includes proper multi-query attention support for models like LLaMA, uses tiled block processing to reduce memory usage, and replaces multiple separate operations (matmul, softmax, masking) with a single efficient compute shader. Differential Revision: D78586517
1 parent 7e603f8 commit f2adda1

File tree

4 files changed

+824
-12
lines changed

4 files changed

+824
-12
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
#extension GL_EXT_debug_printf : enable
11+
12+
#define PRECISION ${PRECISION}
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
${define_required_extensions(DTYPE)}
15+
16+
layout(std430) buffer;
17+
18+
#include "indexing_utils.h"
19+
20+
// Flash Attention inputs: Query, Key, Value tensors
21+
${layout_declare_tensor(B, "rw", "t_O", DTYPE, "buffer")}
22+
${layout_declare_tensor(B, "r", "t_Q", DTYPE, "buffer")}
23+
${layout_declare_tensor(B, "r", "t_K", DTYPE, "buffer")}
24+
${layout_declare_tensor(B, "r", "t_V", DTYPE, "buffer")}
25+
${layout_declare_tensor(B, "rw", "t_l", "float", "buffer")}
26+
${layout_declare_tensor(B, "rw", "t_m", "float", "buffer")}
27+
28+
${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D]
29+
${layout_declare_ubo(B, "ivec4", "K_sizes")}
30+
${layout_declare_ubo(B, "ivec4", "V_sizes")}
31+
${layout_declare_ubo(B, "ivec4", "O_sizes")}
32+
33+
${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N]
34+
${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N]
35+
36+
${layout_declare_ubo(B, "float", "scale")}
37+
${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block)
38+
${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block)
39+
${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking
40+
${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads
41+
${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads
42+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
43+
44+
// Maximum block sizes to prevent array overflow
45+
#define MAX_BR 64
46+
#define MAX_BC 128
47+
48+
void main() {
49+
// Each thread processes one row block
50+
const int thread_id = int(gl_GlobalInvocationID.x);
51+
52+
// Tensor dimensions: Q_sizes = [D, H, N, B] from graph.sizes_ubo()
53+
// The UBO layout is different from the PyTorch tensor layout
54+
const int head_dim = Q_sizes.x; // D (head dim)
55+
const int num_heads = Q_sizes.y; // H (num heads)
56+
const int seq_len = Q_sizes.z; // N (sequence length)
57+
const int batch_size = Q_sizes.w; // B (batch)
58+
59+
// Block sizes
60+
const int Br = block_size_r;
61+
const int Bc = block_size_c;
62+
63+
const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks
64+
const int total_row_blocks = batch_size * num_heads * Tr;
65+
66+
if (thread_id >= total_row_blocks) {
67+
return;
68+
}
69+
70+
// Decode thread_id to (batch, head, row_block)
71+
const int batch = thread_id / (num_heads * Tr);
72+
const int remaining = thread_id % (num_heads * Tr);
73+
const int head = remaining / Tr;
74+
const int row_block = remaining % Tr;
75+
76+
// Calculate row range for this block
77+
const int row_start = row_block * Br;
78+
const int row_end = min(row_start + Br, seq_len);
79+
const int actual_Br = row_end - row_start;
80+
81+
// Base indices for this batch
82+
const int q_base = batch * (seq_len * num_heads * head_dim);
83+
const int k_base = batch * (seq_len * num_heads * head_dim);
84+
const int v_base = batch * (seq_len * num_heads * head_dim);
85+
const int o_base = batch * (seq_len * num_heads * head_dim);
86+
const int lm_base = batch * (seq_len * num_heads);
87+
88+
// STEP 2: Initialize O = 0, l = 0, m = -inf for this row block
89+
for (int r = 0; r < actual_Br; r++) {
90+
const int seq_pos = row_start + r;
91+
const int lm_idx = lm_base + head * seq_len + seq_pos;
92+
93+
t_l[lm_idx] = 0.0;
94+
t_m[lm_idx] = -1.0 / 0.0; // -infinity
95+
96+
for (int dim = 0; dim < head_dim; dim++) {
97+
const int o_idx = o_base + seq_pos * (num_heads * head_dim) + head * head_dim + dim;
98+
t_O[o_idx] = T(0.0);
99+
}
100+
}
101+
102+
// STEP 5: Outer loop over column blocks (For K, V tensors)
103+
const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks
104+
for (int j = 0; j < Tc; j++) {
105+
const int col_start = j * Bc;
106+
const int col_end = min(col_start + Bc, seq_len);
107+
const int actual_Bc = col_end - col_start;
108+
109+
// STEP 6-8 done implicitly below
110+
111+
// Load current statistics for all rows in this block
112+
float m_i[MAX_BR];
113+
float l_i[MAX_BR];
114+
for (int r = 0; r < actual_Br; r++) {
115+
const int seq_pos = row_start + r;
116+
const int lm_idx = lm_base + head * seq_len + seq_pos;
117+
m_i[r] = t_m[lm_idx];
118+
l_i[r] = t_l[lm_idx];
119+
}
120+
121+
// STEP 9: Compute Sij = Qi * Kj^T
122+
T S_block[MAX_BR][MAX_BC]; // Use MAX_BR and MAX_BC constants
123+
float m_tilde_ij[MAX_BR]; // Row maxes (float to match l/m)
124+
float l_tilde_ij[MAX_BR]; // Row sums (float to match l/m)
125+
126+
// Initialize row statistics
127+
for (int r = 0; r < actual_Br; r++) {
128+
m_tilde_ij[r] = -1.0 / 0.0; // -infinity
129+
l_tilde_ij[r] = 0.0;
130+
}
131+
132+
// Compute attention scores Sij = Qi @ Kj^T
133+
for (int r = 0; r < actual_Br; r++) {
134+
const int global_row = row_start + r;
135+
for (int c = 0; c < actual_Bc; c++) {
136+
const int global_col = col_start + c;
137+
138+
// For multi-query attention: map query head to KV head
139+
const int kv_head = (head * num_kv_heads) / num_heads;
140+
141+
// Dot product: Q[seq_pos, :] · K[col_pos, :]
142+
T score = T(0.0);
143+
for (int dim = 0; dim < head_dim; dim++) {
144+
const int q_idx = q_base + global_row * (num_heads * head_dim) + head * head_dim + dim;
145+
const int k_idx = k_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim;
146+
score += t_Q[q_idx] * t_K[k_idx];
147+
}
148+
score *= scale;
149+
150+
// Apply causal masking: mask if global_col > global_row + input_pos
151+
if (global_col > global_row + input_pos) {
152+
score = T(-1.0 / 0.0); // Set to negative infinity
153+
}
154+
155+
S_block[r][c] = score;
156+
157+
// Track row maximum (after masking)
158+
m_tilde_ij[r] = max(m_tilde_ij[r], float(score));
159+
}
160+
}
161+
162+
// STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij)
163+
for (int r = 0; r < actual_Br; r++) {
164+
// Handle the case where all scores are -inf (fully masked row)
165+
if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) {
166+
// All scores are -inf, so all probabilities are 0
167+
for (int c = 0; c < actual_Bc; c++) {
168+
S_block[r][c] = T(0.0);
169+
}
170+
l_tilde_ij[r] = 0.0;
171+
} else {
172+
// Normal case: compute softmax
173+
for (int c = 0; c < actual_Bc; c++) {
174+
S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r]));
175+
l_tilde_ij[r] += float(S_block[r][c]);
176+
}
177+
}
178+
}
179+
180+
// STEP 11: Softmax update
181+
float m_new_i[MAX_BR];
182+
float l_new_i[MAX_BR];
183+
for (int r = 0; r < actual_Br; r++) {
184+
m_new_i[r] = max(m_i[r], m_tilde_ij[r]);
185+
186+
l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r];
187+
}
188+
189+
// STEP 12: Update Oi
190+
for (int r = 0; r < actual_Br; r++) {
191+
const int global_row = row_start + r;
192+
float alpha = exp(m_i[r] - m_new_i[r]);
193+
float beta = exp(m_tilde_ij[r] - m_new_i[r]);
194+
195+
// For multi-query attention: map query head to KV head
196+
const int kv_head = (head * num_kv_heads) / num_heads;
197+
198+
for (int dim = 0; dim < head_dim; dim++) {
199+
const int o_idx = o_base + global_row * (num_heads * head_dim) + head * head_dim + dim;
200+
201+
// Compute P'ij @ Vj for this dimension
202+
T pv_sum = T(0.0);
203+
for (int c = 0; c < actual_Bc; c++) {
204+
const int global_col = col_start + c;
205+
const int v_idx = v_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim;
206+
pv_sum += S_block[r][c] * t_V[v_idx];
207+
}
208+
209+
// Check for division by zero before updating output
210+
if (l_new_i[r] <= 0.0) {
211+
t_O[o_idx] = T(0.0); // Set to zero to avoid NaN
212+
} else {
213+
// Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i
214+
t_O[o_idx] = (T(alpha) * T(l_i[r]) * t_O[o_idx] + T(beta) * pv_sum) / T(l_new_i[r]);
215+
}
216+
}
217+
}
218+
219+
// STEP 13: Update li, mi
220+
for (int r = 0; r < actual_Br; r++) {
221+
const int seq_pos = row_start + r;
222+
const int lm_idx = lm_base + head * seq_len + seq_pos;
223+
t_l[lm_idx] = l_new_i[r];
224+
t_m[lm_idx] = m_new_i[r];
225+
}
226+
}
227+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
flash_attention:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: buffer
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: float
8+
shader_variants:
9+
- NAME: flash_attention_buffer
10+
STORAGE: buffer

0 commit comments

Comments
 (0)