Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ggml/src/ggml-opencl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ set(GGML_OPENCL_KERNELS
mul_mat_f16_f32
conv2d
conv2d_f16_f32
flash_attn_f32_f16
flash_attn_f16
flash_attn_f32
)

foreach (K ${GGML_OPENCL_KERNELS})
Expand Down
244 changes: 244 additions & 0 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <vector>
#include <string>
#include <cmath>
#include <map>
#include <memory>
#include <charconv>
#include <mutex>
Expand Down Expand Up @@ -420,6 +421,14 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
cl_kernel kernel_soft_max, kernel_soft_max_4;
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
std::map<std::pair<int, int>, int> kernels_flash_attn_bm;
std::map<std::pair<int, int>, int> kernels_flash_attn_bn;
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
Expand Down Expand Up @@ -1263,6 +1272,73 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}

// flash_attn
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src_f16 {
#include "flash_attn_f16.cl.h"
};
const std::string kernel_src_f32 {
#include "flash_attn_f32.cl.h"
};
const std::string kernel_src_f32_f16 {
#include "flash_attn_f32_f16.cl.h"
};
#else
const std::string kernel_src_f16 = read_file("flash_attn_f16.cl");
const std::string kernel_src_f32 = read_file("flash_attn_f32.cl");
const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl");
#endif

if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {
const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {
{ 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32},
{112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},
{192, 192, 16, 16}, {256, 256, 16, 16},
};

for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) {
const int dk = fa_dims[i].dk;
const int dv = fa_dims[i].dv;
const int bm = fa_dims[i].bm;
const int bn = fa_dims[i].bn;
std::string OPTS = compile_opts +
" -D DK=" + std::to_string(dk) +
" -D DV=" + std::to_string(dv) +
" -D BLOCK_M=" + std::to_string(bm) +
" -D BLOCK_N=" + std::to_string(bn);

cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS);
cl_kernel k_f16, k_f16_q1;
CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err));
CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err));
backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;
backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;
CL_CHECK(clReleaseProgram(prog_f16));

cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS);
cl_kernel k_f32, k_f32_q1;
CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err));
CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err));
backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;
backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;
CL_CHECK(clReleaseProgram(prog_f32));

cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS);
cl_kernel k_f32_f16, k_f32_f16_q1;
CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err));
CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err));
backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;
backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;
CL_CHECK(clReleaseProgram(prog_f32_f16));

backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;
backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn;
}
GGML_LOG_CONT(".");
}
}

// argsort
{
#ifdef GGML_OPENCL_EMBED_KERNELS
Expand Down Expand Up @@ -2553,6 +2629,41 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM_ROWS:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
{
const ggml_tensor * q = op->src[0];
const ggml_tensor * k = op->src[1];
const ggml_tensor * v = op->src[2];

const int dk = q->ne[0];
const int dv = v->ne[0];

const struct { int dk; int dv; } supported_dims[] = {
{ 64, 64}, { 80, 80}, { 96, 96},
{112, 112}, {128, 128}, {192, 128},
{192, 192}, {256, 256},
};

bool dims_supported = false;
for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) {
if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) {
dims_supported = true;
break;
}
}
if (!dims_supported) {
return false;
}

const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 &&
v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 &&
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16;
const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 &&
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32;

return is_f32_f32 || is_f16_f16 || is_f32_f16;
}
default:
return false;
}
Expand Down Expand Up @@ -5193,6 +5304,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
}

static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
GGML_ASSERT(q->extra);
GGML_ASSERT(k->extra);
GGML_ASSERT(v->extra);
GGML_ASSERT(dst->extra);
if (mask) {
GGML_ASSERT(mask->extra);
}

ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

const int n_q = q->ne[1];
const int n_kv = k->ne[1];
const int d_head_q = q->ne[0];
const int d_head_v = v->ne[0];
const int n_head = q->ne[2];
const int n_head_kv = k->ne[2];
const int n_batch = q->ne[3];

cl_kernel kernel = NULL;

const bool is_f16 = q->type == GGML_TYPE_F16;
const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16;
const std::pair<int, int> dk_dv = {d_head_q, d_head_v};

if (n_q == 1) {
if (is_mixed) {
kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv);
} else if (is_f16) {
kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv);
} else {
kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv);
}
} else {
if (is_mixed) {
kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv);
} else if (is_f16) {
kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv);
} else {
kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv);
}
}
GGML_ASSERT(kernel != NULL);

ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra;
ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra;
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;

cl_ulong offset_q = extra_q->offset + q->view_offs;
cl_ulong offset_k = extra_k->offset + k->view_offs;
cl_ulong offset_v = extra_v->offset + v->view_offs;
cl_ulong offset_o = extra_o->offset + dst->view_offs;
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;

const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3];
const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3];
const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0;
const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0;
const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0;
const int mask_ne2 = mask ? mask->ne[2] : 0;
const int mask_ne3 = mask ? mask->ne[3] : 0;

float scale, max_bias, logit_softcap;
const float * params = (const float *)dst->op_params;
scale = params[0];
max_bias = params[1];
logit_softcap = params[2];

const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv);

const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0;
const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f;
const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);

CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap));
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer));
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask));
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1));
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2));
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));

if (n_q == 1) {
const size_t wg_size = 64;
size_t local_work_size[] = { wg_size, 1 };
size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) };
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
} else {
const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv);
const size_t wg_size = block_m;
size_t local_work_size[] = { wg_size, 1 };
size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) };
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
}
}

static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

Expand Down Expand Up @@ -7239,6 +7477,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_sum_rows;
break;
case GGML_OP_FLASH_ATTN_EXT:
if (!any_on_device) {
return false;
}
ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor);
return true;
default:
return false;
}
Expand Down
Loading
Loading