|
2 | 2 | #include "ggml-sycl/presets.hpp"
|
3 | 3 | #include "ggml.h"
|
4 | 4 | #include "element_wise.hpp"
|
| 5 | +#include <cstring> |
5 | 6 |
|
6 | 7 | #define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
|
7 | 8 | for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
|
@@ -926,6 +927,135 @@ static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor
|
926 | 927 | ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
927 | 928 | });
|
928 | 929 | }
|
| 930 | +static inline void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 931 | + const ggml_tensor * src0 = dst->src[0]; |
| 932 | + GGML_ASSERT(dst->src[1] != nullptr); |
| 933 | + const ggml_tensor * src1 = dst->src[1]; |
| 934 | + |
| 935 | + GGML_ASSERT(src0->type == dst->type); |
| 936 | + GGML_ASSERT(src1->type == dst->type); |
| 937 | +#if defined(GGML_SYCL_F16) |
| 938 | + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_I32); |
| 939 | +#else |
| 940 | + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32); |
| 941 | +#endif |
| 942 | + const size_t ts = ggml_type_size(dst->type); |
| 943 | + |
| 944 | + dpct::queue_ptr q = ctx.stream(); |
| 945 | + { |
| 946 | + const bool same_type = (src0->type == dst->type); |
| 947 | + const bool src_cont = ggml_is_contiguous(src0); |
| 948 | + const bool dst_cont = ggml_is_contiguous(dst); |
| 949 | + |
| 950 | + const void *p_src0 = src0->data; |
| 951 | + void *p_dst = dst->data; |
| 952 | + |
| 953 | + auto pt_src0 = sycl::get_pointer_type((const char*)p_src0, q->get_context()); |
| 954 | + auto pt_dst = sycl::get_pointer_type((char*)p_dst, q->get_context()); |
| 955 | + |
| 956 | + if (same_type && src_cont && dst_cont && ggml_nelements(src0) == ggml_nelements(dst)) { |
| 957 | + const size_t bytes = ggml_nbytes(dst); |
| 958 | + if (pt_src0 != sycl::usm::alloc::unknown && pt_dst != sycl::usm::alloc::unknown) { |
| 959 | + SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(p_dst, p_src0, bytes))); |
| 960 | + } else { |
| 961 | + std::memcpy(p_dst, p_src0, bytes); |
| 962 | + } |
| 963 | + } else { |
| 964 | + const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3]; |
| 965 | + const size_t db0 = dst->nb[0], db1 = dst->nb[1], db2 = dst->nb[2], db3 = dst->nb[3]; |
| 966 | + const size_t sb0 = src0->nb[0], sb1 = src0->nb[1], sb2 = src0->nb[2], sb3 = src0->nb[3]; |
| 967 | + |
| 968 | + const size_t N = (size_t) ggml_nelements(dst); |
| 969 | + const size_t WG = 256; |
| 970 | + const size_t NG = ((N + WG - 1) / WG) * WG; |
| 971 | + |
| 972 | + const size_t ge0 = (size_t) ne0; |
| 973 | + const size_t ge1 = ge0 * (size_t) ne1; |
| 974 | + const size_t ge2 = ge1 * (size_t) ne2; |
| 975 | + |
| 976 | + q->parallel_for( |
| 977 | + sycl::nd_range<1>(sycl::range<1>(NG), sycl::range<1>(WG)), |
| 978 | + [=](sycl::nd_item<1> it) { |
| 979 | + size_t idx = it.get_global_linear_id(); |
| 980 | + if (idx >= N) return; |
| 981 | + |
| 982 | + size_t i3 = idx / ge2; size_t r2 = idx % ge2; |
| 983 | + size_t i2 = r2 / ge1; size_t r1 = r2 % ge1; |
| 984 | + size_t i1 = r1 / ge0; size_t i0 = r1 % ge0; |
| 985 | + |
| 986 | + const char * s = (const char*)p_src0 + (i0*sb0 + i1*sb1 + i2*sb2 + i3*sb3); |
| 987 | + char * d = (char*)p_dst + (i0*db0 + i1*db1 + i2*db2 + i3*db3); |
| 988 | + |
| 989 | + for (size_t b = 0; b < ts; ++b) d[b] = s[b]; |
| 990 | + } |
| 991 | + ); |
| 992 | + } |
| 993 | + } |
| 994 | + |
| 995 | + { |
| 996 | + const int32_t *p = (const int32_t *) dst->op_params; |
| 997 | + const size_t nb1 = (size_t) p[0]; |
| 998 | + const size_t nb2 = (size_t) p[1]; |
| 999 | + const size_t nb3 = (size_t) p[2]; |
| 1000 | + const size_t offset = (size_t) p[3]; |
| 1001 | + |
| 1002 | + const void *p_src1 = src1->data; |
| 1003 | + void *p_dst = dst->data; |
| 1004 | + |
| 1005 | + const size_t sb0 = src1->nb[0], sb1 = src1->nb[1], sb2 = src1->nb[2], sb3 = src1->nb[3]; |
| 1006 | + const size_t db0 = dst->nb[0]; |
| 1007 | + const int64_t ne0 = src1->ne[0], ne1 = src1->ne[1], ne2 = src1->ne[2], ne3 = src1->ne[3]; |
| 1008 | + |
| 1009 | + if (ggml_is_contiguous(src1) && db0 == ts) { |
| 1010 | + const size_t row_bytes = (size_t) ne0 * ts; |
| 1011 | + const char *s_base = (const char*) p_src1; |
| 1012 | + char *d_base = (char*) p_dst + offset; |
| 1013 | + |
| 1014 | + for (int64_t i3 = 0; i3 < ne3; ++i3) { |
| 1015 | + for (int64_t i2 = 0; i2 < ne2; ++i2) { |
| 1016 | + for (int64_t i1 = 0; i1 < ne1; ++i1) { |
| 1017 | + const char *s_row = s_base + i1*sb1 + i2*sb2 + i3*sb3; |
| 1018 | + char *d_row = d_base + i1*nb1 + i2*nb2 + i3*nb3; |
| 1019 | + |
| 1020 | + auto pt_s = sycl::get_pointer_type(s_row, q->get_context()); |
| 1021 | + auto pt_d = sycl::get_pointer_type(d_row, q->get_context()); |
| 1022 | + if (pt_s != sycl::usm::alloc::unknown && pt_d != sycl::usm::alloc::unknown) { |
| 1023 | + SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(d_row, s_row, row_bytes))); |
| 1024 | + } else { |
| 1025 | + std::memcpy(d_row, s_row, row_bytes); |
| 1026 | + } |
| 1027 | + } |
| 1028 | + } |
| 1029 | + } |
| 1030 | + } else { |
| 1031 | + |
| 1032 | + const size_t N = (size_t) (ne0 * ne1 * ne2 * ne3); |
| 1033 | + const size_t WG = 256; |
| 1034 | + const size_t NG = ((N + WG - 1) / WG) * WG; |
| 1035 | + |
| 1036 | + const size_t ge0 = (size_t) ne0; |
| 1037 | + const size_t ge1 = ge0 * (size_t) ne1; |
| 1038 | + const size_t ge2 = ge1 * (size_t) ne2; |
| 1039 | + |
| 1040 | + q->parallel_for( |
| 1041 | + sycl::nd_range<1>(sycl::range<1>(NG), sycl::range<1>(WG)), |
| 1042 | + [=](sycl::nd_item<1> it) { |
| 1043 | + size_t idx = it.get_global_linear_id(); |
| 1044 | + if (idx >= N) return; |
| 1045 | + |
| 1046 | + size_t i3 = idx / ge2; size_t r2 = idx % ge2; |
| 1047 | + size_t i2 = r2 / ge1; size_t r1 = r2 % ge1; |
| 1048 | + size_t i1 = r1 / ge0; size_t i0 = r1 % ge0; |
| 1049 | + |
| 1050 | + const char * s = (const char*) p_src1 + (i0*sb0 + i1*sb1 + i2*sb2 + i3*sb3); |
| 1051 | + char * d = (char*) p_dst + offset + (i0*db0 + i1*nb1 + i2*nb2 + i3*nb3); |
| 1052 | + |
| 1053 | + for (size_t b = 0; b < ts; ++b) d[b] = s[b]; |
| 1054 | + } |
| 1055 | + ); |
| 1056 | + } |
| 1057 | + } |
| 1058 | +} |
929 | 1059 |
|
930 | 1060 | static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
931 | 1061 | float min_val;
|
@@ -1124,6 +1254,11 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1124 | 1254 | ggml_sycl_op_pad(ctx, dst);
|
1125 | 1255 | }
|
1126 | 1256 |
|
| 1257 | +void ggml_sycl_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { |
| 1258 | + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); |
| 1259 | + ggml_sycl_op_set(ctx, dst); |
| 1260 | +} |
| 1261 | + |
1127 | 1262 | void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
1128 | 1263 | scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
1129 | 1264 | ggml_sycl_op_clamp(ctx, dst);
|
|
0 commit comments