diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 9d4740ede7143..da70b790586cf 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -50,6 +50,7 @@ struct WebgpuAttentionParameters { v_hidden_size_(parameters.kv_hidden_size), v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), num_heads_(parameters.num_heads), + is_unidirectional_(true), do_rotary_(parameters.do_rotary), scale_(parameters.scale), seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index c9e182bf10f2f..dbe2614099be1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -275,8 +275,23 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; +)MAIN_FN"; - for(var k_start = 0u; k_start < uniforms.total_sequence_length; k_start+=capped_sg_size) + if (is_unidirectional_) { + // If attention is unidirectional, set the loop bound to enforce causal masking. + shader.MainFunctionBody() << R"MAIN_FN( + let max_causal_len_for_workgroup = uniforms.past_sequence_length + + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; + let loop_bound = min(uniforms.total_sequence_length, max_causal_len_for_workgroup); +)MAIN_FN"; + } else { + shader.MainFunctionBody() << R"MAIN_FN( + let loop_bound = uniforms.total_sequence_length; +)MAIN_FN"; + } + + shader.MainFunctionBody() << R"MAIN_FN( + for(var k_start = 0u; k_start < loop_bound; k_start+=capped_sg_size) { workgroupBarrier(); loadk(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); @@ -337,7 +352,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start+12, head_idx); } - let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_gqa > 0); + let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_unidirectional > 0); // Neuter qk values where K is out of bounds. qk_1[0] = select(min_value, qk_1[0], k_start+0 < seq_causal_length); qk_1[1] = select(min_value, qk_1[1], k_start+1 < seq_causal_length); @@ -903,7 +918,13 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - FlashAttentionProgram program{"FlashAttention", has_attention_bias, is_qualcomm, is_fp16, parameters.head_size_, parameters.num_heads_}; + FlashAttentionProgram program{"FlashAttention", + has_attention_bias, + is_qualcomm, + is_fp16, + parameters.head_size_, + parameters.num_heads_, + parameters.is_unidirectional_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}}); @@ -916,12 +937,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, is_qualcomm) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, {static_cast(parameters.total_sequence_length_ - parameters.kv_sequence_length_)}, - {static_cast(parameters.is_gqa_ ? 1 : 0)}, + {static_cast(parameters.is_unidirectional_)}, {static_cast(parameters.n_reps)}, {alpha}, {num_seq_tile}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 3f79b80fb73bc..9908b33a38372 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -39,13 +39,15 @@ class FlashAttentionProgram final : public Program { bool is_qualcomm, bool is_fp16, int qkv_head_size, - int qkv_num_heads) + int qkv_num_heads, + bool is_unidirectional) : Program{kernel_name}, has_attention_bias_(has_attention_bias), is_qualcomm_(is_qualcomm), is_fp16_(is_fp16), qkv_head_size_(qkv_head_size), - qkv_num_heads_(qkv_num_heads) { + qkv_num_heads_(qkv_num_heads), + is_unidirectional_(is_unidirectional) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -54,7 +56,7 @@ class FlashAttentionProgram final : public Program { {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"is_gqa", ProgramUniformVariableDataType::Uint32}, + {"is_unidirectional", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}); @@ -65,6 +67,7 @@ class FlashAttentionProgram final : public Program { bool is_fp16_; int qkv_head_size_; int qkv_num_heads_; + bool is_unidirectional_; }; class FlashAttentionDecodeQKTProgram final : public Program {