Skip to content

OpenCL: add fused group_norm/norm, mul, add #15314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 185 additions & 4 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
cl_kernel kernel_norm;
cl_kernel kernel_norm, kernel_norm_mul_add;
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
cl_kernel kernel_group_norm;
cl_kernel kernel_group_norm, kernel_group_norm_mul_add;
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
cl_kernel kernel_soft_max, kernel_soft_max_4;
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
Expand Down Expand Up @@ -1130,7 +1130,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_norm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);

CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err));
GGML_LOG_CONT(".");
}

Expand Down Expand Up @@ -1389,7 +1390,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_group_norm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);

CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err));
GGML_LOG_CONT(".");
}

Expand Down Expand Up @@ -2380,12 +2382,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
} else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
const ggml_tensor *norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
const ggml_tensor *add = cgraph->nodes[node_idx+2];
const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0];
const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];

// norm fusion only supports F32
if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
return false;
}

if (norm->src[0]->ne[0] % 4 != 0) {
return false;
}

if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
return false;
}
} else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
const ggml_tensor *gn = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
const ggml_tensor *add = cgraph->nodes[node_idx+2];
const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0];
const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];

if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
return false;
}

if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
return false;
}
}

return true;
}

static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);

static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
Expand All @@ -2402,6 +2439,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
continue;
}

if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
continue;
}
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
continue;
}
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
i++;
Expand Down Expand Up @@ -4858,6 +4905,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}

static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
GGML_ASSERT(norm_tensor && mul_tensor && add_tensor);

const ggml_tensor * src0 = norm_tensor->src[0];
const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
const ggml_tensor * dst = add_tensor;

ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;

cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offset2 = extra2->offset + src2->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;

ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

float eps;
memcpy(&eps, norm_tensor->op_params, sizeof(float));

const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3];
const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3];
const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3];
const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3];

size_t sgs;
if (backend_ctx->gpu_family == ADRENO) sgs = 64;
else if (backend_ctx->gpu_family == INTEL) sgs = 32;
else GGML_ASSERT(false && "Unsupported GPU");

cl_kernel kernel = backend_ctx->kernel_norm_mul_add;

int nth = sgs;
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2;
nth = MIN(nth, max_workgroup_size);
nth = MIN(nth, ne00/4);

size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t lws[] = {(size_t)nth, 1, 1};
size_t num_subgroups = (nth + sgs - 1) / sgs;

CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1));
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3));
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL));

backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst);
}

static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
GGML_ASSERT(gn_tensor && mul_tensor && add_tensor);

const ggml_tensor * src0 = gn_tensor->src[0];
const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
const ggml_tensor * dst = add_tensor;

ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;

cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offset2 = extra2->offset + src2->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;

ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

int groups;
float eps;
memcpy(&groups, gn_tensor->op_params, sizeof(int));
memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float));

cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add;
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
int ne = ggml_nelements(src0);
int group_size = ne / groups;

size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) };
size_t gws[] = { (size_t)groups * lws[0] };

CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps));

backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);
}

static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
Expand Down
49 changes: 49 additions & 0 deletions ggml/src/ggml-opencl/kernels/group_norm.cl
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,52 @@ kernel void kernel_group_norm(
dst[j] *= scale;
}
}

//------------------------------------------------------------------------------
// group_norm_mul_add
//------------------------------------------------------------------------------
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_32
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_group_norm_mul_add(
global float * src0, ulong offset0,
global float * src1, ulong offset1,
global float * src2, ulong offset2,
global float * dst, ulong offsetd,
int ne,
int group_size,
float eps
) {
src0 = (global float *)((global char *)src0 + offset0);
src1 = (global float *)((global char *)src1 + offset1);
src2 = (global float *)((global char *)src2 + offset2);
dst = (global float *)((global char *)dst + offsetd);

int start = get_group_id(0) * group_size;
int end = start + group_size;
if (end > ne) {
end = ne;
}

float sum = 0.0f;
float sum_sq = 0.0f;

for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
float val = src0[j];
sum += val;
sum_sq += val*val;
}

sum = sub_group_reduce_add(sum);
sum_sq = sub_group_reduce_add(sum_sq);

const float mean = sum / group_size;
const float var = sum_sq / group_size - mean * mean;
const float scale = rsqrt(var + eps);

for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j];
}
}
80 changes: 80 additions & 0 deletions ggml/src/ggml-opencl/kernels/norm.cl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,83 @@ kernel void kernel_norm(
y[i00] = y[i00] * scale;
}
}

//------------------------------------------------------------------------------
// norm_mul_add
//------------------------------------------------------------------------------
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_32
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_norm_mul_add(
global char * src0_ptr, ulong src0_offset,
global char * src1_ptr, ulong src1_offset,
global char * src2_ptr, ulong src2_offset,
global char * dst_ptr, ulong dst_offset,
int ne00, int ne01, int ne02, int ne03,
ulong nb01, ulong nb02, ulong nb03,
int ne10, int ne11, int ne12, int ne13,
ulong nb11, ulong nb12, ulong nb13,
int ne20, int ne21, int ne22, int ne23,
ulong nb21, ulong nb22, ulong nb23,
ulong nbd1, ulong nbd2, ulong nbd3,
float eps,
local float2 * sums
) {
const int i03 = get_group_id(2);
const int i02 = get_group_id(1);
const int i01 = get_group_id(0);

global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);
global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);
global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);
global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3);

float p_sum = 0.0f;
float p_sum_sq = 0.0f;

const int n_chunks = ne00 / 4;
for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
float4 val = x[i00];
p_sum += val.x + val.y + val.z + val.w;
p_sum_sq += dot(val, val);
}

p_sum = sub_group_reduce_add(p_sum);
p_sum_sq = sub_group_reduce_add(p_sum_sq);

if (get_sub_group_local_id() == 0) {
sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);
}
barrier(CLK_LOCAL_MEM_FENCE);

if (get_local_id(0) == 0) {
float sum = 0.0f;
float sum_sq = 0.0f;
for (uint i = 0; i < get_num_sub_groups(); ++i) {
float2 s = sums[i];
sum += s.x;
sum_sq += s.y;
}

const float inv_ne00 = 1.0f / (float)ne00;
const float mean = sum * inv_ne00;
const float variance = mad(-mean, mean, sum_sq * inv_ne00);

sums[0] = (float2)(mean, rsqrt(variance + eps));
}
barrier(CLK_LOCAL_MEM_FENCE);

const float2 mean_scale = sums[0];
const float mean = mean_scale.x;
const float scale = mean_scale.y;
const float neg_mean_scale = -mean * scale;

for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
const int w_idx = ne10 > 1 ? i00 : 0;
const int b_idx = ne20 > 1 ? i00 : 0;
const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);
y[i00] = mad(norm_x, w[w_idx], b[b_idx]);
}
}
Loading
Loading