|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#version 450 core |
| 10 | + |
| 11 | +#define PRECISION ${PRECISION} |
| 12 | +#define T ${buffer_scalar_type(DTYPE)} |
| 13 | +${define_required_extensions(DTYPE)} |
| 14 | + |
| 15 | +layout(std430) buffer; |
| 16 | + |
| 17 | +#include "indexing_utils.h" |
| 18 | + |
| 19 | +// Flash Attention inputs: Query, Key, Value tensors |
| 20 | +${layout_declare_tensor(B, "rw", "t_O", DTYPE, "buffer")} |
| 21 | +${layout_declare_tensor(B, "rw", "t_l", "float", "buffer")} |
| 22 | +${layout_declare_tensor(B, "rw", "t_m", "float", "buffer")} |
| 23 | +${layout_declare_tensor(B, "r", "t_Q", DTYPE, "buffer")} |
| 24 | +${layout_declare_tensor(B, "r", "t_K", DTYPE, "buffer")} |
| 25 | +${layout_declare_tensor(B, "r", "t_V", DTYPE, "buffer")} |
| 26 | + |
| 27 | +${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D] |
| 28 | +${layout_declare_ubo(B, "ivec4", "K_sizes")} |
| 29 | +${layout_declare_ubo(B, "ivec4", "V_sizes")} |
| 30 | +${layout_declare_ubo(B, "ivec4", "O_sizes")} |
| 31 | + |
| 32 | +${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N] |
| 33 | +${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N] |
| 34 | + |
| 35 | +${layout_declare_ubo(B, "float", "scale")} |
| 36 | +${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block) |
| 37 | +${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block) |
| 38 | +${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking |
| 39 | +${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads |
| 40 | +${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads |
| 41 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 42 | + |
| 43 | +// Maximum block sizes to prevent array overflow |
| 44 | +#define MAX_BR 64 |
| 45 | +#define MAX_BC 128 |
| 46 | + |
| 47 | +void main() { |
| 48 | + // Each thread processes one row block |
| 49 | + const int thread_id = int(gl_GlobalInvocationID.x); |
| 50 | + |
| 51 | + // Tensor dimensions: Q_sizes = [D, H, N, B] from graph.sizes_ubo() |
| 52 | + // The UBO layout is different from the PyTorch tensor layout |
| 53 | + const int head_dim = Q_sizes.x; // D (head dim) |
| 54 | + const int num_heads = Q_sizes.y; // H (num heads) |
| 55 | + const int seq_len = Q_sizes.z; // N (sequence length) |
| 56 | + const int batch_size = Q_sizes.w; // B (batch) |
| 57 | + |
| 58 | + // Block sizes |
| 59 | + const int Br = block_size_r; |
| 60 | + const int Bc = block_size_c; |
| 61 | + |
| 62 | + const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks |
| 63 | + const int total_row_blocks = batch_size * num_heads * Tr; |
| 64 | + |
| 65 | + if (thread_id >= total_row_blocks) { |
| 66 | + return; |
| 67 | + } |
| 68 | + |
| 69 | + // Decode thread_id to (batch, head, row_block) |
| 70 | + const int batch = thread_id / (num_heads * Tr); |
| 71 | + const int remaining = thread_id % (num_heads * Tr); |
| 72 | + const int head = remaining / Tr; |
| 73 | + const int row_block = remaining % Tr; |
| 74 | + |
| 75 | + // Calculate row range for this block |
| 76 | + const int row_start = row_block * Br; |
| 77 | + const int row_end = min(row_start + Br, seq_len); |
| 78 | + const int actual_Br = row_end - row_start; |
| 79 | + |
| 80 | + // Base indices for this batch |
| 81 | + const int q_base = batch * (seq_len * num_heads * head_dim); |
| 82 | + const int k_base = batch * (seq_len * num_heads * head_dim); |
| 83 | + const int v_base = batch * (seq_len * num_heads * head_dim); |
| 84 | + const int o_base = batch * (seq_len * num_heads * head_dim); |
| 85 | + const int lm_base = batch * (seq_len * num_heads); |
| 86 | + |
| 87 | + // STEP 2: Initialize O = 0, l = 0, m = -inf for this row block |
| 88 | + for (int r = 0; r < actual_Br; r++) { |
| 89 | + const int seq_pos = row_start + r; |
| 90 | + const int lm_idx = lm_base + head * seq_len + seq_pos; |
| 91 | + |
| 92 | + t_l[lm_idx] = 0.0; |
| 93 | + t_m[lm_idx] = -1.0 / 0.0; // -infinity |
| 94 | + |
| 95 | + for (int dim = 0; dim < head_dim; dim++) { |
| 96 | + const int o_idx = o_base + seq_pos * (num_heads * head_dim) + head * head_dim + dim; |
| 97 | + t_O[o_idx] = T(0.0); |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + // STEP 5: Outer loop over column blocks (For K, V tensors) |
| 102 | + const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks |
| 103 | + for (int j = 0; j < Tc; j++) { |
| 104 | + const int col_start = j * Bc; |
| 105 | + const int col_end = min(col_start + Bc, seq_len); |
| 106 | + const int actual_Bc = col_end - col_start; |
| 107 | + |
| 108 | + // STEP 6-8 done implicitly below |
| 109 | + |
| 110 | + // Load current statistics for all rows in this block |
| 111 | + float m_i[MAX_BR]; |
| 112 | + float l_i[MAX_BR]; |
| 113 | + for (int r = 0; r < actual_Br; r++) { |
| 114 | + const int seq_pos = row_start + r; |
| 115 | + const int lm_idx = lm_base + head * seq_len + seq_pos; |
| 116 | + m_i[r] = t_m[lm_idx]; |
| 117 | + l_i[r] = t_l[lm_idx]; |
| 118 | + } |
| 119 | + |
| 120 | + // STEP 9: Compute Sij = Qi * Kj^T |
| 121 | + T S_block[MAX_BR][MAX_BC]; // Use MAX_BR and MAX_BC constants |
| 122 | + float m_tilde_ij[MAX_BR]; // Row maxes (float to match l/m) |
| 123 | + float l_tilde_ij[MAX_BR]; // Row sums (float to match l/m) |
| 124 | + |
| 125 | + // Initialize row statistics |
| 126 | + for (int r = 0; r < actual_Br; r++) { |
| 127 | + m_tilde_ij[r] = -1.0 / 0.0; // -infinity |
| 128 | + l_tilde_ij[r] = 0.0; |
| 129 | + } |
| 130 | + |
| 131 | + // Compute attention scores Sij = Qi @ Kj^T |
| 132 | + for (int r = 0; r < actual_Br; r++) { |
| 133 | + const int global_row = row_start + r; |
| 134 | + for (int c = 0; c < actual_Bc; c++) { |
| 135 | + const int global_col = col_start + c; |
| 136 | + |
| 137 | + // For multi-query attention: map query head to KV head |
| 138 | + const int kv_head = (head * num_kv_heads) / num_heads; |
| 139 | + |
| 140 | + // Dot product: Q[seq_pos, :] · K[col_pos, :] |
| 141 | + T score = T(0.0); |
| 142 | + for (int dim = 0; dim < head_dim; dim++) { |
| 143 | + const int q_idx = q_base + global_row * (num_heads * head_dim) + head * head_dim + dim; |
| 144 | + const int k_idx = k_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim; |
| 145 | + score += t_Q[q_idx] * t_K[k_idx]; |
| 146 | + } |
| 147 | + score *= scale; |
| 148 | + |
| 149 | + // Apply causal masking: mask if global_col > global_row + input_pos |
| 150 | + if (global_col > global_row + input_pos) { |
| 151 | + score = T(-1.0 / 0.0); // Set to negative infinity |
| 152 | + } |
| 153 | + |
| 154 | + S_block[r][c] = score; |
| 155 | + |
| 156 | + // Track row maximum (after masking) |
| 157 | + m_tilde_ij[r] = max(m_tilde_ij[r], float(score)); |
| 158 | + } |
| 159 | + } |
| 160 | + |
| 161 | + // STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij) |
| 162 | + for (int r = 0; r < actual_Br; r++) { |
| 163 | + // Handle the case where all scores are -inf (fully masked row) |
| 164 | + if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) { |
| 165 | + // All scores are -inf, so all probabilities are 0 |
| 166 | + for (int c = 0; c < actual_Bc; c++) { |
| 167 | + S_block[r][c] = T(0.0); |
| 168 | + } |
| 169 | + l_tilde_ij[r] = 0.0; |
| 170 | + } else { |
| 171 | + // Normal case: compute softmax |
| 172 | + for (int c = 0; c < actual_Bc; c++) { |
| 173 | + S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r])); |
| 174 | + l_tilde_ij[r] += float(S_block[r][c]); |
| 175 | + } |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | + // STEP 11: Softmax update |
| 180 | + float m_new_i[MAX_BR]; |
| 181 | + float l_new_i[MAX_BR]; |
| 182 | + for (int r = 0; r < actual_Br; r++) { |
| 183 | + m_new_i[r] = max(m_i[r], m_tilde_ij[r]); |
| 184 | + |
| 185 | + l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r]; |
| 186 | + } |
| 187 | + |
| 188 | + // STEP 12: Update Oi |
| 189 | + for (int r = 0; r < actual_Br; r++) { |
| 190 | + const int global_row = row_start + r; |
| 191 | + float alpha = exp(m_i[r] - m_new_i[r]); |
| 192 | + float beta = exp(m_tilde_ij[r] - m_new_i[r]); |
| 193 | + |
| 194 | + // For multi-query attention: map query head to KV head |
| 195 | + const int kv_head = (head * num_kv_heads) / num_heads; |
| 196 | + |
| 197 | + for (int dim = 0; dim < head_dim; dim++) { |
| 198 | + const int o_idx = o_base + global_row * (num_heads * head_dim) + head * head_dim + dim; |
| 199 | + |
| 200 | + // Compute P'ij @ Vj for this dimension |
| 201 | + T pv_sum = T(0.0); |
| 202 | + for (int c = 0; c < actual_Bc; c++) { |
| 203 | + const int global_col = col_start + c; |
| 204 | + const int v_idx = v_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim; |
| 205 | + pv_sum += S_block[r][c] * t_V[v_idx]; |
| 206 | + } |
| 207 | + |
| 208 | + // Check for division by zero before updating output |
| 209 | + if (l_new_i[r] <= 0.0) { |
| 210 | + t_O[o_idx] = T(0.0); // Set to zero to avoid NaN |
| 211 | + } else { |
| 212 | + // Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i |
| 213 | + t_O[o_idx] = (T(alpha) * T(l_i[r]) * t_O[o_idx] + T(beta) * pv_sum) / T(l_new_i[r]); |
| 214 | + } |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + // STEP 13: Update li, mi |
| 219 | + for (int r = 0; r < actual_Br; r++) { |
| 220 | + const int seq_pos = row_start + r; |
| 221 | + const int lm_idx = lm_base + head * seq_len + seq_pos; |
| 222 | + t_l[lm_idx] = l_new_i[r]; |
| 223 | + t_m[lm_idx] = m_new_i[r]; |
| 224 | + } |
| 225 | + } |
| 226 | +} |
0 commit comments