Skip to content

Commit 2a20a7e

Browse files
committed
cuda : fix mask dim 2/3 (wip)
ggml-ci
1 parent 6036177 commit 2a20a7e

File tree

8 files changed

+46
-29
lines changed

8 files changed

+46
-29
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
3333
const int ne13,
3434
const int ne31,
3535
const int ne32,
36+
const int ne33,
3637
const int nb31,
3738
const int nb32,
39+
const int nb33,
3840
const int nb01,
3941
const int nb02,
4042
const int nb03,
@@ -705,8 +707,6 @@ void launch_fattn(
705707

706708
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
707709

708-
GGML_ASSERT(Q->ne[3] == 1);
709-
710710
ggml_cuda_pool & pool = ctx.pool();
711711
cudaStream_t main_stream = ctx.stream();
712712
const int id = ggml_cuda_get_device();
@@ -853,8 +853,8 @@ void launch_fattn(
853853
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
854854
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
855855
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
856-
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
857-
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
856+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
857+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
858858
Q->nb[1], Q->nb[2], Q->nb[3],
859859
nb11, nb12, nb13,
860860
nb21, nb22, nb23,

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
12241224
const int ne13,
12251225
const int ne31,
12261226
const int ne32,
1227+
const int ne33,
12271228
const int nb31,
12281229
const int nb32,
1230+
const int nb33,
12291231
const int nb01,
12301232
const int nb02,
12311233
const int nb03,

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
3131
const int ne13,
3232
const int ne31,
3333
const int ne32,
34+
const int ne33,
3435
const int nb31,
3536
const int nb32,
37+
const int nb33,
3638
const int nb01,
3739
const int nb02,
3840
const int nb03,
@@ -61,12 +63,14 @@ static __global__ void flash_attn_tile_ext_f16(
6163
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
6264

6365
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
66+
const int i02 = blockIdx.z % ne02;
67+
const int i03 = blockIdx.z / ne02;
6468

6569
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
66-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
67-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
68-
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
69-
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
70+
const float2 * Q_f2 = (const float2 *) (Q + nb03*i03 + nb02*i02 + nb01*ic0);
71+
const half2 * K_h2 = (const half2 *) (K + nb13*i03 + nb12*(i02 / gqa_ratio));
72+
const half2 * V_h2 = (const half2 *) (V + nb23*i03 + nb22*(i02 / gqa_ratio)); // K and V have same shape
73+
const half * maskh = (const half *) (mask + nb33*(i03 % ne33) + nb32*(i02 % ne32) + nb31*ic0);
7074

7175
const int stride_KV2 = nb11 / sizeof(half2);
7276

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
3131
const int ne13,
3232
const int ne31,
3333
const int ne32,
34+
const int ne33,
3435
const int nb31,
3536
const int nb32,
37+
const int nb33,
3638
const int nb01,
3739
const int nb02,
3840
const int nb03,
@@ -73,12 +75,14 @@ static __global__ void flash_attn_tile_ext_f32(
7375
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
7476

7577
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
78+
const int i02 = blockIdx.z % ne02;
79+
const int i03 = blockIdx.z / ne02;
7680

7781
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
78-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
79-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
80-
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
81-
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
82+
const float2 * Q_f2 = (const float2 *) (Q + nb03*i03 + nb02*i02 + nb01*ic0);
83+
const half2 * K_h2 = (const half2 *) (K + nb13*i03 + nb12*(i02 / gqa_ratio));
84+
const half2 * V_h2 = (const half2 *) (V + nb23*i03 + nb22*(i02 / gqa_ratio)); // K and V have same shape
85+
const half * maskh = (const half *) (mask + nb33*(i03 % ne33) + nb32*(i02 % ne32) + nb31*ic0);
8286

8387
const int stride_KV2 = nb11 / sizeof(half2);
8488

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
2828
const int ne13,
2929
const int ne31,
3030
const int ne32,
31+
const int ne33,
3132
const int nb31,
3233
const int nb32,
34+
const int nb33,
3335
const int nb01,
3436
const int nb02,
3537
const int nb03,
@@ -64,13 +66,15 @@ static __global__ void flash_attn_vec_ext_f16(
6466
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
6567

6668
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
69+
const int i02 = blockIdx.z % ne02;
70+
const int i03 = blockIdx.z / ne02;
6771

6872
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
69-
Q += nb02* blockIdx.z + nb01*ic0;
70-
K += nb12*(blockIdx.z / gqa_ratio);
71-
V += nb22*(blockIdx.z / gqa_ratio);
73+
Q += nb03*i03 + nb02*i02 + nb01*ic0;
74+
K += nb13*i03 + nb12*(i02 / gqa_ratio);
75+
V += nb23*i03 + nb22*(i02 / gqa_ratio);
7276

73-
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
77+
const half * maskh = (const half *) (mask + nb33*(i03 % ne33) + nb32*(i02 % ne32) + nb31*ic0);
7478

7579
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
7680
const half slopeh = __float2half(slopef);

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
2828
const int ne13,
2929
const int ne31,
3030
const int ne32,
31+
const int ne33,
3132
const int nb31,
3233
const int nb32,
34+
const int nb33,
3335
const int nb01,
3436
const int nb02,
3537
const int nb03,
@@ -76,13 +78,15 @@ static __global__ void flash_attn_vec_ext_f32(
7678
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
7779

7880
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
81+
const int i02 = blockIdx.z % ne02;
82+
const int i03 = blockIdx.z / ne02;
7983

8084
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
81-
Q += nb02* blockIdx.z + nb01*ic0;
82-
K += nb12*(blockIdx.z / gqa_ratio);
83-
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
85+
Q += nb03*i03 + nb02*i02 + nb01*ic0;
86+
K += nb13*i03 + nb12*(i02 / gqa_ratio);
87+
V += nb23*i03 + nb22*(i02 / gqa_ratio);
8488

85-
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
89+
const half * maskh = (const half *) (mask + nb33*(i03 % ne33) + nb32*(i02 % ne32) + nb31*ic0);
8690

8791
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
8892

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
4747
const int ne13,
4848
const int ne31,
4949
const int ne32,
50+
const int ne33,
5051
const int nb31,
5152
const int nb32,
53+
const int nb33,
5254
const int nb01,
5355
const int nb02,
5456
const int nb03,
@@ -74,6 +76,8 @@ static __global__ void flash_attn_ext_f16(
7476
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
7577

7678
const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
79+
const int i02 = blockIdx.z % ne02;
80+
const int i03 = blockIdx.z / ne02;
7781

7882
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
7983
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
@@ -96,10 +100,10 @@ static __global__ void flash_attn_ext_f16(
96100
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
97101

98102
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
99-
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
100-
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
101-
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
102-
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
103+
const float * Q_f = (const float *) (Q + nb03*i03 + nb02*i02 + nb01*ic0);
104+
const half * K_h = (const half *) (K + nb13*i03 + nb12*(i02 / gqa_ratio));
105+
const half * V_h = (const half *) (V + nb23*i03 + nb22*(i02 / gqa_ratio)); // K and V have same shape
106+
const half * maskh = (const half *) (mask + nb33*(i03 % ne33) + nb32*(i02 % ne32) + nb31*ic0);
103107
const half2 * mask2 = (const half2 *) maskh;
104108

105109
const int stride_Q = nb01 / sizeof(float);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,11 +3376,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33763376
if (op->src[0]->ne[0] == 192) {
33773377
return false;
33783378
}
3379-
// TODO: support broadcast
3380-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
3381-
if (op->src[0]->ne[3] != 1) {
3382-
return false;
3383-
}
33843379
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
33853380
return false;
33863381
}

0 commit comments

Comments
 (0)