Skip to content

Commit f401352

Browse files
Remove redundant unitest and fix wrong endpoint computation of slice (#1505)
1 parent c24a59c commit f401352

File tree

3 files changed

+169
-51
lines changed

3 files changed

+169
-51
lines changed

backends/mlu/kernels/slice_kernel.cc

Lines changed: 169 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,164 @@
1717

1818
namespace custom_kernel {
1919

20+
/**
21+
* @brief Normalizes the slice interval [st, ed) with a given step and dimension
22+
* size.
23+
*
24+
* This function adjusts the interval [st, ed) to fit within the bounds defined
25+
* by the dimension size, taking into account the specified step. It handles
26+
* both positive and negative steps and accounts for negative indices by
27+
* converting them to equivalent positive indices within the dimension size.
28+
*
29+
* @tparam T The data type of the input parameters, which can be an integer or
30+
* floating-point type.
31+
* @param st The starting index of the interval.
32+
* @param ed The ending index of the interval (exclusive).
33+
* @param step The step size for iterating through the interval, which can be
34+
* positive or negative.
35+
* @param dim_size The size of the dimension, serving as the upper bound for
36+
* valid indices.
37+
* @param st_out Pointer to store the normalized starting index.
38+
* @param ed_out Pointer to store the normalized ending index.
39+
* @param zero_dim_out Pointer to a boolean flag that is set to true if the
40+
* resulting interval is empty.
41+
*
42+
* @details
43+
* - If `step > 0`, the function ensures that `st` and `ed` are adjusted to be
44+
* within the range [0, dim_size).
45+
* - If `step < 0`, the function adjusts `st` and `ed` to accommodate the
46+
* reverse traversal of the interval.
47+
* - Handles special cases where `st` and `ed` may be out of bounds or where
48+
* `dim_size` is zero.
49+
* - Uses pointer parameters for output to modify the values directly.
50+
* - The function also handles scenarios involving negative indices, converting
51+
* them appropriately.
52+
*
53+
* @example
54+
* T st_out, ed_out;
55+
* bool zero_dim;
56+
* normalize_interval(-3, -2, 1, 4, &st_out, &ed_out, &zero_dim);
57+
* // Results in: st_out = 1, ed_out = 2, zero_dim = false
58+
*
59+
* @note The function assumes that the pointers provided for output parameters
60+
* are valid and non-null.
61+
*/
62+
template <typename T>
63+
void normalize_interval(
64+
T st, T ed, T step, T dim_size, T* st_out, T* ed_out, bool* zero_dim_out) {
65+
/* Normalize slice interval [st, ed) with given step and dim_size.
66+
e.g. if given st = -3, ed = -2, step = 1, dim_size = 4,
67+
then normalized st_out = 1(-3+4), st_ed = 2(-2+4).
68+
69+
This function is general enough and applicable
70+
for both step > 0 and step < 0 scenarios.
71+
72+
Indicices dipicted as below:
73+
74+
===============================================================
75+
| 0 1 2 3 ... D-1 | D D+1 ...
76+
... -D-2 -D-1 | -D -D+1 -D+2 -D+3 ... -1 |
77+
===============================================================
78+
*/
79+
// 0 dim size, just return
80+
if (dim_size <= 0) {
81+
*st_out = *ed_out = 0;
82+
*zero_dim_out = true;
83+
return;
84+
}
85+
86+
if (step > 0) {
87+
/* positive step */
88+
// 0 dim size case 1
89+
if (st >= dim_size) {
90+
*st_out = *ed_out = 0;
91+
*zero_dim_out = true;
92+
return;
93+
}
94+
95+
// 0 dim size case 2
96+
if (ed <= -dim_size) {
97+
*st_out = *ed_out = 0;
98+
*zero_dim_out = true;
99+
return;
100+
}
101+
102+
// make st belongs: (-inf, -D-1)∪[0, D)
103+
if (-dim_size <= st && st < 0) {
104+
st += dim_size;
105+
}
106+
// make st belongs: [0, D)
107+
st = std::max(st, static_cast<T>(0));
108+
109+
// make ed belongs: [0, +inf)
110+
if (-dim_size <= ed && ed < 0) {
111+
ed += dim_size;
112+
}
113+
// make ed belongs: [0, D]
114+
ed = std::min(ed, dim_size);
115+
116+
// 0 dim size case 3
117+
if (st >= ed) {
118+
*st_out = *ed_out = 0;
119+
*zero_dim_out = true;
120+
return;
121+
}
122+
*st_out = st;
123+
*ed_out = ed;
124+
return;
125+
126+
} else {
127+
/* negative step */
128+
// 0 dim size case 1
129+
if (st <= -dim_size - 1) {
130+
*st_out = *ed_out = 0;
131+
*zero_dim_out = true;
132+
return;
133+
}
134+
135+
// 0 dim size case 2
136+
if (ed >= dim_size - 1) {
137+
*st_out = *ed_out = 0;
138+
*zero_dim_out = true;
139+
return;
140+
}
141+
142+
// make st belongs: [0, D)∪[0, +inf)
143+
if (-dim_size <= st && st < 0) {
144+
st += dim_size;
145+
}
146+
// make st belongs: [0, D)
147+
st = std::min(st, dim_size - 1);
148+
149+
// make ed belongs: [-inf, -D)∪[0, D)
150+
if (-dim_size <= ed && ed < 0) {
151+
ed += dim_size;
152+
}
153+
// make ed belongs: [-D-1, -D)∪[0, D) ==> {-D-1}∪[0, D)
154+
ed = std::max(ed, -dim_size - 1);
155+
156+
if (ed == -dim_size - 1) {
157+
// When ed=-D-1, it is symmetrical to when step is greater than 0 and
158+
// ed=D.
159+
*st_out = st;
160+
*ed_out = ed;
161+
return;
162+
}
163+
164+
// now only remain the case that ed belongs to: [0, D)
165+
// 0 dim size case 3
166+
if (ed >= st) {
167+
*st_out = *ed_out = 0;
168+
*zero_dim_out = true;
169+
return;
170+
}
171+
172+
*st_out = st;
173+
*ed_out = ed;
174+
return;
175+
}
176+
}
177+
20178
void UpdateAttr(const phi::DDim& in_dims,
21179
const std::vector<int> axes,
22180
const std::vector<int> starts,
@@ -76,47 +234,17 @@ inline void CheckAndUpdateSliceAttrs(const phi::DDim in_dims,
76234

77235
if (dim_value > 0) {
78236
T step = steps == nullptr ? 1 : (*steps)[i];
79-
PADDLE_ENFORCE_NE(
80-
step,
81-
0,
82-
phi::errors::InvalidArgument(
83-
"Step should not be 0, but received step = %d.", step));
84-
85-
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
86-
start = std::max(start, static_cast<T>(0));
87-
88-
T end =
89-
0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
90-
end = std::min(end, dim_value);
91-
92-
if (step > 0) {
93-
start = std::min(start, dim_value);
94-
end = std::max(end, static_cast<T>(0));
95-
PADDLE_ENFORCE_GE(
96-
end,
97-
start,
98-
phi::errors::InvalidArgument(
99-
"When step > 0, end should be greater than start, but "
100-
"received end = %d, start = %d.",
101-
end,
102-
start));
103-
} else {
104-
// NOTE(liym27): When step < 0, start should less and equal to
105-
// dim_value-1
106-
// "end is -1" means contain the 0-th element of this axis.
107-
start = std::min(start, dim_value - 1);
108-
if (end < -1) {
109-
end += dim_value;
110-
}
111-
end = std::max(end, static_cast<T>(-1));
112-
PADDLE_ENFORCE_GE(
113-
start,
114-
end,
115-
phi::errors::InvalidArgument(
116-
"When step < 0, start should be greater than end, but "
117-
"received start = %d, end = %d.",
118-
start,
119-
end));
237+
T start, end;
238+
bool dummy_zero_out_dim = false;
239+
normalize_interval((*starts)[i],
240+
(*ends)[i],
241+
step,
242+
dim_value,
243+
&start,
244+
&end,
245+
&dummy_zero_out_dim);
246+
if (end == -dim_value - 1) {
247+
end = -1;
120248
}
121249

122250
(*starts)[i] = start;

backends/mlu/tests/unittests/test_multinomial_op_mlu.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,6 @@ def test_dim_less_than_1():
287287

288288
self.assertRaises(ValueError, test_dim_less_than_1)
289289

290-
with self.assertRaises(ValueError):
291-
prob = paddle.rand([20, 1000])
292-
prob[1:0] = 0
293-
out = paddle.multinomial(prob)
294-
295290

296291
if __name__ == "__main__":
297292
unittest.main()

backends/npu/tests/unittests/test_multinomial_op_npu.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,6 @@ def test_dim_less_than_1():
287287

288288
self.assertRaises(ValueError, test_dim_less_than_1)
289289

290-
with self.assertRaises(ValueError):
291-
prob = paddle.rand([20, 1000])
292-
prob[1:0] = 0
293-
out = paddle.multinomial(prob)
294-
295290

296291
if __name__ == "__main__":
297292
unittest.main()

0 commit comments

Comments
 (0)