@@ -636,6 +636,7 @@ struct vk_flash_attn_push_constants {
636
636
uint32_t nev3;
637
637
uint32_t nem1;
638
638
uint32_t nem2;
639
+ uint32_t nem3;
639
640
640
641
uint32_t nb01;
641
642
uint32_t nb02;
@@ -651,8 +652,7 @@ struct vk_flash_attn_push_constants {
651
652
float max_bias;
652
653
float logit_softcap;
653
654
654
- uint32_t mask;
655
- uint32_t n_head_log2;
655
+ uint32_t mask_n_head_log2;
656
656
float m0;
657
657
float m1;
658
658
@@ -6114,6 +6114,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6114
6114
6115
6115
const uint32_t nem1 = mask ? mask->ne[1] : 0;
6116
6116
const uint32_t nem2 = mask ? mask->ne[2] : 0;
6117
+ const uint32_t nem3 = mask ? mask->ne[3] : 0;
6117
6118
6118
6119
const uint32_t HSK = nek0;
6119
6120
const uint32_t HSV = nev0;
@@ -6181,7 +6182,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6181
6182
}
6182
6183
6183
6184
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6184
- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 = = 1) {
6185
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 < = 1) {
6185
6186
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
6186
6187
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
6187
6188
// and change addressing calculations to index Q's dimension 2.
@@ -6351,17 +6352,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6351
6352
}
6352
6353
}
6353
6354
6355
+ uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6356
+
6354
6357
const vk_flash_attn_push_constants pc = { N, KV,
6355
6358
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
6356
6359
(uint32_t)neq2, (uint32_t)neq3,
6357
6360
(uint32_t)nek2, (uint32_t)nek3,
6358
6361
(uint32_t)nev2, (uint32_t)nev3,
6359
- nem1, nem2,
6362
+ nem1, nem2, nem3,
6360
6363
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
6361
6364
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
6362
6365
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6363
6366
scale, max_bias, logit_softcap,
6364
- mask != nullptr, n_head_log2 , m0, m1,
6367
+ mask_n_head_log2 , m0, m1,
6365
6368
gqa_ratio, split_kv, split_k };
6366
6369
6367
6370
ggml_vk_sync_buffers(subctx);
@@ -10306,12 +10309,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10306
10309
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
10307
10310
return false;
10308
10311
}
10309
- // TODO: support broadcast
10310
- // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
10311
- // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
10312
- if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
10313
- return false;
10314
- }
10315
10312
// It's straightforward to support different K/V dequant, but would
10316
10313
// significantly increase the number of pipelines
10317
10314
if (op->src[1]->type != op->src[2]->type) {
0 commit comments