Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/AssociativeOpsTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ void populate_ops_table_single_uint8_cast(const vector<Type> &types, vector<Asso

void populate_ops_table_single_uint8_select(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_single(types);
table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating add
table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add
table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, x0 + y0), zero_0, true); // Saturating add
table.emplace_back(select(x0 < -y0, x0 + y0, tmax_0), zero_0, true); // Saturating add
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this condition: x0 < -y0 doesn't make sense to me with unsigned integer types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In uint8, -y0 == 256 - y0 == (255-y0) + 1, which only overflows when y0 is zero, so let's consider that case first. If y0 == 0, x0 < -y0 is always false, so we return 255. That seems wrong. It's fine when y0 != 0

Maybe it's meant to be x0 < ~y0?

}

void populate_ops_table_single_uint16_cast(const vector<Type> &types, vector<AssociativePattern> &table) {
Expand All @@ -233,8 +233,8 @@ void populate_ops_table_single_uint16_cast(const vector<Type> &types, vector<Ass

void populate_ops_table_single_uint16_select(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_single(types);
table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating add
table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add
table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, x0 + y0), zero_0, true); // Saturating add
table.emplace_back(select(x0 < -y0, x0 + y0, tmax_0), zero_0, true); // Saturating add
}

void populate_ops_table_single_uint32_cast(const vector<Type> &types, vector<AssociativePattern> &table) {
Expand All @@ -245,8 +245,8 @@ void populate_ops_table_single_uint32_cast(const vector<Type> &types, vector<Ass

void populate_ops_table_single_uint32_select(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_single(types);
table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating add
table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add
table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, x0 + y0), zero_0, true); // Saturating add
table.emplace_back(select(x0 < -y0, x0 + y0, tmax_0), zero_0, true); // Saturating add
}

void populate_ops_table_single_float_select(const vector<Type> &types, vector<AssociativePattern> &table) {
Expand Down
4 changes: 2 additions & 2 deletions src/Associativity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ void associativity_test() {
Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0);

for (const Expr &e : {cast<uint8_t>(min(cast<uint16_t>(x) + y, 255)),
select(x > 255 - y, cast<uint8_t>(255), y),
select(x < -y, y, cast<uint8_t>(255)),
select(x > 255 - y, cast<uint8_t>(255), x + y),
select(x < -y, x + y, cast<uint8_t>(255)),
saturating_add(x, y),
saturating_add(y, x),
saturating_cast<uint8_t>(widening_add(x, y))}) {
Expand Down
54 changes: 54 additions & 0 deletions test/correctness/rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,59 @@ int argmin_rfactor_test() {
return 0;
}

int saturating_add_rfactor_test() {
Func f("f"), g("g"), ref("ref");
Var x("x"), y("y"), z("z");

f(x) = cast<uint8_t>(x);
f.compute_root();

Param<int> inner_extent;
RDom r(10, inner_extent);
inner_extent.set(6);
uint8_t max_int = 255;

g() = Tuple(cast<uint8_t>(0), cast<uint8_t>(0));
g() = Tuple(select(g()[0] > max_int - 3 * f(r.x), max_int, g()[0] + 3 * f(r.x)),
select(g()[1] > max_int - 9 * f(r.x), max_int, 9 * f(r.x) + g()[1]));

RVar rxi("rxi"), rxo("rxo");
g.update(0).split(r.x, rxo, rxi, 2);

Var u("u");
Func intm = g.update(0).rfactor(rxo, u);
intm.compute_root();
intm.update(0).vectorize(u, 2);

Realization rn = g.realize();
Buffer<uint8_t> im1(rn[0]);
Buffer<uint8_t> im2(rn[1]);

auto func1 = [](int x, int y, int z) {
int ret = 0;
for (int i = 10; i < 16; i++) {
ret += 3 * i;
}
return std::min(ret, 255);
};
if (check_image(im1, func1)) {
return 1;
}

auto func2 = [](int x, int y, int z) {
int ret = 0;
for (int i = 10; i < 16; i++) {
ret += 9 * i;
}
return std::min(ret, 255);
};
if (check_image(im2, func2)) {
return 1;
}

return 0;
}

int allocation_bound_test_trace(JITUserContext *user_context, const halide_trace_event_t *e) {
// The schedule implies that f will be stored from 0 to 1
if (e->event == 2 && std::string(e->func) == "f") {
Expand Down Expand Up @@ -1156,6 +1209,7 @@ int main(int argc, char **argv) {
{"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test},
{"complex multiply rfactor test", complex_multiply_rfactor_test},
{"argmin rfactor test", argmin_rfactor_test},
{"saturating add rfactor test", saturating_add_rfactor_test},
{"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test},
{"rfactor bounds tests", rfactor_precise_bounds_test},
{"isnan max rfactor test (bitwise or)", isnan_max_rfactor_test<BitwiseOr>},
Expand Down
Loading