@@ -2106,10 +2106,16 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
21062106    const  int  cc      = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
21072107    use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf (src0->type , cc, src0->ne , is_mul_mat_id ? src1->ne [2 ] : src1->ne [1 ]);
21082108
2109-     if  (tensor->op  == GGML_OP_MUL_MAT_ID) {
2110-         use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne [2 ] == 1 ;
2109+     // we only support fusion for ncols_dst = 1
2110+     if  (tensor->op  == GGML_OP_MUL_MAT && dst->ne [1 ] != 1 ) {
2111+         return  false ;
21112112    }
21122113
2114+     if  (tensor->op  == GGML_OP_MUL_MAT_ID && dst->ne [2 ] != 1 ) {
2115+         return  false ;
2116+     }
2117+ 
2118+ 
21132119    return  use_mul_mat_vec_f;
21142120}
21152121
@@ -2125,8 +2131,13 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
21252131    bool  use_mul_mat_vec_q = ggml_is_quantized (src0->type ) && !bad_padding_clear && src1->type  == GGML_TYPE_F32 &&
21262132                             dst->type  == GGML_TYPE_F32 && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
21272133
2128-     if  (tensor->op  == GGML_OP_MUL_MAT_ID) {
2129-         use_mul_mat_vec_q = use_mul_mat_vec_q && dst->ne [2 ] == 1 ;
2134+     // we only support fusion for ncols_dst = 1
2135+     if  (tensor->op  == GGML_OP_MUL_MAT && dst->ne [1 ] != 1 ) {
2136+         return  false ;
2137+     }
2138+ 
2139+     if  (tensor->op  == GGML_OP_MUL_MAT_ID && dst->ne [2 ] != 1 ) {
2140+         return  false ;
21302141    }
21312142
21322143    return  use_mul_mat_vec_q;
@@ -2979,12 +2990,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29792990        }
29802991    }
29812992
2982-     std::initializer_list<enum  ggml_op> mul_mat_bias_glu_ops    = { GGML_OP_MUL_MAT, GGML_OP_ADD,    GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_GLU };
2993+     std::initializer_list<enum  ggml_op> mul_mat_bias_glu_ops    = { GGML_OP_MUL_MAT,     GGML_OP_ADD,    GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_GLU };
29832994    std::initializer_list<enum  ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
29842995
29852996    std::initializer_list<enum  ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
2986- 
2987-     std::initializer_list<enum  ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
2997+     std::initializer_list<enum  ggml_op> mul_mat_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_MUL_MAT,    GGML_OP_GLU };
29882998
29892999    if  (ops.size () == 5  && (ggml_can_fuse_subgraph (cgraph, node_idx, ops, {node_idx + 4 }) ||
29903000                            ggml_can_fuse_subgraph (cgraph, node_idx, ops, {node_idx + 4 }))) {
0 commit comments