Skip to content

Commit 9b95697

Browse files
committed
only fuse ncols_dst=1
1 parent 8371c0c commit 9b95697

38 files changed

+1024
-1246
lines changed

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ if (CUDAToolkit_FOUND)
5050
list(APPEND GGML_SOURCES_CUDA ${SRCS})
5151
file(GLOB SRCS "template-instances/mmq*.cu")
5252
list(APPEND GGML_SOURCES_CUDA ${SRCS})
53-
file(GLOB SRCS "template-instances/mmvq*.cu")
54-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
55-
file(GLOB SRCS "template-instances/mmvf*.cu")
56-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
5753
file(GLOB SRCS "template-instances/mmf*.cu")
5854
list(APPEND GGML_SOURCES_CUDA ${SRCS})
5955

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2106,10 +2106,16 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
21062106
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
21072107
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
21082108

2109-
if (tensor->op == GGML_OP_MUL_MAT_ID) {
2110-
use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne[2] == 1;
2109+
//we only support fusion for ncols_dst = 1
2110+
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2111+
return false;
21112112
}
21122113

2114+
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2115+
return false;
2116+
}
2117+
2118+
21132119
return use_mul_mat_vec_f;
21142120
}
21152121

@@ -2125,8 +2131,13 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
21252131
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
21262132
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
21272133

2128-
if (tensor->op == GGML_OP_MUL_MAT_ID) {
2129-
use_mul_mat_vec_q = use_mul_mat_vec_q && dst->ne[2] == 1;
2134+
//we only support fusion for ncols_dst = 1
2135+
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2136+
return false;
2137+
}
2138+
2139+
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2140+
return false;
21302141
}
21312142

21322143
return use_mul_mat_vec_q;
@@ -2979,12 +2990,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29792990
}
29802991
}
29812992

2982-
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
2993+
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
29832994
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
29842995

29852996
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
2986-
2987-
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
2997+
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
29882998

29892999
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
29903000
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {

0 commit comments

Comments
 (0)