@@ -1178,22 +1178,25 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f
11781178 h_im < IH);
11791179 }
11801180 if (!skip_compute) {
1181- // if (h_im >= 0 && w_im >= 0 && h_im < IH && w_im < IW) {
1182- const int cur_height = IH - h_in;
1183- const int cur_width = IW - w_in;
1184- int h_low = std::max (static_cast <int >(floorf (map_h)), 0 );
1185- int w_low = std::max (static_cast <int >(floorf (map_w)), 0 );
1186- int h_high = with_bi_pad ? h_low + 1 : std::min (static_cast <int >(ceilf (map_h)), cur_height - 1 );
1187- int w_high = with_bi_pad ? w_low + 1 : std::min (static_cast <int >(ceilf (map_w)), cur_width - 1 );
1181+ const int cur_h_end = IH - h_in;
1182+ const int cur_w_end = IW - w_in;
1183+ int h_low = with_bi_pad ? static_cast <int >(floorf (map_h)) :
1184+ std::max (static_cast <int >(floorf (map_h)), 0 );
1185+ int w_low = with_bi_pad ? static_cast <int >(floorf (map_w)) :
1186+ std::max (static_cast <int >(floorf (map_w)), 0 );
1187+ const int cur_h_start = h_low + h_in;
1188+ const int cur_w_start = w_low + w_in;
1189+ int h_high = with_bi_pad ? h_low + 1 : std::min (static_cast <int >(ceilf (map_h)), cur_h_end - 1 );
1190+ int w_high = with_bi_pad ? w_low + 1 : std::min (static_cast <int >(ceilf (map_w)), cur_w_end - 1 );
11881191
11891192 float lh = map_h - h_low;
11901193 float lw = map_w - w_low;
11911194 float hh = 1 - lh, hw = 1 - lw;
11921195
1193- float v1 = (w_low >= 0 && h_low >= 0 ) ? data_im_ptr[h_low * src_strides[2 ] + w_low * src_strides[3 ]] : 0 .0f ;
1194- float v2 = (w_high < cur_width && h_low >= 0 ) ? data_im_ptr[h_low * src_strides[2 ] + w_high * src_strides[3 ]] : 0 .0f ;
1195- float v3 = (w_low >= 0 && h_high < cur_height ) ? data_im_ptr[h_high * src_strides[2 ] + w_low * src_strides[3 ]] : 0 .0f ;
1196- float v4 = (w_high < cur_width && h_high < cur_height ) ? data_im_ptr[h_high * src_strides[2 ] + w_high * src_strides[3 ]] : 0 .0f ;
1196+ float v1 = (cur_w_start >= 0 && cur_h_start >= 0 ) ? data_im_ptr[h_low * src_strides[2 ] + w_low * src_strides[3 ]] : 0 .0f ;
1197+ float v2 = (w_high < cur_w_end && cur_h_start >= 0 ) ? data_im_ptr[h_low * src_strides[2 ] + w_high * src_strides[3 ]] : 0 .0f ;
1198+ float v3 = (cur_w_start >= 0 && h_high < cur_h_end ) ? data_im_ptr[h_high * src_strides[2 ] + w_low * src_strides[3 ]] : 0 .0f ;
1199+ float v4 = (w_high < cur_w_end && h_high < cur_h_end ) ? data_im_ptr[h_high * src_strides[2 ] + w_high * src_strides[3 ]] : 0 .0f ;
11971200 float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
11981201
11991202 float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
0 commit comments