Skip to content
2 changes: 1 addition & 1 deletion backends/mlu/kernels/slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ inline void CheckAndUpdateSliceAttrs(const phi::DDim in_dims,
&start,
&end,
&dummy_zero_out_dim);
if (end == -dim_value - 1) {
if (step < 0 && end == -dim_value - 1) {
end = -1;
}

Expand Down
75 changes: 34 additions & 41 deletions backends/mlu/kernels/strided_slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "kernels/funcs/mlu_baseop.h"
#include "kernels/funcs/mlu_funcs.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"

namespace custom_kernel {
static void StridedSliceOutDims(const std::vector<int64_t>& starts,
Expand All @@ -40,6 +41,8 @@ static void StridedSliceOutDims(const std::vector<int64_t>& starts,
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
start_index = in_dims[axes_index] - 1;
end_index = in_dims[axes_index];
}
}
if (decrease_axis_affect) {
Expand All @@ -61,39 +64,27 @@ static void StridedSliceOutDims(const std::vector<int64_t>& starts,
continue;
}

if (start_index < 0) {
start_index = start_index + axis_size;
start_index = std::max<int64_t>(start_index, 0);
}
if (end_index < 0) {
if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition
end_index = end_index + axis_size;
if (end_index < 0) {
end_index = 0;
}
}
bool neg_dim_condition = false;
phi::funcs::normalize_interval(start_index,
end_index,
stride_index,
axis_size,
&start_index,
&end_index,
&neg_dim_condition);
if (stride_index < 0 && end_index == -axis_size - 1) {
end_index = -1;
}

if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
int64_t out_dims_index;
if (neg_dim_condition) {
out_dims_index = 0;
} else {
int64_t step_size = std::abs(stride_index);
out_dims_index =
(std::abs(end_index - start_index) + step_size - 1) / step_size;
}

bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) ||
(stride_index > 0 && (start_index > end_index)));
PADDLE_ENFORCE_EQ(neg_dim_condition,
false,
phi::errors::InvalidArgument(
"The start index and end index are invalid for their "
"corresponding stride."));

int64_t left =
std::max(static_cast<int64_t>(0), std::min(start_index, end_index));
int64_t right = std::min(axis_size, std::max(start_index, end_index));
int64_t step = std::abs(stride_index);

auto out_dims_index = (std::abs(right - left) + step - 1) / step;

out_dims_vector[axes_index] = out_dims_index;
}
}
Expand Down Expand Up @@ -122,21 +113,23 @@ static void StridedSliceFunctor(int64_t* starts,
decrease_axis.begin(), decrease_axis.end(), axes[axis_index]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
starts[axis_index] = axis_size - 1;
ends[axis_index] = axis_size;
}
}
// stride must not be zero
if (starts[axis_index] < 0) {
starts[axis_index] = starts[axis_index] + axis_size;
starts[axis_index] = std::max<int64_t>(starts[axis_index], 0);
}
if (ends[axis_index] < 0) {
if (!(ends[axis_index] == -1 &&
strides[axis_index] < 0)) { // skip None stop condition
ends[axis_index] = ends[axis_index] + axis_size;
if (ends[axis_index] < 0) {
ends[axis_index] = 0;
}
}
bool dummy_zero_dim_out = false;
phi::funcs::normalize_interval(starts[axis_index],
ends[axis_index],
strides[axis_index],
axis_size,
&starts[axis_index],
&ends[axis_index],
&dummy_zero_dim_out);
if (strides[axis_index] < 0 && ends[axis_index] == -axis_size - 1) {
// manually set the end to -1 when step < 0,
// which indicates that it can extend to the left endpoint.
ends[axis_index] = -1;
}
if (decrease_axis_affect) {
if (strides[axis_index] < 0) {
Expand Down
1 change: 0 additions & 1 deletion backends/mlu/tools/disable_ut_mlu
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@ test_rms_norm_op_mlu
test_sync_batch_norm_op_mlu
test_unsqueeze_op_mlu
test_LeNet_MNIST
test_strided_slice_op_mlu