99#include " intel_gpu/primitives/scaled_dot_product_attention.hpp"
1010
1111#ifdef ENABLE_ONEDNN_FOR_GPU
12- # define OV_GPU_WITH_ONEDNN 1
12+ # define OV_GPU_WITH_ONEDNN 1
1313#else
14- # define OV_GPU_WITH_ONEDNN 0
14+ # define OV_GPU_WITH_ONEDNN 0
1515#endif
1616
1717#if !defined(OV_GPU_WITH_SYCL)
18- # define OV_GPU_WITH_SYCL 0
18+ # define OV_GPU_WITH_SYCL 0
1919#endif
2020
21- #define OV_GPU_WITH_OCL 1
21+ #define OV_GPU_WITH_OCL 1
2222#define OV_GPU_WITH_COMMON 1
23- #define OV_GPU_WITH_CPU 1
24- #define OV_GPU_WITH_CM 1
23+ #define OV_GPU_WITH_CPU 1
24+ #define OV_GPU_WITH_CM 1
2525
2626#define COUNT_N (_1, _2, _3, _4, _5, N, ...) N
27- #define COUNT (...) EXPAND(COUNT_N(__VA_ARGS__, 5 , 4 , 3 , 2 , 1 ))
28- #define CAT (a, b ) a## b
27+ #define COUNT (...) EXPAND(COUNT_N(__VA_ARGS__, 5 , 4 , 3 , 2 , 1 ))
28+ #define CAT (a, b ) a ## b
2929
3030#define EXPAND (N ) N
3131
32- #define IMPL_TYPE_CPU_D impl_types::cpu, cldnn::shape_types::dynamic_shape
33- #define IMPL_TYPE_CPU_S impl_types::cpu, cldnn::shape_types::static_shape
34- #define IMPL_TYPE_OCL_D impl_types::ocl, cldnn::shape_types::dynamic_shape
35- #define IMPL_TYPE_OCL_S impl_types::ocl, cldnn::shape_types::static_shape
32+ #define IMPL_TYPE_CPU_D impl_types::cpu, cldnn::shape_types::dynamic_shape
33+ #define IMPL_TYPE_CPU_S impl_types::cpu, cldnn::shape_types::static_shape
34+ #define IMPL_TYPE_OCL_D impl_types::ocl, cldnn::shape_types::dynamic_shape
35+ #define IMPL_TYPE_OCL_S impl_types::ocl, cldnn::shape_types::static_shape
3636#define IMPL_TYPE_COMMON_D impl_types::common, cldnn::shape_types::dynamic_shape
3737#define IMPL_TYPE_COMMON_S impl_types::common, cldnn::shape_types::static_shape
3838
39- #define INSTANTIATE_1 (prim, suffix ) cldnn::implementation_map<cldnn::prim>::get(cldnn::CAT(IMPL_TYPE_, suffix))
39+ #define INSTANTIATE_1 (prim, suffix ) cldnn::implementation_map<cldnn::prim>::get(cldnn::CAT(IMPL_TYPE_, suffix))
4040#define INSTANTIATE_2 (prim, suffix, ...) INSTANTIATE_1(prim, suffix), INSTANTIATE_1(prim, __VA_ARGS__)
4141#define INSTANTIATE_3 (prim, suffix, ...) INSTANTIATE_1(prim, suffix), INSTANTIATE_2(prim, __VA_ARGS__)
4242#define INSTANTIATE_4 (prim, suffix, ...) INSTANTIATE_1(prim, suffix), INSTANTIATE_3(prim, __VA_ARGS__)
4343
4444#define FOR_EACH_ (N, prim, ...) EXPAND(CAT(INSTANTIATE_, N)(prim, __VA_ARGS__))
45- #define INSTANTIATE (prim, ...) EXPAND(FOR_EACH_(COUNT(__VA_ARGS__), prim, __VA_ARGS__))
45+ #define INSTANTIATE (prim, ...) EXPAND(FOR_EACH_(COUNT(__VA_ARGS__), prim, __VA_ARGS__))
4646
4747#define CREATE_INSTANCE (Type, ...) std::make_shared<Type>(__VA_ARGS__),
48- #define GET_INSTANCE (Type, ...) cldnn::implementation_map<cldnn::Type>::get(__VA_ARGS__)
48+ #define GET_INSTANCE (Type, ...) cldnn::implementation_map<cldnn::Type>::get(__VA_ARGS__)
4949
5050#define OV_GPU_GET_INSTANCE_1 (prim, impl_type, shape_types ) GET_INSTANCE(prim, impl_type, shape_types),
51- #define OV_GPU_GET_INSTANCE_2 (prim, impl_type, shape_types, verify_callback ) \
52- std::make_shared<cldnn::ImplementationManagerLegacy<cldnn::prim>>( \
53- std::dynamic_pointer_cast<cldnn::ImplementationManagerLegacy<cldnn::prim>>(GET_INSTANCE(prim, impl_type, shape_types)).get(), \
54- verify_callback),
51+ #define OV_GPU_GET_INSTANCE_2 (prim, impl_type, shape_types, verify_callback ) \
52+ std::make_shared<cldnn::ImplementationManagerLegacy<cldnn::prim>>( \
53+ std::dynamic_pointer_cast<cldnn::ImplementationManagerLegacy<cldnn::prim>>(GET_INSTANCE(prim, impl_type, shape_types)).get(), verify_callback),
5554
5655#define SELECT (N, ...) EXPAND(CAT(OV_GPU_GET_INSTANCE_, N)(__VA_ARGS__))
5756
7473#endif
7574
7675#if OV_GPU_WITH_OCL
77- # define OV_GPU_CREATE_INSTANCE_OCL (...) EXPAND(CREATE_INSTANCE(__VA_ARGS__))
76+ # define OV_GPU_CREATE_INSTANCE_OCL (...) EXPAND(CREATE_INSTANCE(__VA_ARGS__))
7877# define OV_GPU_GET_INSTANCE_OCL (prim, ...) EXPAND(SELECT(COUNT(__VA_ARGS__), prim, impl_types::ocl, __VA_ARGS__))
7978#else
8079# define OV_GPU_CREATE_INSTANCE_OCL (...)
8180# define OV_GPU_GET_INSTANCE_OCL (...)
8281#endif
8382
8483#if OV_GPU_WITH_COMMON
85- # define OV_GPU_CREATE_INSTANCE_COMMON (...) EXPAND(CREATE_INSTANCE(__VA_ARGS__))
84+ # define OV_GPU_CREATE_INSTANCE_COMMON (...) EXPAND(CREATE_INSTANCE(__VA_ARGS__))
8685# define OV_GPU_GET_INSTANCE_COMMON (prim, ...) EXPAND(SELECT(COUNT(__VA_ARGS__), prim, impl_types::ocl, __VA_ARGS__))
8786#else
8887# define OV_GPU_CREATE_INSTANCE_COMMON (...)
9594# define OV_GPU_GET_INSTANCE_CPU (...)
9695#endif
9796
98- #define REGISTER_DEFAULT_IMPLS (prim, ...) \
99- namespace cldnn { \
100- struct prim ; \
101- } \
102- template <> \
103- struct ov ::intel_gpu::Registry<cldnn::prim> { \
104- static const std::vector<std::shared_ptr<cldnn::ImplementationManager>>& get_implementations () { \
105- static const std::vector<std::shared_ptr<cldnn::ImplementationManager>> impls = {INSTANTIATE (prim, __VA_ARGS__)}; \
106- return impls; \
107- } \
97+ #define REGISTER_DEFAULT_IMPLS (prim, ...) \
98+ namespace cldnn { struct prim ; } \
99+ template <> struct ov ::intel_gpu::Registry<cldnn::prim> { \
100+ static const std::vector<std::shared_ptr<cldnn::ImplementationManager>>& get_implementations () { \
101+ static const std::vector<std::shared_ptr<cldnn::ImplementationManager>> impls = { \
102+ INSTANTIATE (prim, __VA_ARGS__) \
103+ }; \
104+ return impls; \
105+ } \
108106 }
109107
110- #define REGISTER_IMPLS (prim ) \
111- namespace cldnn { \
112- struct prim ; \
113- } \
114- template <> \
115- struct ov ::intel_gpu::Registry<cldnn::prim> { \
108+ #define REGISTER_IMPLS (prim ) \
109+ namespace cldnn { struct prim ; } \
110+ template <> struct ov ::intel_gpu::Registry<cldnn::prim> { \
116111 static const std::vector<std::shared_ptr<cldnn::ImplementationManager>>& get_implementations (); \
117112 }
118113
@@ -121,7 +116,7 @@ namespace ov::intel_gpu {
121116// Global list of implementations for given primitive type
122117// List must be sorted by priority of implementations
123118// Same impls may repeat multiple times with different configurations
124- template <typename PrimitiveType>
119+ template <typename PrimitiveType>
125120struct Registry {
126121 static const std::vector<std::shared_ptr<cldnn::ImplementationManager>>& get_implementations () {
127122 static_assert (cldnn::meta::always_false<PrimitiveType>::value, " Only specialization instantiations are allowed" );
@@ -165,6 +160,7 @@ REGISTER_IMPLS(scaled_dot_product_attention);
165160REGISTER_IMPLS (scatter_update);
166161REGISTER_IMPLS (scatter_elements_update);
167162REGISTER_IMPLS (scatter_nd_update);
163+ REGISTER_IMPLS (slice_scatter);
168164REGISTER_IMPLS (softmax);
169165REGISTER_IMPLS (shape_of);
170166REGISTER_IMPLS (strided_slice);
@@ -224,7 +220,6 @@ REGISTER_DEFAULT_IMPLS(roi_pooling, OCL_S);
224220REGISTER_DEFAULT_IMPLS (roll, OCL_S);
225221REGISTER_DEFAULT_IMPLS (shuffle_channels, OCL_S);
226222REGISTER_DEFAULT_IMPLS (slice, OCL_S, OCL_D);
227- REGISTER_IMPLS (slice_scatter);
228223REGISTER_DEFAULT_IMPLS (space_to_batch, OCL_S);
229224REGISTER_DEFAULT_IMPLS (space_to_depth, OCL_S);
230225REGISTER_DEFAULT_IMPLS (swiglu, OCL_S, OCL_D);
0 commit comments