@@ -5661,8 +5661,12 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
5661
5661
ggml_vk_queue_command_pools_cleanup(dst->device);
5662
5662
}
5663
5663
5664
- static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
5665
- VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
5664
+ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
5665
+ VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");
5666
+
5667
+ if (disable_split_k) {
5668
+ return 1;
5669
+ }
5666
5670
5667
5671
uint32_t split_k = 1;
5668
5672
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
@@ -5987,7 +5991,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
5987
5991
ggml_vk_sync_buffers(ctx, subctx);
5988
5992
}
5989
5993
5990
- static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5994
+ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k, bool dryrun = false) {
5991
5995
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
5992
5996
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
5993
5997
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
@@ -6005,8 +6009,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
6005
6009
const uint64_t ne12 = src1->ne[2];
6006
6010
const uint64_t ne13 = src1->ne[3];
6007
6011
6008
- const uint64_t ne20 = dst->ne[0];
6009
6012
const uint64_t ne21 = dst->ne[1];
6013
+ const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);
6014
+ const uint32_t stride_batch_d = stride_d*ne21;
6010
6015
6011
6016
const uint64_t r2 = ne12 / ne02;
6012
6017
const uint64_t r3 = ne13 / ne03;
@@ -6075,7 +6080,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
6075
6080
const int y_ne = padded_n * ne10;
6076
6081
const int d_ne = ne11 * ne01;
6077
6082
6078
- const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
6083
+ const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);
6079
6084
6080
6085
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
6081
6086
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@@ -6234,13 +6239,16 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
6234
6239
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
6235
6240
}
6236
6241
6242
+ // No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange.
6243
+ VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
6244
+
6237
6245
// compute
6238
6246
ggml_vk_matmul(
6239
6247
ctx, subctx, pipeline,
6240
6248
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
6241
- { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
6249
+ { d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
6242
6250
ne01, ne11, ne10,
6243
- ne10, ne10, ne01 , stride_batch_x, stride_batch_y, ne20*ne21 ,
6251
+ ne10, ne10, stride_d , stride_batch_x, stride_batch_y, stride_batch_d ,
6244
6252
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
6245
6253
); // NOLINT
6246
6254
@@ -6718,9 +6726,36 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
6718
6726
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
6719
6727
}
6720
6728
6721
- static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6729
+ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6722
6730
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
6723
- if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
6731
+
6732
+ // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
6733
+ // where the M dimension is very large.
6734
+ // Split_k doesn't work with M splitting.
6735
+ const size_t nbytes = ggml_nbytes(src0);
6736
+ const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange;
6737
+ if (needs_split) {
6738
+ // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
6739
+ const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
6740
+ uint32_t m_offset = 0;
6741
+ while (m_offset < dst->ne[0]) {
6742
+ const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset));
6743
+ ggml_tensor dst2 = *dst;
6744
+ ggml_tensor src02 = *src0;
6745
+
6746
+ dst2.view_src = dst->view_src ? dst->view_src : dst;
6747
+ src02.view_src = src0->view_src ? src0->view_src : src0;
6748
+
6749
+ dst2.view_offs += m_offset * dst->nb[0];
6750
+ src02.view_offs += m_offset * src0->nb[1];
6751
+ dst2.ne[0] = cur_M_size;
6752
+ src02.ne[1] = cur_M_size;
6753
+
6754
+ ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true, dryrun);
6755
+
6756
+ m_offset += cur_M_size;
6757
+ }
6758
+ } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
6724
6759
// detect 0213 permutation, and batch size of 1
6725
6760
src0->nb[0] <= src0->nb[2] &&
6726
6761
src0->nb[2] <= src0->nb[1] &&
@@ -6740,7 +6775,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
6740
6775
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
6741
6776
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
6742
6777
} else {
6743
- ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
6778
+ ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false, dryrun);
6744
6779
}
6745
6780
}
6746
6781
@@ -10675,10 +10710,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10675
10710
VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
10676
10711
ctx->semaphore_idx = 0;
10677
10712
10678
- const ggml_tensor * src0 = node->src[0];
10679
- const ggml_tensor * src1 = node->src[1];
10680
- const ggml_tensor * src2 = node->src[2];
10681
- const ggml_tensor * src3 = node->src[3];
10713
+ ggml_tensor * src0 = node->src[0];
10714
+ ggml_tensor * src1 = node->src[1];
10715
+ ggml_tensor * src2 = node->src[2];
10716
+ ggml_tensor * src3 = node->src[3];
10682
10717
10683
10718
switch (node->op) {
10684
10719
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
0 commit comments