Skip to content

Commit d1b5555

Browse files
authored
Merge pull request #433 from arcondello/fix/BinaryOp-broadcasting
Fix shape when broadcasting with BinaryOpNode
2 parents 13ab460 + 9c6943d commit d1b5555

File tree

6 files changed

+81
-21
lines changed

6 files changed

+81
-21
lines changed

dwave/optimization/src/array.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,12 @@ std::vector<ssize_t> broadcast_shape(const std::span<const ssize_t> lhs,
264264
}
265265
assert(sit == shape.rend());
266266

267+
// Check that we haven't put a dynamic axis anywhere except axis 0
268+
if (std::ranges::any_of(shape | std::views::drop(1), [](const auto& val) { return val < 0; })) {
269+
throw std::invalid_argument("operands could not be broadcast together with shapes " +
270+
shape_to_string(lhs) + " " + shape_to_string(rhs));
271+
}
272+
267273
return shape;
268274
}
269275
std::vector<ssize_t> broadcast_shape(std::initializer_list<ssize_t> lhs,

dwave/optimization/src/nodes/binaryop.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -367,25 +367,40 @@ template <class BinaryOp>
367367
std::span<const ssize_t> BinaryOpNode<BinaryOp>::shape(const State& state) const {
368368
if (!this->dynamic()) return this->shape();
369369

370-
const ssize_t lhs_size = operands_[0]->size(state);
370+
const auto [lhs, rhs] = operands_;
371371

372-
if (lhs_size == operands_[1]->size(state)) return operands_[0]->shape(state);
372+
// If we're broadcasting we know which size we're looking at
373+
if (lhs->size() == 1) return rhs->shape(state);
374+
if (rhs->size() == 1) return lhs->shape(state);
373375

374-
return (lhs_size == 1) ? operands_[1]->shape(state) : operands_[0]->shape(state);
376+
// Dev note: it's very tempting to put an assert here that checks
377+
// that lhs.shape == rhs.shape, but that can make calling shape()
378+
// very expensive in some cases.
379+
// Having this commented out makes the second if-branch above redundant, but
380+
// I am leaving this in for clarity and future debugging.
381+
// assert(std::ranges::equal(lhs->shape(state), rhs->shape(state)));
382+
383+
return lhs->shape(state);
375384
}
376385

377386
template <class BinaryOp>
378387
ssize_t BinaryOpNode<BinaryOp>::size(const State& state) const {
379-
if (ssize_t size = this->size(); size >= 0) {
380-
return size;
381-
}
388+
if (const ssize_t size = this->size(); size >= 0) return size;
389+
390+
const auto [lhs, rhs] = operands_;
382391

383-
const ssize_t lhs_size = operands_[0]->size(state);
384-
const ssize_t rhs_size = operands_[1]->size(state);
392+
// If we're broadcasting we know which size we're looking at
393+
if (lhs->size() == 1) return rhs->size(state);
394+
if (rhs->size() == 1) return lhs->size(state);
385395

386-
if (lhs_size == rhs_size) return lhs_size;
396+
// Dev note: it's very tempting to put an assert here that checks
397+
// that lhs.size == rhs.size, but that can make calling size()
398+
// very expensive in some cases.
399+
// Having this commented out makes the second if-branch above redundant, but
400+
// I am leaving this in for clarity and future debugging.
401+
// assert(lhs->size(state) == rhs->size(state)));
387402

388-
return (lhs_size == 1) ? rhs_size : lhs_size;
403+
return lhs->size(state);
389404
}
390405

391406
template <class BinaryOp>
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
fixes:
3+
- |
4+
Change C++ function ``broadcast_shape()`` correctly throw an error when
5+
an attempt to broadcast a dynamically sized shape would result in an invalid
6+
result.
7+
See `#429 <https://github.com/dwavesystems/dwave-optimization/issues/429>`_.
8+
- |
9+
Fix ``BinaryNode::shape(const State&)`` and ``BinaryNode::size(const State&)``
10+
so that they do not propagate the wrong shape/size when both predecessors
11+
have the same shape.

tests/cpp/nodes/test_binaryop.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,22 @@ TEST_CASE("BinaryOpNode - MultiplyNode") {
466466
CHECK(y_ptr->integral());
467467
}
468468
}
469+
470+
GIVEN("x = SetNode(5), y = Constant(0), z = x * y") {
471+
// This test is for a specific bug we had where it would accidentlly
472+
// pick up the shape from the lhs if the size of the lhs and rhs were
473+
// both 1.
474+
475+
auto x_ptr = graph.emplace_node<SetNode>(5);
476+
auto y_ptr = graph.emplace_node<ConstantNode>(0);
477+
auto z_ptr = graph.emplace_node<MultiplyNode>(y_ptr, x_ptr);
478+
479+
auto state = graph.empty_state();
480+
x_ptr->initialize_state(state, {0});
481+
graph.initialize_state(state);
482+
483+
CHECK_THAT(z_ptr->shape(state), RangeEquals({1}));
484+
}
469485
}
470486

471487
TEST_CASE("BinaryOpNode - DivideNode") {

tests/cpp/test_array.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -481,32 +481,39 @@ TEST_CASE("Update") {
481481
}
482482
}
483483

484-
TEST_CASE("Test resulting_shape()") {
484+
TEST_CASE("Test broadcast_shape()") {
485485
SECTION("(256,256,3) x (3,) -> (256,256,3)") {
486-
CHECK(std::ranges::equal(broadcast_shape({256, 256, 3}, {3}), std::vector{256, 256, 3}));
486+
CHECK_THAT(broadcast_shape({256, 256, 3}, {3}), RangeEquals({256, 256, 3}));
487487
}
488488
SECTION("(8,1,6,1) x (7,1,5) -> (8,7,6,5)") {
489-
CHECK(std::ranges::equal(broadcast_shape({8, 1, 6, 1}, {7, 1, 5}),
490-
std::vector{8, 7, 6, 5}));
489+
CHECK_THAT(broadcast_shape({8, 1, 6, 1}, {7, 1, 5}), RangeEquals({8, 7, 6, 5}));
491490
}
492491
SECTION("(7,1,5) x (8,1,6,1) -> (8,7,6,5)") {
493-
CHECK(std::ranges::equal(broadcast_shape({7, 1, 5}, {8, 1, 6, 1}),
494-
std::vector{8, 7, 6, 5}));
492+
CHECK_THAT(broadcast_shape({7, 1, 5}, {8, 1, 6, 1}), RangeEquals({8, 7, 6, 5}));
495493
}
496494
SECTION("(5,4) x (1,) -> (5, 4)") {
497-
CHECK(std::ranges::equal(broadcast_shape({5, 4}, {1}), std::vector{5, 4}));
495+
CHECK_THAT(broadcast_shape({5, 4}, {1}), RangeEquals({5, 4}));
498496
}
499497
SECTION("(5,4) x (4,) -> (5, 4)") {
500-
CHECK(std::ranges::equal(broadcast_shape({5, 4}, {4}), std::vector{5, 4}));
498+
CHECK_THAT(broadcast_shape({5, 4}, {4}), RangeEquals({5, 4}));
501499
}
502500
SECTION("(15,3,5) x (15,1,5) -> (15,3,5)") {
503-
CHECK(std::ranges::equal(broadcast_shape({15, 3, 5}, {15, 1, 5}), std::vector{15, 3, 5}));
501+
CHECK_THAT(broadcast_shape({15, 3, 5}, {15, 1, 5}), RangeEquals({15, 3, 5}));
504502
}
505503
SECTION("(15,3,5) x (3,5) -> (15,3,5)") {
506-
CHECK(std::ranges::equal(broadcast_shape({15, 3, 5}, {3, 5}), std::vector{15, 3, 5}));
504+
CHECK_THAT(broadcast_shape({15, 3, 5}, {3, 5}), RangeEquals({15, 3, 5}));
507505
}
508506
SECTION("(15,3,5) x (3,1) -> (15,3,5)") {
509-
CHECK(std::ranges::equal(broadcast_shape({15, 3, 5}, {3, 1}), std::vector{15, 3, 5}));
507+
CHECK_THAT(broadcast_shape({15, 3, 5}, {3, 1}), RangeEquals({15, 3, 5}));
508+
}
509+
SECTION("(-1,) x (1,) -> (-1,)") { CHECK_THAT(broadcast_shape({-1}, {1}), RangeEquals({-1})); }
510+
SECTION("(1,) x (-1,) -> (-1,)") { CHECK_THAT(broadcast_shape({1}, {-1}), RangeEquals({-1})); }
511+
SECTION("(1,2) x (-1,1) -> (-1,2)") {
512+
CHECK_THAT(broadcast_shape({1, 2}, {-1, 1}), RangeEquals({-1, 2}));
513+
}
514+
SECTION("(1,1) x (-1,) -> invalid") {
515+
CHECK_THROWS_WITH(broadcast_shape({1, 1}, {-1}),
516+
"operands could not be broadcast together with shapes (1, 1) (-1,)");
510517
}
511518
SECTION("(3,) x (4,)") {
512519
CHECK_THROWS_WITH(broadcast_shape({3}, {4}),

tests/test_symbols.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ def test_broadcasting(self):
164164
with self.assertRaises(ValueError):
165165
a + b
166166

167+
c = model.constant([[[1]]])
168+
d = model.set(5)
169+
with self.assertRaises(ValueError):
170+
c + d
171+
167172
def test_scalar_addition(self):
168173
model = Model()
169174
a = model.constant(5)

0 commit comments

Comments
 (0)