Skip to content

Commit 7d3553f

Browse files
committed
make SHMEM_PAD a spec constant
1 parent 95ee61a commit 7d3553f

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3140,11 +3140,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
31403140
ggml_vk_create_pipeline(
31413141
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
31423142
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3143-
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, use_collectives);
3143+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }, 1, true, use_collectives);
31443144
ggml_vk_create_pipeline(
31453145
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
31463146
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3147-
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, use_collectives);
3147+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }, 1, true, use_collectives);
31483148
}
31493149

31503150
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88

99
#include "types.comp"
1010

11-
// Make spec constant
12-
#define SHMEM_PAD 4
13-
1411
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
1512
layout(binding = 0) readonly buffer A {
1613
A_TYPE knl_data[];
@@ -76,6 +73,7 @@ layout(constant_id = 3) const uint BS_NPQ = 128;
7673
// Thread-tile sizes
7774
layout(constant_id = 4) const uint TS_K = 8;
7875
layout(constant_id = 5) const uint use_collectives = 1;
76+
layout(constant_id = 6) const uint SHMEM_PAD = 4;
7977

8078
uint32_t tid = gl_LocalInvocationID.x;
8179
const uint32_t WG_SIZE = gl_WorkGroupSize.x;

0 commit comments

Comments
 (0)