@@ -6059,29 +6059,29 @@ void ggml_compute_forward_im2col_back_f32(
60596059 }
60606060}
60616061
6062- static void ggml_call_mul_mat (
6063- const ggml_compute_params * params,
6064- int64_t m, int64_t n, int64_t k,
6065- void * a, void * b, void * c) {
6066-
6062+ static void ggml_call_mul_mat (ggml_type T, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6063+ void * a, void * b, void * c) {
6064+ const ggml_type_traits * traits = ggml_get_type_traits (T);
60676065 struct ggml_tensor src1 = {};
6066+ src1.type = T;
60686067 src1.ne [0 ] = k;
60696068 src1.ne [1 ] = m;
60706069 src1.ne [2 ] = 1 ;
60716070 src1.ne [3 ] = 1 ;
6072- src1.nb [0 ] = sizeof ( float ) ;
6073- src1.nb [1 ] = k * sizeof ( float ) ;
6071+ src1.nb [0 ] = traits-> type_size ;
6072+ src1.nb [1 ] = k * traits-> type_size ;
60746073 src1.nb [2 ] = src1.nb [1 ];
60756074 src1.nb [3 ] = src1.nb [2 ];
60766075 src1.data = a;
60776076
60786077 struct ggml_tensor src0 = {};
6078+ src0.type = T;
60796079 src0.ne [0 ] = k;
60806080 src0.ne [1 ] = n;
60816081 src0.ne [2 ] = 1 ;
60826082 src0.ne [3 ] = 1 ;
6083- src0.nb [0 ] = sizeof ( float ) ;
6084- src0.nb [1 ] = k * sizeof ( float ) ;
6083+ src0.nb [0 ] = traits-> type_size ;
6084+ src0.nb [1 ] = k * traits-> type_size ;
60856085 src0.nb [2 ] = src0.nb [1 ];
60866086 src0.nb [3 ] = src0.nb [2 ];
60876087 src0.data = b;
@@ -6102,17 +6102,18 @@ static void ggml_call_mul_mat(
61026102 ggml_compute_forward_mul_mat (params, &dst);
61036103}
61046104
6105-
61066105// ggml_compute_forward_conv_2d
61076106
6108- static void ggml_compute_forward_conv_2d_f32 (
6109- const ggml_compute_params * params,
6110- const ggml_tensor * kernel , // [KW, KH, IC, OC] - fp32
6111- const ggml_tensor * src , // [W, H, C , N]
6112- ggml_tensor * dst) { // [OW, OH, OC, N]
6107+ static void ggml_compute_forward_conv_2d_impl ( const ggml_compute_params * params,
6108+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6109+ const ggml_tensor * src , // [W, H, C, N]
6110+ ggml_tensor * dst , // [OW, OH, OC , N]
6111+ ggml_type kernel_type) {
61136112
61146113 GGML_ASSERT (ggml_is_contiguous (kernel));
6115- GGML_ASSERT (kernel->type == GGML_TYPE_F32);
6114+ GGML_ASSERT (kernel->type == kernel_type);
6115+
6116+ const ggml_type_traits * traits = ggml_get_type_traits (kernel_type);
61166117
61176118 const int32_t stride_x = dst->op_params [0 ];
61186119 const int32_t stride_y = dst->op_params [1 ];
@@ -6133,20 +6134,20 @@ static void ggml_compute_forward_conv_2d_f32(
61336134 const int64_t dst_h = dst->ne [1 ];
61346135
61356136 float * src_data = (float *) src->data ;
6136- float * knl_data = ( float *) kernel->data ;
6137+ void * knl_data = kernel->data ;
61376138 float * dst_data = (float *) dst->data ;
61386139
61396140 const int64_t knl_n = knl_w * knl_h * c_in;
61406141 const int64_t patch_total = dst->ne [3 ] * dst_w * dst_h;
61416142
6142- const int64_t space_per_patch = knl_n * sizeof ( float ) + c_out * sizeof (float );
6143+ const int64_t space_per_patch = knl_n * traits-> type_size + c_out * sizeof (float );
61436144 const int64_t batch_size = params->wsize / space_per_patch;
61446145 const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8 ) * 8 : batch_size;
61456146 const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
61466147
61476148 GGML_ASSERT (patches_per_batch > 0 && batch_size >= 1 );
61486149
6149- float * tmp = ( float *) params->wdata ;
6150+ void * tmp = params->wdata ;
61506151
61516152 for (int64_t batch_i = 0 ; batch_i < batch_n; ++batch_i) {
61526153
@@ -6166,7 +6167,7 @@ static void ggml_compute_forward_conv_2d_f32(
61666167 const int64_t src_y = p % dst_w;
61676168
61686169 float * src_base = (float *)((char *)src_data + batch_n * src->nb [3 ]);
6169- float * dst_row = tmp + (p % patches_per_batch) * knl_n;
6170+ char * dst_row = ( char *) tmp + (p % patches_per_batch) * knl_n * traits-> type_size ;
61706171
61716172 for (int64_t ic = 0 ; ic < c_in; ++ic) {
61726173 for (int64_t ky = 0 ; ky < knl_h; ++ky) {
@@ -6176,11 +6177,19 @@ static void ggml_compute_forward_conv_2d_f32(
61766177
61776178 int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
61786179
6180+ float src_val;
61796181 if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6180- dst_row[dst_idx] = 0 .0f ;
6182+ src_val = 0 .0f ;
61816183 } else {
61826184 float * src_ptr = (float *)((char *)src_base + sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6183- dst_row[dst_idx] = *src_ptr;
6185+ src_val = *src_ptr;
6186+ }
6187+
6188+ char * element_ptr = dst_row + dst_idx * traits->type_size ;
6189+ if (kernel_type == GGML_TYPE_F32) {
6190+ *(float *) element_ptr = src_val;
6191+ } else if (kernel_type == GGML_TYPE_F16) {
6192+ *(ggml_fp16_t *) element_ptr = GGML_FP32_TO_FP16 (src_val);
61846193 }
61856194 }
61866195 }
@@ -6189,11 +6198,10 @@ static void ggml_compute_forward_conv_2d_f32(
61896198
61906199 ggml_barrier (params->threadpool );
61916200
6192- float * gemm_output = tmp + patches_per_batch * knl_n;
6201+ float * gemm_output = ( float *) (( char *) tmp + patches_per_batch * knl_n * traits-> type_size ) ;
61936202
61946203 // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6195- ggml_call_mul_mat (params, patch_n, c_out, knl_n,
6196- tmp, knl_data, gemm_output);
6204+ ggml_call_mul_mat (kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
61976205
61986206 ggml_barrier (params->threadpool );
61996207
@@ -6211,7 +6219,6 @@ static void ggml_compute_forward_conv_2d_f32(
62116219
62126220 for (int64_t oc = 0 ; oc < c_out; ++oc) {
62136221 const float value = gemm_output[i * c_out + oc];
6214- // Write to WHCN layout: dst[w, h, c, n]
62156222 float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb [0 ] + dst_y * dst->nb [1 ] + oc * dst->nb [2 ] + batch_n * dst->nb [3 ]);
62166223 *dst_ptr = value;
62176224 }
@@ -6226,11 +6233,7 @@ void ggml_compute_forward_conv_2d(
62266233 const ggml_tensor * src0 = dst->src [0 ];
62276234 const ggml_tensor * src1 = dst->src [1 ];
62286235
6229- if (src0->type == GGML_TYPE_F16) {
6230- GGML_ASSERT (false && " F16 not supported yet" );
6231- } else {
6232- ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6233- }
6236+ ggml_compute_forward_conv_2d_impl (params, src0, src1, dst, src0->type );
62346237}
62356238
62366239// ggml_compute_forward_conv_transpose_2d
0 commit comments