diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index b9d3639448500..651943fa92380 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -20,8 +20,8 @@ #define N_R0_Q5_1 4 #define N_SG_Q5_1 2 -#define N_R0_Q8_0 4 -#define N_SG_Q8_0 2 +#define N_R0_Q8_0 2 +#define N_SG_Q8_0 4 #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 @@ -68,6 +68,11 @@ #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 +// function constants offsets +#define FC_FLASH_ATTN_EXT 100 +#define FC_FLASH_ATTN_EXT_VEC 200 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -236,9 +241,11 @@ typedef struct { int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; + int32_t ns10; uint64_t nb11; uint64_t nb12; uint64_t nb13; + int32_t ns20; uint64_t nb21; uint64_t nb22; uint64_t nb23; @@ -258,10 +265,43 @@ typedef struct { float logit_softcap; } ggml_metal_kargs_flash_attn_ext; +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + int32_t ns10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ns20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; + int32_t ne1; + int32_t ne2; + int32_t ne3; + float scale; + float max_bias; + float m0; + float m1; + int32_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext_vec; + typedef struct { int32_t nrows; - int32_t ne20; -} ggml_metal_kargs_flash_attn_ext_reduce; +} ggml_metal_kargs_flash_attn_ext_vec_reduce; typedef struct { int32_t ne00; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c1a0a2bef171e..f2b19e0dd42ae 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -174,6 +174,19 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte id pipeline; }; +@interface ggml_metal_kernel_wrapper : NSObject + +@property (nonatomic, assign) struct ggml_metal_kernel kernel; + +@end + +@implementation ggml_metal_kernel_wrapper +- (void) dealloc { + [_kernel.pipeline release]; + [super dealloc]; +} +@end + enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, @@ -454,126 +467,6 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, GGML_METAL_KERNEL_TYPE_SET_I32, GGML_METAL_KERNEL_TYPE_SET_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, @@ -882,8 +775,12 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) { dispatch_queue_t d_queue; + // the set of pre-compiled kernels for this context struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; + // additional, inference-time compiled kernels + NSMutableDictionary * kernels_ext; + // capture state bool capture_next_compute; bool capture_started; @@ -949,6 +846,8 @@ @implementation GGMLMetalClass // - if not found, load the source and compile it // - if that fails, return NULL static id ggml_metal_load_library(id device, bool use_bfloat) { + const int64_t t_start = ggml_time_us(); + id metal_library = nil; NSError * error = nil; NSString * src = nil; @@ -1072,6 +971,8 @@ @implementation GGMLMetalClass [src release]; #endif // GGML_METAL_EMBED_LIBRARY + GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6); + return metal_library; } @@ -1269,7 +1170,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); @@ -1487,126 +1388,6 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); @@ -1651,9 +1432,219 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } + ctx->kernels_ext = [[NSMutableDictionary alloc] init]; + return ctx; } +static id ggml_metal_get_kernel(struct ggml_backend_metal_context * ctx, const char * name) { + NSString * key = [NSString stringWithUTF8String:name]; + + ggml_metal_kernel_wrapper * obj = [ctx->kernels_ext objectForKey:key]; + if (obj) { + return obj.kernel.pipeline; + } + + return nil; +} + +static id ggml_metal_compile_kernel(ggml_backend_t backend, const char * base, const char * name, MTLFunctionConstantValues * cv) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + id res = nil; + + @autoreleasepool { + NSError * error = nil; + + NSString * base_func = [NSString stringWithUTF8String:base]; + + GGML_LOG_DEBUG("%s: compiling kernel: base = '%s', name = '%s'\n", __func__, base, name); + + // TODO: make sure it is thread-safe to compile kernels in parallel + id metal_function = [ctx_dev->mtl_library newFunctionWithName:base_func constantValues:cv error:&error]; + if (!metal_function) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + + return nil; + } + + struct ggml_metal_kernel kernel = { + /*.pipeline =*/ [ctx_dev->mtl_device newComputePipelineStateWithFunction:metal_function error:&error], + }; + + ggml_metal_kernel_wrapper * obj = [[ggml_metal_kernel_wrapper alloc] init]; + obj.kernel = kernel; + + res = obj.kernel.pipeline; + + NSString * key = [NSString stringWithUTF8String:name]; + [ctx->kernels_ext setObject:obj forKey:key]; + + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) kernel.pipeline, + (int) kernel.pipeline.maxTotalThreadsPerThreadgroup, + (int) kernel.pipeline.threadExecutionWidth); + } + + return res; +} + +static id ggml_metal_get_pipeline_flash_attn_ext( + ggml_backend_t backend, struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d", + "flash_attn_ext", + ggml_type_name(op->src[1]->type), + dk, + dv, + has_mask, + has_sinks, + has_bias, + has_scap, + ns10, + ns20, + nsg); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + cv = [[MTLFunctionConstantValues alloc] init]; + + [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 0]; + [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 1]; + [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 2]; + [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 3]; + + [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 20]; + [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 21]; + [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 22]; + + return ggml_metal_compile_kernel(backend, base, name, cv); + } +} + +static id ggml_metal_get_pipeline_flash_attn_ext_vec( + ggml_backend_t backend, struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg, + int32_t nwg) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext_vec", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + "flash_attn_ext_vec", + ggml_type_name(op->src[1]->type), + dk, + dv, + has_mask, + has_sinks, + has_bias, + has_scap, + ns10, + ns20, + nsg, nwg); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + cv = [[MTLFunctionConstantValues alloc] init]; + + [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 0]; + [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 1]; + [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 2]; + [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 3]; + + [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 20]; + [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 21]; + [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 22]; + [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 23]; + + return ggml_metal_compile_kernel(backend, base, name, cv); + } +} + +static id ggml_metal_get_pipeline_flash_attn_ext_vec_reduce( + ggml_backend_t backend, struct ggml_tensor * op, + int32_t dv, + int32_t nwg) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init]; + + snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); + snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + cv = [[MTLFunctionConstantValues alloc] init]; + + [cv setConstantValue:&dv type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 0]; + [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 1]; + + return ggml_metal_compile_kernel(backend, base, name, cv); + } + + GGML_UNUSED(op); +} + static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { GGML_LOG_INFO("%s: deallocating\n", __func__); @@ -1661,6 +1652,11 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->kernels[i].pipeline release]; } + if (ctx->kernels_ext) { + [ctx->kernels_ext release]; + ctx->kernels_ext = nil; + } + Block_release(ctx->encode_async); [ctx->queue release]; @@ -3765,6 +3761,7 @@ static int ggml_metal_encode_node( { nsg = N_SG_Q8_0; nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_MXFP4: @@ -3901,7 +3898,12 @@ static int ggml_metal_encode_node( if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + if (src0t == GGML_TYPE_Q8_0) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } } break; case GGML_OP_MUL_MAT_ID: @@ -4122,6 +4124,7 @@ static int ggml_metal_encode_node( { nsg = N_SG_Q8_0; nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_MXFP4: @@ -4267,7 +4270,12 @@ static int ggml_metal_encode_node( if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + if (src0t == GGML_TYPE_Q8_0) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } } break; case GGML_OP_GET_ROWS: @@ -5118,6 +5126,7 @@ static int ggml_metal_encode_node( float scale; float max_bias; float logit_softcap; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); @@ -5126,398 +5135,24 @@ static int ggml_metal_encode_node( scale /= logit_softcap; } + const bool has_mask = src3 != NULL; + const bool has_sinks = src4 != NULL; + const bool has_bias = max_bias != 0.0f; + const bool has_scap = logit_softcap != 0.0f; + const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - id pipeline = nil; - - bool use_vec_kernel = false; + GGML_ASSERT(ne01 < 65536); // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size - if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) { - switch (src1->type) { - case GGML_TYPE_F16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_BF16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q8_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 64: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 96: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 128: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 192: - { - if (ne20 == 128) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } - } break; - case 256: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 576: - { - if (ne20 == 512) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - GGML_LOG_ERROR("unsupported size: %lld\n", ne20); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - ggml_metal_kargs_flash_attn_ext args = { - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne11 =*/ ne11, - /*.ne_12_2 =*/ ne12, - /*.ne_12_3 =*/ ne13, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb23 =*/ nb23, - /*.ne32 =*/ ne32, - /*.ne33 =*/ ne33, - /*.nb31 =*/ nb31, - /*.nb32 =*/ nb32, - /*.nb33 =*/ nb33, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.scale =*/ scale, - /*.max_bias =*/ max_bias, - /*.m0 =*/ m0, - /*.m1 =*/ m1, - /*.n_head_log2 =*/ n_head_log2, - /*.logit_softcap =*/ logit_softcap, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; - } - if (id_src4) { - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; - } - - if (!use_vec_kernel) { + if (ne01 >= 20 || (ne00 % 32 != 0)) { // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 8 == 0); @@ -5525,34 +5160,90 @@ static int ggml_metal_encode_node( const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; - // 2*(2*ncpsg + nqptg)*(nsg) - // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) + // 2*(2*ncpsg) + // ncpsg soft_max values + ncpsg mask values // // 16*32*(nsg) // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) - - int64_t nsgmax = 2; +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - if (smem > device.maxThreadgroupMemoryLength/2) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; + //int64_t nsgmax = 4; + // + //if (is_q) { + // nsgmax = 2; + // while (true) { + // const size_t smem = FATTN_SMEM(nsgmax); + // if (smem > device.maxThreadgroupMemoryLength/2) { + // break; + // } + // nsgmax *= 2; + // } + // nsgmax /= 2; + //} // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + int32_t nsg = 4; const size_t smem = FATTN_SMEM(nsg); + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ nb11/nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ nb21/nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + id pipeline = ggml_metal_get_pipeline_flash_attn_ext(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; + } + if (id_src4) { + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); + //printf("smem: %zu, max: %zu, nsg = %d, ne02 = %d, ne12 = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, ne02, ne12); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; @@ -5561,7 +5252,7 @@ static int ggml_metal_encode_node( // half4x4 kernel const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable + const int64_t nkpsg = 1*ncpsg; GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); @@ -5574,8 +5265,7 @@ static int ggml_metal_encode_node( // ne20*(nsg) // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) -//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; while (true) { @@ -5589,7 +5279,8 @@ static int ggml_metal_encode_node( nsgmax /= 2; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32))); int64_t nsg = 1; while (nsg <= nsgt) { @@ -5599,28 +5290,86 @@ static int ggml_metal_encode_node( // workgroups // each workgroup handles nsg*nkpsg cache values - uint16_t nwg = 1; - if (4*nsg*nkpsg >= ne11) { - const size_t smem = FATTN_SMEM(nsg); + int32_t nwg = 1; + if (false) { + // for small KV caches, we could launch a single workgroup and write the results directly to dst/ + // however, this does not lead to significant improvement, so disabled + nwg = 1; + nsg = 4; + } else { + nwg = 32; + nsg = 1; + while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) { + nsg *= 2; + } + } - //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + ggml_metal_kargs_flash_attn_ext_vec args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ nb11/nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ nb21/nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; - // using 1 workgroup -> write the result directly into dst - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7]; + id pipeline = ggml_metal_get_pipeline_flash_attn_ext_vec(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg, nwg); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + GGML_ASSERT(nsg*32 <= (int) pipeline.maxTotalThreadsPerThreadgroup); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; } else { - nwg = 32; - nsg = MIN(4, nsg); + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; + } + if (id_src4) { + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; + } - const size_t smem = FATTN_SMEM(nsg); + const size_t smem = FATTN_SMEM(nsg); - //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + + if (nwg == 1) { + // using 1 workgroup -> write the result directly into dst + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { // sanity checks GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31)); @@ -5640,20 +5389,18 @@ static int ggml_metal_encode_node( //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20); //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f); - [encoder setBuffer:h_tmp offset:0 atIndex:6]; - [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7]; + [encoder setBuffer:h_tmp offset:0 atIndex:6]; [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; // reduce the results from the workgroups { - ggml_metal_kargs_flash_attn_ext_reduce args0 = { + ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { nrows, - ne20, }; - id pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline; + id pipeline0 = ggml_metal_get_pipeline_flash_attn_ext_vec_reduce(backend, node, ne20, nwg); [encoder setComputePipelineState:pipeline0]; [encoder setBytes:&args0 length:sizeof(args0) atIndex:0]; @@ -5661,7 +5408,7 @@ static int ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20); - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*nwg, 1, 1)]; } } #undef FATTN_SMEM diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2d56c62674c8e..a5b009dd9974d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -15,6 +15,10 @@ using namespace metal; #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x) + #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf @@ -2755,7 +2759,47 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -template +template +static inline void helper_mv_reduce_and_write( + device float * dst_f32, + float sumf[NR0], + const int r0, + const int ne01, + ushort tiisg, + ushort sgitg, + threadgroup char * shmem) { + threadgroup float * shmem_f32[NR0]; + + for (short row = 0; row < NR0; ++row) { + shmem_f32[row] = (threadgroup float *) shmem + NW*row; + + if (sgitg == 0) { + shmem_f32[row][tiisg] = 0.0f; + } + + sumf[row] = simd_sum(sumf[row]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0; ++row) { + if (tiisg == 0) { + shmem_f32[row][sgitg] = sumf[row]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0 && r0 + row < ne01; ++row) { + float tot = simd_sum(shmem_f32[row][tiisg]); + + if (tiisg == 0 && sgitg == 0) { + dst_f32[r0 + row] = tot; + } + } +} + +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -2765,45 +2809,51 @@ void mul_vec_q_n_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const int nb = args.ne00/QK4_0; + constexpr short NQ = 16; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; + const int nb = args.ne00/QK4_0; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int r0 = (tgpig.x*NSG + sgitg)*NR0; + //const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q_type * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } - float yl[16]; // src1 vector cache - float sumf[nr0] = {0.f}; + float sumf[NR0] = {0.f}; - const short ix = (tiisg/2); - const short il = (tiisg%2)*8; + const short ix = (tiisg/(NW/NQ)); + const short il = (tiisg%(NW/NQ))*8; - device const float * yb = y + ix*QK4_0 + il; + //const int ib0 = sgitg*NQ + ix; + const int ib0 = ix; + + float yl[16]; // src1 vector cache + + //device const float * yb = y + ix*QK4_0 + il; + device const float * yb = y + ib0*QK4_0 + il; // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { + //for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (int ib = ib0; ib < nb; ib += NQ) { float sumy[2] = { 0.f, 0.f }; -#pragma unroll - for (short i = 0; i < 8; i += 2) { + FOR_UNROLL (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -2813,21 +2863,23 @@ void mul_vec_q_n_f32_impl( yl[i + 9] = yb[i + 17]/4096.f; } -#pragma unroll - for (short row = 0; row < nr0; row++) { + FOR_UNROLL (short row = 0; row < NR0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } yb += QK4_0 * 16; + //yb += NSG*NQ*QK4_0; } device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr0; ++row) { + //helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); + + for (int row = 0; row < NR0; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; } } } @@ -2837,10 +2889,11 @@ kernel void kernel_mul_mv_q4_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -2848,10 +2901,11 @@ kernel void kernel_mul_mv_q4_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -2859,10 +2913,11 @@ kernel void kernel_mul_mv_q5_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -2870,15 +2925,14 @@ kernel void kernel_mul_mv_q5_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -#define NB_Q8_0 8 - -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -2888,66 +2942,65 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + constexpr short NQ = 8; + const int nb = args.ne00/QK8_0; - const int r0 = tgpig.x; + const int r0 = tgpig.x*NR0; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; - const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q8_0 * ax[NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } - float yl[NB_Q8_0]; - float sumf[nr0] = { 0.f }; + float sumf[NR0] = { 0.f }; + + const short ix = tiisg/(NW/NQ); + const short il = tiisg%(NW/NQ); + + const int ib0 = sgitg*NQ + ix; - const short ix = tiisg/4; - const short il = tiisg%4; + float yl[NQ]; - device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; + device const float * yb = y + ib0*QK8_0 + il*NQ; - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (short i = 0; i < NB_Q8_0; ++i) { + // each thread in a SIMD group deals with NQ quants at a time + for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (short i = 0; i < NQ; ++i) { yl[i] = yb[i]; } - for (short row = 0; row < nr0; row++) { - device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; + for (short row = 0; row < NR0; row++) { + device const int8_t * qs = ax[row][ib].qs + il*NQ; + float sumq = 0.f; - for (short iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; + FOR_UNROLL (short i = 0; i < NQ; ++i) { + sumq += qs[i] * yl[i]; } + sumf[row] += sumq*ax[row][ib].d; } - yb += nw*NB_Q8_0; + yb += NSG*NQ*QK8_0; } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0; ++row) { - const float tot = simd_sum(sumf[row]); - - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; - } - } + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } [[host_name("kernel_mul_mv_q8_0_f32")]] @@ -2956,10 +3009,11 @@ kernel void kernel_mul_mv_q8_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 @@ -4197,6 +4251,19 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; } +constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; +constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; +constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; +constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; + +//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; +//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; +//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]]; + +constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]]; +constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]]; +constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]]; + // ref: https://arxiv.org/pdf/2307.08691.pdf template< typename q_t, // query types in shared memory @@ -4211,6 +4278,7 @@ template< typename qk_t, // Q*K types typename qk8x8_t, typename s_t, // soft-max types + typename s2_t, typename s8x8_t, typename o_t, // attention accumulation types typename o4_t, @@ -4221,12 +4289,12 @@ template< typename vd4x4_t, // value type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short DK, // K head size - short DV, // V head size - short Q = 8, // queries per threadgroup - short KV = 8, // key/value processed per each simdgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext( + short DK, // K head size + short DV, // V head size + short Q, // queries per threadgroup + short C, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_impl( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, @@ -4234,46 +4302,85 @@ kernel void kernel_flash_attn_ext( device const char * mask, device const char * sinks, device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups - - const int iq3 = tgpig[2]; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]*Q; + threadgroup half * shmem_f16, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const ushort iq3 = tgpig[2]; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]*Q; + +#define NS10 (FC_flash_attn_ext_ns10) +#define NS20 (FC_flash_attn_ext_ns20) + + // note: I had some concerns that using this instead of the ugly macros above was affecting performance + // need to re-check carefully and if no regressions are observerd - remove the macros + // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler + // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC + //const short NS10 = FC_flash_attn_ext_ns10; + //const short NS20 = FC_flash_attn_ext_ns20; + + constexpr short KV = 8; constexpr short DK4 = DK/4; constexpr short DK8 = DK/8; constexpr short DK16 = DK/16; constexpr short DV4 = DV/4; - constexpr short DV8 = DV/8; + //constexpr short DV8 = DV/8; constexpr short DV16 = DV/16; + constexpr short PV = PAD2(DV, 64); + constexpr short PV4 = PV/4; + constexpr short PV8 = PV/8; + //constexpr short PV16 = PV/16; + constexpr short NW = N_SIMDWIDTH; - constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + constexpr short NQ = Q/NSG; + constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float) + + constexpr short TS = 2*SH; + constexpr short T = DK + 2*PV; // shared memory size per query in (half) + + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper) + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK); + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix + threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t + + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = 2*DK + 2*TS; // shared memory size per query in (half) + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix + // mask storage in shared mem + threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C); - threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + // per-query mask pointers + device const half2 * pm2[NQ]; - threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o8x8_t lo[DV8]; + pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + } + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01); for (short i = tiisg; i < DK4; i += NW) { if (iq1 + j < args.ne01) { @@ -4284,43 +4391,30 @@ kernel void kernel_flash_attn_ext( } } - // zero out lo - for (short i = 0; i < DV8; ++i) { - lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); - } + // zero out + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] = 0; + } - // zero out shared memory SH - for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TS + i] = 0.0f; + ss[j*SH + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); - { - float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; - - // thread indices inside the simdgroup - // TODO: see if we can utilize quad-group functions for better performance - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) - const short tx = tiisg%4; - const short ty = tiisg/4; - - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); + float S[NQ] = { [0 ... NQ-1] = 0.0f }; - const bool has_mask = mask != q; + { + float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 }; float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4331,177 +4425,277 @@ kernel void kernel_flash_attn_ext( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= args.ne11) { - break; - } - - if (has_mask) { - // used to detect blocks full of -INF - float smax = -INFINITY; + for (int ic = 0; ic < args.ne11; ic += C) { + // read the mask into shared mem + if (FC_flash_attn_ext_has_mask) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + pm2[jj] += NW; + } - // load the mask in shared memory - #pragma unroll(Q) - for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + threadgroup_barrier(mem_flags::mem_threadgroup); - const float m = pm[ic + tiisg]; + // used to detect blocks full of -INF + // skip only when the entire threadgroup is masked + half2 smax2(-MAXHALF/2, -MAXHALF/2); - ss[j*TS + C + tiisg] = m; - smax = max(smax, m); + FOR_UNROLL (short j = 0; j < Q; ++j) { + smax2 = max(smax2, sm2[j*SH + tiisg]); } - smax = simd_max(smax); + smax2 = simd_max(smax2); + + if (max(smax2[0], smax2[1]) <= -MAXHALF/2) { + // this barrier is important + threadgroup_barrier(mem_flags::mem_threadgroup); - if (smax == -INFINITY) { continue; } } // Q*K^T - { - for (short cc = 0; cc < C/8; ++cc) { + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + threadgroup const q_t * pq = sq; + threadgroup s_t * ps = ss; + + pk += sgitg*(8*NS10); + ps += sgitg*(8*1); + + static_assert((C/8) % NSG == 0, ""); + + constexpr short NC = (C/8)/NSG; + + // TODO: not good to unroll for large contexts - not sure why? + for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - // this is compile-time check, so it does not have runtime overhead - if (is_same::value) { - // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + if (DK8 % 16 != 0) { + k8x8_t mk; + q8x8_t mq; + + FOR_UNROLL (short i = 0; i < DK8; ++i) { + simdgroup_barrier(mem_flags::mem_none); + + simdgroup_load(mk, pk, NS10, 0, true); + simdgroup_load(mq, pq, DK); - #pragma unroll(DK8) - for (short i = 0; i < DK8; ++i) { - k8x8_t mk; - simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + simdgroup_barrier(mem_flags::mem_none); - q8x8_t mq; - simdgroup_load(mq, sq + i*8, DK); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + pk += 8; + pq += 8; } } else { - for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + k8x8_t mk[2]; + q8x8_t mq[2]; - if (DK16%4 == 0) { - // the head is evenly divisible by 4*16 = 64, so no need for bound checks - { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + simdgroup_barrier(mem_flags::mem_none); - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_load(mq[0], pq + 0*8, DK); + simdgroup_load(mq[1], pq + 1*8, DK); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } - } else { - if (ii + tx < DK16) { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); + simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - simdgroup_barrier(mem_flags::mem_threadgroup); + pk += 16; + pq += 16; + } + } - for (short k = 0; k < 4 && ii + k < DK16; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_store(mqk, ps, SH, 0, false); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + pk += 8*(NSG*NS10 - DK8); + pq += 8*(NSG*0 - DK8); + ps += 8*(NSG); + } + } else { + // TODO: this is the quantized K cache branch - not optimized yet + for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) { + const short cc = ccc*NSG + sgitg; - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } + const short tx = tiisg%4; + const short ty = tiisg/4; + + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } } - // cast qk_t -> s_t - //s8x8_t mqks(1.0f); - //simdgroup_multiply(mqks, mqk, mqks); - //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); - - simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + simdgroup_store(mqk, ss + 8*cc, SH, 0, false); } } + threadgroup_barrier(mem_flags::mem_threadgroup); + // online softmax - { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - // scale and apply the logitcap / mask - float s = ss[j*TS + tiisg]*args.scale; + const float m = M[jj]; - if (args.logit_softcap != 0.0f) { - s = args.logit_softcap*precise::tanh(s); - } + // scale and apply the logitcap / mask + float2 s2 = ss2[j*SH/2 + tiisg]*args.scale; + + if (FC_flash_attn_ext_has_scap) { + s2 = args.logit_softcap*precise::tanh(s2); + } - // mqk = mqk + mask*slope - s += slope*ss[j*TS + C + tiisg]; + // mqk = mqk + slope*mask + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } - M[j] = simd_max(max(M[j], s)); + M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float ms = exp(m - M[jj]); + const float2 vs2 = exp(s2 - M[jj]); - S[j] = S[j]*ms + simd_sum(vs); + S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]); - // the P matrix from the paper (Q rows, C columns) - ss[j*TS + tiisg] = vs; + // the P matrix from the paper (Q rows, C columns) + ss2[j*SH/2 + tiisg] = vs2; - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; + + so4[j*PV4 + i] *= ms; + } + } else { + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); - } - } + threadgroup_barrier(mem_flags::mem_threadgroup); // O = O + (Q*K^T)*V { - for (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, ss + 8*cc, TS, 0, false); + // we can read directly from global memory + if (is_same::value) { + static_assert(PV8 % NSG == 0, ""); + + constexpr short NO = PV8/NSG; - if (is_same::value) { - // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + o8x8_t lo[NO]; - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - v8x8_t mv; - simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 + { + auto sot = so + 8*sgitg; - simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_load(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; } - } else { - for (short ii = 0; ii < DV16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + } + + { + auto sst = ss; + + device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + + pv += 8*sgitg; + + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, sst, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + v8x8_t mv; + + simdgroup_load(mv, pv, NS20, 0, false); + simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + + pv += 8*NSG; + } + + pv += 8*(NS20 - NO*NSG); + sst += 8; + } + } + + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_store(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; + } + } + } else { + // TODO: this is the quantized V cache branch - not optimized yet + + const short tx = tiisg%4; + const short ty = tiisg/4; + + for (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); + + for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -4513,15 +4707,20 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - v8x8_t mv; + FOR_UNROLL (short k = 0; k < 4; ++k) { + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } else { if (ii + tx < DV16) { @@ -4533,243 +4732,249 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); for (short k = 0; k < 4 && ii + k < DV16; ++k) { - v8x8_t mv; + v8x8_t mv[2]; + o8x8_t lo[2]; + + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } } } } } - } - if (sinks != q && sgitg == 0) { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; - const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; + threadgroup_barrier(mem_flags::mem_threadgroup); + } - M[j] = simd_max(max(M[j], s)); + if (FC_flash_attn_ext_has_sinks) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float m = M[jj]; + const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; - S[j] = S[j]*ms + simd_sum(vs); + M[jj] = simd_max(max(M[jj], s)); - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; - } - } + const float ms = exp(m - M[jj]); + const float vs = exp(s - M[jj]); - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); + S[jj] = S[jj]*ms + simd_sum(vs); - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = tiisg; j < Q; j += NW) { - ss[j*TS + 0] = S[j]; - ss[j*TS + 1] = M[j]; - } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); - - // store result to shared memory in F32 - if (sgitg == 0) { - for (short i = 0; i < DV8; ++i) { - //simdgroup_store(lo[i], so + i*8, DV, 0, false); - simdgroup_float8x8 t(1.0f); - simdgroup_multiply(t, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); + // store to global memory + for (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + if (iq1 + j >= args.ne01) { + break; } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps sequentially - for (ushort sg = 1; sg < nsg; ++sg) { - if (sgitg == sg) { - for (short j = tiisg; j < Q; j += NW) { - const float S0 = ss[j*TS - 1*SH + 0]; - const float S1 = ss[j*TS + 0]; - const float M0 = ss[j*TS - 1*SH + 1]; - const float M1 = ss[j*TS + 1]; - - const float M = max(M0, M1); - - float ms0 = exp(M0 - M); - float ms1 = exp(M1 - M); + device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - const float S = S0*ms0 + S1*ms1; + const float scale = 1.0f/S[jj]; - ss[j*TS + 0] = S; - ss[j*TS + 1] = M; + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; - ss[j*TS + 2*C + j - 1*SH] = ms0; - ss[j*TS + 2*C + j ] = ms1; + dst4[i] = (float4) so4[j*PV4 + i]*scale; } - - //simdgroup_barrier(mem_flags::mem_threadgroup); - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { - s8x8_t ms0; - s8x8_t ms1; - - simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); - simdgroup_load(ms1, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_float8x8 t; - - simdgroup_load (t, so + i*8, DV, 0, false); - simdgroup_multiply(t, ms0, t); - - simdgroup_multiply_accumulate(t, ms1, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); - } + } else { + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*PV4 + i]*scale; } } - - threadgroup_barrier(mem_flags::mem_threadgroup); } - threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); - - // final rescale with 1/S and store to global memory - for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { - const float S = 1.0f/sf[j*TS + 0]; - - device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; +#undef NS10 +#undef NS20 +} - for (short i = tiisg; i < DV4; i += NW) { - dst4[i] = (float4) so4[j*DV4 + i]*S; - } +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s2_t, + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // value type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = 8, // queries per threadgroup + short C = 64> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_nsg) { + // note: disabled cases to reduce library load time + //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; + //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; } +#undef FWD_TMPL +#undef FWD_ARGS } // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as // template to be able to explore different combinations // #define FA_TYPES \ - float, float4, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 - //float, float4, simdgroup_float8x8 + float, float2, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 #define FA_TYPES_BF \ bfloat, bfloat4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ + float, float2, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 //float, float4, simdgroup_float8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif -template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES #undef FA_TYPES_BF +constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; +constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; +constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; +constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; + +//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; +//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; +//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]]; + +constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]]; +constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]]; +constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]]; +constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]]; + template< typename q4_t, // query types in shared memory typename k4_t, // key types in shared memory @@ -4788,63 +4993,86 @@ template< short DV, // V head size short NE = 4, // head elements per thread short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext & args, + short C = 32, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_vec_impl( + constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, device const char * v, device const char * mask, device const char * sinks, device char * dst, - constant uint16_t & nwg, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { static_assert(DK % 32 == 0, "DK must be divisible by 32"); static_assert(DV % 32 == 0, "DV must be divisible by 32"); - const short nsg = ntg.y; // number of simdgroups - const short iwg = tgpig[2]%nwg; +#define NWG (FC_flash_attn_ext_vec_nwg) + +#define NS10 (FC_flash_attn_ext_vec_ns10) +#define NS20 (FC_flash_attn_ext_vec_ns20) + + const short iwg = tgpig[2]%NWG; - const int iq3 = tgpig[2]/nwg; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]; + const ushort iq3 = tgpig[2]/NWG; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]; constexpr short DK4 = DK/4; constexpr short DV4 = DV/4; + + constexpr short PK = PAD2(DK, 128); + constexpr short PK4 = PK/4; + + constexpr short PV = PAD2(DV, 128); + constexpr short PV4 = PV/4; + constexpr short NW = N_SIMDWIDTH; constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short SH = 4*C; // shared memory per simdgroup - const short T = DK + nsg*SH; // shared memory size per query in (half) + static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); + static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); + + const short T = PK + NSG*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results - // store the result for all queries in local memory (the O matrix from the paper) - o4_t lo[DV4/NL]; + // store the result for all queries in shared memory (the O matrix from the paper) + so4 += tiisg; + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < DK4; i += NW) { - if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (iq1 < args.ne01 && i < DK4) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; } } - // zero out lo + // zero out so for (short i = 0; i < DV4/NL; ++i) { - lo[i] = (o4_t) 0.0f; + so4[i*NL] = (o4_t) 0.0f; } // zero out shared memory SH @@ -4856,28 +5084,19 @@ kernel void kernel_flash_attn_ext_vec( { float S = 0.0f; - float M = -__FLT_MAX__/2; + float M = -FLT_MAX/2; // thread indices inside the simdgroup const short tx = tiisg%NL; const short ty = tiisg/NL; - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; - // pointer to the mask device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_vec_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4888,13 +5107,13 @@ kernel void kernel_flash_attn_ext_vec( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) { + for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { const int ic = ic0 + C*sgitg; if (ic >= args.ne11) { break; } - if (has_mask) { + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -4905,69 +5124,81 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { - // each simdgroup processes 1 query and NE (NW/NL) head elements - for (short cc = 0; cc < C/NE; ++cc) { - qk_t mqk = 0.0f; + device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + threadgroup const q4_t * pq4 = sq4; - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + pk4 += ty*NS10/4 + tx; + pq4 += tx; - #pragma unroll(DK4/NL) - for (short ii = 0; ii < DK4; ii += NL) { - const short i = ii + tx; + qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f }; + + // each simdgroup processes 1 query and NE (NW/NL) cache elements + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + if (is_same::value) { + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); + } + } else { + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; - deq_k_t4(pk + i/nl_k, i%nl_k, mk); - // note: this is less precise than the version below - //mqka[0] += dot(mq[0], mk[0]); - //mqka[1] += dot(mq[1], mk[1]); - //mqka[2] += dot(mq[2], mk[2]); - //mqka[3] += dot(mq[3], mk[3]); + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + const short i = ii*NL + tx; - //q4x4_t mq = sq4x4[i]; - //mqka[0] += dot((float4) mq[0], (float4) mk[0]); - //mqka[1] += dot((float4) mq[1], (float4) mk[1]); - //mqka[2] += dot((float4) mq[2], (float4) mk[2]); - //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + deq_k_t4(pk + i/nl_k, i%nl_k, mk); - mqk += dot((float4) mk, (float4) sq4[i]); + mqk[cc] += dot((float4) mk, (float4) sq4[i]); + } } - static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails + if (NE == 1) { + mqk[cc] = simd_sum(mqk[cc]); + } else { + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk[cc] += simd_shuffle_down(mqk[cc], 16); + } + if (NE <= 2) { + mqk[cc] += simd_shuffle_down(mqk[cc], 8); + } + if (NE <= 4) { + mqk[cc] += simd_shuffle_down(mqk[cc], 4); + } + if (NE <= 8) { + mqk[cc] += simd_shuffle_down(mqk[cc], 2); + } + if (NE <= 16) { + mqk[cc] += simd_shuffle_down(mqk[cc], 1); + } - // simdgroup reduce (NE = 4) - // [ 0 .. 7] -> [ 0] - // [ 8 .. 15] -> [ 8] - // [16 .. 23] -> [16] - // [24 .. 31] -> [24] - if (NE <= 1) { - mqk += simd_shuffle_down(mqk, 16); - } - if (NE <= 2) { - mqk += simd_shuffle_down(mqk, 8); - } - if (NE <= 4) { - mqk += simd_shuffle_down(mqk, 4); - } - if (NE <= 8) { - mqk += simd_shuffle_down(mqk, 2); + // broadcast + mqk[cc] = simd_shuffle(mqk[cc], NL*ty); } - if (NE <= 16) { - mqk += simd_shuffle_down(mqk, 1); - } - - // mqk = mqk*scale + mask*slope - if (tx == 0) { - mqk *= args.scale; + } - if (args.logit_softcap != 0.0f) { - mqk = args.logit_softcap*precise::tanh(mqk); - } + if (FC_flash_attn_ext_vec_has_mask && + !FC_flash_attn_ext_vec_has_scap && + !FC_flash_attn_ext_vec_has_bias) { + ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]); + } else { + mqk[tx] *= args.scale; - mqk += sm[NE*cc + ty]*slope; + if (FC_flash_attn_ext_vec_has_scap) { + mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]); + } - ss[NE*cc + ty] = mqk; + if (FC_flash_attn_ext_vec_has_bias) { + mqk[tx] += (qk_t) sm[NE*tx + ty]*slope; + } else { + mqk[tx] += (qk_t) sm[NE*tx + ty]; } + + ss[NE*tx + ty] = mqk[tx]; } } @@ -4989,9 +5220,10 @@ kernel void kernel_flash_attn_ext_vec( ss[tiisg] = vs; // O = diag(ms)*O - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -4999,26 +5231,84 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - //#pragma unroll(C/NE) - for (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + o4_t lo[DV4/NL]; + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] = 0.0f; + } - const s4_t ms(ss[NE*cc + ty]); + if (is_same::value) { + device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - const short i = ii + tx; + pv4 += ty*NS20/4 + tx; - v4_t mv; - deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + const auto sst = ss + ty; - lo[ii/NL] += o4_t(float4(mv)*float4(ms)); + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE])); + } + } + } else { + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + const short i = ii*NL + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); + } + } + } + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + if (NE > 1) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 16); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 16); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 16); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 16); + } + + if (NE > 2) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 8); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 8); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 8); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 8); + } + + if (NE > 4) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 4); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 4); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 4); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 4); + } + + if (NE > 8) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 2); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 2); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 2); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 2); + } + + if (NE > 16) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 1); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 1); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 1); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 1); + } + } + + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] += lo[ii]; } } } } - if (sinks != q && sgitg == 0 && iwg == 0) { + if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) { const float m = M; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; @@ -5029,9 +5319,10 @@ kernel void kernel_flash_attn_ext_vec( S = S*ms + simd_sum(vs); -#pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -5042,63 +5333,12 @@ kernel void kernel_flash_attn_ext_vec( } } - // simdgroup reduce (NE = 4) - // [ 0, 8, 16, 24] -> [ 0] - // [ 1, 9, 17, 25] -> [ 1] - // [ 2, 10, 18, 26] -> [ 2] - // [ 3, 11, 19, 27] -> [ 3] - // [ 4, 12, 20, 28] -> [ 4] - // [ 5, 13, 21, 29] -> [ 5] - // [ 6, 14, 22, 30] -> [ 6] - // [ 7, 15, 23, 31] -> [ 7] - for (short ii = 0; ii < DV4; ii += NL) { - if (NE > 1) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); - } - - if (NE > 2) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); - } - - if (NE > 4) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); - } - - if (NE > 8) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); - } - - if (NE > 16) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // store results to shared memory - for (short i = tiisg; i < DV4; i += NL) { - sr4[i] = lo[i/NL]; - } + so4 -= tiisg; threadgroup_barrier(mem_flags::mem_threadgroup); // parallel reduce - for (short r = nsg/2; r > 0; r >>= 1) { + for (short r = NSG/2; r > 0; r >>= 1) { if (sgitg < r) { const float S0 = ss[ 0]; const float S1 = ss[r*(SH/2) + 0]; @@ -5120,7 +5360,7 @@ kernel void kernel_flash_attn_ext_vec( // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short i = tiisg; i < DV4; i += NW) { - sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; + so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1; } } @@ -5133,21 +5373,73 @@ kernel void kernel_flash_attn_ext_vec( const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1; device float4 * dst4 = (device float4 *) dst; - device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results + device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f; + const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { - dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S; + dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S; } // store S and M - if (nwg > 1 && tiisg == 0) { - dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0]; - dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1]; + if (NWG > 1) { + if (tiisg == 0) { + dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0]; + dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1]; + } } } + +#undef NWG +#undef NS10 +#undef NS20 +} + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // value type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext_vec & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_vec_nsg) { + // note: disabled cases to reduce library load time + case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 2: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 8: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem @@ -5163,111 +5455,120 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES -kernel void kernel_flash_attn_ext_reduce( - constant ggml_metal_kargs_flash_attn_ext_reduce & args, +constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; +constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; + +kernel void kernel_flash_attn_ext_vec_reduce( + constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args, device const char * htmp, device char * dst, uint tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define NWG (FC_flash_attn_ext_vec_reduce_NWG) +#define DV (FC_flash_attn_ext_vec_reduce_DV) + const uint64_t rid = tgpig; - const short nwg = 32; const short iwg = tiisg; - const short DV = args.ne20; - const short DV4 = DV/4; - device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg; - device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg; - device float4 * dst4 = (device float4 *) dst + rid*DV4; + device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG; - float S = ss[rid*(2*nwg) + 2*iwg + 0]; - float M = ss[rid*(2*nwg) + 2*iwg + 1]; + float S = ss[rid*(2*NWG) + 2*iwg + 0]; + float M = ss[rid*(2*NWG) + 2*iwg + 1]; const float m = simd_max(M); const float ms = exp(M - m); S = 1.0f/simd_sum(S*ms); - for (int i = sgitg; i < DV4; i += nwg) { - const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms); + const short DV4 = DV/4; + + device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG; + device float4 * dst4 = (device float4 *) dst + rid*DV4; + + for (short i = sgitg; i < DV4; i += NWG) { + const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms); if (iwg == 0) { dst4[i] = v*S; } } + +#undef NWG +#undef DV } template @@ -7395,7 +7696,7 @@ kernel void kernel_set_rows_f( const int32_t i10 = i01; const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; - device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { @@ -7494,18 +7795,20 @@ kernel void kernel_mul_mm( #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(4) for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(8) for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f4cd8f65cf8be..348b4ff5b44bf 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6490,8 +6490,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v)); } - for (int hsk : { 40, 64, 80, 128, 192, 256, 576 }) { - for (int hsv : { 40, 64, 80, 128, 192, 256, 512 }) { + for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) { + for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) { if (hsk != 192 && hsk != 576 && hsk != hsv) continue; if (hsk == 192 && (hsv != 128 && hsv != 192)) continue; if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA