|
17 | 17 |
|
18 | 18 | namespace custom_kernel {
|
19 | 19 |
|
| 20 | +/** |
| 21 | + * @brief Normalizes the slice interval [st, ed) with a given step and dimension |
| 22 | + * size. |
| 23 | + * |
| 24 | + * This function adjusts the interval [st, ed) to fit within the bounds defined |
| 25 | + * by the dimension size, taking into account the specified step. It handles |
| 26 | + * both positive and negative steps and accounts for negative indices by |
| 27 | + * converting them to equivalent positive indices within the dimension size. |
| 28 | + * |
| 29 | + * @tparam T The data type of the input parameters, which can be an integer or |
| 30 | + * floating-point type. |
| 31 | + * @param st The starting index of the interval. |
| 32 | + * @param ed The ending index of the interval (exclusive). |
| 33 | + * @param step The step size for iterating through the interval, which can be |
| 34 | + * positive or negative. |
| 35 | + * @param dim_size The size of the dimension, serving as the upper bound for |
| 36 | + * valid indices. |
| 37 | + * @param st_out Pointer to store the normalized starting index. |
| 38 | + * @param ed_out Pointer to store the normalized ending index. |
| 39 | + * @param zero_dim_out Pointer to a boolean flag that is set to true if the |
| 40 | + * resulting interval is empty. |
| 41 | + * |
| 42 | + * @details |
| 43 | + * - If `step > 0`, the function ensures that `st` and `ed` are adjusted to be |
| 44 | + * within the range [0, dim_size). |
| 45 | + * - If `step < 0`, the function adjusts `st` and `ed` to accommodate the |
| 46 | + * reverse traversal of the interval. |
| 47 | + * - Handles special cases where `st` and `ed` may be out of bounds or where |
| 48 | + * `dim_size` is zero. |
| 49 | + * - Uses pointer parameters for output to modify the values directly. |
| 50 | + * - The function also handles scenarios involving negative indices, converting |
| 51 | + * them appropriately. |
| 52 | + * |
| 53 | + * @example |
| 54 | + * T st_out, ed_out; |
| 55 | + * bool zero_dim; |
| 56 | + * normalize_interval(-3, -2, 1, 4, &st_out, &ed_out, &zero_dim); |
| 57 | + * // Results in: st_out = 1, ed_out = 2, zero_dim = false |
| 58 | + * |
| 59 | + * @note The function assumes that the pointers provided for output parameters |
| 60 | + * are valid and non-null. |
| 61 | + */ |
| 62 | +template <typename T> |
| 63 | +void normalize_interval( |
| 64 | + T st, T ed, T step, T dim_size, T* st_out, T* ed_out, bool* zero_dim_out) { |
| 65 | + /* Normalize slice interval [st, ed) with given step and dim_size. |
| 66 | + e.g. if given st = -3, ed = -2, step = 1, dim_size = 4, |
| 67 | + then normalized st_out = 1(-3+4), st_ed = 2(-2+4). |
| 68 | +
|
| 69 | + This function is general enough and applicable |
| 70 | + for both step > 0 and step < 0 scenarios. |
| 71 | +
|
| 72 | + Indicices dipicted as below: |
| 73 | +
|
| 74 | + =============================================================== |
| 75 | + | 0 1 2 3 ... D-1 | D D+1 ... |
| 76 | + ... -D-2 -D-1 | -D -D+1 -D+2 -D+3 ... -1 | |
| 77 | + =============================================================== |
| 78 | + */ |
| 79 | + // 0 dim size, just return |
| 80 | + if (dim_size <= 0) { |
| 81 | + *st_out = *ed_out = 0; |
| 82 | + *zero_dim_out = true; |
| 83 | + return; |
| 84 | + } |
| 85 | + |
| 86 | + if (step > 0) { |
| 87 | + /* positive step */ |
| 88 | + // 0 dim size case 1 |
| 89 | + if (st >= dim_size) { |
| 90 | + *st_out = *ed_out = 0; |
| 91 | + *zero_dim_out = true; |
| 92 | + return; |
| 93 | + } |
| 94 | + |
| 95 | + // 0 dim size case 2 |
| 96 | + if (ed <= -dim_size) { |
| 97 | + *st_out = *ed_out = 0; |
| 98 | + *zero_dim_out = true; |
| 99 | + return; |
| 100 | + } |
| 101 | + |
| 102 | + // make st belongs: (-inf, -D-1)∪[0, D) |
| 103 | + if (-dim_size <= st && st < 0) { |
| 104 | + st += dim_size; |
| 105 | + } |
| 106 | + // make st belongs: [0, D) |
| 107 | + st = std::max(st, static_cast<T>(0)); |
| 108 | + |
| 109 | + // make ed belongs: [0, +inf) |
| 110 | + if (-dim_size <= ed && ed < 0) { |
| 111 | + ed += dim_size; |
| 112 | + } |
| 113 | + // make ed belongs: [0, D] |
| 114 | + ed = std::min(ed, dim_size); |
| 115 | + |
| 116 | + // 0 dim size case 3 |
| 117 | + if (st >= ed) { |
| 118 | + *st_out = *ed_out = 0; |
| 119 | + *zero_dim_out = true; |
| 120 | + return; |
| 121 | + } |
| 122 | + *st_out = st; |
| 123 | + *ed_out = ed; |
| 124 | + return; |
| 125 | + |
| 126 | + } else { |
| 127 | + /* negative step */ |
| 128 | + // 0 dim size case 1 |
| 129 | + if (st <= -dim_size - 1) { |
| 130 | + *st_out = *ed_out = 0; |
| 131 | + *zero_dim_out = true; |
| 132 | + return; |
| 133 | + } |
| 134 | + |
| 135 | + // 0 dim size case 2 |
| 136 | + if (ed >= dim_size - 1) { |
| 137 | + *st_out = *ed_out = 0; |
| 138 | + *zero_dim_out = true; |
| 139 | + return; |
| 140 | + } |
| 141 | + |
| 142 | + // make st belongs: [0, D)∪[0, +inf) |
| 143 | + if (-dim_size <= st && st < 0) { |
| 144 | + st += dim_size; |
| 145 | + } |
| 146 | + // make st belongs: [0, D) |
| 147 | + st = std::min(st, dim_size - 1); |
| 148 | + |
| 149 | + // make ed belongs: [-inf, -D)∪[0, D) |
| 150 | + if (-dim_size <= ed && ed < 0) { |
| 151 | + ed += dim_size; |
| 152 | + } |
| 153 | + // make ed belongs: [-D-1, -D)∪[0, D) ==> {-D-1}∪[0, D) |
| 154 | + ed = std::max(ed, -dim_size - 1); |
| 155 | + |
| 156 | + if (ed == -dim_size - 1) { |
| 157 | + // When ed=-D-1, it is symmetrical to when step is greater than 0 and |
| 158 | + // ed=D. |
| 159 | + *st_out = st; |
| 160 | + *ed_out = ed; |
| 161 | + return; |
| 162 | + } |
| 163 | + |
| 164 | + // now only remain the case that ed belongs to: [0, D) |
| 165 | + // 0 dim size case 3 |
| 166 | + if (ed >= st) { |
| 167 | + *st_out = *ed_out = 0; |
| 168 | + *zero_dim_out = true; |
| 169 | + return; |
| 170 | + } |
| 171 | + |
| 172 | + *st_out = st; |
| 173 | + *ed_out = ed; |
| 174 | + return; |
| 175 | + } |
| 176 | +} |
| 177 | + |
20 | 178 | void UpdateAttr(const phi::DDim& in_dims,
|
21 | 179 | const std::vector<int> axes,
|
22 | 180 | const std::vector<int> starts,
|
@@ -76,47 +234,17 @@ inline void CheckAndUpdateSliceAttrs(const phi::DDim in_dims,
|
76 | 234 |
|
77 | 235 | if (dim_value > 0) {
|
78 | 236 | T step = steps == nullptr ? 1 : (*steps)[i];
|
79 |
| - PADDLE_ENFORCE_NE( |
80 |
| - step, |
81 |
| - 0, |
82 |
| - phi::errors::InvalidArgument( |
83 |
| - "Step should not be 0, but received step = %d.", step)); |
84 |
| - |
85 |
| - T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; |
86 |
| - start = std::max(start, static_cast<T>(0)); |
87 |
| - |
88 |
| - T end = |
89 |
| - 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i]; |
90 |
| - end = std::min(end, dim_value); |
91 |
| - |
92 |
| - if (step > 0) { |
93 |
| - start = std::min(start, dim_value); |
94 |
| - end = std::max(end, static_cast<T>(0)); |
95 |
| - PADDLE_ENFORCE_GE( |
96 |
| - end, |
97 |
| - start, |
98 |
| - phi::errors::InvalidArgument( |
99 |
| - "When step > 0, end should be greater than start, but " |
100 |
| - "received end = %d, start = %d.", |
101 |
| - end, |
102 |
| - start)); |
103 |
| - } else { |
104 |
| - // NOTE(liym27): When step < 0, start should less and equal to |
105 |
| - // dim_value-1 |
106 |
| - // "end is -1" means contain the 0-th element of this axis. |
107 |
| - start = std::min(start, dim_value - 1); |
108 |
| - if (end < -1) { |
109 |
| - end += dim_value; |
110 |
| - } |
111 |
| - end = std::max(end, static_cast<T>(-1)); |
112 |
| - PADDLE_ENFORCE_GE( |
113 |
| - start, |
114 |
| - end, |
115 |
| - phi::errors::InvalidArgument( |
116 |
| - "When step < 0, start should be greater than end, but " |
117 |
| - "received start = %d, end = %d.", |
118 |
| - start, |
119 |
| - end)); |
| 237 | + T start, end; |
| 238 | + bool dummy_zero_out_dim = false; |
| 239 | + normalize_interval((*starts)[i], |
| 240 | + (*ends)[i], |
| 241 | + step, |
| 242 | + dim_value, |
| 243 | + &start, |
| 244 | + &end, |
| 245 | + &dummy_zero_out_dim); |
| 246 | + if (end == -dim_value - 1) { |
| 247 | + end = -1; |
120 | 248 | }
|
121 | 249 |
|
122 | 250 | (*starts)[i] = start;
|
|
0 commit comments