Skip to content

Commit c4df49a

Browse files
authored
kleidiai: generalize compute_forward_kv_cache to compute_forward_fp16 (#15817)
1 parent 3c3635d commit c4df49a

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
154154
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
155155
return compute_forward_q4_0(params, dst);
156156
} else if (dst->src[0]->type == GGML_TYPE_F16) {
157-
return compute_forward_kv_cache(params, dst);
157+
return compute_forward_fp16(params, dst);
158158
}
159159
} else if (dst->op == GGML_OP_GET_ROWS) {
160160
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
@@ -164,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
164164
return false;
165165
}
166166

167-
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
167+
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
168168
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
169169

170170
const ggml_tensor * src0 = dst->src[0];
@@ -534,13 +534,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
534534
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
535535
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
536536
}
537-
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
538-
op->src[0]->op == GGML_OP_VIEW &&
539-
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
540-
op->src[1]->ne[1] > 1) {
541-
if ((op->src[0]->nb[0] != 2) ||
542-
(op->src[1]->nb[0] != 4) ||
543-
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
537+
else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
538+
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
544539
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
545540
return nullptr;
546541
}

0 commit comments

Comments
 (0)