Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct WebgpuAttentionParameters {
v_hidden_size_(parameters.kv_hidden_size),
v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads),
num_heads_(parameters.num_heads),
is_unidirectional_(true),
do_rotary_(parameters.do_rotary),
scale_(parameters.scale),
seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache),
Expand Down
31 changes: 26 additions & 5 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,23 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {

var previous_max : q_element_t = min_value;
var previous_denom : q_element_t = 0;
)MAIN_FN";

for(var k_start = 0u; k_start < uniforms.total_sequence_length; k_start+=capped_sg_size)
if (is_unidirectional_) {
// If attention is unidirectional, set the loop bound to enforce causal masking.
shader.MainFunctionBody() << R"MAIN_FN(
let max_causal_len_for_workgroup = uniforms.past_sequence_length +
(workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x;
let loop_bound = min(uniforms.total_sequence_length, max_causal_len_for_workgroup);
)MAIN_FN";
} else {
shader.MainFunctionBody() << R"MAIN_FN(
let loop_bound = uniforms.total_sequence_length;
)MAIN_FN";
}

shader.MainFunctionBody() << R"MAIN_FN(
for(var k_start = 0u; k_start < loop_bound; k_start+=capped_sg_size)
{
workgroupBarrier();
loadk(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size);
Expand Down Expand Up @@ -337,7 +352,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start+12, head_idx);
}

let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_gqa > 0);
let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_unidirectional > 0);
// Neuter qk values where K is out of bounds.
qk_1[0] = select(min_value, qk_1[0], k_start+0 < seq_causal_length);
qk_1[1] = select(min_value, qk_1[1], k_start+1 < seq_causal_length);
Expand Down Expand Up @@ -903,7 +918,13 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
bool has_attention_bias = attention_bias != nullptr;
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
FlashAttentionProgram program{"FlashAttention", has_attention_bias, is_qualcomm, is_fp16, parameters.head_size_, parameters.num_heads_};
FlashAttentionProgram program{"FlashAttention",
has_attention_bias,
is_qualcomm,
is_fp16,
parameters.head_size_,
parameters.num_heads_,
parameters.is_unidirectional_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
{present_key, ProgramTensorMetadataDependency::TypeAndRank, 4},
{present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}});
Expand All @@ -916,12 +937,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile)
.SetWorkgroupSize(tile_size)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, is_qualcomm)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(parameters.total_sequence_length_)},
{static_cast<uint32_t>(present_sequence_length)},
{static_cast<uint32_t>(parameters.total_sequence_length_ - parameters.kv_sequence_length_)},
{static_cast<uint32_t>(parameters.is_gqa_ ? 1 : 0)},
{static_cast<uint32_t>(parameters.is_unidirectional_)},
{static_cast<uint32_t>(parameters.n_reps)},
{alpha},
{num_seq_tile}});
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
bool is_qualcomm,
bool is_fp16,
int qkv_head_size,
int qkv_num_heads)
int qkv_num_heads,
bool is_unidirectional)
: Program{kernel_name},
has_attention_bias_(has_attention_bias),
is_qualcomm_(is_qualcomm),
is_fp16_(is_fp16),
qkv_head_size_(qkv_head_size),
qkv_num_heads_(qkv_num_heads) {
qkv_num_heads_(qkv_num_heads),
is_unidirectional_(is_unidirectional) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -54,7 +56,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"is_gqa", ProgramUniformVariableDataType::Uint32},
{"is_unidirectional", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32},
{"num_seq_tile", ProgramUniformVariableDataType::Uint32});
Expand All @@ -65,6 +67,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
bool is_fp16_;
int qkv_head_size_;
int qkv_num_heads_;
bool is_unidirectional_;
};

class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecodeQKTProgram> {
Expand Down
Loading