From 8a589317b6bfe60d732a1a0d1b9bb153145f9fd2 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 2 Sep 2025 22:47:41 -0400 Subject: [PATCH 1/7] Add implicit GEMM convolution operation for 2D tensors in CUDA --- ggml/include/ggml.h | 12 + ggml/src/ggml-cuda/conv2d-implicit.cu | 394 +++++++++++++++++++++++++ ggml/src/ggml-cuda/conv2d-implicit.cuh | 5 + ggml/src/ggml.c | 39 +++ 4 files changed, 450 insertions(+) create mode 100644 ggml/src/ggml-cuda/conv2d-implicit.cu create mode 100644 ggml/src/ggml-cuda/conv2d-implicit.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 7e9c3c8c7a096..d37a0a91ff35b 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -512,6 +512,7 @@ extern "C" { GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, GGML_OP_CONV_2D, + GGML_OP_CONV_2D_IMPLICIT, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, @@ -1941,6 +1942,17 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_2d_implicitgemm( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] + struct ggml_tensor * b, // input data [W, H, C, N] + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_3d( struct ggml_context * ctx, struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu new file mode 100644 index 0000000000000..d1b1dc7d3cacc --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -0,0 +1,394 @@ +#include "conv2d-implicit.cuh" +#include "convert.cuh" + +struct conv_params { + const int64_t IW, IH; + const int64_t OW, OH; + const int64_t KW, KH; + const int64_t ST_X, ST_Y; + const int64_t PD_X, PD_Y; + const int64_t DL_X, DL_Y; + const int64_t IC, OC; + const int64_t B; + const int64_t TOTAL; +}; + +struct kernel_bounds { + int64_t y_min, y_max; + int64_t x_min, x_max; +}; + +__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { + return (a > b) ? a : b; +} + +__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { + return (a < b) ? a : b; +} + +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { + kernel_bounds bounds; + bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + return bounds; +} + +__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { + return out_coord * stride + kern_coord * dilation - padding; +} + +struct whcn_layout { + __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; + } + + __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; + } + + __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; + } + + __device__ static void unpack_indices(int64_t global_idx, + const conv_params & P, + int64_t & n, + int64_t & c, + int64_t & out_y, + int64_t & out_x) { + out_x = global_idx % P.OW; + out_y = (global_idx / P.OW) % P.OH; + c = (global_idx / (P.OW * P.OH)) % P.OC; + n = global_idx / (P.OW * P.OH * P.OC); + } +}; + +template +static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const conv_params P) { + + __shared__ __align__(16 * 1024) char smem[24 * 1024]; + T *smemweight = reinterpret_cast(smem); + float *smeminput = reinterpret_cast(smem + 16 * 1024); + + int tx = threadIdx.x; + int bx = blockIdx.x; + int by = blockIdx.y; + + // Warp tile + const int lane_id = threadIdx.x % 32; + const int warp_id = threadIdx.x / 32; + const int mma_tid_x = (lane_id / 2) % 8; + const int mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2); + // lds addr + int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; + int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + + int x = bx * 128 + input_lds_addr; + int y = by * 128 + weight_lds_addr; + int z = blockIdx.z; + + T weight_ldg_reg[4]; + float input_ldg_reg[4]; + + int posh_ori[4]; + int posw_ori[4]; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p; + posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q; + } + + int inOffset = z * param.c * param.h * param.w; + int weiOffset = (by * 128 + tx / 8 * 4) * param.c * param.r * param.s; + int inChannelOffset = param.h * param.w; + int weightChannelOffset = param.r * param.s; + int weightKOffset = param.c * param.r * param.s; + + // sts addr + int weight_sts_addr = (tx % 8) * 132 + + (tx / 8) * 4; + int input_sts_addr = (tx / 32) * 128 + (tx % 32); + + int write_flag = 1; + T weight_frag[2][8]; + float input_frag[2][8]; + float output_frag[8][8]; +#pragma unroll + for (int i = 0; i < 8; ++i) + { +#pragma unroll + for (int j = 0; j < 8; ++j) + { + output_frag[i][j] = 0; + } + } +// ldg +#pragma unroll + for (int i = 0; i < 4; ++i) + { + if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) + { + weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset]; + } + else + { + weight_ldg_reg[i] = (T)0.f; + } + } + int curC = (tx / 32) / (param.r * param.s); // channel offset + int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset + int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset +#pragma unroll + for (int i = 0; i < 4; ++i) + { + int curH = posh_ori[i] + curR; // input h + int curW = posw_ori[i] + curS; // input w + int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + { + input_ldg_reg[i] = input[inOffset + inOffsetTmp]; + } + else + { + input_ldg_reg[i] = 0.0; + } + } + // sts + for (int i = 0; i < 4; ++i) + { + smemweight[weight_sts_addr + i] = weight_ldg_reg[i]; + } + for (int i = 0; i < 4; ++i) + { + smeminput[input_sts_addr + i * 32] = input_ldg_reg[i]; + } + + __syncthreads(); + // lds +#pragma unroll + for (int i = 0; i < 4; ++i) + { + weight_frag[0][i] = smemweight[weight_lds_addr + i]; + weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; + } +#pragma unroll + for (int i = 0; i < 4; ++i) + { + input_frag[0][i] = smeminput[input_lds_addr + i]; + input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; + } + for (int crs = 0; crs < param.r * param.s * param.c; crs += 8) + { + // ldg + int weiOffsetTmp = crs + 8 + tx % 8; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) + { + weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; + } + else + { + weight_ldg_reg[i] = (T)0.f; + } + } + curC = (crs + 8 + tx / 32) / (param.r * param.s); // channel offset + curR = ((crs + 8 + tx / 32) % (param.r * param.s)) / param.s; // kernel r offset + curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + +#pragma unroll + for (int i = 0; i < 4; ++i) + { + int curH = posh_ori[i] + curR; // input h + int curW = posw_ori[i] + curS; // input w + int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + { + input_ldg_reg[i] = input[inOffset + inOffsetTmp]; + } + else + { + input_ldg_reg[i] = 0.f; + } + } + int load_flag = write_flag ^ 1; +#pragma unroll + for (int subcrs = 0; subcrs < 8 - 1; ++subcrs) + { +#pragma unroll + for (int i = 0; i < 4; ++i) + { + weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; + weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; + } +#pragma unroll + for (int i = 0; i < 4; ++i) + { + input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; + input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; + } + +#pragma unroll + for (int i = 0; i < 8; ++i) + { +#pragma unroll + for (int j = 0; j < 8; ++j) + { + output_frag[i][j] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; + } + } + } + // sts + for (int i = 0; i < 4; ++i) + { + smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i]; + } + for (int i = 0; i < 4; ++i) + { + smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i]; + } + __syncthreads(); + write_flag ^= 1; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i]; + weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16]; + } +#pragma unroll + for (int i = 0; i < 4; ++i) + { + input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i]; + input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32]; + } +#pragma unroll + for (int i = 0; i < 8; ++i) + { +#pragma unroll + for (int j = 0; j < 8; ++j) + { + output_frag[i][j] += ggml_cuda_cast(weight_frag[1][i]) * input_frag[1][j]; + } + } + } + + // reuse smem + float *smemoutput = reinterpret_cast(smem); + // float *smembias = reinterpret_cast(smem + 16 * 1024); + + // bias ldg/sts + // if (tx < 128) + // { + // smembias[tx] = param.bias[by * 128 + tx]; + // } + + uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4; + uint32_t output_lds_addr = warp_id * 512 + lane_id; + // uint32_t bias_lds_addr = warp_id / 2 * 32; + + uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32; + uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id; + +#pragma unroll + for (int i = 0; i < 2; ++i) + { +#pragma unroll + for (int j = 0; j < 2; ++j) + { + __syncthreads(); + +#pragma unroll + for (int subi = 0; subi < 4; ++subi) + { +#pragma unroll + for (int subj = 0; subj < 4; ++subj) + { + // output sts + smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; + } + } + __syncthreads(); + +#pragma unroll + for (int subk = 0; subk < 16; ++subk) + { + int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; + if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + // output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk]; + output[outOffset] = smemoutput[output_lds_addr + subk * 32]; + } + } + } + +} + +template +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); +} + +static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); +} + +static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); +} + +void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + float * K_D = (float *) kernel->data; + const float * X_D = (const float *) input->data; + float * Y_D = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + + // same number of input channels + GGML_ASSERT(input->ne[2] == kernel->ne[2]); + + cudaStream_t st = ctx.stream(); + + const int32_t * p = (const int32_t *) dst->op_params; + const int ST_X = p[0]; // stride_x + const int ST_Y = p[1]; // stride_y + const int PD_X = p[2]; // padding_x + const int PD_Y = p[3]; // padding_y + const int DL_X = p[4]; // dilation_x + const int DL_Y = p[5]; // dilation_y + + // No cwhn + GGML_ASSERT(p[6] == false); + + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h + const int OW = dst->ne[0]; // output_w + const int OH = dst->ne[1]; // output_h + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles + const int B = input->ne[3]; // n_batches + + const int64_t total = B * OC * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + + if (kernel->type == GGML_TYPE_F16) { + conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + } else { + conv2d_implicit_cuda_f32(X_D, K_D, Y_D, params, st); + } +} diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh new file mode 100644 index 0000000000000..46161feb3c9fd --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -0,0 +1,5 @@ +#pragma once +#include "common.cuh" + +#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 +void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d76ea58f789e2..4e0fd672bd433 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4482,6 +4482,45 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } + +// ggml_conv_2d_implicitgemm + +struct ggml_tensor * ggml_conv_2d_implicitgemm( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] + struct ggml_tensor * b, // input data [W, H, C, N] + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1) {// dilation dimension 1 + + GGML_ASSERT(a->ne[2] == b->ne[2]); + //GGML_ASSERT(a->type == b->type); + + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); + ne[2] = a->ne[3]; + ne[3] = b->ne[3]; + + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); + + ggml_set_op_params_i32(result, 0, s0); + ggml_set_op_params_i32(result, 1, s1); + ggml_set_op_params_i32(result, 2, p0); + ggml_set_op_params_i32(result, 3, p1); + ggml_set_op_params_i32(result, 4, d0); + ggml_set_op_params_i32(result, 5, d1); + + result->op = GGML_OP_CONV_2D_IMPLICIT; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_3d struct ggml_tensor * ggml_conv_3d( From 4d772873b94641386a48f923cead6aca618e0d8e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 3 Sep 2025 11:29:14 -0400 Subject: [PATCH 2/7] Add implicit convolution support for 2D tensors in CPU and CUDA implementations --- ggml/src/ggml-cpu/ggml-cpu.c | 6 ++ ggml/src/ggml-cuda/conv2d-implicit.cu | 118 +++++++++----------------- ggml/src/ggml-cuda/ggml-cuda.cu | 5 ++ ggml/src/ggml.c | 4 +- 4 files changed, 53 insertions(+), 80 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0d5d3a3440aaf..16d9f0204aa3d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_2d(params, tensor); } break; + case GGML_OP_CONV_2D_IMPLICIT: + { + ggml_compute_forward_conv_2d(params, tensor); + } break; case GGML_OP_CONV_3D: { ggml_compute_forward_conv_3d(params, tensor); @@ -2256,6 +2260,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: case GGML_OP_CONV_2D: + case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: @@ -2778,6 +2783,7 @@ struct ggml_cplan ggml_graph_plan( } } break; case GGML_OP_CONV_2D: + case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_3D: { cur = GGML_IM2COL_WORK_SIZE; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d1b1dc7d3cacc..72f8d30baf4bf 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1,81 +1,33 @@ #include "conv2d-implicit.cuh" #include "convert.cuh" -struct conv_params { - const int64_t IW, IH; - const int64_t OW, OH; - const int64_t KW, KH; - const int64_t ST_X, ST_Y; - const int64_t PD_X, PD_Y; - const int64_t DL_X, DL_Y; - const int64_t IC, OC; - const int64_t B; - const int64_t TOTAL; -}; - -struct kernel_bounds { - int64_t y_min, y_max; - int64_t x_min, x_max; -}; - -__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { - return (a > b) ? a : b; -} - -__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { - return (a < b) ? a : b; -} - -__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { - kernel_bounds bounds; - bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - return bounds; -} - -__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, - int64_t kern_coord, - int64_t stride, - int64_t dilation, - int64_t padding) { - return out_coord * stride + kern_coord * dilation - padding; -} +typedef struct{ + unsigned int n; //batch szie + unsigned int c; //channel number + unsigned int h; //height + unsigned int w; //width + unsigned int k; //number of filters + unsigned int r; //filter height + unsigned int s; //filter width + unsigned int u; //stride height + unsigned int v; //stride width + unsigned int p; //padding height + unsigned int q; //padding width + unsigned int d_h; //dilation height + unsigned int d_w; //dilation width + unsigned int Oh; //output height + unsigned int Ow; //output width +} param_t; -struct whcn_layout { - __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { - return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; - } - - __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { - return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; - } - __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { - return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; - } - __device__ static void unpack_indices(int64_t global_idx, - const conv_params & P, - int64_t & n, - int64_t & c, - int64_t & out_y, - int64_t & out_x) { - out_x = global_idx % P.OW; - out_y = (global_idx / P.OW) % P.OH; - c = (global_idx / (P.OW * P.OH)) % P.OC; - n = global_idx / (P.OW * P.OH * P.OC); - } -}; - -template +template static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, - const conv_params P) { + const param_t ¶m) { - __shared__ __align__(16 * 1024) char smem[24 * 1024]; + extern __shared__ __align__(16 * 1024) char smem[]; T *smemweight = reinterpret_cast(smem); float *smeminput = reinterpret_cast(smem + 16 * 1024); @@ -151,8 +103,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, #pragma unroll for (int i = 0; i < 4; ++i) { - int curH = posh_ori[i] + curR; // input h - int curW = posw_ori[i] + curS; // input w + int curH = posh_ori[i] + curR * param.d_h; // input h + int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) { @@ -210,8 +162,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, #pragma unroll for (int i = 0; i < 4; ++i) { - int curH = posh_ori[i] + curR; // input h - int curW = posw_ori[i] + curS; // input w + int curH = posh_ori[i] + curR * param.d_h; // input h + int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) { @@ -334,16 +286,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } template -static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t &P, cudaStream_t st) { + // const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number + int blocky = (P.k + 127) / 128; // blocky number + int blockz = P.n; // blockz number + int threadx = CUDA_CONV2D_IMPLICT_BLOCK_SIZE; // threadx number per block + int thready = 1; // thready number per block + int threadz = 1; // threadz number per block + dim3 thblock(threadx, thready, threadz); + dim3 grid(blockx, blocky, blockz); + int smem_size = 24 * 1024; + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { +static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t &P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } -static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { +static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t &P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } @@ -384,7 +345,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int B = input->ne[3]; // n_batches const int64_t total = B * OC * OH * OW; - conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + // param_t params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW }; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e06f95f0819ed..0b799fbaf129e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -13,6 +13,7 @@ #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv2d.cuh" +#include "ggml-cuda/conv2d-implicit.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -2455,6 +2456,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CONV_2D: ggml_cuda_op_conv2d(ctx, dst); break; + case GGML_OP_CONV_2D_IMPLICIT: + ggml_cuda_op_conv2d_implicit(ctx, dst); + break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); break; @@ -3560,6 +3564,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } case GGML_OP_IM2COL: case GGML_OP_CONV_2D: + case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4e0fd672bd433..69003dfc5cf6b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1018,7 +1018,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1121,7 +1121,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); From 3877608dc05e86e824ada455b0cf36f759c04192 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 3 Sep 2025 12:45:19 -0400 Subject: [PATCH 3/7] fix passing param as reference --- ggml/src/ggml-cuda/conv2d-implicit.cu | 25 ++++--- ggml/src/ggml.c | 2 + tests/test-backend-ops.cpp | 99 +++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 72f8d30baf4bf..a78720ecc60d1 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -25,9 +25,9 @@ template static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, - const param_t ¶m) { + const param_t param) { - extern __shared__ __align__(16 * 1024) char smem[]; + extern __shared__ __align__(16 * 1024) char smem[]; T *smemweight = reinterpret_cast(smem); float *smeminput = reinterpret_cast(smem + 16 * 1024); @@ -35,6 +35,12 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int bx = blockIdx.x; int by = blockIdx.y; + // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ + // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); + // // printf("param.n=%d\n",param.n); + // } + // __syncthreads(); + // Warp tile const int lane_id = threadIdx.x % 32; const int warp_id = threadIdx.x / 32; @@ -85,6 +91,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } // ldg + // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ + // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); + // } + // __syncthreads(); #pragma unroll for (int i = 0; i < 4; ++i) { @@ -282,11 +292,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } } - } template -static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t &P, cudaStream_t st) { +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { // const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number int blocky = (P.k + 127) / 128; // blocky number @@ -300,11 +309,11 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t &P, cudaStream_t st) { +static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } -static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t &P, cudaStream_t st) { +static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } @@ -343,9 +352,9 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int IC = input->ne[2]; // input_channels const int OC = kernel->ne[3]; // ouptut_chanles const int B = input->ne[3]; // n_batches - + const int64_t total = B * OC * OH * OW; - // param_t params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW }; if (kernel->type == GGML_TYPE_F16) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 69003dfc5cf6b..cdf13a1370aa9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "IM2COL", "IM2COL_BACK", "CONV_2D", + "CONV_2D_IMPLICIT", "CONV_3D", "CONV_2D_DW", "CONV_TRANSPOSE_2D", @@ -1078,6 +1079,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "im2col(x)", "im2col_back(x)", "conv_2d(x)", + "conv_2d_implicit(x)", "conv_3d(x)", "conv_2d_dw(x)", "conv_transpose_2d(x)", diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3a58621094d17..9ab73434fedea 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4116,6 +4116,94 @@ struct test_conv_2d : public test_case { } }; +// CONV_2D_IMPLICIT +struct test_conv_2d_implicit : public test_case { + const std::array ne_input; + const std::array ne_kernel; + const ggml_type type_kernel; + const int stride0; + const int stride1; + const int padding0; + const int padding1; + const int dilation0; + const int dilation1; + // Whether the inputs are contiguous in the channel dim or the width dim + const bool cwhn; + + + + std::string vars() override { + return VARS_TO_STR10(ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn); + } + + double max_nmse_err() override { + return 5e-4; + } + + uint64_t op_flops(ggml_tensor * t) override { + GGML_UNUSED(t); + // Just counting matmul costs: + // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + + int64_t W = ne_input[0]; + int64_t H = ne_input[1]; + int64_t KW = ne_kernel[0]; + int64_t KH = ne_kernel[1]; + int64_t Cin = ne_kernel[2]; + int64_t Cout = ne_kernel[3]; + int64_t N = ne_input[3]; + int64_t OH = calc_conv_output_size(H, KH, stride0, padding0, dilation0); + int64_t OW = calc_conv_output_size(W, KW, stride0, padding0, dilation0); + + int64_t K = Cout; + int64_t CRS = Cin * KH * KW; + int64_t NPQ = N * OH * OW; + + return K * NPQ * (2 * CRS - 1); + } + + test_conv_2d_implicit(std::array ne_input = { 64, 64, 16, 1 }, + std::array ne_kernel = { 3, 3, 1, 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1, + int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) : + ne_input(ne_input), + ne_kernel(ne_kernel), + type_kernel(type_kernel), + stride0(stride0), + stride1(stride1), + padding0(padding0), + padding1(padding1), + dilation0(dilation0), + dilation1(dilation1), + cwhn(cwhn) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); + ggml_set_name(input, "input"); + + ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + + if (cwhn) { + // change memory layout to channel-most-contiguous (CWHN), + // then permute it back so NE matches the original input + input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); + input = ggml_permute(ctx, input, 2, 0, 1, 3); + kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); + kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); + } + + ggml_tensor * out = + ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1); + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_CONV_2D_DW struct test_conv_2d_dw : public test_case { const std::array ne_input; @@ -6454,6 +6542,17 @@ static std::vector> make_test_cases_perf() { } } + for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (auto act_case : cases) { + // Direct CONV_2D + test_cases.emplace_back(new test_conv_2d_implicit( + { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, + { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, + kernel_type, 1, 1, 0, 0, 1, 1, false)); + } + } + + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); From 6d84cbb5abc2f7f3590c9ec3c5b01496543ec593 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 3 Sep 2025 15:45:09 -0400 Subject: [PATCH 4/7] Fix parameter order in conv2d_implicit and add comprehensive test cases for 2D convolution --- ggml/src/ggml-cuda/conv2d-implicit.cu | 2 +- tests/test-backend-ops.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index a78720ecc60d1..4f452ab98bbf9 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -355,7 +355,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t total = B * OC * OH * OW; - param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW }; + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9ab73434fedea..d5e1005d2fe94 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5790,6 +5790,30 @@ static std::vector> make_test_cases_eval() { } } + for (uint32_t s0 : { 1, 3 }) { + for (uint32_t p1 : { 2, 5 }) { + for (uint32_t Cin : { 1, 25 }) { + for (uint32_t Cout : { 1, 12 }) { + for (uint32_t KH : { 1, 2, 3, 11 }) { + for (uint32_t KW : { 1, 2, 3, 11 }) { + for (uint32_t H : { 1, 133 }) { + for (uint32_t W : { 1, 141 }) { + if (calc_conv_output_size(W, KW, s0, p0, d0) > 0 && + calc_conv_output_size(H, KH, s1, p1, d1) > 0) { + for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + test_cases.emplace_back(new test_conv_2d_implicit( + { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, kernel_type, s0, s1, p0, p1, d0, d1, false)); + } + } + } + } + } + } + } + } + } + } + // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) From 5ffe97be9c35169aea1e451426eb53e5430f7d24 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 4 Sep 2025 15:32:29 -0400 Subject: [PATCH 5/7] Fix boundary check in conv2d_implicit_kernel to include channel limits --- ggml/src/ggml-cuda/conv2d-implicit.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 4f452ab98bbf9..d9fabd9657d8c 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -116,7 +116,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) { input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } @@ -175,7 +175,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) { input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } From 4b0f9d571f4166035ee72558e4710b6205893af7 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 5 Sep 2025 08:29:57 -0400 Subject: [PATCH 6/7] Refactor conv2d_implicit_kernel for improved readability and consistency; update parameter comments and remove unused code --- ggml/src/ggml-cuda/conv2d-implicit.cu | 143 ++++++++------------------ 1 file changed, 44 insertions(+), 99 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d9fabd9657d8c..31205187c1d20 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -2,8 +2,8 @@ #include "convert.cuh" typedef struct{ - unsigned int n; //batch szie - unsigned int c; //channel number + unsigned int n; //batch size + unsigned int c; //number if channels unsigned int h; //height unsigned int w; //width unsigned int k; //number of filters @@ -23,23 +23,18 @@ typedef struct{ template static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, - const T * __restrict__ kernel, - float * __restrict__ output, - const param_t param) { + const T * __restrict__ kernel, + float * __restrict__ output, + const param_t param) { - extern __shared__ __align__(16 * 1024) char smem[]; + extern __shared__ unsigned char smem[]; T *smemweight = reinterpret_cast(smem); float *smeminput = reinterpret_cast(smem + 16 * 1024); int tx = threadIdx.x; int bx = blockIdx.x; int by = blockIdx.y; - - // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ - // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); - // // printf("param.n=%d\n",param.n); - // } - // __syncthreads(); + // Warp tile const int lane_id = threadIdx.x % 32; @@ -60,8 +55,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int posh_ori[4]; int posw_ori[4]; #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p; posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q; } @@ -82,28 +76,19 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, float input_frag[2][8]; float output_frag[8][8]; #pragma unroll - for (int i = 0; i < 8; ++i) - { + for (int i = 0; i < 8; ++i){ #pragma unroll - for (int j = 0; j < 8; ++j) - { + for (int j = 0; j < 8; ++j){ output_frag[i][j] = 0; } } // ldg - // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ - // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); - // } - // __syncthreads(); #pragma unroll - for (int i = 0; i < 4; ++i) - { - if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) - { + for (int i = 0; i < 4; ++i){ + if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset]; } - else - { + else{ weight_ldg_reg[i] = (T)0.f; } } @@ -111,57 +96,46 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) - { + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } - else - { + else{ input_ldg_reg[i] = 0.0; } } // sts - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smemweight[weight_sts_addr + i] = weight_ldg_reg[i]; } - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smeminput[input_sts_addr + i * 32] = input_ldg_reg[i]; } __syncthreads(); // lds #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ weight_frag[0][i] = smemweight[weight_lds_addr + i]; weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; } #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ input_frag[0][i] = smeminput[input_lds_addr + i]; input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; } - for (int crs = 0; crs < param.r * param.s * param.c; crs += 8) - { + for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){ // ldg int weiOffsetTmp = crs + 8 + tx % 8; #pragma unroll - for (int i = 0; i < 4; ++i) - { - if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) - { + for (int i = 0; i < 4; ++i){ + if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; } - else - { + else{ weight_ldg_reg[i] = (T)0.f; } } @@ -170,76 +144,62 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) - { + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } - else - { + else{ input_ldg_reg[i] = 0.f; } } int load_flag = write_flag ^ 1; #pragma unroll - for (int subcrs = 0; subcrs < 8 - 1; ++subcrs) - { + for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){ #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; } #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; } #pragma unroll - for (int i = 0; i < 8; ++i) - { + for (int i = 0; i < 8; ++i){ #pragma unroll - for (int j = 0; j < 8; ++j) - { + for (int j = 0; j < 8; ++j){ output_frag[i][j] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; } } } // sts - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i]; } - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i]; } __syncthreads(); write_flag ^= 1; #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i]; weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16]; } #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i]; input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32]; } #pragma unroll - for (int i = 0; i < 8; ++i) - { + for (int i = 0; i < 8; ++i){ #pragma unroll - for (int j = 0; j < 8; ++j) - { + for (int j = 0; j < 8; ++j){ output_frag[i][j] += ggml_cuda_cast(weight_frag[1][i]) * input_frag[1][j]; } } @@ -247,35 +207,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // reuse smem float *smemoutput = reinterpret_cast(smem); - // float *smembias = reinterpret_cast(smem + 16 * 1024); - // bias ldg/sts - // if (tx < 128) - // { - // smembias[tx] = param.bias[by * 128 + tx]; - // } uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4; uint32_t output_lds_addr = warp_id * 512 + lane_id; - // uint32_t bias_lds_addr = warp_id / 2 * 32; uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32; uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id; #pragma unroll - for (int i = 0; i < 2; ++i) - { + for (int i = 0; i < 2; ++i){ #pragma unroll - for (int j = 0; j < 2; ++j) - { + for (int j = 0; j < 2; ++j){ __syncthreads(); - #pragma unroll - for (int subi = 0; subi < 4; ++subi) - { + for (int subi = 0; subi < 4; ++subi){ #pragma unroll - for (int subj = 0; subj < 4; ++subj) - { + for (int subj = 0; subj < 4; ++subj){ // output sts smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; } @@ -283,11 +231,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, __syncthreads(); #pragma unroll - for (int subk = 0; subk < 16; ++subk) - { + for (int subk = 0; subk < 16; ++subk){ int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) - // output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk]; output[outOffset] = smemoutput[output_lds_addr + subk * 32]; } } @@ -295,8 +241,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } template -static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { - // const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number int blocky = (P.k + 127) / 128; // blocky number int blockz = P.n; // blockz number From 83a3b7d6a98727705ed04f29afdd587ea3c17c37 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 6 Sep 2025 17:26:19 -0400 Subject: [PATCH 7/7] Refactor conv2d_implicit_kernel for improved bitwise operations; add test for implicit convolution --- ggml/src/ggml-cuda/conv2d-implicit.cu | 72 +++-- tests/CMakeLists.txt | 1 + tests/test-conv2d-implicit.cpp | 390 ++++++++++++++++++++++++++ 3 files changed, 434 insertions(+), 29 deletions(-) create mode 100644 tests/test-conv2d-implicit.cpp diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 31205187c1d20..1e2540f8ca4ba 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -37,16 +37,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // Warp tile - const int lane_id = threadIdx.x % 32; - const int warp_id = threadIdx.x / 32; - const int mma_tid_x = (lane_id / 2) % 8; - const int mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2); + const int lane_id = threadIdx.x & 31; + const int warp_id = threadIdx.x >> 5; + const int mma_tid_x = (lane_id >> 1) % 8; + const int mma_tid_y = (lane_id >> 4) * 2 + (lane_id & 1); // lds addr - int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; - int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + int weight_lds_addr = (warp_id >> 1) * 32 + mma_tid_y * 4; + int input_lds_addr = (warp_id & 1) * 64 + mma_tid_x * 4; - int x = bx * 128 + input_lds_addr; - int y = by * 128 + weight_lds_addr; + // int x = bx * 128 + input_lds_addr; + // int y = by * 128 + weight_lds_addr; int z = blockIdx.z; T weight_ldg_reg[4]; @@ -56,20 +56,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int posw_ori[4]; #pragma unroll for (int i = 0; i < 4; ++i){ - posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p; - posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q; + posh_ori[i] = ((bx * 128 + lane_id + i * 32) / param.Ow) * param.u - param.p; + posw_ori[i] = ((bx * 128 + lane_id + i * 32) % param.Ow) * param.v - param.q; } int inOffset = z * param.c * param.h * param.w; - int weiOffset = (by * 128 + tx / 8 * 4) * param.c * param.r * param.s; + int weiOffset = (by * 128 + (tx >> 3) * 4) * param.c * param.r * param.s; int inChannelOffset = param.h * param.w; - int weightChannelOffset = param.r * param.s; + // int weightChannelOffset = param.r * param.s; int weightKOffset = param.c * param.r * param.s; // sts addr - int weight_sts_addr = (tx % 8) * 132 + - (tx / 8) * 4; - int input_sts_addr = (tx / 32) * 128 + (tx % 32); + int weight_sts_addr = (tx & 7) * 132 + + (tx >> 3) * 4; + int input_sts_addr = (warp_id) * 128 + (lane_id); int write_flag = 1; T weight_frag[2][8]; @@ -85,16 +85,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // ldg #pragma unroll for (int i = 0; i < 4; ++i){ - if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ - weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset]; + if (tx % 8 < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ + weight_ldg_reg[i] = kernel[weiOffset + (tx & 7) + i * weightKOffset]; } else{ weight_ldg_reg[i] = (T)0.f; } } - int curC = (tx / 32) / (param.r * param.s); // channel offset - int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset - int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + int curC = (warp_id) / (param.r * param.s); // channel offset + int curR = ((warp_id) % (param.r * param.s)) / param.s; // kernel r offset + int curS = ((warp_id) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll for (int i = 0; i < 4; ++i){ int curH = posh_ori[i] + curR * param.d_h; // input h @@ -127,21 +127,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, input_frag[0][i] = smeminput[input_lds_addr + i]; input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; } + + // main loop for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){ // ldg - int weiOffsetTmp = crs + 8 + tx % 8; + int weiOffsetTmp = crs + 8 + (tx & 7); #pragma unroll for (int i = 0; i < 4; ++i){ - if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ + if (weiOffsetTmp < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; } else{ weight_ldg_reg[i] = (T)0.f; } } - curC = (crs + 8 + tx / 32) / (param.r * param.s); // channel offset - curR = ((crs + 8 + tx / 32) % (param.r * param.s)) / param.s; // kernel r offset - curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + curC = (crs + 8 + warp_id) / (param.r * param.s); // channel offset + curR = ((crs + 8 + warp_id) % (param.r * param.s)) / param.s; // kernel r offset + curS = ((crs + 8 + warp_id) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll for (int i = 0; i < 4; ++i){ @@ -160,13 +162,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){ #pragma unroll for (int i = 0; i < 4; ++i){ - weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; - weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; + weight_frag[(subcrs + 1) & 1][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; + weight_frag[(subcrs + 1) & 1][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; } + // // compute base pointer once + // T* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132; + + // // first 4 values -> weight_frag[...][0..3] + // float4 v0 = *reinterpret_cast(base_ptr); + + // // next 4 values (offset +16) -> weight_frag[...][4..7] + // float4 v1 = *reinterpret_cast(base_ptr + 16); + + // // unpack into weight_frag + // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][0]) = v0; + // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][4]) = v1; #pragma unroll for (int i = 0; i < 4; ++i){ - input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; - input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; + input_frag[(subcrs + 1) & 1][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; + input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; } #pragma unroll diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 91719577564a9..7ce76f01058b3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -198,6 +198,7 @@ if (NOT LLAMA_SANITIZE_ADDRESS) endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) +llama_build_and_test(test-conv2d-implicit.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp new file mode 100644 index 0000000000000..b0efba2f1c843 --- /dev/null +++ b/tests/test-conv2d-implicit.cpp @@ -0,0 +1,390 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-cpu.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +//#include +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + + + +void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu = false ) { + // create data + int KW = 3, KH = 3, IC = ic, OC = oc; + int IW = iw, IH = ih, N = 1; + + // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); + + // Initialize adata + std::vector adata(KW * KH * IC * OC); + for (int i = 0; i < KW * KH * IC * OC; i++) { + adata[i] = 2.5f; + } + + // Convert adata to fp16 format + // std::vector hadata(KW * KH * IC * OC); + // ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); + + // Initialize bdata + std::vector bdata(IW * IH * IC * N); + for (int i = 0; i < IW * IH * IC * N; i++) { + bdata[i] = 1.5f; + } + + size_t buffer_size = 0; + { + buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUDA + if (use_gpu) { + // fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); + model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); + + // create a allocator + struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); + } + + // alloc memory + ggml_tallocr_alloc(&alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); + } +} + +typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model); + +struct ggml_cgraph * build_graph_0(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(conv2d_res, "conv2d_res"); + ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); + // ggml_set_name(wino_res, "wino_res"); + // ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph * build_graph_1(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + // ggml_set_name(conv2d_res, "conv2d_res"); + // ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(wino_res, "wino_res"); + ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + + + + +std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, + build_graph_t build_graph, int iters, double *t) { + struct ggml_cgraph * gf = build_graph(model); + + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + + + ggml_backend_graph_compute(model.backend, gf); + + ggml_backend_synchronize(model.backend); + + int64_t start_time = ggml_time_us(); + + for(int iter=0; iter data(ggml_nelements(res)); + ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res)); + + *t = time_us/1000; + return data; + +} + + +int main(void) +{ + ggml_time_init(); + std::vector> configs = { + std::make_tuple(64,64,48,64), + std::make_tuple(320,320,104,152), + std::make_tuple(640,640,52,76), + std::make_tuple(640,640,104,152), + std::make_tuple(960,320,104,152), + std::make_tuple(1280,1280,26,38), + std::make_tuple(1280,640,52,76), + std::make_tuple(1920,1280,26,38), + std::make_tuple(2560,1280,26,38), + std::make_tuple(512,512,104,152), + std::make_tuple(512,512,208,304), + std::make_tuple(512,256,416,608), + std::make_tuple(256,128,832,1216), + std::make_tuple(256,256,832,1216), + std::make_tuple(320,256,1024,1920) + }; + + int k = 0; + + for (auto c : configs){ + test_model model; + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), true); + + ggml_gallocr_t allocr = NULL; + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph_0(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_0 = NULL; + int iterations = 20; + + double run_time0; + std::vector conv2d_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); + + ggml_gallocr_free(allocr); + + allocr = NULL; + + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + gf = build_graph_1(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_1 = NULL; + + double run_time1; + std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + + if(k==0) { + k = 1; + fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| --- | --- | --- | --- | --- \n"); + } + + fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), + run_time0, mem_size0/1024.0f/1024.0f, + run_time1, mem_size1/1024.0f/1024.0f); + + + // for(int i = 0; i < ggml_nelements(wino_res); i++) { + // for(int i = 0; i < 3*28; i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + // // if(diff > 1.e-4) { + // printf("(%f, %f, %f, %d) \n", + // conv2d_data[i], + // wino_data[i], diff, i); + // // break; + // // } + // } + + ggml_free(model.ctx); + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + + } + + + // printf("\nPerforming test:\n"); + + return 0; +}