diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention.glsl b/backends/vulkan/runtime/graph/ops/glsl/flash_attention.glsl new file mode 100644 index 00000000000..1b5f47f3f3c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/flash_attention.glsl @@ -0,0 +1,226 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define T ${buffer_scalar_type(DTYPE)} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +// Flash Attention inputs: Query, Key, Value tensors +${layout_declare_tensor(B, "rw", "t_O", DTYPE, "buffer")} +${layout_declare_tensor(B, "rw", "t_l", "float", "buffer")} +${layout_declare_tensor(B, "rw", "t_m", "float", "buffer")} +${layout_declare_tensor(B, "r", "t_Q", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_K", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_V", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D] +${layout_declare_ubo(B, "ivec4", "K_sizes")} +${layout_declare_ubo(B, "ivec4", "V_sizes")} +${layout_declare_ubo(B, "ivec4", "O_sizes")} + +${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N] +${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N] + +${layout_declare_ubo(B, "float", "scale")} +${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block) +${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block) +${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking +${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads +${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Maximum block sizes to prevent array overflow +#define MAX_BR 64 +#define MAX_BC 128 + +void main() { + // Each thread processes one row block + const int thread_id = int(gl_GlobalInvocationID.x); + + // Tensor dimensions: Q_sizes = [D, H, N, B] from graph.sizes_ubo() + // The UBO layout is different from the PyTorch tensor layout + const int head_dim = Q_sizes.x; // D (head dim) + const int num_heads = Q_sizes.y; // H (num heads) + const int seq_len = Q_sizes.z; // N (sequence length) + const int batch_size = Q_sizes.w; // B (batch) + + // Block sizes + const int Br = block_size_r; + const int Bc = block_size_c; + + const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks + const int total_row_blocks = batch_size * num_heads * Tr; + + if (thread_id >= total_row_blocks) { + return; + } + + // Decode thread_id to (batch, head, row_block) + const int batch = thread_id / (num_heads * Tr); + const int remaining = thread_id % (num_heads * Tr); + const int head = remaining / Tr; + const int row_block = remaining % Tr; + + // Calculate row range for this block + const int row_start = row_block * Br; + const int row_end = min(row_start + Br, seq_len); + const int actual_Br = row_end - row_start; + + // Base indices for this batch + const int q_base = batch * (seq_len * num_heads * head_dim); + const int k_base = batch * (seq_len * num_heads * head_dim); + const int v_base = batch * (seq_len * num_heads * head_dim); + const int o_base = batch * (seq_len * num_heads * head_dim); + const int lm_base = batch * (seq_len * num_heads); + + // STEP 2: Initialize O = 0, l = 0, m = -inf for this row block + for (int r = 0; r < actual_Br; r++) { + const int seq_pos = row_start + r; + const int lm_idx = lm_base + head * seq_len + seq_pos; + + t_l[lm_idx] = 0.0; + t_m[lm_idx] = -1.0 / 0.0; // -infinity + + for (int dim = 0; dim < head_dim; dim++) { + const int o_idx = o_base + seq_pos * (num_heads * head_dim) + head * head_dim + dim; + t_O[o_idx] = T(0.0); + } + } + + // STEP 5: Outer loop over column blocks (For K, V tensors) + const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks + for (int j = 0; j < Tc; j++) { + const int col_start = j * Bc; + const int col_end = min(col_start + Bc, seq_len); + const int actual_Bc = col_end - col_start; + + // STEP 6-8 done implicitly below + + // Load current statistics for all rows in this block + float m_i[MAX_BR]; + float l_i[MAX_BR]; + for (int r = 0; r < actual_Br; r++) { + const int seq_pos = row_start + r; + const int lm_idx = lm_base + head * seq_len + seq_pos; + m_i[r] = t_m[lm_idx]; + l_i[r] = t_l[lm_idx]; + } + + // STEP 9: Compute Sij = Qi * Kj^T + T S_block[MAX_BR][MAX_BC]; // Use MAX_BR and MAX_BC constants + float m_tilde_ij[MAX_BR]; // Row maxes (float to match l/m) + float l_tilde_ij[MAX_BR]; // Row sums (float to match l/m) + + // Initialize row statistics + for (int r = 0; r < actual_Br; r++) { + m_tilde_ij[r] = -1.0 / 0.0; // -infinity + l_tilde_ij[r] = 0.0; + } + + // Compute attention scores Sij = Qi @ Kj^T + for (int r = 0; r < actual_Br; r++) { + const int global_row = row_start + r; + for (int c = 0; c < actual_Bc; c++) { + const int global_col = col_start + c; + + // For multi-query attention: map query head to KV head + const int kv_head = (head * num_kv_heads) / num_heads; + + // Dot product: Q[seq_pos, :] · K[col_pos, :] + T score = T(0.0); + for (int dim = 0; dim < head_dim; dim++) { + const int q_idx = q_base + global_row * (num_heads * head_dim) + head * head_dim + dim; + const int k_idx = k_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim; + score += t_Q[q_idx] * t_K[k_idx]; + } + score *= scale; + + // Apply causal masking: mask if global_col > global_row + input_pos + if (global_col > global_row + input_pos) { + score = T(-1.0 / 0.0); // Set to negative infinity + } + + S_block[r][c] = score; + + // Track row maximum (after masking) + m_tilde_ij[r] = max(m_tilde_ij[r], float(score)); + } + } + + // STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij) + for (int r = 0; r < actual_Br; r++) { + // Handle the case where all scores are -inf (fully masked row) + if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) { + // All scores are -inf, so all probabilities are 0 + for (int c = 0; c < actual_Bc; c++) { + S_block[r][c] = T(0.0); + } + l_tilde_ij[r] = 0.0; + } else { + // Normal case: compute softmax + for (int c = 0; c < actual_Bc; c++) { + S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r])); + l_tilde_ij[r] += float(S_block[r][c]); + } + } + } + + // STEP 11: Softmax update + float m_new_i[MAX_BR]; + float l_new_i[MAX_BR]; + for (int r = 0; r < actual_Br; r++) { + m_new_i[r] = max(m_i[r], m_tilde_ij[r]); + + l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r]; + } + + // STEP 12: Update Oi + for (int r = 0; r < actual_Br; r++) { + const int global_row = row_start + r; + float alpha = exp(m_i[r] - m_new_i[r]); + float beta = exp(m_tilde_ij[r] - m_new_i[r]); + + // For multi-query attention: map query head to KV head + const int kv_head = (head * num_kv_heads) / num_heads; + + for (int dim = 0; dim < head_dim; dim++) { + const int o_idx = o_base + global_row * (num_heads * head_dim) + head * head_dim + dim; + + // Compute P'ij @ Vj for this dimension + T pv_sum = T(0.0); + for (int c = 0; c < actual_Bc; c++) { + const int global_col = col_start + c; + const int v_idx = v_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim; + pv_sum += S_block[r][c] * t_V[v_idx]; + } + + // Check for division by zero before updating output + if (l_new_i[r] <= 0.0) { + t_O[o_idx] = T(0.0); // Set to zero to avoid NaN + } else { + // Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i + t_O[o_idx] = (T(alpha) * T(l_i[r]) * t_O[o_idx] + T(beta) * pv_sum) / T(l_new_i[r]); + } + } + } + + // STEP 13: Update li, mi + for (int r = 0; r < actual_Br; r++) { + const int seq_pos = row_start + r; + const int lm_idx = lm_base + head * seq_len + seq_pos; + t_l[lm_idx] = l_new_i[r]; + t_m[lm_idx] = m_new_i[r]; + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention.yaml b/backends/vulkan/runtime/graph/ops/glsl/flash_attention.yaml new file mode 100644 index 00000000000..2314b701040 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/flash_attention.yaml @@ -0,0 +1,10 @@ +flash_attention: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: flash_attention_buffer + STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 5ac8077d95f..6057f1e183a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -19,10 +19,178 @@ #include +#include #include namespace vkcompute { +void resize_sdpa_out( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)args; + + int arg_idx = 0; + const ValueRef q_projected = extra_args[arg_idx++]; + const ValueRef out = extra_args[arg_idx++]; + graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected)); +} + +void resize_flash_attention_out( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + + // Find the output tensor in the args - it's the first tensor in the first + // ArgGroup + const ValueRef out = args.at(0).refs.at(0); + // Find the query tensor - it's the first tensor in the second ArgGroup + const ValueRef q_projected = args.at(1).refs.at(0); + + // Resize output to match query dimensions + graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected)); +} + +// Flash Attention implementation using single compute shader +utils::uvec3 flash_attention_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + + const ValueRef q_projected = resize_args.at(0); + const ValueRef block_size_r = resize_args.at(1); + + // Get tensor dimensions - PyTorch format is [B, N, H, D] + // But Vulkan uses negative indexing: -4=B, -3=N, -2=H, -1=D + const int32_t B = graph->size_at(-4, q_projected); // batch + const int32_t N = graph->size_at(-3, q_projected); // sequence length + const int32_t H = graph->size_at(-2, q_projected); // num heads + const int32_t Br = + static_cast(graph->extract_scalar(block_size_r)); + + // Calculate number of row blocks + const int32_t Tr = (N + Br - 1) / Br; + + // Dispatch size: (B * H * Tr, 1, 1) + return {static_cast(B * H * Tr), 1, 1}; +} + +void flash_attention_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef q_projected = args[arg_idx++]; + const ValueRef k_cache = args[arg_idx++]; + const ValueRef v_cache = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + const ValueRef dropout_p = args[arg_idx++]; + const ValueRef is_causal = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + + const ValueRef out = args[arg_idx++]; + + // Extract input_pos value for causal masking + const int32_t input_pos_val = graph.read_symint(input_pos_symint); + + const ValueRef k_cache_tensor = k_cache; + const ValueRef v_cache_tensor = v_cache; + + // Validation checks - re-enable with correct indexing + VK_CHECK_COND(graph.size_at(-4, q_projected) == 1); // batch size = 1 + VK_CHECK_COND(graph.size_at(-4, k_cache_tensor) == 1); + VK_CHECK_COND(graph.size_at(-4, v_cache_tensor) == 1); + VK_CHECK_COND( + graph.sizes_of(k_cache_tensor) == graph.sizes_of(v_cache_tensor)); + VK_CHECK_COND( + graph.size_at(-1, q_projected) == + graph.size_at(-1, k_cache_tensor)); // head_dim must match + VK_CHECK_COND( + graph.val_is_none(dropout_p) || + graph.extract_scalar(dropout_p) == 0); + VK_CHECK_COND(graph.val_is_none(scale)); + VK_CHECK_COND( + graph.val_is_none(is_causal) || graph.extract_scalar(is_causal)); + VK_CHECK_COND(graph.val_is_none(attn_mask)); + + // Ensure all tensors use buffer storage for Flash Attention + VK_CHECK_COND(graph.is_buffer_storage(q_projected)); + VK_CHECK_COND(graph.is_buffer_storage(k_cache_tensor)); + VK_CHECK_COND(graph.is_buffer_storage(v_cache_tensor)); + VK_CHECK_COND(graph.is_buffer_storage(out)); + + // Calculate scale factor + const int32_t head_dim_size = graph.size_at(-1, q_projected); + const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); + + // Get number of heads for multi-query attention support + const int32_t num_heads = graph.size_at(-2, q_projected); + const int32_t num_kv_heads = graph.size_at(-2, k_cache_tensor); + + const int32_t block_size_r = 32; // Row block size + const int32_t block_size_c = 32; // Column block size + + // l and m have shape [B, H, N] + std::vector lm_sizes = { + graph.size_at(-4, q_projected), // B (batch) + graph.size_at(-2, q_projected), // H (num heads) + graph.size_at(-3, q_projected) // N (sequence length) + }; + + // t_l stores row-wise normalization sums for softmax computation + // t_m stores row-wise maximum values for numerical stability in softmax + TmpTensor t_l(&graph, lm_sizes, vkapi::kFloat); + TmpTensor t_m(&graph, lm_sizes, vkapi::kFloat); + + std::string kernel_name = "flash_attention"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + // Set up parameter buffers + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(q_projected), // Q_sizes + graph.sizes_ubo(k_cache_tensor), // K_sizes + graph.sizes_ubo(v_cache_tensor), // V_sizes + graph.sizes_ubo(out), // O_sizes + graph.sizes_ubo(t_l), // l_sizes (3D) + graph.sizes_ubo(t_m), // m_sizes (3D) + graph.create_params_buffer(scale_val), // scale + graph.create_params_buffer(block_size_r), // block_size_r + graph.create_params_buffer(block_size_c), // block_size_c + graph.create_params_buffer(input_pos_val), // input_pos + graph.create_params_buffer(num_heads), // num_heads + graph.create_params_buffer(num_kv_heads) // num_kv_heads + }; + + // Create block size references for dispatch calculation + const ValueRef block_size_r_ref = + graph.add_scalar(static_cast(block_size_r)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + flash_attention_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + { + {{out, t_l, t_m}, vkapi::kReadWrite}, + {{q_projected, k_cache_tensor, v_cache_tensor}, vkapi::kRead}, + }, + // Shader param buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {q_projected, block_size_r_ref}, + // Resizing Logic + resize_flash_attention_out)); +} + utils::uvec3 kv_cache_update_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -192,18 +360,6 @@ void add_cache_slice_view_node( {cache, input_pos_symint, q_projected, cache_sliced})); } -void resize_sdpa_out( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)args; - - int arg_idx = 0; - const ValueRef q_projected = extra_args[arg_idx++]; - const ValueRef out = extra_args[arg_idx++]; - graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected)); -} - void update_cache_impl(ComputeGraph& graph, const std::vector& args) { int arg_idx = 0; const ValueRef value = args[arg_idx++]; @@ -409,6 +565,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); VK_REGISTER_OP(update_cache.default, update_cache_impl); VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); + VK_REGISTER_OP(llama.flash_attention.default, flash_attention_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index 9a3da49ddad..303dc9c85ec 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -497,3 +497,433 @@ TEST(VulkanSDPATest, test_reference_impl) { batch_size, max_seq_len); } + +void test_vulkan_flash_attention( + const int start_input_pos, + const int sequence_len, + const int embedding_dim, + const int num_heads, + const int num_kv_heads, + const int batch_size, + const int max_seq_len, + at::ScalarType dtype = at::kFloat) { + const int head_dim = embedding_dim / num_heads; + + at::Tensor k_cache = at::zeros( + {batch_size, max_seq_len, num_kv_heads, head_dim}, + at::device(at::kCPU).dtype(dtype)); + at::Tensor v_cache = at::zeros_like(k_cache); + + at::Tensor q = at::rand( + {batch_size, sequence_len, num_heads, head_dim}, + at::device(at::kCPU).dtype(dtype)); + at::Tensor k = at::rand( + {batch_size, sequence_len, num_kv_heads, head_dim}, + at::device(at::kCPU).dtype(dtype)); + at::Tensor v = at::rand_like(k); + + // Get reference output using existing SDPA + at::Tensor reference_out = sdpa_reference_impl( + q, + k, + v, + k_cache, + v_cache, + start_input_pos, + sequence_len, + {}, + 0.0, + true, + {}); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override( + utils::kBuffer); // Flash Attention requires buffer storage + ComputeGraph graph(config); + + // Create input references + IOValueRef r_q = graph.add_input_tensor( + q.sizes().vec(), from_at_scalartype(q.scalar_type())); + IOValueRef r_k = graph.add_input_tensor( + k.sizes().vec(), from_at_scalartype(k.scalar_type())); + IOValueRef r_v = graph.add_input_tensor( + v.sizes().vec(), from_at_scalartype(v.scalar_type())); + + // Create cache tensors (these would be updated by cache update operations in + // practice) + ValueRef r_k_cache = graph.add_tensorref( + k_cache.sizes().vec(), + from_at_scalartype(k_cache.scalar_type()), + k_cache.const_data_ptr()); + ValueRef r_v_cache = graph.add_tensorref( + v_cache.sizes().vec(), + from_at_scalartype(v_cache.scalar_type()), + v_cache.const_data_ptr()); + + const ValueRef r_input_pos_symint = graph.add_symint(start_input_pos); + const ValueRef r_out = + graph.add_tensor(q.sizes().vec(), from_at_scalartype(q.scalar_type())); + + // Call Flash Attention implementation + VK_GET_OP_FN("llama.flash_attention.default") + (graph, + { + r_q.value, + r_k.value, // Use actual K tensor, not cache + r_v.value, // Use actual V tensor, not cache + r_input_pos_symint, + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy inputs and run + graph.copy_into_staging(r_q.staging, q.const_data_ptr(), q.numel()); + graph.copy_into_staging(r_k.staging, k.const_data_ptr(), k.numel()); + graph.copy_into_staging(r_v.staging, v.const_data_ptr(), v.numel()); + + graph.execute(); + + // Extract output + at::Tensor vk_out = at::zeros_like(q).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare results + const bool output_correct = at::allclose(reference_out, vk_out, 1e-3, 1e-3); + + if (!output_correct) { + at::Tensor diffs = at::abs(reference_out - vk_out); + std::cout << "Flash Attention test failed!" << std::endl; + std::cout << "Maximum difference: " << at::max(diffs).item() << std::endl; + std::cout << "Maximum value observed: " + << at::max(at::abs(at::cat({reference_out, vk_out}, -1))).item() + << std::endl; + } + ASSERT_TRUE(output_correct); +} + +TEST(VulkanSDPATest, test_flash_attention_small_params) { + // TINY DEBUG PARAMETERS - easy to trace by hand + const int starting_input_pos = 0; + const int sequence_len = 2; // Very small sequence + const int embedding_dim = 4; // Very small embedding + const int num_heads = 2; // Just 2 heads + const int num_kv_heads = 2; // Match query heads (no multi-query complexity) + const int batch_size = 1; // Single batch + const int max_seq_len = 4; // Small cache + + test_vulkan_flash_attention( + starting_input_pos, + sequence_len, + embedding_dim, + num_heads, + num_kv_heads, + batch_size, + max_seq_len); +} + +TEST(VulkanSDPATest, test_flash_attention_multi_tile) { + // MULTI-TILE TEST - tests the tiling algorithm with multiple blocks + // With block_size_r=32, block_size_c=32 (from SDPA.cpp), and seq_len=48: + // - Tr = ceil(48/32) = 2 row tiles (blocks: 0-31, 32-47) + // - Tc = ceil(48/32) = 2 column tiles (blocks: 0-31, 32-47) + // - Total of 2x2 = 4 tile combinations to process per head + // - Memory usage: 48*2*16 = 1,536 elements per tensor (much more reasonable) + const int starting_input_pos = 0; + const int sequence_len = 48; // Moderate size to force multiple tiles + const int embedding_dim = 32; // head_dim = 32/2 = 16 per head + const int num_heads = 2; // 2 heads to keep manageable + const int num_kv_heads = 2; // Match query heads + const int batch_size = 1; // Single batch + const int max_seq_len = 64; // Reasonable cache size + + test_vulkan_flash_attention( + starting_input_pos, + sequence_len, + embedding_dim, + num_heads, + num_kv_heads, + batch_size, + max_seq_len); +} + +// Flash Attention tests corresponding to traditional SDPA tests + +TEST(VulkanSDPATest, test_flash_attention_op_small_params) { + // Corresponds to test_sdpa_op_small_params + const int starting_input_pos = 0; + const int sequence_len = 3; + const int embedding_dim = 18; + const int num_heads = 6; + const int num_kv_heads = 2; + const int batch_size = 1; + const int max_seq_len = 7; + + test_vulkan_flash_attention( + starting_input_pos, + sequence_len, + embedding_dim, + num_heads, + num_kv_heads, + batch_size, + max_seq_len); +} + +TEST(VulkanSDPATest, test_flash_attention_op_small_params_dynamic) { + // Corresponds to test_sdpa_op_small_params_dynamic + // Note: Flash attention doesn't support dynamic sequence lengths in the same + // way as traditional SDPA, so we test with the base sequence length + const int starting_input_pos = 0; + const int sequence_len = 3; + const int embedding_dim = 18; + const int num_heads = 6; + const int num_kv_heads = 2; + const int batch_size = 1; + const int max_seq_len = 12; + + test_vulkan_flash_attention( + starting_input_pos, + sequence_len, + embedding_dim, + num_heads, + num_kv_heads, + batch_size, + max_seq_len); +} + +TEST(VulkanSDPATest, test_flash_attention_op_llama3_params) { + // Corresponds to test_sdpa_op_llama3_params_dynamic + // This is a large test that exercises the multi-tile algorithm extensively + const int starting_input_pos = 0; + const int sequence_len = 3; + const int embedding_dim = 2048; + const int num_heads = 32; + const int num_kv_heads = 8; + const int batch_size = 1; + const int max_seq_len = 128; + + test_vulkan_flash_attention( + starting_input_pos, + sequence_len, + embedding_dim, + num_heads, + num_kv_heads, + batch_size, + max_seq_len); +} + +TEST(VulkanSDPATest, test_flash_attention_op_llama3_params_dynamic) { + // Corresponds to test_sdpa_op_llama3_params_dynamic + // Test with varying sequence lengths to ensure flash attention works with + // different sizes + const int starting_input_pos = 0; + const int embedding_dim = 2048; + const int num_heads = 32; + const int num_kv_heads = 8; + const int batch_size = 1; + const int max_seq_len = 128; + + // Test with different sequence lengths + std::vector sequence_lengths = {1, 3, 5, 7, 16, 32}; + + for (int seq_len : sequence_lengths) { + if (seq_len < max_seq_len) { + test_vulkan_flash_attention( + starting_input_pos, + seq_len, + embedding_dim, + num_heads, + num_kv_heads, + batch_size, + max_seq_len); + } + } +} + +void test_reference_flash_attention( + const int start_input_pos, + const int sequence_len, + const int embedding_dim, + const int num_heads, + const int num_kv_heads, + const int batch_size, + const int max_seq_len, + at::ScalarType dtype = at::kFloat) { + const int head_dim = embedding_dim / num_heads; + + // For flash attention reference test, we test single-shot attention + // rather than iterative cache updates, since flash attention processes + // the entire sequence at once + + at::Tensor q = at::rand( + {batch_size, sequence_len, num_heads, head_dim}, + at::device(at::kCPU).dtype(dtype)); + at::Tensor k = at::rand( + {batch_size, sequence_len, num_kv_heads, head_dim}, + at::device(at::kCPU).dtype(dtype)); + at::Tensor v = at::rand_like(k); + + // Create empty caches for reference implementation + at::Tensor k_cache_ref = at::zeros( + {batch_size, max_seq_len, num_kv_heads, head_dim}, + at::device(at::kCPU).dtype(dtype)); + at::Tensor v_cache_ref = at::zeros_like(k_cache_ref); + + // Get reference implementation output + at::Tensor reference_out = sdpa_reference_impl( + q, + k, + v, + k_cache_ref, + v_cache_ref, + start_input_pos, + sequence_len, + {}, + 0.0, + true, + {}); + + // Build Vulkan Flash Attention graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kBuffer); + ComputeGraph graph(config); + + IOValueRef r_q = graph.add_input_tensor( + q.sizes().vec(), from_at_scalartype(q.scalar_type())); + IOValueRef r_k = graph.add_input_tensor( + k.sizes().vec(), from_at_scalartype(k.scalar_type())); + IOValueRef r_v = graph.add_input_tensor( + v.sizes().vec(), from_at_scalartype(v.scalar_type())); + + // Create empty cache tensors for flash attention + at::Tensor k_cache_flash = at::zeros_like(k_cache_ref); + at::Tensor v_cache_flash = at::zeros_like(v_cache_ref); + + ValueRef r_k_cache = graph.add_tensorref( + k_cache_flash.sizes().vec(), + from_at_scalartype(k_cache_flash.scalar_type()), + k_cache_flash.const_data_ptr()); + ValueRef r_v_cache = graph.add_tensorref( + v_cache_flash.sizes().vec(), + from_at_scalartype(v_cache_flash.scalar_type()), + v_cache_flash.const_data_ptr()); + + const ValueRef r_input_pos_symint = graph.add_symint(start_input_pos); + const ValueRef r_out = + graph.add_tensor(q.sizes().vec(), from_at_scalartype(q.scalar_type())); + + VK_GET_OP_FN("llama.flash_attention.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_input_pos_symint, + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + graph.copy_into_staging(r_q.staging, q.const_data_ptr(), q.numel()); + graph.copy_into_staging(r_k.staging, k.const_data_ptr(), k.numel()); + graph.copy_into_staging(r_v.staging, v.const_data_ptr(), v.numel()); + + graph.execute(); + + at::Tensor flash_out = at::zeros_like(q).contiguous(); + graph.copy_from_staging( + staging_out, flash_out.mutable_data_ptr(), flash_out.numel()); + + // Compare flash attention output with reference implementation + const bool output_correct = + at::allclose(reference_out, flash_out, 1e-3, 1e-3); + + if (!output_correct) { + at::Tensor diffs = at::abs(reference_out - flash_out); + std::cout << "Flash Attention reference test failed" << std::endl; + std::cout << "Maximum difference: " << at::max(diffs).item() << std::endl; + std::cout + << "Maximum value observed: " + << at::max(at::abs(at::cat({reference_out, flash_out}, -1))).item() + << std::endl; + } + ASSERT_TRUE(output_correct); +} + +TEST(VulkanSDPATest, test_flash_attention_reference_impl) { + const int starting_input_pos = 0; + const int sequence_len = 3; + const int embedding_dim = 2048; + const int num_heads = 32; + const int num_kv_heads = 8; + const int batch_size = 1; + const int max_seq_len = 128; + + test_reference_flash_attention( + starting_input_pos, + sequence_len, + embedding_dim, + num_heads, + num_kv_heads, + batch_size, + max_seq_len); +} + +TEST(VulkanSDPATest, test_flash_attention_reference_impl_small) { + const int starting_input_pos = 0; + const int sequence_len = 2; + const int embedding_dim = 32; + const int num_heads = 4; + const int num_kv_heads = 2; + const int batch_size = 1; + const int max_seq_len = 16; + + test_reference_flash_attention( + starting_input_pos, + sequence_len, + embedding_dim, + num_heads, + num_kv_heads, + batch_size, + max_seq_len); +} + +TEST(VulkanSDPATest, test_flash_attention_edge_cases) { + // Test with single head (no multi-query complexity) + test_vulkan_flash_attention(0, 1, 8, 1, 1, 1, 4); + + // Test with equal heads (no multi-query complexity) + test_vulkan_flash_attention(0, 2, 16, 4, 4, 1, 8); + + // Test with large head dimension + test_vulkan_flash_attention(0, 2, 128, 2, 1, 1, 8); + + // Test with sequence length that exactly matches block size (32) + test_vulkan_flash_attention(0, 32, 64, 2, 1, 1, 64); + + // Test with sequence length slightly larger than block size + test_vulkan_flash_attention(0, 33, 64, 2, 1, 1, 64); +}