Skip to content

Commit 2c62fd5

Browse files
committed
vulkan : add fp16 support for the conv_2d kernel
1 parent c12bbde commit 2c62fd5

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ struct vk_device_struct {
484484
vk_pipeline pipeline_rwkv_wkv7_f32;
485485
vk_pipeline pipeline_opt_step_adamw_f32;
486486
vk_pipeline pipeline_conv2d_f32;
487+
vk_pipeline pipeline_conv2d_f16_f32;
487488
vk_pipeline pipeline_conv2d_dw_whcn_f32;
488489
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
489490

@@ -3074,12 +3075,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
30743075
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
30753076
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
30763077
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
3078+
ggml_vk_create_pipeline(
3079+
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3080+
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3081+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
30773082
} else {
30783083
ggml_vk_create_pipeline(
30793084
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
30803085
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
30813086
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
30823087
false);
3088+
ggml_vk_create_pipeline(
3089+
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3090+
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3091+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
3092+
false);
30833093
}
30843094

30853095
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -6958,9 +6968,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69586968
}
69596969
return nullptr;
69606970
case GGML_OP_CONV_2D:
6961-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
6971+
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
69626972
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
6963-
return ctx->device->pipeline_conv2d_f32;
6973+
if (src0->type == GGML_TYPE_F32) {
6974+
return ctx->device->pipeline_conv2d_f32;
6975+
} else (src0->type == GGML_TYPE_F16) {
6976+
return ctx->device->pipeline_conv2d_f16_f32;
6977+
}
69646978
}
69656979
return nullptr;
69666980
case GGML_OP_CONV_2D_DW:
@@ -8178,13 +8192,13 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
81788192

81798193
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
81808194
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8181-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
8195+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
81828196
GGML_ASSERT(src1->type == GGML_TYPE_F32);
81838197
GGML_ASSERT(dst->type == GGML_TYPE_F32);
81848198

81858199
GGML_TENSOR_BINARY_OP_LOCALS
81868200

8187-
GGML_ASSERT(nb00 == sizeof(float));
8201+
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
81888202
GGML_ASSERT(nb10 == sizeof(float));
81898203
GGML_ASSERT(nb0 == sizeof(float));
81908204

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ void process_shaders() {
656656
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
657657

658658
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
659+
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
659660

660661
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
661662
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));

0 commit comments

Comments
 (0)