@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
1223
1223
const int ne12,
1224
1224
const int ne13,
1225
1225
const int ne31,
1226
+ const int ne32,
1226
1227
const int nb31,
1228
+ const int nb32,
1227
1229
const int nb01,
1228
1230
const int nb02,
1229
1231
const int nb03,
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
1288
1290
1289
1291
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1290
1292
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1291
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof (half2))*jt*ncols1 : nullptr ;
1293
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1294
+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1292
1295
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2 );
1293
1296
1294
1297
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
1327
1330
1328
1331
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1329
1332
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1330
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof (half2))*jt*ncols1 : nullptr ;
1333
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1334
+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1331
1335
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2 );
1332
1336
1333
1337
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
1348
1352
GGML_UNUSED (max_bias); GGML_UNUSED (m0); GGML_UNUSED (m1);
1349
1353
GGML_UNUSED (n_head_log2); GGML_UNUSED (logit_softcap); GGML_UNUSED (ne00);
1350
1354
GGML_UNUSED (ne01); GGML_UNUSED (ne02); GGML_UNUSED (ne03); GGML_UNUSED (ne10);
1351
- GGML_UNUSED (ne11); GGML_UNUSED (ne12); GGML_UNUSED (ne13); GGML_UNUSED (ne31);
1352
- GGML_UNUSED (nb31); GGML_UNUSED (nb01); GGML_UNUSED (nb02); GGML_UNUSED (nb03);
1355
+ GGML_UNUSED (ne11); GGML_UNUSED (ne12); GGML_UNUSED (ne13); GGML_UNUSED (ne31); GGML_UNUSED (ne32);
1356
+ GGML_UNUSED (nb31); GGML_UNUSED (nb32); GGML_UNUSED ( nb01); GGML_UNUSED (nb02); GGML_UNUSED (nb03);
1353
1357
GGML_UNUSED (nb11); GGML_UNUSED (nb12); GGML_UNUSED (nb13); GGML_UNUSED (nb21);
1354
1358
GGML_UNUSED (nb22); GGML_UNUSED (nb23); GGML_UNUSED (ne0); GGML_UNUSED (ne1);
1355
1359
GGML_UNUSED (ne2); GGML_UNUSED (ne3);
0 commit comments