diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index ce0a3e1285eb0..a21b1ac25ed09 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -70,12 +70,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_OPENMP) find_package(OpenMP) if (OpenMP_FOUND) - set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) else() - set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") message(WARNING "OpenMP not found") endif() endif() @@ -458,9 +456,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -march=z16) elseif (${S390X_M} MATCHES "9175|9176") # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version. - # binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15. message(STATUS "z17 target") - list(APPEND ARCH_FLAGS -march=arch15) + list(APPEND ARCH_FLAGS -march=z17) else() message(STATUS "Unknown target") message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.") @@ -489,7 +486,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_REPACK) endif() - if (GGML_CPU_KLEIDIAI) + if (GGML_CPU_KLEIDIAI AND GGML_CPU_AARCH64 AND (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")) message(STATUS "Using KleidiAI optimized kernels if applicable") # Disable the KleidiAI tests @@ -497,9 +494,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.11.0") + set(KLEIDIAI_COMMIT_TAG "v1.9.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2") + set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index ddd29d002d1ca..0b7dfa12f9a9e 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -22,94 +22,12 @@ #include "kai_common.h" -#include "simd-mappings.h" - #include "kernels.h" #define NELEMS(x) sizeof(x) / sizeof(*x) -static const size_t INT4_PER_BYTE = 2; -static const size_t INT4_BITS = 4; -static const int Q4_0_ZERO_POINT = 8; -const size_t INT4_PER_UINT16 = 4; - -static void dequantize_row_qsi4c32pscalef16( - const void *packed_data, - int32_t row_idx, - int64_t nc, - float *out, - size_t nr_pack, - size_t packed_row_stride, - size_t kr, - size_t bl, - size_t num_bytes_multiplier -) { - size_t group_idx = row_idx / nr_pack; - size_t row_in_group = row_idx % nr_pack; - const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; - size_t num_blocks = nc / bl; - const uint8_t *block_ptr = packed_group; - - for (size_t b = 0; b < num_blocks; ++b) { - uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier)); - float scale = GGML_CPU_FP16_TO_FP32(scale_f16); - - const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier; - size_t num_segments = bl / kr; - size_t num_bytes_per_segment = kr / INT4_PER_BYTE; - - for (size_t s = 0; s < num_segments; ++s) { - const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment; - const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment; - for (size_t k = 0; k < num_bytes_per_segment; ++k) { - uint8_t byte = qbytes[k] ^ 0x88; - int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT; - int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT; - out[b * bl + s * num_bytes_per_segment + k] = x0 * scale; - out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale; - } - } - block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment; - } -} - -static void dequantize_row_qsi4c32ps1s0scalef16( - const void *packed_data, - int32_t row_idx, - int64_t k, - float *out, - size_t nr, - size_t packed_row_stride, - size_t kr, - size_t bl, - size_t num_bytes_multiplier -) { - const size_t num_blocks = k / bl; - const size_t bl4 = bl / INT4_PER_UINT16; - - size_t group_idx = row_idx / nr; - size_t row_in_group = row_idx % nr; - - const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; - const uint16_t *qdata = (const uint16_t *)packed_group; - const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier)); - - for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) { - uint16_t scale_f16 = scales[row_in_group + block_idx * nr]; - float scale = GGML_CPU_FP16_TO_FP32(scale_f16); - - for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) { - uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group]; - - for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) { - int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT; - out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale; - } - } - } - GGML_UNUSED(kr); -} - +// Check if any ARM features are available +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) static ggml_kleidiai_kernels gemm_gemv_kernels[] = { #if defined(__ARM_FEATURE_SME) { @@ -148,10 +66,8 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -194,10 +110,8 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .packed_stride = */ NULL, - /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .to_float = */ NULL, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -243,10 +157,8 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -291,10 +203,8 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -340,10 +250,8 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -388,10 +296,8 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -401,10 +307,15 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { #endif #endif }; +#else +// Fallback for when no ARM features are available - provide an empty array +static ggml_kleidiai_kernels gemm_gemv_kernels[1] = {}; +#endif ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { ggml_kleidiai_kernels * kernel = nullptr; +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && @@ -416,6 +327,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c } } } +#endif return kernel; } @@ -423,12 +335,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { kernels = &gemm_gemv_kernels[i]; break; } } +#endif return kernels; } diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index dff8fa244a1c9..fafe45e6c5c51 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -40,17 +40,6 @@ struct ggml_kleidiai_context { ggml_kleidiai_kernels * kernels; } static ctx = { CPU_FEATURE_NONE, NULL }; -static const char* cpu_feature_to_string(cpu_feature f) { - switch (f) { - case CPU_FEATURE_NONE: return "NONE"; - case CPU_FEATURE_DOTPROD: return "DOTPROD"; - case CPU_FEATURE_I8MM: return "I8MM"; - case CPU_FEATURE_SVE: return "SVE"; - case CPU_FEATURE_SME: return "SME"; - default: return "UNKNOWN"; - } -} - static void init_kleidiai_context(void) { ggml_critical_section_start(); @@ -73,11 +62,6 @@ static void init_kleidiai_context(void) { ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; } ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features); -#ifndef NDEBUG - if (ctx.kernels) { - GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu)); - } -#endif } ggml_critical_section_end(); } @@ -118,9 +102,6 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1 class tensor_traits : public ggml::cpu::tensor_traits { bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - if (op->op != GGML_OP_MUL_MAT) { - return false; - } ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); GGML_ASSERT(kernels); kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; @@ -154,10 +135,6 @@ class tensor_traits : public ggml::cpu::tensor_traits { } else if (dst->src[0]->type == GGML_TYPE_F16) { return compute_forward_kv_cache(params, dst); } - } else if (dst->op == GGML_OP_GET_ROWS) { - if (dst->src[0]->type == GGML_TYPE_Q4_0) { - return compute_forward_get_rows(params, dst); - } } return false; } @@ -259,10 +236,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t m_start = 0; const int64_t n_step = static_cast(kernel->get_n_step()); - int64_t num_threads = KAI_MIN(n / n_step, nth); - if (num_threads <= 0) { - num_threads = 1; - } + const int64_t num_threads = KAI_MIN(n / n_step, nth); if (ith < num_threads) { const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); @@ -296,8 +270,6 @@ class tensor_traits : public ggml::cpu::tensor_traits { } bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); - const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -312,8 +284,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_ASSERT(kernel); const int ith = params->ith; - const int nth_raw = params->nth; - const int nth = nth_raw > 0 ? nth_raw : 1; + const int nth = params->nth; const size_t k = ne00; const size_t m = ne11; @@ -331,12 +302,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); const size_t n_start = ith * num_n_per_thread; - size_t n_to_process = 0; - if (n_start < n) { - n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; - } + size_t n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; } // Calculate number of columns to be processed per thread @@ -368,57 +336,14 @@ class tensor_traits : public ggml::cpu::tensor_traits { const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); - if (n_to_process > 0) { - variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, - sizeof(float), -FLT_MAX, FLT_MAX); - } - - return true; - } - - bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); - GGML_ASSERT(ctx.kernels); - - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - rhs_packing_info * rhs_info = &ctx.kernels->rhs_info; - kernel_info * kernel = &ctx.kernels->gemm; - - const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); - - const size_t block_rows = kernel->get_nr(); - const size_t kr = kernel->get_kr(); - - const size_t num_bytes_multiplier = sizeof(uint16_t); - const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0); - - const int ith = params->ith; - const int nth = params->nth; - - const int dr = (nr + nth - 1) / nth; - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int64_t i = ir0; i < ir1; ++i) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - int64_t row_idx = ((const int32_t *)src1->data)[i]; - GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]); - - float *out = (float *)((char *)dst->data + i * nb1); - rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier); - } + variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + sizeof(float), -FLT_MAX, FLT_MAX); return true; } public: int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { - GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0); GGML_ASSERT(ctx.kernels); const size_t n = tensor->ne[1]; const size_t k = tensor->ne[0]; @@ -426,12 +351,17 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t kr = ctx.kernels->gemm.get_kr(); size_t sr = ctx.kernels->gemm.get_sr(); +#ifndef NDEBUG + const size_t repacked_size = variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); + GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); +#endif struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); return 0; + GGML_UNUSED(data_size); } }; @@ -445,8 +375,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); - return GGML_STATUS_SUCCESS; GGML_UNUSED(buffer); + return GGML_STATUS_SUCCESS; } static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, @@ -488,35 +418,18 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b GGML_UNUSED(buft); } -static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { - GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0); - GGML_ASSERT(ctx.kernels); - - const size_t n = tensor->ne[1]; - const size_t k = tensor->ne[0]; - const size_t nr = ctx.kernels->gemm.get_nr(); - const size_t kr = ctx.kernels->gemm.get_kr(); - - return variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); - - GGML_UNUSED(buft); -} - namespace ggml::cpu::kleidiai { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) && + if (op->op == GGML_OP_MUL_MAT && op->src[0]->type == GGML_TYPE_Q4_0 && op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { - if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) { - return false; - } if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } - if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) && + if (op->src[1]->type == GGML_TYPE_F32 && ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { return true; } @@ -525,7 +438,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { - if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) { + if (op->op == GGML_OP_MUL_MAT) { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } @@ -556,7 +469,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) { /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment, /* .get_max_size = */ nullptr, // defaults to SIZE_MAX - /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size, + /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes /* .is_host = */ nullptr, }, /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),