Skip to content

OpenCL: add initial FA support #14987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ggml/src/ggml-opencl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ set(GGML_OPENCL_KERNELS
mul_mat_f16_f32
conv2d
conv2d_f16_f32
flash_attn_f32_f16
flash_attn_f16
flash_attn_f32
)

foreach (K ${GGML_OPENCL_KERNELS})
Expand Down
244 changes: 244 additions & 0 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <vector>
#include <string>
#include <cmath>
#include <map>
#include <memory>
#include <charconv>
#include <mutex>
Expand Down Expand Up @@ -420,6 +421,14 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
cl_kernel kernel_soft_max, kernel_soft_max_4;
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
std::map<std::pair<int, int>, int> kernels_flash_attn_bm;
std::map<std::pair<int, int>, int> kernels_flash_attn_bn;
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
Expand Down Expand Up @@ -1263,6 +1272,73 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}

// flash_attn
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src_f16 {
#include "flash_attn_f16.cl.h"
};
const std::string kernel_src_f32 {
#include "flash_attn_f32.cl.h"
};
const std::string kernel_src_f32_f16 {
#include "flash_attn_f32_f16.cl.h"
};
#else
const std::string kernel_src_f16 = read_file("flash_attn_f16.cl");
const std::string kernel_src_f32 = read_file("flash_attn_f32.cl");
const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl");
#endif

if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {
const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {
{ 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32},
{112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},
{192, 192, 16, 16}, {256, 256, 16, 16},
};

for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) {
const int dk = fa_dims[i].dk;
const int dv = fa_dims[i].dv;
const int bm = fa_dims[i].bm;
const int bn = fa_dims[i].bn;
std::string OPTS = compile_opts +
" -D DK=" + std::to_string(dk) +
" -D DV=" + std::to_string(dv) +
" -D BLOCK_M=" + std::to_string(bm) +
" -D BLOCK_N=" + std::to_string(bn);

cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS);
cl_kernel k_f16, k_f16_q1;
CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err));
CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err));
backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;
backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;
CL_CHECK(clReleaseProgram(prog_f16));

cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS);
cl_kernel k_f32, k_f32_q1;
CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err));
CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err));
backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;
backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;
CL_CHECK(clReleaseProgram(prog_f32));

cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS);
cl_kernel k_f32_f16, k_f32_f16_q1;
CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err));
CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err));
backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;
backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;
CL_CHECK(clReleaseProgram(prog_f32_f16));

backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;
backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn;
}
GGML_LOG_CONT(".");
}
}

// argsort
{
#ifdef GGML_OPENCL_EMBED_KERNELS
Expand Down Expand Up @@ -2553,6 +2629,41 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM_ROWS:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
{
const ggml_tensor * q = op->src[0];
const ggml_tensor * k = op->src[1];
const ggml_tensor * v = op->src[2];

const int dk = q->ne[0];
const int dv = v->ne[0];

const struct { int dk; int dv; } supported_dims[] = {
{ 64, 64}, { 80, 80}, { 96, 96},
{112, 112}, {128, 128}, {192, 128},
{192, 192}, {256, 256},
};

bool dims_supported = false;
for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) {
if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) {
dims_supported = true;
break;
}
}
if (!dims_supported) {
return false;
}

const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 &&
v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 &&
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16;
const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 &&
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32;

return is_f32_f32 || is_f16_f16 || is_f32_f16;
}
default:
return false;
}
Expand Down Expand Up @@ -5193,6 +5304,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
}

static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
GGML_ASSERT(q->extra);
GGML_ASSERT(k->extra);
GGML_ASSERT(v->extra);
GGML_ASSERT(dst->extra);
if (mask) {
GGML_ASSERT(mask->extra);
}

ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

const int n_q = q->ne[1];
const int n_kv = k->ne[1];
const int d_head_q = q->ne[0];
const int d_head_v = v->ne[0];
const int n_head = q->ne[2];
const int n_head_kv = k->ne[2];
const int n_batch = q->ne[3];

cl_kernel kernel = NULL;

const bool is_f16 = q->type == GGML_TYPE_F16;
const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16;
const std::pair<int, int> dk_dv = {d_head_q, d_head_v};

if (n_q == 1) {
if (is_mixed) {
kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv);
} else if (is_f16) {
kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv);
} else {
kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv);
}
} else {
if (is_mixed) {
kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv);
} else if (is_f16) {
kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv);
} else {
kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv);
}
}
GGML_ASSERT(kernel != NULL);

ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra;
ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra;
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;

cl_ulong offset_q = extra_q->offset + q->view_offs;
cl_ulong offset_k = extra_k->offset + k->view_offs;
cl_ulong offset_v = extra_v->offset + v->view_offs;
cl_ulong offset_o = extra_o->offset + dst->view_offs;
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;

const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3];
const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3];
const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0;
const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0;
const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0;
const int mask_ne2 = mask ? mask->ne[2] : 0;
const int mask_ne3 = mask ? mask->ne[3] : 0;

float scale, max_bias, logit_softcap;
const float * params = (const float *)dst->op_params;
scale = params[0];
max_bias = params[1];
logit_softcap = params[2];

const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv);

const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0;
const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f;
const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);

CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap));
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer));
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask));
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1));
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2));
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));

if (n_q == 1) {
const size_t wg_size = 64;
size_t local_work_size[] = { wg_size, 1 };
size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) };
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
} else {
const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv);
const size_t wg_size = block_m;
size_t local_work_size[] = { wg_size, 1 };
size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) };
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
}
}

static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

Expand Down Expand Up @@ -7239,6 +7477,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_sum_rows;
break;
case GGML_OP_FLASH_ATTN_EXT:
if (!any_on_device) {
return false;
}
ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor);
return true;
default:
return false;
}
Expand Down
Loading
Loading