3737#include " ggml-backend-impl.h"
3838
3939#include " ggml-sycl/backend.hpp"
40+ #include " ggml-sycl/common.hpp"
4041#include " ggml-sycl/presets.hpp"
4142#include " ggml-sycl/gemm.hpp"
4243#include " ggml-sycl/sycl_hw.hpp"
@@ -490,6 +491,23 @@ catch (sycl::exception const &exc) {
490491 std::exit (1 );
491492}
492493
494+ static void ggml_backend_sycl_buffer_memset_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
495+ size_t offset, size_t size) {
496+ GGML_SYCL_DEBUG (" [SYCL] call %s\n " , __func__);
497+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context ;
498+ SYCL_CHECK (ggml_sycl_set_device (ctx->device ));
499+ auto stream = &(dpct::dev_mgr::instance ().get_device (ctx->device ).default_queue ());
500+ if (size == 0 ) {
501+ return ; // Nothing to do
502+ }
503+ if (tensor->data == nullptr ) {
504+ GGML_ABORT (" Error: Tensor data pointer is null.\n " );
505+ }
506+ void * target_ptr = static_cast <char *>(tensor->data ) + offset;
507+ SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memset (target_ptr, value, size)));
508+ SYCL_CHECK (CHECK_TRY_ERROR ((*stream).wait ()));
509+ }
510+
493511static void ggml_backend_sycl_buffer_reset (ggml_backend_buffer_t buffer) {
494512 GGML_SYCL_DEBUG (" [SYCL] call %s\n " , __func__);
495513 if (buffer == nullptr ) {
@@ -510,7 +528,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
510528 /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
511529 /* .get_base = */ ggml_backend_sycl_buffer_get_base,
512530 /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
513- /* .memset_tensor = */ NULL ,
531+ /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor ,
514532 /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
515533 /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
516534 /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
0 commit comments