@@ -6016,10 +6016,35 @@ static enum ggml_status ggml_metal_graph_compute(
6016
6016
}
6017
6017
}
6018
6018
6019
+ // wait for any previous processing
6020
+ // TODO: find a more cannonincal Metal way to do it
6021
+ // or maybe we can create a new set of command buffers and avoid the wait here. need to figure out how
6022
+ // to release old command buffers that have already completed. maybe MTLCommandBufferStatus?
6023
+ {
6024
+ {
6025
+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [n_cb].obj ;
6026
+ if (cmd_buf) {
6027
+ [cmd_buf waitUntilCompleted ];
6028
+ [cmd_buf release ];
6029
+ ctx->cmd_bufs [n_cb].obj = nil ;
6030
+ }
6031
+ }
6032
+
6033
+ for (int i = 0 ; i < n_cb; ++i) {
6034
+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [i].obj ;
6035
+ if (cmd_buf) {
6036
+ [cmd_buf waitUntilCompleted ];
6037
+ [cmd_buf release ];
6038
+ ctx->cmd_bufs [i].obj = nil ;
6039
+ }
6040
+ }
6041
+ }
6042
+
6019
6043
// the main thread commits the first few commands immediately
6020
6044
// cmd_buf[n_cb]
6021
6045
{
6022
6046
id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBufferWithUnretainedReferences ];
6047
+ [cmd_buf retain ];
6023
6048
ctx->cmd_bufs [n_cb].obj = cmd_buf;
6024
6049
6025
6050
[cmd_buf enqueue ];
@@ -6030,6 +6055,7 @@ static enum ggml_status ggml_metal_graph_compute(
6030
6055
// cmd_buf[0.. n_cb)
6031
6056
for (int cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
6032
6057
id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBufferWithUnretainedReferences ];
6058
+ [cmd_buf retain ];
6033
6059
ctx->cmd_bufs [cb_idx].obj = cmd_buf;
6034
6060
6035
6061
// always enqueue the first two command buffers
@@ -6043,52 +6069,52 @@ static enum ggml_status ggml_metal_graph_compute(
6043
6069
6044
6070
// wait for completion and check status of each command buffer
6045
6071
// needed to detect if the device ran out-of-memory for example (#1881)
6046
- {
6047
- id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [n_cb].obj ;
6048
- [cmd_buf waitUntilCompleted ];
6049
-
6050
- MTLCommandBufferStatus status = [cmd_buf status ];
6051
- if (status != MTLCommandBufferStatusCompleted ) {
6052
- GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, n_cb, status);
6053
- if (status == MTLCommandBufferStatusError ) {
6054
- GGML_LOG_INFO (" error: %s \n " , [[cmd_buf error ].localizedDescription UTF8String ]);
6055
- }
6056
-
6057
- return GGML_STATUS_FAILED;
6058
- }
6059
- }
6060
-
6061
- for (int i = 0 ; i < n_cb; ++i) {
6062
- id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [i].obj ;
6063
- [cmd_buf waitUntilCompleted ];
6064
-
6065
- MTLCommandBufferStatus status = [cmd_buf status ];
6066
- if (status != MTLCommandBufferStatusCompleted ) {
6067
- GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, i, status);
6068
- if (status == MTLCommandBufferStatusError ) {
6069
- GGML_LOG_INFO (" error: %s \n " , [[cmd_buf error ].localizedDescription UTF8String ]);
6070
- }
6071
-
6072
- return GGML_STATUS_FAILED;
6073
- }
6074
-
6075
- id <MTLCommandBuffer > next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs [i + 1 ].obj : nil );
6076
- if (!next_buffer) {
6077
- continue ;
6078
- }
6079
-
6080
- const bool next_queued = ([next_buffer status ] != MTLCommandBufferStatusNotEnqueued );
6081
- if (next_queued) {
6082
- continue ;
6083
- }
6084
-
6085
- if (ctx->abort_callback && ctx->abort_callback (ctx->abort_callback_data )) {
6086
- GGML_LOG_INFO (" %s : command buffer %d aborted" , __func__, i);
6087
- return GGML_STATUS_ABORTED;
6088
- }
6089
-
6090
- [next_buffer commit ];
6091
- }
6072
+ // {
6073
+ // id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
6074
+ // [cmd_buf waitUntilCompleted];
6075
+
6076
+ // MTLCommandBufferStatus status = [cmd_buf status];
6077
+ // if (status != MTLCommandBufferStatusCompleted) {
6078
+ // GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
6079
+ // if (status == MTLCommandBufferStatusError) {
6080
+ // GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
6081
+ // }
6082
+
6083
+ // return GGML_STATUS_FAILED;
6084
+ // }
6085
+ // }
6086
+
6087
+ // for (int i = 0; i < n_cb; ++i) {
6088
+ // id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
6089
+ // [cmd_buf waitUntilCompleted];
6090
+
6091
+ // MTLCommandBufferStatus status = [cmd_buf status];
6092
+ // if (status != MTLCommandBufferStatusCompleted) {
6093
+ // GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
6094
+ // if (status == MTLCommandBufferStatusError) {
6095
+ // GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
6096
+ // }
6097
+
6098
+ // return GGML_STATUS_FAILED;
6099
+ // }
6100
+
6101
+ // id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
6102
+ // if (!next_buffer) {
6103
+ // continue;
6104
+ // }
6105
+
6106
+ // const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
6107
+ // if (next_queued) {
6108
+ // continue;
6109
+ // }
6110
+
6111
+ // if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
6112
+ // GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
6113
+ // return GGML_STATUS_ABORTED;
6114
+ // }
6115
+
6116
+ // [next_buffer commit];
6117
+ // }
6092
6118
6093
6119
if (!should_capture && ctx->capture_started ) {
6094
6120
[ctx->capture_scope endScope ];
@@ -6422,6 +6448,54 @@ static void ggml_backend_metal_free(ggml_backend_t backend) {
6422
6448
free (backend);
6423
6449
}
6424
6450
6451
+ // TODO: tmp impl to make the results good
6452
+ // I think here we have to waitUntilCompleted on all existing MTLCommandBuffers in the backend's context
6453
+ static void ggml_backend_metal_synchronize (ggml_backend_t backend) {
6454
+ struct ggml_backend_metal_context * ctx = backend->context ;
6455
+
6456
+ const int n_cb = ctx->n_cb ;
6457
+
6458
+ {
6459
+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [n_cb].obj ;
6460
+ if (cmd_buf) {
6461
+ [cmd_buf waitUntilCompleted ];
6462
+ }
6463
+ }
6464
+
6465
+ for (int cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
6466
+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [cb_idx].obj ;
6467
+ if (cmd_buf) {
6468
+ [cmd_buf waitUntilCompleted ];
6469
+ }
6470
+ }
6471
+ }
6472
+
6473
+ static void ggml_backend_metal_set_tensor_async (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
6474
+ // TODO: figure out how Metal does async copies
6475
+ // I think one way is:
6476
+ // - wrap the src and dst in MTLBuffers with newBufferWithBytesNoCopy (i.e. views)
6477
+ // - create an MTLCommandBuffer and encode a copy command in it
6478
+ // - commit the command buffer to the backend's queue
6479
+ // - keep the command buffer in the backend context so we can later wait for completion and release it?
6480
+ ggml_backend_metal_synchronize (backend);
6481
+
6482
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src ->buffer : tensor->buffer ;
6483
+ ggml_backend_metal_buffer_set_tensor (buf, tensor, data, offset, size);
6484
+ }
6485
+
6486
+ static void ggml_backend_metal_get_tensor_async (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
6487
+ // TODO: figure out how Metal does async copies (see above)
6488
+ ggml_backend_metal_synchronize (backend);
6489
+
6490
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src ->buffer : tensor->buffer ;
6491
+ ggml_backend_metal_buffer_get_tensor (buf, tensor, data, offset, size);
6492
+ }
6493
+
6494
+ // static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
6495
+ // TODO: figure out how Metal does async copies (see above)
6496
+ // return true;
6497
+ // }
6498
+
6425
6499
static enum ggml_status ggml_backend_metal_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph) {
6426
6500
return ggml_metal_graph_compute (backend, cgraph);
6427
6501
}
@@ -6502,15 +6576,17 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
6502
6576
static struct ggml_backend_i ggml_backend_metal_i = {
6503
6577
/* .get_name = */ ggml_backend_metal_name,
6504
6578
/* .free = */ ggml_backend_metal_free,
6505
- /* .set_tensor_async = */ NULL ,
6506
- /* .get_tensor_async = */ NULL ,
6507
- /* .cpy_tensor_async = */ NULL ,
6508
- /* .synchronize = */ NULL ,
6579
+ /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async ,
6580
+ /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async ,
6581
+ /* .cpy_tensor_async = */ /* ggml_backend_metal_cpy_tensor_async */ NULL , // TODO
6582
+ /* .synchronize = */ ggml_backend_metal_synchronize ,
6509
6583
/* .graph_plan_create = */ NULL ,
6510
6584
/* .graph_plan_free = */ NULL ,
6511
6585
/* .graph_plan_update = */ NULL ,
6512
6586
/* .graph_plan_compute = */ NULL ,
6513
6587
/* .graph_compute = */ ggml_backend_metal_graph_compute,
6588
+
6589
+ // TODO: https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
6514
6590
/* .event_record = */ NULL ,
6515
6591
/* .event_wait = */ NULL ,
6516
6592
};
0 commit comments