@@ -5232,14 +5232,17 @@ static void ggml_compute_forward_soft_max_f32(
5232
5232
memcpy (&scale, (float *) dst->op_params + 0 , sizeof (float ));
5233
5233
memcpy (&max_bias, (float *) dst->op_params + 1 , sizeof (float ));
5234
5234
5235
- // TODO: handle transposed/permuted matrices
5236
-
5237
5235
const int ith = params->ith ;
5238
5236
const int nth = params->nth ;
5239
5237
5240
5238
GGML_TENSOR_UNARY_OP_LOCALS
5241
5239
5242
- // const int64_t ne11 = src1 ? src1->ne[1] : 1;
5240
+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
5241
+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
5242
+ const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
5243
+
5244
+ const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
5245
+ const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
5243
5246
5244
5247
// TODO: is this supposed to be ceil instead of floor?
5245
5248
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -5249,68 +5252,66 @@ static void ggml_compute_forward_soft_max_f32(
5249
5252
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
5250
5253
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
5251
5254
5252
- const int nc = src0->ne [0 ];
5253
- const int nr = ggml_nrows (src0);
5254
-
5255
- // rows per thread
5256
- const int dr = (nr + nth - 1 )/nth;
5257
-
5258
- // row range for this thread
5259
- const int ir0 = dr*ith;
5260
- const int ir1 = MIN (ir0 + dr, nr);
5261
-
5262
- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
5255
+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5263
5256
5264
5257
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5265
5258
5266
- for (int i1 = ir0; i1 < ir1; i1++) {
5267
- // ALiBi
5268
- const uint32_t h = (i1/ne01)%ne02; // head
5269
- const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
5270
-
5271
- float * sp = (float *)((char *) src0->data + i1*src0->nb [1 ]);
5272
- float * dp = (float *)((char *) dst->data + i1*dst->nb [1 ]);
5273
-
5274
- // broadcast the mask across rows
5275
- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
5276
- float * mp_f32 = src1 ? (float *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
5277
-
5278
- ggml_vec_cpy_f32 (nc, wp, sp);
5279
- ggml_vec_scale_f32 (nc, wp, scale);
5280
- if (mp_f32) {
5281
- if (use_f16) {
5282
- for (int i = 0 ; i < nc; ++i) {
5283
- wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
5284
- }
5285
- } else {
5286
- for (int i = 0 ; i < nc; ++i) {
5287
- wp[i] += slope*mp_f32[i];
5259
+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
5260
+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
5261
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5262
+ const int64_t i11 = i01;
5263
+ const int64_t i12 = i02%ne12;
5264
+ const int64_t i13 = i03%ne13;
5265
+
5266
+ // ALiBi
5267
+ const uint32_t h = i02; // head
5268
+ const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
5269
+
5270
+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5271
+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5272
+
5273
+ // broadcast the mask across rows
5274
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
5275
+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
5276
+
5277
+ ggml_vec_cpy_f32 (ne00, wp, sp);
5278
+ ggml_vec_scale_f32 (ne00, wp, scale);
5279
+ if (mp_f32) {
5280
+ if (use_f16) {
5281
+ for (int i = 0 ; i < ne00; ++i) {
5282
+ wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
5283
+ }
5284
+ } else {
5285
+ for (int i = 0 ; i < ne00; ++i) {
5286
+ wp[i] += slope*mp_f32[i];
5287
+ }
5288
+ }
5288
5289
}
5289
- }
5290
- }
5291
5290
5292
5291
#ifndef NDEBUG
5293
- for (int i = 0 ; i < nc ; ++i) {
5294
- // printf("p[%d] = %f\n", i, p[i]);
5295
- assert (!isnan (wp[i]));
5296
- }
5292
+ for (int i = 0 ; i < ne00 ; ++i) {
5293
+ // printf("p[%d] = %f\n", i, p[i]);
5294
+ assert (!isnan (wp[i]));
5295
+ }
5297
5296
#endif
5298
5297
5299
- float max = -INFINITY;
5300
- ggml_vec_max_f32 (nc , &max, wp);
5298
+ float max = -INFINITY;
5299
+ ggml_vec_max_f32 (ne00 , &max, wp);
5301
5300
5302
- ggml_float sum = ggml_vec_soft_max_f32 (nc , dp, wp, max);
5303
- assert (sum > 0.0 );
5301
+ ggml_float sum = ggml_vec_soft_max_f32 (ne00 , dp, wp, max);
5302
+ assert (sum > 0.0 );
5304
5303
5305
- sum = 1.0 /sum;
5306
- ggml_vec_scale_f32 (nc , dp, sum);
5304
+ sum = 1.0 /sum;
5305
+ ggml_vec_scale_f32 (ne00 , dp, sum);
5307
5306
5308
5307
#ifndef NDEBUG
5309
- for (int i = 0 ; i < nc ; ++i) {
5310
- assert (!isnan (dp[i]));
5311
- assert (!isinf (dp[i]));
5312
- }
5308
+ for (int i = 0 ; i < ne00 ; ++i) {
5309
+ assert (!isnan (dp[i]));
5310
+ assert (!isinf (dp[i]));
5311
+ }
5313
5312
#endif
5313
+ }
5314
+ }
5314
5315
}
5315
5316
}
5316
5317
@@ -7766,7 +7767,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7766
7767
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
7767
7768
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
7768
7769
7769
- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
7770
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
7770
7771
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu (k_vec_dot_type)->from_float ;
7771
7772
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu (k->type )->vec_dot ;
7772
7773
ggml_to_float_t const v_to_float = ggml_get_type_traits (v->type )->to_float ;
@@ -7798,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
7798
7799
memset (VKQ32, 0 , DV*sizeof (float ));
7799
7800
}
7800
7801
7801
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ]) : NULL ;
7802
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ] + (iq3%mask-> ne [ 2 ])*mask-> nb [ 2 ] ) : NULL ;
7802
7803
7803
7804
// k indices
7804
7805
const int ik3 = iq3 / rk3;
0 commit comments