@@ -895,9 +895,10 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
895895
896896bool MKLDNNDeformableConvolutionNode::isSupportedOperation (const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
897897 try {
898- const auto defConvNode = ngraph::as_type_ptr<const ngraph::op::v8::DeformableConvolution>(op);
899- if (!defConvNode) {
900- errorMessage = " Node is not an instance of DeformableConvolution form the operation set v8." ;
898+ if (!one_of (op->get_type_info (),
899+ ngraph::op::v1::DeformableConvolution::type_info,
900+ ngraph::op::v8::DeformableConvolution::type_info)) {
901+ errorMessage = " Node is not an instance of DeformableConvolution form the operation set v1 or v8." ;
901902 return false ;
902903 }
903904 } catch (...) {
@@ -913,22 +914,28 @@ MKLDNNDeformableConvolutionNode::MKLDNNDeformableConvolutionNode(const std::shar
913914 if (!isSupportedOperation (op, errorMessage)) {
914915 IE_THROW (NotImplemented) << errorMessage;
915916 }
916- auto defConvNode = ngraph::as_type_ptr< const ngraph::op::v8::DeformableConvolution >(op);
917+ auto defConvNodeBase = std::dynamic_pointer_cast< ngraph::op::util::DeformableConvolutionBase >(op);
917918
918- group = defConvNode->get_group ();
919- deformable_group = defConvNode->get_deformable_group ();
920- with_bilinear_pad = defConvNode->get_use_bilinear_interpolation_padding ();
921- auto & strides = defConvNode->get_strides ();
919+ group = defConvNodeBase->get_group ();
920+ deformable_group = defConvNodeBase->get_deformable_group ();
921+ auto & strides = defConvNodeBase->get_strides ();
922922 for (int i = 0 ; i < strides.size (); i++) {
923923 stride.push_back (strides[i]);
924924 }
925925
926- auto & dilations = defConvNode ->get_dilations ();
926+ auto & dilations = defConvNodeBase ->get_dilations ();
927927 for (int i = 1 ; i <= dilations.size (); i++) {
928928 dilation.push_back (dilations[dilations.size () - i] - 1 );
929929 }
930930
931- paddingL = defConvNode->get_pads_begin ();
931+ paddingL = defConvNodeBase->get_pads_begin ();
932+
933+ if (op->get_type_info () == ngraph::op::v8::DeformableConvolution::type_info) {
934+ auto defConvNode = std::dynamic_pointer_cast<ngraph::op::v8::DeformableConvolution>(op);
935+ with_bilinear_pad = defConvNode->get_bilinear_interpolation_pad ();
936+ } else {
937+ with_bilinear_pad = false ;
938+ }
932939}
933940
934941void MKLDNNDeformableConvolutionNode::getSupportedDescriptors () {
@@ -999,7 +1006,17 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
9991006
10001007 config.inConfs [0 ].desc = MKLDNNMemoryDesc (getParentEdgeAt (0 )->getDims (), memory::data_type::f32 , dataFormat);
10011008 config.inConfs [1 ].desc = MKLDNNMemoryDesc (getParentEdgeAt (1 )->getDims (), memory::data_type::f32 , offFormat);
1002- config.inConfs [2 ].desc = MKLDNNMemoryDesc (getParentEdgeAt (2 )->getDims (), memory::data_type::f32 , weiFormat);
1009+ auto & wDims = getParentEdgeAt (2 )->getDims ();
1010+ if (group > 1 && wDims.ndims () != 5 ) {
1011+ auto old_dims = wDims.ToSizeVector ();
1012+ auto new_dims = InferenceEngine::SizeVector ({group, div_up (old_dims[0 ], group)});
1013+ for (int i = 1 ; i < old_dims.size (); i++) {
1014+ new_dims.push_back (old_dims[i]);
1015+ }
1016+ config.inConfs [2 ].desc = MKLDNNMemoryDesc (MKLDNNDims (new_dims), memory::data_type::f32 , weiFormat);
1017+ } else {
1018+ config.inConfs [2 ].desc = MKLDNNMemoryDesc (getParentEdgeAt (2 )->getDims (), memory::data_type::f32 , weiFormat);
1019+ }
10031020 if (inputsNumber > 3 ) {
10041021 config.inConfs [3 ].desc = MKLDNNMemoryDesc (getParentEdgeAt (3 )->getDims (), memory::data_type::f32 , memory::format_tag::nchw);
10051022 }
0 commit comments