Skip to content

Conversation

bssrdf
Copy link
Contributor

@bssrdf bssrdf commented Sep 4, 2025

This PR added another CUDA conv_2d op using implicit GEMM approach. It is only optimized for cuda cores and its performance is up to 10x of that of direct method currently in llama.cpp.

On a RTX4090

Cases Direct Implicit GEMM
ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096], 2.23 TFLOPS 38.76 TFLOPS
ne_input=[19,19,8,16],ne_kernel=[4,4,8,128], 1.85 TFLOPS 9.12 TFLOPS
ne_input=[19,19,8,16],ne_kernel=[4,4,8,130], 1.76 TFLOPS 9.27 TFLOPS
ne_input=[19,19,4,16],ne_kernel=[2,2,4,4], 147.71 GFLOPS 150.00 GFLOPS
ne_input=[224,224,3,1],ne_kernel=[3,3,3,8], 1.04 TFLOPS 1.02 TFLOPS
ne_input=[224,224,1,1],ne_kernel=[2,2,1,8], 255.40 GFLOPS 238.21 GFLOPS
ne_input=[224,224,1,8],ne_kernel=[2,2,1,8], 308.44 GFLOPS 324.17 GFLOPS
ne_input=[58,58,32,1],ne_kernel=[3,3,32,64], 1.49 TFLOPS 3.98 TFLOPS
ne_input=[58,58,32,8],ne_kernel=[3,3,32,64], 1.88 TFLOPS 15.85 TFLOPS
ne_input=[16,16,128,8],ne_kernel=[3,3,128,512], 1.98 TFLOPS 16.90 TFLOPS
ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096], 2.27 TFLOPS 38.00 TFLOPS
ne_input=[19,19,8,16],ne_kernel=[4,4,8,128], 1.86 TFLOPS 8.64 TFLOPS
ne_input=[19,19,8,16],ne_kernel=[4,4,8,130], 1.80 TFLOPS 8.78 TFLOPS
ne_input=[19,19,4,16],ne_kernel=[2,2,4,4], 150.12 GFLOPS 147.95 GFLOPS
ne_input=[224,224,3,1],ne_kernel=[3,3,3,8], 1.01 TFLOPS 980.39 GFLOPS
ne_input=[224,224,1,1],ne_kernel=[2,2,1,8], 245.83 GFLOPS 212.52 GFLOPS
ne_input=[224,224,1,8],ne_kernel=[2,2,1,8], 305.41 GFLOPS 317.95 GFLOPS
ne_input=[58,58,32,1],ne_kernel=[3,3,32,64], 1.43 TFLOPS 3.74 TFLOPS
ne_input=[58,58,32,8],ne_kernel=[3,3,32,64], 1.81 TFLOPS 14.96 TFLOPS
ne_input=[16,16,128,8],ne_kernel=[3,3,128,512], 1.84 TFLOPS 15.80 TFLOPS

Comparison with im2col+gemm

Fp16 filter, Fp32 activation

(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit GEMM TIME implicit GEMM VRAM
(64, 64, 48, 64) 0.03 ms 4.12 MB 0.07 ms 0.75 MB
(320, 320, 104, 152) 0.56 ms 106.13 MB 0.98 ms 19.30 MB
(640, 640, 52, 76) 0.32 ms 53.07 MB 1.24 ms 9.65 MB
(640, 640, 104, 152) 1.41 ms 212.27 MB 3.04 ms 38.59 MB
(960, 320, 104, 152) 1.48 ms 279.80 MB 2.68 ms 19.30 MB
(1280, 1280, 26, 38) 0.21 ms 26.53 MB 1.19 ms 4.82 MB
(1280, 640, 52, 76) 0.62 ms 96.48 MB 2.33 ms 9.65 MB
(1920, 1280, 26, 38) 0.30 ms 37.39 MB 1.79 ms 4.82 MB
(2560, 1280, 26, 38) 0.42 ms 48.24 MB 2.36 ms 4.82 MB
(512, 512, 104, 152) 0.91 ms 169.81 MB 1.88 ms 30.88 MB
(512, 512, 208, 304) 3.90 ms 679.25 MB 7.95 ms 123.50 MB
(512, 256, 416, 608) 12.55 ms 2470.00 MB 15.67 ms 247.00 MB
(256, 128, 832, 1216) 24.82 ms 4940.00 MB 15.67 ms 494.00 MB
(256, 256, 832, 1216) 27.43 ms 5434.00 MB 31.17 ms 988.00 MB
(320, 256, 1024, 1920) 66.56 ms 12720.00 MB 76.05 ms 1920.00 MB

Fp32 filter, Fp32 activation

(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit GEMM TIME implicit GEMM VRAM
(64, 64, 48, 64) 0.04 ms 7.50 MB 0.07 ms 0.75 MB
(320, 320, 104, 152) 0.92 ms 192.97 MB 0.90 ms 19.30 MB
(640, 640, 52, 76) 0.68 ms 96.48 MB 1.19 ms 9.65 MB
(640, 640, 104, 152) 2.41 ms 385.94 MB 2.95 ms 38.59 MB
(960, 320, 104, 152) 2.38 ms 540.31 MB 2.56 ms 19.30 MB
(1280, 1280, 26, 38) 0.71 ms 48.24 MB 1.10 ms 4.82 MB
(1280, 640, 52, 76) 1.18 ms 183.32 MB 2.20 ms 9.65 MB
(1920, 1280, 26, 38) 0.72 ms 69.95 MB 1.83 ms 4.82 MB
(2560, 1280, 26, 38) 0.94 ms 91.66 MB 2.35 ms 4.82 MB
(512, 512, 104, 152) 1.57 ms 308.75 MB 1.79 ms 30.88 MB
(512, 512, 208, 304) 6.34 ms 1235.00 MB 7.61 ms 123.50 MB
(512, 256, 416, 608) 17.49 ms 4693.00 MB 15.00 ms 247.00 MB
(256, 128, 832, 1216) 32.16 ms 9386.00 MB 15.06 ms 494.00 MB
(256, 256, 832, 1216) 36.54 ms 9880.00 MB 30.23 ms 988.00 MB
(320, 256, 1024, 1920) 562.36 ms 23520.00 MB 73.56 ms 1920.00 MB

@bssrdf bssrdf marked this pull request as draft September 4, 2025 20:17
@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 4, 2025
@JohannesGaessler
Copy link
Collaborator

Why are you adding a new ggml op?

@bssrdf
Copy link
Contributor Author

bssrdf commented Sep 4, 2025

Why are you adding a new ggml op?

Because of #15669 (comment)

@leejet
Copy link
Contributor

leejet commented Sep 5, 2025

I think the implementation of implicit gemm can directly use ggml_conv2d_direct. There's really no need to provide so many conv2d functions.

@bssrdf
Copy link
Contributor Author

bssrdf commented Sep 5, 2025

I think the implementation of implicit gemm can directly use ggml_conv2d_direct. There's really no need to provide so many conv2d functions.

I can reuse ggml_conv2d_direct. TBH it is not a very good or intuitive name (the best one, ggml_conv_2d, is already occupied). I do wish it has an additional argument (ggml_conv_2d should carry in the beginning) for what method implemented.

@leejet
Copy link
Contributor

leejet commented Sep 5, 2025

If the performance of implicit gemm is on par with or even better than that of im2col + gemm, I think ggml_conv_2d can also adopt the implementation of implicit gemm.

@JohannesGaessler
Copy link
Collaborator

What should be done regarding IM2COL vs. CONV2D is to construct the compute graph using CONV2D and to then let each backend decide how to do the operation. If a backend lacks support for convolution it should allocate a temporary buffer for IM2COL and use that as a workaround.

For kernel selection, please take a look at how e.g. FLASH_ATTN_EXT is being handled. There are multiple kernels that can be used, at runtime one is selected based on hardware capabilities and tensor shapes. All convolution kernels that do the exact same operation should be using the same ggml op. If we have multiple kernels that could be used we need to test which code paths are faster under which circumstances and write the logic accordingly. This is particularly relevant because there is a concurrent PR using tensor cores: #15813 . cc @mnehete32

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR, try removing the current conv2d kernel and replacing it with this one. Chances are it will be universally faster since it uses shared memory and has (unless I misread the code) coalesced memory accesses. I'll test the performance using a P40, RTX 3090, and RTX 4090 for NVIDIA and an RX 6800 and Mi 50 for AMD.

#include "convert.cuh"

typedef struct{
unsigned int n; //batch szie
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unsigned int n; //batch szie
unsigned int n; //batch size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


typedef struct{
unsigned int n; //batch szie
unsigned int c; //channel number
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to either "channel index" or "number of channels" depending on which this is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int threadz = 1; // threadz number per block
dim3 thblock(threadx, thready, threadz);
dim3 grid(blockx, blocky, blockz);
int smem_size = 24 * 1024;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On some CUDA architectures shared memory comes out of the L1 cache so it at all possible you should reserve only as much as will actually be used.

float * __restrict__ output,
const param_t param) {

extern __shared__ __align__(16 * 1024) char smem[];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of __align__ here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, no difference in performance

Comment on lines 63 to 64
for (int i = 0; i < 4; ++i)
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i) {

See contribution guidelines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. corrected styles in all places

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @JohannesGaessler, for taking time to review. I agree with your idea as to kernel selection behind the scenes. Indeed, no single kernel is optimal for input and filter shapes. That's why cudnn provide all kinds of them for user to choose. Previously I am not sure if selecting kernels is possible and I 'll look into FLASH_ATTN_EXT example (thanks again).

Now #15813 is adding tensor support with shared mem, I don't want to step over. This PR will be in hold for now. I may contribute to the current conv_2d_direct once tensor code is merged.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if there is a kernel with tensor core support a good kernel without tensor cores would still be extremely useful. P40s and Mi50s are very cheap options for 24/32 GB VRAM but they lack tensor cores. And from a ggml perspective it's much easier to squeeze out more performance than it is to compress the weights (without affecting quality).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking of P40s, you should be careful with FP16 arithmetic since that is massively gimped on Pascal. You can use the macro FAST_FP16_AVAILABLE to check whether FP16 would be fast and use FP32 as a workaround if not. You can look at e.g. mmvf.cu for an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into it. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants