Skip to content

Commit 1c872f7

Browse files
authored
opencl: add f16 for add, sub, mul, div (#14984)
1 parent baad948 commit 1c872f7

File tree

5 files changed

+414
-70
lines changed

5 files changed

+414
-70
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 136 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,10 @@ struct ggml_backend_opencl_context {
400400
cl_program program_mul_mm_f32_f32_l4_lm;
401401
cl_program program_mul_mm_f16_f32_l4_lm;
402402

403-
cl_kernel kernel_add, kernel_add_row;
404-
cl_kernel kernel_mul, kernel_mul_row;
405-
cl_kernel kernel_div, kernel_div_row;
406-
cl_kernel kernel_sub, kernel_sub_row;
403+
cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
404+
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
405+
cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
406+
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
407407
cl_kernel kernel_scale;
408408
cl_kernel kernel_silu, kernel_silu_4;
409409
cl_kernel kernel_gelu, kernel_gelu_4;
@@ -674,8 +674,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
674674
backend_ctx->program_add =
675675
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
676676

677-
CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err));
678-
CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err));
677+
CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err));
678+
CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err));
679+
CL_CHECK((backend_ctx->kernel_add_f16 = clCreateKernel(backend_ctx->program_add, "kernel_add_f16", &err), err));
680+
CL_CHECK((backend_ctx->kernel_add_row_f16 = clCreateKernel(backend_ctx->program_add, "kernel_add_row_f16", &err), err));
679681
GGML_LOG_CONT(".");
680682
}
681683

@@ -1089,8 +1091,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
10891091
backend_ctx->program_mul =
10901092
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
10911093

1092-
CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err));
1093-
CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err));
1094+
CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err));
1095+
CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err));
1096+
CL_CHECK((backend_ctx->kernel_mul_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_f16", &err), err));
1097+
CL_CHECK((backend_ctx->kernel_mul_row_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row_f16", &err), err));
10941098
GGML_LOG_CONT(".");
10951099
}
10961100

@@ -1288,11 +1292,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
12881292
#else
12891293
const std::string kernel_src = read_file("div.cl");
12901294
#endif
1295+
std::string compile_opts = std::string("-cl-std=") + opencl_c_std +
1296+
" -cl-mad-enable -cl-finite-math-only ";
1297+
12911298
backend_ctx->program_div =
12921299
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
12931300

1294-
CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err));
1295-
CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err));
1301+
CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err));
1302+
CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err));
1303+
CL_CHECK((backend_ctx->kernel_div_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_f16", &err), err));
1304+
CL_CHECK((backend_ctx->kernel_div_row_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_row_f16", &err), err));
12961305
GGML_LOG_CONT(".");
12971306
}
12981307

@@ -1308,8 +1317,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13081317
backend_ctx->program_sub =
13091318
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
13101319

1311-
CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err));
1312-
CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err));
1320+
CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err));
1321+
CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err));
1322+
CL_CHECK((backend_ctx->kernel_sub_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_f16", &err), err));
1323+
CL_CHECK((backend_ctx->kernel_sub_row_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row_f16", &err), err));
13131324
GGML_LOG_CONT(".");
13141325
}
13151326

@@ -2447,12 +2458,15 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
24472458
default:
24482459
return false;
24492460
}
2450-
case GGML_OP_ADD:
24512461
case GGML_OP_SCALE:
2462+
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2463+
case GGML_OP_ADD:
24522464
case GGML_OP_MUL:
24532465
case GGML_OP_DIV:
24542466
case GGML_OP_SUB:
2455-
return op->src[0]->type == GGML_TYPE_F32;
2467+
return (op->src[0]->type == op->src[1]->type) &&
2468+
(op->src[0]->type == op->type) &&
2469+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
24562470
case GGML_OP_UNARY:
24572471
switch (ggml_get_unary_op(op)) {
24582472
case GGML_UNARY_OP_GELU:
@@ -3680,35 +3694,39 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
36803694
GGML_ASSERT(dst);
36813695
GGML_ASSERT(dst->extra);
36823696

3683-
const int ne00 = src0 ? src0->ne[0] : 0;
3684-
const int ne01 = src0 ? src0->ne[1] : 0;
3685-
const int ne02 = src0 ? src0->ne[2] : 0;
3686-
const int ne03 = src0 ? src0->ne[3] : 0;
3697+
GGML_ASSERT(src0->type == src1->type);
3698+
GGML_ASSERT(src0->type == dst->type);
3699+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
36873700

3688-
const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
3689-
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
3690-
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
3691-
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
3701+
const int ne00 = src0->ne[0];
3702+
const int ne01 = src0->ne[1];
3703+
const int ne02 = src0->ne[2];
3704+
const int ne03 = src0->ne[3];
36923705

3693-
const int ne10 = src1 ? src1->ne[0] : 0;
3694-
const int ne11 = src1 ? src1->ne[1] : 0;
3695-
const int ne12 = src1 ? src1->ne[2] : 0;
3696-
const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
3706+
const cl_ulong nb00 = src0->nb[0];
3707+
const cl_ulong nb01 = src0->nb[1];
3708+
const cl_ulong nb02 = src0->nb[2];
3709+
const cl_ulong nb03 = src0->nb[3];
36973710

3698-
const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
3699-
const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
3700-
const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
3701-
const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
3711+
const int ne10 = src1->ne[0];
3712+
const int ne11 = src1->ne[1];
3713+
const int ne12 = src1->ne[2];
3714+
const int ne13 = src1->ne[3]; UNUSED(ne13);
37023715

3703-
const int ne0 = dst ? dst->ne[0] : 0;
3704-
const int ne1 = dst ? dst->ne[1] : 0;
3705-
const int ne2 = dst ? dst->ne[2] : 0;
3706-
const int ne3 = dst ? dst->ne[3] : 0;
3716+
const cl_ulong nb10 = src1->nb[0];
3717+
const cl_ulong nb11 = src1->nb[1];
3718+
const cl_ulong nb12 = src1->nb[2];
3719+
const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
37073720

3708-
const cl_ulong nb0 = dst ? dst->nb[0] : 0;
3709-
const cl_ulong nb1 = dst ? dst->nb[1] : 0;
3710-
const cl_ulong nb2 = dst ? dst->nb[2] : 0;
3711-
const cl_ulong nb3 = dst ? dst->nb[3] : 0;
3721+
const int ne0 = dst->ne[0];
3722+
const int ne1 = dst->ne[1];
3723+
const int ne2 = dst->ne[2];
3724+
const int ne3 = dst->ne[3];
3725+
3726+
const cl_ulong nb0 = dst->nb[0];
3727+
const cl_ulong nb1 = dst->nb[1];
3728+
const cl_ulong nb2 = dst->nb[2];
3729+
const cl_ulong nb3 = dst->nb[3];
37123730

37133731
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
37143732

@@ -3731,7 +3749,12 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37313749

37323750
bcast_row = true;
37333751
int ne = ne00 / 4;
3734-
kernel = backend_ctx->kernel_add_row;
3752+
3753+
if (src0->type == GGML_TYPE_F32) {
3754+
kernel = backend_ctx->kernel_add_row;
3755+
} else {
3756+
kernel = backend_ctx->kernel_add_row_f16;
3757+
}
37353758

37363759
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
37373760
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3741,7 +3764,11 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
37413764
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
37423765
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
37433766
} else {
3744-
kernel = backend_ctx->kernel_add;
3767+
if (src0->type == GGML_TYPE_F32) {
3768+
kernel = backend_ctx->kernel_add;
3769+
} else {
3770+
kernel = backend_ctx->kernel_add_f16;
3771+
}
37453772

37463773
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
37473774
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3803,35 +3830,39 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
38033830
GGML_ASSERT(dst);
38043831
GGML_ASSERT(dst->extra);
38053832

3806-
const int ne00 = src0 ? src0->ne[0] : 0;
3807-
const int ne01 = src0 ? src0->ne[1] : 0;
3808-
const int ne02 = src0 ? src0->ne[2] : 0;
3809-
const int ne03 = src0 ? src0->ne[3] : 0;
3833+
GGML_ASSERT(src0->type == src1->type);
3834+
GGML_ASSERT(src0->type == dst->type);
3835+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
38103836

3811-
const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
3812-
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
3813-
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
3814-
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
3837+
const int ne00 = src0->ne[0];
3838+
const int ne01 = src0->ne[1];
3839+
const int ne02 = src0->ne[2];
3840+
const int ne03 = src0->ne[3];
38153841

3816-
const int ne10 = src1 ? src1->ne[0] : 0;
3817-
const int ne11 = src1 ? src1->ne[1] : 0;
3818-
const int ne12 = src1 ? src1->ne[2] : 0;
3819-
const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
3842+
const cl_ulong nb00 = src0->nb[0];
3843+
const cl_ulong nb01 = src0->nb[1];
3844+
const cl_ulong nb02 = src0->nb[2];
3845+
const cl_ulong nb03 = src0->nb[3];
38203846

3821-
const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
3822-
const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
3823-
const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
3824-
const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
3847+
const int ne10 = src1->ne[0];
3848+
const int ne11 = src1->ne[1];
3849+
const int ne12 = src1->ne[2];
3850+
const int ne13 = src1->ne[3]; UNUSED(ne13);
3851+
3852+
const cl_ulong nb10 = src1->nb[0];
3853+
const cl_ulong nb11 = src1->nb[1];
3854+
const cl_ulong nb12 = src1->nb[2];
3855+
const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
38253856

3826-
const int ne0 = dst ? dst->ne[0] : 0;
3827-
const int ne1 = dst ? dst->ne[1] : 0;
3828-
const int ne2 = dst ? dst->ne[2] : 0;
3829-
const int ne3 = dst ? dst->ne[3] : 0;
3857+
const int ne0 = dst->ne[0];
3858+
const int ne1 = dst->ne[1];
3859+
const int ne2 = dst->ne[2];
3860+
const int ne3 = dst->ne[3];
38303861

3831-
const cl_ulong nb0 = dst ? dst->nb[0] : 0;
3832-
const cl_ulong nb1 = dst ? dst->nb[1] : 0;
3833-
const cl_ulong nb2 = dst ? dst->nb[2] : 0;
3834-
const cl_ulong nb3 = dst ? dst->nb[3] : 0;
3862+
const cl_ulong nb0 = dst->nb[0];
3863+
const cl_ulong nb1 = dst->nb[1];
3864+
const cl_ulong nb2 = dst->nb[2];
3865+
const cl_ulong nb3 = dst->nb[3];
38353866

38363867
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
38373868

@@ -3854,7 +3885,12 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
38543885

38553886
bcast_row = true;
38563887
int ne = ne00 / 4;
3857-
kernel = backend_ctx->kernel_mul_row;
3888+
3889+
if (src0->type == GGML_TYPE_F32) {
3890+
kernel = backend_ctx->kernel_mul_row;
3891+
} else {
3892+
kernel = backend_ctx->kernel_mul_row_f16;
3893+
}
38583894

38593895
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
38603896
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3864,7 +3900,11 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
38643900
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
38653901
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
38663902
} else {
3867-
kernel = backend_ctx->kernel_mul;
3903+
if (src0->type == GGML_TYPE_F32) {
3904+
kernel = backend_ctx->kernel_mul;
3905+
} else {
3906+
kernel = backend_ctx->kernel_mul_f16;
3907+
}
38683908

38693909
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
38703910
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3926,6 +3966,10 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
39263966
GGML_ASSERT(dst);
39273967
GGML_ASSERT(dst->extra);
39283968

3969+
GGML_ASSERT(src0->type == src1->type);
3970+
GGML_ASSERT(src0->type == dst->type);
3971+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
3972+
39293973
const int ne00 = src0->ne[0];
39303974
const int ne01 = src0->ne[1];
39313975
const int ne02 = src0->ne[2];
@@ -3974,7 +4018,12 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
39744018

39754019
bcast_row = true;
39764020
int ne = ne00 / 4;
3977-
kernel = backend_ctx->kernel_div_row;
4021+
4022+
if (src0->type == GGML_TYPE_F32) {
4023+
kernel = backend_ctx->kernel_div_row;
4024+
} else {
4025+
kernel = backend_ctx->kernel_div_row_f16;
4026+
}
39784027

39794028
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
39804029
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3984,7 +4033,11 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
39844033
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
39854034
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
39864035
} else {
3987-
kernel = backend_ctx->kernel_div;
4036+
if (src0->type == GGML_TYPE_F32) {
4037+
kernel = backend_ctx->kernel_div;
4038+
} else {
4039+
kernel = backend_ctx->kernel_div_f16;
4040+
}
39884041

39894042
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
39904043
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -4034,6 +4087,10 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
40344087
GGML_ASSERT(dst);
40354088
GGML_ASSERT(dst->extra);
40364089

4090+
GGML_ASSERT(src0->type == src1->type);
4091+
GGML_ASSERT(src0->type == dst->type);
4092+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
4093+
40374094
const int ne00 = src0->ne[0];
40384095
const int ne01 = src0->ne[1];
40394096
const int ne02 = src0->ne[2];
@@ -4082,7 +4139,12 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
40824139

40834140
bcast_row = true;
40844141
int ne = ne00 / 4;
4085-
kernel = backend_ctx->kernel_sub_row;
4142+
4143+
if (src0->type == GGML_TYPE_F32) {
4144+
kernel = backend_ctx->kernel_sub_row;
4145+
} else {
4146+
kernel = backend_ctx->kernel_sub_row_f16;
4147+
}
40864148

40874149
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
40884150
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -4092,7 +4154,11 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
40924154
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
40934155
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
40944156
} else {
4095-
kernel = backend_ctx->kernel_sub;
4157+
if (src0->type == GGML_TYPE_F32) {
4158+
kernel = backend_ctx->kernel_sub;
4159+
} else {
4160+
kernel = backend_ctx->kernel_sub_f16;
4161+
}
40964162

40974163
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
40984164
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));

0 commit comments

Comments
 (0)