Skip to content

Commit 1332bcf

Browse files
author
Gitty Burstein
committed
SYCL/SET: implement operator + wire-up; docs/ops updates; element_wise & ggml-sycl changes
1 parent 28c39da commit 1332bcf

File tree

4 files changed

+160
-1
lines changed

4 files changed

+160
-1
lines changed

docs/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ Legend:
8585
| RWKV_WKV6 ||||||||||
8686
| RWKV_WKV7 ||||||||||
8787
| SCALE || 🟡 ||||||||
88-
| SET ||||||| |||
88+
| SET ||||||| |||
8989
| SET_ROWS ||| 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 ||
9090
| SGN |||| 🟡 | 🟡 || 🟡 |||
9191
| SIGMOID |||| 🟡 | 🟡 | 🟡 | 🟡 | 🟡 ||

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "ggml-sycl/presets.hpp"
33
#include "ggml.h"
44
#include "element_wise.hpp"
5+
#include <cstring>
56

67
#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
78
for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
@@ -926,6 +927,135 @@ static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor
926927
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
927928
});
928929
}
930+
static inline void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
931+
const ggml_tensor * src0 = dst->src[0];
932+
GGML_ASSERT(dst->src[1] != nullptr);
933+
const ggml_tensor * src1 = dst->src[1];
934+
935+
GGML_ASSERT(src0->type == dst->type);
936+
GGML_ASSERT(src1->type == dst->type);
937+
#if defined(GGML_SYCL_F16)
938+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_I32);
939+
#else
940+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32);
941+
#endif
942+
const size_t ts = ggml_type_size(dst->type);
943+
944+
dpct::queue_ptr q = ctx.stream();
945+
{
946+
const bool same_type = (src0->type == dst->type);
947+
const bool src_cont = ggml_is_contiguous(src0);
948+
const bool dst_cont = ggml_is_contiguous(dst);
949+
950+
const void *p_src0 = src0->data;
951+
void *p_dst = dst->data;
952+
953+
auto pt_src0 = sycl::get_pointer_type((const char*)p_src0, q->get_context());
954+
auto pt_dst = sycl::get_pointer_type((char*)p_dst, q->get_context());
955+
956+
if (same_type && src_cont && dst_cont && ggml_nelements(src0) == ggml_nelements(dst)) {
957+
const size_t bytes = ggml_nbytes(dst);
958+
if (pt_src0 != sycl::usm::alloc::unknown && pt_dst != sycl::usm::alloc::unknown) {
959+
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(p_dst, p_src0, bytes)));
960+
} else {
961+
std::memcpy(p_dst, p_src0, bytes);
962+
}
963+
} else {
964+
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
965+
const size_t db0 = dst->nb[0], db1 = dst->nb[1], db2 = dst->nb[2], db3 = dst->nb[3];
966+
const size_t sb0 = src0->nb[0], sb1 = src0->nb[1], sb2 = src0->nb[2], sb3 = src0->nb[3];
967+
968+
const size_t N = (size_t) ggml_nelements(dst);
969+
const size_t WG = 256;
970+
const size_t NG = ((N + WG - 1) / WG) * WG;
971+
972+
const size_t ge0 = (size_t) ne0;
973+
const size_t ge1 = ge0 * (size_t) ne1;
974+
const size_t ge2 = ge1 * (size_t) ne2;
975+
976+
q->parallel_for(
977+
sycl::nd_range<1>(sycl::range<1>(NG), sycl::range<1>(WG)),
978+
[=](sycl::nd_item<1> it) {
979+
size_t idx = it.get_global_linear_id();
980+
if (idx >= N) return;
981+
982+
size_t i3 = idx / ge2; size_t r2 = idx % ge2;
983+
size_t i2 = r2 / ge1; size_t r1 = r2 % ge1;
984+
size_t i1 = r1 / ge0; size_t i0 = r1 % ge0;
985+
986+
const char * s = (const char*)p_src0 + (i0*sb0 + i1*sb1 + i2*sb2 + i3*sb3);
987+
char * d = (char*)p_dst + (i0*db0 + i1*db1 + i2*db2 + i3*db3);
988+
989+
for (size_t b = 0; b < ts; ++b) d[b] = s[b];
990+
}
991+
);
992+
}
993+
}
994+
995+
{
996+
const int32_t *p = (const int32_t *) dst->op_params;
997+
const size_t nb1 = (size_t) p[0];
998+
const size_t nb2 = (size_t) p[1];
999+
const size_t nb3 = (size_t) p[2];
1000+
const size_t offset = (size_t) p[3];
1001+
1002+
const void *p_src1 = src1->data;
1003+
void *p_dst = dst->data;
1004+
1005+
const size_t sb0 = src1->nb[0], sb1 = src1->nb[1], sb2 = src1->nb[2], sb3 = src1->nb[3];
1006+
const size_t db0 = dst->nb[0];
1007+
const int64_t ne0 = src1->ne[0], ne1 = src1->ne[1], ne2 = src1->ne[2], ne3 = src1->ne[3];
1008+
1009+
if (ggml_is_contiguous(src1) && db0 == ts) {
1010+
const size_t row_bytes = (size_t) ne0 * ts;
1011+
const char *s_base = (const char*) p_src1;
1012+
char *d_base = (char*) p_dst + offset;
1013+
1014+
for (int64_t i3 = 0; i3 < ne3; ++i3) {
1015+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
1016+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
1017+
const char *s_row = s_base + i1*sb1 + i2*sb2 + i3*sb3;
1018+
char *d_row = d_base + i1*nb1 + i2*nb2 + i3*nb3;
1019+
1020+
auto pt_s = sycl::get_pointer_type(s_row, q->get_context());
1021+
auto pt_d = sycl::get_pointer_type(d_row, q->get_context());
1022+
if (pt_s != sycl::usm::alloc::unknown && pt_d != sycl::usm::alloc::unknown) {
1023+
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(d_row, s_row, row_bytes)));
1024+
} else {
1025+
std::memcpy(d_row, s_row, row_bytes);
1026+
}
1027+
}
1028+
}
1029+
}
1030+
} else {
1031+
1032+
const size_t N = (size_t) (ne0 * ne1 * ne2 * ne3);
1033+
const size_t WG = 256;
1034+
const size_t NG = ((N + WG - 1) / WG) * WG;
1035+
1036+
const size_t ge0 = (size_t) ne0;
1037+
const size_t ge1 = ge0 * (size_t) ne1;
1038+
const size_t ge2 = ge1 * (size_t) ne2;
1039+
1040+
q->parallel_for(
1041+
sycl::nd_range<1>(sycl::range<1>(NG), sycl::range<1>(WG)),
1042+
[=](sycl::nd_item<1> it) {
1043+
size_t idx = it.get_global_linear_id();
1044+
if (idx >= N) return;
1045+
1046+
size_t i3 = idx / ge2; size_t r2 = idx % ge2;
1047+
size_t i2 = r2 / ge1; size_t r1 = r2 % ge1;
1048+
size_t i1 = r1 / ge0; size_t i0 = r1 % ge0;
1049+
1050+
const char * s = (const char*) p_src1 + (i0*sb0 + i1*sb1 + i2*sb2 + i3*sb3);
1051+
char * d = (char*) p_dst + offset + (i0*db0 + i1*nb1 + i2*nb2 + i3*nb3);
1052+
1053+
for (size_t b = 0; b < ts; ++b) d[b] = s[b];
1054+
}
1055+
);
1056+
}
1057+
}
1058+
}
9291059

9301060
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
9311061
float min_val;
@@ -1124,6 +1254,11 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11241254
ggml_sycl_op_pad(ctx, dst);
11251255
}
11261256

1257+
void ggml_sycl_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1258+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
1259+
ggml_sycl_op_set(ctx, dst);
1260+
}
1261+
11271262
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11281263
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
11291264
ggml_sycl_op_clamp(ctx, dst);

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,6 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8383
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8484
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8585

86+
void ggml_sycl_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
87+
8688
#endif // GGML_SYCL_ELEMENTWISE_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3564,6 +3564,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
35643564
case GGML_OP_GET_ROWS:
35653565
ggml_sycl_get_rows(ctx, dst);
35663566
break;
3567+
case GGML_OP_SET:
3568+
ggml_sycl_set(ctx, dst);
3569+
break;
35673570
case GGML_OP_SET_ROWS:
35683571
ggml_sycl_op_set_rows(ctx, dst);
35693572
break;
@@ -4164,6 +4167,25 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
41644167

41654168
static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
41664169
switch (op->op) {
4170+
case GGML_OP_SET: {
4171+
#if defined(GGML_SYCL_F16)
4172+
const bool types_ok =
4173+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_I32) &&
4174+
(op->src[0]->type == op->type) &&
4175+
(op->src[1] && op->src[1]->type == op->type);
4176+
#else
4177+
const bool types_ok =
4178+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32) &&
4179+
(op->src[0]->type == op->type) &&
4180+
(op->src[1] && op->src[1]->type == op->type);
4181+
#endif
4182+
4183+
const bool contiguous_ok =
4184+
ggml_is_contiguous(op->src[0]) &&
4185+
(!op->src[1] || ggml_is_contiguous(op->src[1]));
4186+
4187+
return types_ok && contiguous_ok;
4188+
}
41674189
case GGML_OP_CONV_TRANSPOSE_1D:
41684190
{
41694191
ggml_type src0_type = op->src[0]->type;

0 commit comments

Comments
 (0)