Skip to content

Commit 2b8074d

Browse files
committed
metal : make the backend async (wip)
ggml-ci
1 parent 186415d commit 2b8074d

File tree

1 file changed

+126
-50
lines changed

1 file changed

+126
-50
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 126 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6016,10 +6016,35 @@ static enum ggml_status ggml_metal_graph_compute(
60166016
}
60176017
}
60186018

6019+
// wait for any previous processing
6020+
// TODO: find a more cannonincal Metal way to do it
6021+
// or maybe we can create a new set of command buffers and avoid the wait here. need to figure out how
6022+
// to release old command buffers that have already completed. maybe MTLCommandBufferStatus?
6023+
{
6024+
{
6025+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
6026+
if (cmd_buf) {
6027+
[cmd_buf waitUntilCompleted];
6028+
[cmd_buf release];
6029+
ctx->cmd_bufs[n_cb].obj = nil;
6030+
}
6031+
}
6032+
6033+
for (int i = 0; i < n_cb; ++i) {
6034+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
6035+
if (cmd_buf) {
6036+
[cmd_buf waitUntilCompleted];
6037+
[cmd_buf release];
6038+
ctx->cmd_bufs[i].obj = nil;
6039+
}
6040+
}
6041+
}
6042+
60196043
// the main thread commits the first few commands immediately
60206044
// cmd_buf[n_cb]
60216045
{
60226046
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
6047+
[cmd_buf retain];
60236048
ctx->cmd_bufs[n_cb].obj = cmd_buf;
60246049

60256050
[cmd_buf enqueue];
@@ -6030,6 +6055,7 @@ static enum ggml_status ggml_metal_graph_compute(
60306055
// cmd_buf[0.. n_cb)
60316056
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
60326057
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
6058+
[cmd_buf retain];
60336059
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
60346060

60356061
// always enqueue the first two command buffers
@@ -6043,52 +6069,52 @@ static enum ggml_status ggml_metal_graph_compute(
60436069

60446070
// wait for completion and check status of each command buffer
60456071
// needed to detect if the device ran out-of-memory for example (#1881)
6046-
{
6047-
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
6048-
[cmd_buf waitUntilCompleted];
6049-
6050-
MTLCommandBufferStatus status = [cmd_buf status];
6051-
if (status != MTLCommandBufferStatusCompleted) {
6052-
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
6053-
if (status == MTLCommandBufferStatusError) {
6054-
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
6055-
}
6056-
6057-
return GGML_STATUS_FAILED;
6058-
}
6059-
}
6060-
6061-
for (int i = 0; i < n_cb; ++i) {
6062-
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
6063-
[cmd_buf waitUntilCompleted];
6064-
6065-
MTLCommandBufferStatus status = [cmd_buf status];
6066-
if (status != MTLCommandBufferStatusCompleted) {
6067-
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
6068-
if (status == MTLCommandBufferStatusError) {
6069-
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
6070-
}
6071-
6072-
return GGML_STATUS_FAILED;
6073-
}
6074-
6075-
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
6076-
if (!next_buffer) {
6077-
continue;
6078-
}
6079-
6080-
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
6081-
if (next_queued) {
6082-
continue;
6083-
}
6084-
6085-
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
6086-
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
6087-
return GGML_STATUS_ABORTED;
6088-
}
6089-
6090-
[next_buffer commit];
6091-
}
6072+
//{
6073+
// id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
6074+
// [cmd_buf waitUntilCompleted];
6075+
6076+
// MTLCommandBufferStatus status = [cmd_buf status];
6077+
// if (status != MTLCommandBufferStatusCompleted) {
6078+
// GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
6079+
// if (status == MTLCommandBufferStatusError) {
6080+
// GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
6081+
// }
6082+
6083+
// return GGML_STATUS_FAILED;
6084+
// }
6085+
//}
6086+
6087+
//for (int i = 0; i < n_cb; ++i) {
6088+
// id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
6089+
// [cmd_buf waitUntilCompleted];
6090+
6091+
// MTLCommandBufferStatus status = [cmd_buf status];
6092+
// if (status != MTLCommandBufferStatusCompleted) {
6093+
// GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
6094+
// if (status == MTLCommandBufferStatusError) {
6095+
// GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
6096+
// }
6097+
6098+
// return GGML_STATUS_FAILED;
6099+
// }
6100+
6101+
// id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
6102+
// if (!next_buffer) {
6103+
// continue;
6104+
// }
6105+
6106+
// const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
6107+
// if (next_queued) {
6108+
// continue;
6109+
// }
6110+
6111+
// if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
6112+
// GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
6113+
// return GGML_STATUS_ABORTED;
6114+
// }
6115+
6116+
// [next_buffer commit];
6117+
//}
60926118

60936119
if (!should_capture && ctx->capture_started) {
60946120
[ctx->capture_scope endScope];
@@ -6422,6 +6448,54 @@ static void ggml_backend_metal_free(ggml_backend_t backend) {
64226448
free(backend);
64236449
}
64246450

6451+
// TODO: tmp impl to make the results good
6452+
// I think here we have to waitUntilCompleted on all existing MTLCommandBuffers in the backend's context
6453+
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
6454+
struct ggml_backend_metal_context * ctx = backend->context;
6455+
6456+
const int n_cb = ctx->n_cb;
6457+
6458+
{
6459+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
6460+
if (cmd_buf) {
6461+
[cmd_buf waitUntilCompleted];
6462+
}
6463+
}
6464+
6465+
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
6466+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
6467+
if (cmd_buf) {
6468+
[cmd_buf waitUntilCompleted];
6469+
}
6470+
}
6471+
}
6472+
6473+
static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
6474+
// TODO: figure out how Metal does async copies
6475+
// I think one way is:
6476+
// - wrap the src and dst in MTLBuffers with newBufferWithBytesNoCopy (i.e. views)
6477+
// - create an MTLCommandBuffer and encode a copy command in it
6478+
// - commit the command buffer to the backend's queue
6479+
// - keep the command buffer in the backend context so we can later wait for completion and release it?
6480+
ggml_backend_metal_synchronize(backend);
6481+
6482+
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
6483+
ggml_backend_metal_buffer_set_tensor(buf, tensor, data, offset, size);
6484+
}
6485+
6486+
static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
6487+
// TODO: figure out how Metal does async copies (see above)
6488+
ggml_backend_metal_synchronize(backend);
6489+
6490+
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
6491+
ggml_backend_metal_buffer_get_tensor(buf, tensor, data, offset, size);
6492+
}
6493+
6494+
//static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
6495+
// TODO: figure out how Metal does async copies (see above)
6496+
// return true;
6497+
//}
6498+
64256499
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
64266500
return ggml_metal_graph_compute(backend, cgraph);
64276501
}
@@ -6502,15 +6576,17 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
65026576
static struct ggml_backend_i ggml_backend_metal_i = {
65036577
/* .get_name = */ ggml_backend_metal_name,
65046578
/* .free = */ ggml_backend_metal_free,
6505-
/* .set_tensor_async = */ NULL,
6506-
/* .get_tensor_async = */ NULL,
6507-
/* .cpy_tensor_async = */ NULL,
6508-
/* .synchronize = */ NULL,
6579+
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
6580+
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
6581+
/* .cpy_tensor_async = */ /*ggml_backend_metal_cpy_tensor_async*/ NULL, // TODO
6582+
/* .synchronize = */ ggml_backend_metal_synchronize,
65096583
/* .graph_plan_create = */ NULL,
65106584
/* .graph_plan_free = */ NULL,
65116585
/* .graph_plan_update = */ NULL,
65126586
/* .graph_plan_compute = */ NULL,
65136587
/* .graph_compute = */ ggml_backend_metal_graph_compute,
6588+
6589+
// TODO: https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
65146590
/* .event_record = */ NULL,
65156591
/* .event_wait = */ NULL,
65166592
};

0 commit comments

Comments
 (0)