@@ -966,6 +966,12 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
966966 if (!supportedPrimitiveDescriptors.empty ())
967967 return ;
968968
969+ const int simd_w = mayiuse (cpu::x64::avx512_common) ? 16 : 8 ;
970+ if (group != 1 && (((getParentEdgeAt (0 )->getDims ()[1 ] / group) % simd_w != 0 )
971+ || ((getChildEdgeAt (0 )->getDims ()[1 ] / group) % simd_w != 0 ))) {
972+ enforceRef = true ;
973+ }
974+
969975 size_t inputsNumber = getOriginalInputsNumber ();
970976 InferenceEngine::LayerConfig config;
971977 config.dynBatchSupport = false ;
@@ -986,19 +992,20 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
986992 config.outConfs [0 ].inPlace = -1 ;
987993
988994 impl_desc_type impl_type;
989- // if (mayiuse(cpu::x64::avx512_common)) {
990- // impl_type = impl_desc_type::jit_avx512;
991- // } else if (mayiuse(cpu::x64::avx2)) {
992- // impl_type = impl_desc_type::jit_avx2;
993- // } else if (mayiuse(cpu::x64::sse41)) {
994- // impl_type = impl_desc_type::jit_sse42;
995- // } else {
996- // impl_type = impl_desc_type::ref;
997- // }
998- impl_type = impl_desc_type::ref;
999-
1000- if (false && mayiuse (cpu::x64::sse41)) {
1001- // optimzed implementation
995+ if (enforceRef) {
996+ impl_type = impl_desc_type::ref;
997+ } else if (mayiuse (cpu::x64::avx512_common)) {
998+ impl_type = impl_desc_type::jit_avx512;
999+ } else if (mayiuse (cpu::x64::avx2)) {
1000+ impl_type = impl_desc_type::jit_avx2;
1001+ } else if (mayiuse (cpu::x64::sse41)) {
1002+ impl_type = impl_desc_type::jit_sse42;
1003+ } else {
1004+ impl_type = impl_desc_type::ref;
1005+ }
1006+
1007+ if (!enforceRef && mayiuse (cpu::x64::sse41)) {
1008+ // optimized implementation
10021009 auto dataFormat = memory::format_tag::nhwc;
10031010 auto offFormat = memory::format_tag::nchw;
10041011 auto weiFormat = group > 1 ? mayiuse (avx512_common) ? memory::format_tag::gOIhw16i16o : memory::format_tag::gOIhw8i8o
@@ -1097,13 +1104,15 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() {
10971104
10981105 jcp.nthr = dnnl_get_max_threads ();
10991106
1100- // if (mayiuse(cpu::x64::avx512_common)) {
1101- // def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
1102- // } else if (mayiuse(cpu::x64::avx2)) {
1103- // def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
1104- // } else if (mayiuse(cpu::x64::sse41)) {
1105- // def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
1106- // }
1107+ if (enforceRef) {
1108+ return ;
1109+ } else if (mayiuse (cpu::x64::avx512_common)) {
1110+ def_conv_kernel.reset (new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
1111+ } else if (mayiuse (cpu::x64::avx2)) {
1112+ def_conv_kernel.reset (new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
1113+ } else if (mayiuse (cpu::x64::sse41)) {
1114+ def_conv_kernel.reset (new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
1115+ }
11071116
11081117 if (def_conv_kernel)
11091118 def_conv_kernel->create_ker ();
0 commit comments