Skip to content
210 changes: 169 additions & 41 deletions backends/mlu/kernels/slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,164 @@

namespace custom_kernel {

/**
* @brief Normalizes the slice interval [st, ed) with a given step and dimension
* size.
*
* This function adjusts the interval [st, ed) to fit within the bounds defined
* by the dimension size, taking into account the specified step. It handles
* both positive and negative steps and accounts for negative indices by
* converting them to equivalent positive indices within the dimension size.
*
* @tparam T The data type of the input parameters, which can be an integer or
* floating-point type.
* @param st The starting index of the interval.
* @param ed The ending index of the interval (exclusive).
* @param step The step size for iterating through the interval, which can be
* positive or negative.
* @param dim_size The size of the dimension, serving as the upper bound for
* valid indices.
* @param st_out Pointer to store the normalized starting index.
* @param ed_out Pointer to store the normalized ending index.
* @param zero_dim_out Pointer to a boolean flag that is set to true if the
* resulting interval is empty.
*
* @details
* - If `step > 0`, the function ensures that `st` and `ed` are adjusted to be
* within the range [0, dim_size).
* - If `step < 0`, the function adjusts `st` and `ed` to accommodate the
* reverse traversal of the interval.
* - Handles special cases where `st` and `ed` may be out of bounds or where
* `dim_size` is zero.
* - Uses pointer parameters for output to modify the values directly.
* - The function also handles scenarios involving negative indices, converting
* them appropriately.
*
* @example
* T st_out, ed_out;
* bool zero_dim;
* normalize_interval(-3, -2, 1, 4, &st_out, &ed_out, &zero_dim);
* // Results in: st_out = 1, ed_out = 2, zero_dim = false
*
* @note The function assumes that the pointers provided for output parameters
* are valid and non-null.
*/
template <typename T>
void normalize_interval(
T st, T ed, T step, T dim_size, T* st_out, T* ed_out, bool* zero_dim_out) {
/* Normalize slice interval [st, ed) with given step and dim_size.
e.g. if given st = -3, ed = -2, step = 1, dim_size = 4,
then normalized st_out = 1(-3+4), st_ed = 2(-2+4).

This function is general enough and applicable
for both step > 0 and step < 0 scenarios.

Indicices dipicted as below:

===============================================================
| 0 1 2 3 ... D-1 | D D+1 ...
... -D-2 -D-1 | -D -D+1 -D+2 -D+3 ... -1 |
===============================================================
*/
// 0 dim size, just return
if (dim_size <= 0) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

if (step > 0) {
/* positive step */
// 0 dim size case 1
if (st >= dim_size) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

// 0 dim size case 2
if (ed <= -dim_size) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

// make st belongs: (-inf, -D-1)∪[0, D)
if (-dim_size <= st && st < 0) {
st += dim_size;
}
// make st belongs: [0, D)
st = std::max(st, static_cast<T>(0));

// make ed belongs: [0, +inf)
if (-dim_size <= ed && ed < 0) {
ed += dim_size;
}
// make ed belongs: [0, D]
ed = std::min(ed, dim_size);

// 0 dim size case 3
if (st >= ed) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}
*st_out = st;
*ed_out = ed;
return;

} else {
/* negative step */
// 0 dim size case 1
if (st <= -dim_size - 1) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

// 0 dim size case 2
if (ed >= dim_size - 1) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

// make st belongs: [0, D)∪[0, +inf)
if (-dim_size <= st && st < 0) {
st += dim_size;
}
// make st belongs: [0, D)
st = std::min(st, dim_size - 1);

// make ed belongs: [-inf, -D)∪[0, D)
if (-dim_size <= ed && ed < 0) {
ed += dim_size;
}
// make ed belongs: [-D-1, -D)∪[0, D) ==> {-D-1}∪[0, D)
ed = std::max(ed, -dim_size - 1);

if (ed == -dim_size - 1) {
// When ed=-D-1, it is symmetrical to when step is greater than 0 and
// ed=D.
*st_out = st;
*ed_out = ed;
return;
}

// now only remain the case that ed belongs to: [0, D)
// 0 dim size case 3
if (ed >= st) {
*st_out = *ed_out = 0;
*zero_dim_out = true;
return;
}

*st_out = st;
*ed_out = ed;
return;
}
}

void UpdateAttr(const phi::DDim& in_dims,
const std::vector<int> axes,
const std::vector<int> starts,
Expand Down Expand Up @@ -76,47 +234,17 @@ inline void CheckAndUpdateSliceAttrs(const phi::DDim in_dims,

if (dim_value > 0) {
T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE(
step,
0,
phi::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));

T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
start = std::max(start, static_cast<T>(0));

T end =
0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
end = std::min(end, dim_value);

if (step > 0) {
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GE(
end,
start,
phi::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end,
start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to
// dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std::min(start, dim_value - 1);
if (end < -1) {
end += dim_value;
}
end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GE(
start,
end,
phi::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start,
end));
T start, end;
bool dummy_zero_out_dim = false;
normalize_interval((*starts)[i],
(*ends)[i],
step,
dim_value,
&start,
&end,
&dummy_zero_out_dim);
if (end == -dim_value - 1) {
end = -1;
}

(*starts)[i] = start;
Expand Down
5 changes: 0 additions & 5 deletions backends/mlu/tests/unittests/test_multinomial_op_mlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,6 @@ def test_dim_less_than_1():

self.assertRaises(ValueError, test_dim_less_than_1)

with self.assertRaises(ValueError):
prob = paddle.rand([20, 1000])
prob[1:0] = 0
out = paddle.multinomial(prob)


if __name__ == "__main__":
unittest.main()
5 changes: 0 additions & 5 deletions backends/npu/tests/unittests/test_multinomial_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,6 @@ def test_dim_less_than_1():

self.assertRaises(ValueError, test_dim_less_than_1)

with self.assertRaises(ValueError):
prob = paddle.rand([20, 1000])
prob[1:0] = 0
out = paddle.multinomial(prob)


if __name__ == "__main__":
unittest.main()