77#include " executors/x64/interpolate.hpp"
88#include " executors/common/interpolate.hpp"
99
10- #include < cpu/x64/xbyak/xbyak.h>
1110
12- #include < algorithm>
1311#include < cassert>
14- #include < cmath>
1512#include < common/c_types_map.hpp>
16- #include < common/primitive_attr.hpp>
1713#include < common/primitive_hashing_utils.hpp>
18- #include < common/utils.hpp>
1914#include < cpu/x64/cpu_isa_traits.hpp>
2015#include < cstddef>
2116#include < cstdint>
2419#include < oneapi/dnnl/dnnl.hpp>
2520#include < oneapi/dnnl/dnnl_common.hpp>
2621#include < string>
27- #include < unordered_map>
2822#include < utility>
2923#include < vector>
3024
3125#include " common/cpu_memcpy.h"
32- #include " cpu/x64/injectors/jit_uni_depthwise_injector.hpp"
3326#include " cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
34- #include " cpu/x64/injectors/jit_uni_quantization_injector.hpp"
3527#include " cpu/x64/jit_generator.hpp"
3628#include " cpu_types.h"
3729#include " dnnl_extension_utils.h"
3830#include " eltwise.h"
39- #include " emitters/plugin/x64/jit_emitter.hpp"
4031#include " emitters/plugin/x64/jit_load_store_emitters.hpp"
4132#include " fake_quantize.h"
4233#include " graph_context.h"
4334#include " memory_desc/cpu_memory_desc.h"
4435#include " node.h"
4536#include " nodes/common/blocked_desc_creator.h"
4637#include " nodes/executors/executor.hpp"
47- #include " nodes/executors/interpolate.hpp"
4838#include " nodes/executors/interpolate_list.hpp"
4939#include " nodes/node_config.h"
5040#include " onednn/iml_type_mapper.h"
5848#include " openvino/op/interpolate.hpp"
5949#include " shape_inference/shape_inference.hpp"
6050#include " shape_inference/shape_inference_cpu.hpp"
61- #include " utils/bfloat16.hpp"
6251#include " utils/general_utils.h"
6352#include " utils/ngraph_utils.hpp"
6453#include " utils/precision_support.h"
@@ -115,9 +104,15 @@ using ngInterpShapeCalcMode = ov::op::v4::Interpolate::ShapeCalcMode;
115104
116105bool Interpolate::isSupportedOperation (const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
117106 try {
107+ constexpr size_t DATA_ID = 0 ;
108+ constexpr size_t SCALES_ID = 2 ;
109+ constexpr size_t AXES_ID = 3 ;
110+ constexpr size_t SIZE_OR_SCALE_ID_V11 = 1 ;
111+ constexpr size_t AXES_ID_V11 = 2 ;
112+
118113 if (const auto interp = ov::as_type_ptr<const ov::op::v4::Interpolate>(op)) {
119- const auto & interpAttr = interp->get_attrs ();
120- const auto & interpMode = interpAttr .mode ;
114+ const auto & tmpInterpAttr = interp->get_attrs ();
115+ const auto & interpMode = tmpInterpAttr .mode ;
121116 if (!one_of (interpMode,
122117 ngInterpMode::NEAREST,
123118 ngInterpMode::LINEAR,
@@ -127,7 +122,7 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
127122 return false ;
128123 }
129124
130- const auto & interpCoordTransMode = interpAttr .coordinate_transformation_mode ;
125+ const auto & interpCoordTransMode = tmpInterpAttr .coordinate_transformation_mode ;
131126 if (!one_of (interpCoordTransMode,
132127 ngInterpCoordTransf::HALF_PIXEL,
133128 ngInterpCoordTransf::PYTORCH_HALF_PIXEL,
@@ -140,7 +135,7 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
140135 }
141136
142137 if (interpMode == ngInterpMode::NEAREST) {
143- const auto & interpNearestMode = interpAttr .nearest_mode ;
138+ const auto & interpNearestMode = tmpInterpAttr .nearest_mode ;
144139 if (!one_of (interpNearestMode,
145140 ngInterpNearMode::ROUND_PREFER_FLOOR,
146141 ngInterpNearMode::ROUND_PREFER_CEIL,
@@ -153,7 +148,7 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
153148 }
154149 }
155150
156- const auto & interpShapeCalcMode = interpAttr .shape_calculation_mode ;
151+ const auto & interpShapeCalcMode = tmpInterpAttr .shape_calculation_mode ;
157152 if (!one_of (interpShapeCalcMode, ngInterpShapeCalcMode::SCALES, ngInterpShapeCalcMode::SIZES)) {
158153 errorMessage =
159154 " Interpolate-4 does not support shape_calculation_mode: " + ov::as_string (interpShapeCalcMode);
@@ -183,20 +178,20 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
183178 errorMessage = " Only const 'axes' input is supported in Interpolate-4" ;
184179 return false ;
185180 }
186- } else if (const auto interp = ov::as_type_ptr<const ov::op::v11::Interpolate>(op)) {
187- const auto & interpAttr = interp ->get_attrs ();
188- const auto & interpMode = interpAttr .mode ;
181+ } else if (const auto interp_v11 = ov::as_type_ptr<const ov::op::v11::Interpolate>(op)) {
182+ const auto & tmpInterpAttr = interp_v11 ->get_attrs ();
183+ const auto & interpMode = tmpInterpAttr .mode ;
189184 if (!one_of (interpMode, ngInterpMode::BILINEAR_PILLOW, ngInterpMode::BICUBIC_PILLOW)) {
190185 errorMessage = " Interpolate-11 does not support interpolate mode: " + ov::as_string (interpMode);
191186 return false ;
192187 }
193- const auto & interpShapeCalcMode = interpAttr .shape_calculation_mode ;
188+ const auto & interpShapeCalcMode = tmpInterpAttr .shape_calculation_mode ;
194189 if (!one_of (interpShapeCalcMode, ngInterpShapeCalcMode::SCALES, ngInterpShapeCalcMode::SIZES)) {
195190 errorMessage =
196191 " Interpolate-11 does not support shape_calculation_mode: " + ov::as_string (interpShapeCalcMode);
197192 return false ;
198193 }
199- const size_t dataRank = interp ->get_input_partial_shape (DATA_ID).rank ().get_length ();
194+ const size_t dataRank = interp_v11 ->get_input_partial_shape (DATA_ID).rank ().get_length ();
200195 if (dataRank < 2 || dataRank > 4 ) {
201196 // pillow only resize on H and W. resize on D(depth) is not defined.
202197 errorMessage = " Interpolate-11 does not support input tensor of rank : " + std::to_string (dataRank);
@@ -207,8 +202,8 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
207202 errorMessage = " Only const 'scales_or_sizes' input is supported for static shapes in Interpolate-11" ;
208203 return false ;
209204 }
210- if (interp ->get_input_size () > 2 && ov::as_type_ptr<const ov::op::v0::Constant>(
211- interp ->get_input_node_shared_ptr (AXES_ID_V11)) == nullptr ) {
205+ if (interp_v11 ->get_input_size () > 2 && ov::as_type_ptr<const ov::op::v0::Constant>(
206+ interp_v11 ->get_input_node_shared_ptr (AXES_ID_V11)) == nullptr ) {
212207 errorMessage = " Only const 'axes' input is supported in Interpolate-11" ;
213208 return false ;
214209 }
@@ -257,8 +252,8 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
257252 : Node(op, context, InterpolateShapeInferFactory(op)) {
258253 std::string errorMessage;
259254 if (isSupportedOperation (op, errorMessage)) {
260- dataRank = getInputShapeAtPort (DATA_ID).getRank ();
261- if (const auto interp = ov::as_type_ptr<const ov::op::v4::Interpolate>(op)) {
255+ dataRank = getInputShapeAtPort (interpAttrs. DATA_ID ).getRank ();
256+ if (const auto interp_v4 = ov::as_type_ptr<const ov::op::v4::Interpolate>(op)) {
262257 is_version11 = false ;
263258 const auto numInputs = inputShapes.size ();
264259 if (numInputs != 3 && numInputs != 4 ) {
@@ -269,7 +264,7 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
269264 }
270265 isAxesSpecified = numInputs != 3 ;
271266
272- const auto & interpAttr = interp ->get_attrs ();
267+ const auto & interpAttr = interp_v4 ->get_attrs ();
273268
274269 const auto & interpMode = interpAttr.mode ;
275270 if (interpMode == ngInterpMode::NEAREST) {
@@ -351,14 +346,14 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
351346 }
352347
353348 const auto scalesNode =
354- ov::as_type_ptr<const ov::op::v0::Constant>(interp ->get_input_node_shared_ptr (SCALES_ID));
349+ ov::as_type_ptr<const ov::op::v0::Constant>(interp_v4 ->get_input_node_shared_ptr (interpAttrs. SCALES_ID ));
355350 if (scalesNode) {
356351 scales = scalesNode->cast_vector <float >();
357352 isScaleConstant = true ;
358353 }
359354
360355 if (isAxesSpecified) {
361- axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp ->get_input_node_shared_ptr (AXES_ID))
356+ axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp_v4 ->get_input_node_shared_ptr (interpAttrs. AXES_ID ))
362357 ->cast_vector <int >();
363358 } else {
364359 axes.resize (dataRank);
@@ -396,7 +391,7 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
396391 if (interpShapeCalcMode == ngInterpShapeCalcMode::SCALES) {
397392 interpAttrs.shapeCalcMode = InterpolateShapeCalcMode::scales;
398393 const auto scalesNode = ov::as_type_ptr<const ov::op::v0::Constant>(
399- interp->get_input_node_shared_ptr (SIZE_OR_SCALE_ID_V11));
394+ interp->get_input_node_shared_ptr (interpAttrs. SIZE_OR_SCALE_ID_V11 ));
400395 if (scalesNode) {
401396 scales = scalesNode->cast_vector <float >();
402397 isScaleConstant = true ;
@@ -426,7 +421,7 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
426421 }
427422
428423 if (isAxesSpecified) {
429- axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr (AXES_ID_V11))
424+ axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr (interpAttrs. AXES_ID_V11 ))
430425 ->cast_vector <int >();
431426 if (dataRank == 4 && axes.size () == 2 && axes[0 ] == 1 && axes[1 ] == 2 ) {
432427 interpAttrs.NCHWAsNHWC = true ;
@@ -496,7 +491,7 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
496491 return ;
497492 }
498493
499- ov::element::Type inputPrecision = getOriginalInputPrecisionAtPort (DATA_ID);
494+ ov::element::Type inputPrecision = getOriginalInputPrecisionAtPort (interpAttrs. DATA_ID );
500495
501496#if defined(OV_CPU_WITH_ACL)
502497 bool isInputPrecisionSupported = one_of (inputPrecision, ov::element::i8 , ov::element::u8 , ov::element::f16 );
@@ -519,7 +514,7 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
519514 ov::element::Type outputPrecision = inputPrecision;
520515
521516 if (!fusedWith.empty ()) {
522- outputPrecision = fusedWith[fusedWith.size () - 1 ]->getOriginalOutputPrecisionAtPort (DATA_ID);
517+ outputPrecision = fusedWith[fusedWith.size () - 1 ]->getOriginalOutputPrecisionAtPort (interpAttrs. DATA_ID );
523518 }
524519
525520#if !defined(OV_CPU_WITH_ACL)
@@ -550,29 +545,29 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
550545 auto & creatorsMap = BlockedDescCreator::getCommonCreators ();
551546 auto pushDesc = [&](LayoutType dataFormat,
552547 impl_desc_type implDetail,
553- bool is_version11 ,
548+ bool is_version11_desc ,
554549 bool useAclExecutor = false ) {
555- config.inConfs [DATA_ID].setMemDesc (
556- creatorsMap.at (dataFormat)->createSharedDesc (inputPrecision, getInputShapeAtPort (DATA_ID)));
557- if (is_version11 ) {
550+ config.inConfs [interpAttrs. DATA_ID ].setMemDesc (
551+ creatorsMap.at (dataFormat)->createSharedDesc (inputPrecision, getInputShapeAtPort (interpAttrs. DATA_ID )));
552+ if (is_version11_desc ) {
558553 if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes) {
559- config.inConfs [SIZE_OR_SCALE_ID_V11].setMemDesc (
554+ config.inConfs [interpAttrs. SIZE_OR_SCALE_ID_V11 ].setMemDesc (
560555 creatorsMap.at (LayoutType::ncsp)
561- ->createSharedDesc (targetShapeType, getInputShapeAtPort (SIZE_OR_SCALE_ID_V11)));
556+ ->createSharedDesc (targetShapeType, getInputShapeAtPort (interpAttrs. SIZE_OR_SCALE_ID_V11 )));
562557 } else {
563- config.inConfs [SIZE_OR_SCALE_ID_V11].setMemDesc (
558+ config.inConfs [interpAttrs. SIZE_OR_SCALE_ID_V11 ].setMemDesc (
564559 creatorsMap.at (LayoutType::ncsp)
565- ->createSharedDesc (scalesType, getInputShapeAtPort (SIZE_OR_SCALE_ID_V11)));
560+ ->createSharedDesc (scalesType, getInputShapeAtPort (interpAttrs. SIZE_OR_SCALE_ID_V11 )));
566561 }
567562
568563 if (isAxesSpecified) {
569- config.inConfs [AXES_ID_V11].setMemDesc (
570- creatorsMap.at (LayoutType::ncsp)->createSharedDesc (axesType, getInputShapeAtPort (AXES_ID_V11)));
564+ config.inConfs [interpAttrs. AXES_ID_V11 ].setMemDesc (
565+ creatorsMap.at (LayoutType::ncsp)->createSharedDesc (axesType, getInputShapeAtPort (interpAttrs. AXES_ID_V11 )));
571566 }
572567 } else {
573- config.inConfs [TARGET_SHAPE_ID].setMemDesc (
568+ config.inConfs [interpAttrs. TARGET_SHAPE_ID ].setMemDesc (
574569 creatorsMap.at (LayoutType::ncsp)
575- ->createSharedDesc (targetShapeType, getInputShapeAtPort (TARGET_SHAPE_ID)));
570+ ->createSharedDesc (targetShapeType, getInputShapeAtPort (interpAttrs. TARGET_SHAPE_ID )));
576571 config.inConfs [get_scale_id ()].setMemDesc (
577572 creatorsMap.at (LayoutType::ncsp)->createSharedDesc (scalesType, getInputShapeAtPort (get_scale_id ())));
578573
@@ -644,7 +639,7 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
644639 }
645640 pushDesc (LayoutType::ncsp, ref, true );
646641 } else {
647- const auto & dataMinDims = getInputShapeAtPort (DATA_ID).getMinDims ();
642+ const auto & dataMinDims = getInputShapeAtPort (interpAttrs. DATA_ID ).getMinDims ();
648643 bool isBlkApplied = dataRank > 1 && dataMinDims[1 ] != Shape::UNDEFINED_DIM && dataMinDims[1 ] > 1 ;
649644
650645#if defined(OV_CPU_WITH_ACL)
@@ -703,17 +698,17 @@ bool Interpolate::needShapeInfer() const {
703698 if (lastScales.empty ()) {
704699 return true ;
705700 }
706- const auto * scales = getSrcDataAtPortAs<const float >(get_scale_id ());
701+ const auto * scales_inf = getSrcDataAtPortAs<const float >(get_scale_id ());
707702 for (size_t i = 0 ; i < lastScales.size (); i++) {
708- if (lastScales[i] != scales [i]) {
703+ if (lastScales[i] != scales_inf [i]) {
709704 return true ;
710705 }
711706 }
712707 } else {
713708 if (lastSizes.empty ()) {
714709 return true ;
715710 }
716- const auto * sizes = getSrcDataAtPortAs<const int32_t >(TARGET_SHAPE_ID);
711+ const auto * sizes = getSrcDataAtPortAs<const int32_t >(interpAttrs. TARGET_SHAPE_ID );
717712 for (size_t i = 0 ; i < lastSizes.size (); i++) {
718713 if (sizes[i] != lastSizes[i]) {
719714 return true ;
@@ -726,11 +721,11 @@ bool Interpolate::needShapeInfer() const {
726721void Interpolate::executeDynamicImpl (const dnnl::stream& strm) {
727722 execute (strm);
728723
729- const size_t port = interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes ? TARGET_SHAPE_ID : get_scale_id ();
724+ const size_t port = interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes ? interpAttrs. TARGET_SHAPE_ID : get_scale_id ();
730725 const auto & memory = getParentEdgeAt (port)->getMemory ();
731726 if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::scales) {
732- const auto * scales = memory.getDataAs <const float >();
733- lastScales.assign (scales, scales + memory.getDesc ().getShape ().getElementsCount ());
727+ const auto * scales_dyn = memory.getDataAs <const float >();
728+ lastScales.assign (scales_dyn, scales_dyn + memory.getDesc ().getShape ().getElementsCount ());
734729 } else {
735730 const auto * sizes = memory.getDataAs <const int32_t >();
736731 lastSizes.assign (sizes, sizes + memory.getDesc ().getShape ().getElementsCount ());
@@ -743,15 +738,15 @@ bool Interpolate::needPrepareParams() const {
743738
744739inline int Interpolate::get_scale_id () const {
745740 if (is_version11) {
746- return SIZE_OR_SCALE_ID_V11;
741+ return interpAttrs. SIZE_OR_SCALE_ID_V11 ;
747742 }
748- return SCALES_ID;
743+ return interpAttrs. SCALES_ID ;
749744}
750745inline int Interpolate::get_axis_id () const {
751746 if (is_version11) {
752- return AXES_ID_V11;
747+ return interpAttrs. AXES_ID_V11 ;
753748 }
754- return AXES_ID;
749+ return interpAttrs. AXES_ID ;
755750}
756751
757752void Interpolate::prepareParams () {
@@ -764,13 +759,13 @@ void Interpolate::prepareParams() {
764759 THROW_CPU_NODE_ERR (" has undefined destination memory" );
765760 }
766761
767- auto srcMemPtr = getSrcMemoryAtPort (DATA_ID);
762+ auto srcMemPtr = getSrcMemoryAtPort (interpAttrs. DATA_ID );
768763 if (!srcMemPtr || !srcMemPtr->isDefined ()) {
769764 THROW_CPU_NODE_ERR (" has undefined input memory" );
770765 }
771766
772767 if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes) {
773- auto tsMemPtr = getSrcMemoryAtPort (TARGET_SHAPE_ID);
768+ auto tsMemPtr = getSrcMemoryAtPort (interpAttrs. TARGET_SHAPE_ID );
774769 if (!tsMemPtr || !tsMemPtr->isDefined ()) {
775770 THROW_CPU_NODE_ERR (" has undefined target shape memory" );
776771 }
@@ -884,7 +879,7 @@ void Interpolate::prepareParams() {
884879}
885880
886881void Interpolate::createPrimitive () {
887- auto srcMemPtr = getSrcMemoryAtPort (DATA_ID);
882+ auto srcMemPtr = getSrcMemoryAtPort (interpAttrs. DATA_ID );
888883 auto dstMemPtr = getDstMemoryAtPort (0 );
889884 if (!srcMemPtr) {
890885 THROW_CPU_NODE_ERR (" has null input memory" );
@@ -978,7 +973,7 @@ std::vector<float> Interpolate::getScales(const VectorDims& srcDimPad, const Vec
978973
979974void Interpolate::execute ([[maybe_unused]] const dnnl::stream& strm) {
980975 auto dstMemPtr = getDstMemoryAtPort (0 );
981- auto srcMemPtr = getSrcMemoryAtPort (DATA_ID);
976+ auto srcMemPtr = getSrcMemoryAtPort (interpAttrs. DATA_ID );
982977
983978 if (execPtr) {
984979 auto * dst_data = dstMemPtr->getDataAs <uint8_t >();
0 commit comments