Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
#define N_R0_Q5_1 4
#define N_SG_Q5_1 2

#define N_R0_Q8_0 4
#define N_SG_Q8_0 2
#define N_R0_Q8_0 2
#define N_SG_Q8_0 4

#define N_R0_MXFP4 2
#define N_SG_MXFP4 2
Expand Down Expand Up @@ -68,6 +68,11 @@
#define N_R0_IQ4_XS 2
#define N_SG_IQ4_XS 2

// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300

// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
Expand Down Expand Up @@ -236,9 +241,11 @@ typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
int32_t ns10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ns20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
Expand All @@ -258,10 +265,43 @@ typedef struct {
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;

typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
int32_t ns10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ns20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
int32_t ne1;
int32_t ne2;
int32_t ne3;
float scale;
float max_bias;
float m0;
float m1;
int32_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext_vec;

typedef struct {
int32_t nrows;
int32_t ne20;
} ggml_metal_kargs_flash_attn_ext_reduce;
} ggml_metal_kargs_flash_attn_ext_vec_reduce;

typedef struct {
int32_t ne00;
Expand Down
Loading
Loading