From 9306dac487a2af5af1e55956456843ec43a1e9b7 Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Fri, 21 Nov 2025 09:38:54 -0800 Subject: [PATCH 01/10] Add MatMulNode --- .../include/dwave-optimization/array.hpp | 25 + .../nodes/linear_algebra.hpp | 76 +++ .../optimization/src/nodes/linear_algebra.cpp | 294 ++++++++++++ meson.build | 1 + tests/cpp/meson.build | 3 +- tests/cpp/nodes/test_linear_algebra.cpp | 440 ++++++++++++++++++ 6 files changed, 838 insertions(+), 1 deletion(-) create mode 100644 dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp create mode 100644 dwave/optimization/src/nodes/linear_algebra.cpp create mode 100644 tests/cpp/nodes/test_linear_algebra.cpp diff --git a/dwave/optimization/include/dwave-optimization/array.hpp b/dwave/optimization/include/dwave-optimization/array.hpp index d9330426..af6b6631 100644 --- a/dwave/optimization/include/dwave-optimization/array.hpp +++ b/dwave/optimization/include/dwave-optimization/array.hpp @@ -76,6 +76,31 @@ struct SizeInfo { } bool operator==(const SizeInfo& other) const; + constexpr SizeInfo& operator*=(const std::integral auto n) { + multiplier *= n; + offset *= n; + if (min.has_value()) min.value() *= n; + if (max.has_value()) max.value() *= n; + return *this; + } + friend SizeInfo operator*(SizeInfo lhs, const std::integral auto rhs) { + lhs *= rhs; + return lhs; + } + + constexpr SizeInfo& operator/=(const std::integral auto n) { + if (!n) throw std::invalid_argument("cannot divide by 0"); + multiplier /= n; + offset /= n; + if (min.has_value()) min.value() /= n; + if (max.has_value()) max.value() /= n; + return *this; + } + friend SizeInfo operator/(SizeInfo lhs, const std::integral auto rhs) { + lhs /= rhs; + return lhs; + } + // SizeInfos are printable friend std::ostream& operator<<(std::ostream& os, const SizeInfo& sizeinfo); diff --git a/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp b/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp new file mode 100644 index 00000000..e79e8305 --- /dev/null +++ b/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp @@ -0,0 +1,76 @@ +// Copyright 2025 D-Wave Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "dwave-optimization/array.hpp" +#include "dwave-optimization/graph.hpp" + +namespace dwave::optimization { + +class MatMulNode : public ArrayOutputMixin { + public: + MatMulNode(ArrayNode* x_ptr, ArrayNode* y_ptr); + + /// @copydoc Array::buff() + double const* buff(const State& state) const override; + + /// @copydoc Node::commit() + void commit(State& state) const override; + + /// @copydoc Array::diff() + std::span diff(const State& state) const override; + + /// @copydoc Node::initialize_state() + void initialize_state(State& state) const override; + + /// @copydoc Array::integral() + bool integral() const override; + + /// @copydoc Array::max() + double max() const override; + + /// @copydoc Array::min() + double min() const override; + + /// @copydoc Node::propagate() + void propagate(State& state) const override; + + /// @copydoc Node::revert() + void revert(State& state) const override; + + using Array::shape; + + std::span shape(const State& state) const override; + + using Array::size; + + ssize_t size(const State& state) const override; + ssize_t size_diff(const State& state) const override; + SizeInfo sizeinfo() const override; + + private: + void matmul(State& state, std::span out, std::span out_shape) const; + void update_shape(State& state) const; + + const ArrayNode* x_ptr_; + const ArrayNode* y_ptr_; + + const SizeInfo sizeinfo_; + const ValuesInfo values_info_; +}; + +} // namespace dwave::optimization diff --git a/dwave/optimization/src/nodes/linear_algebra.cpp b/dwave/optimization/src/nodes/linear_algebra.cpp new file mode 100644 index 00000000..eb046c71 --- /dev/null +++ b/dwave/optimization/src/nodes/linear_algebra.cpp @@ -0,0 +1,294 @@ +// Copyright 2025 D-Wave Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dwave-optimization/nodes/linear_algebra.hpp" + +#include "../functional_.hpp" +#include "_state.hpp" +#include "dwave-optimization/array.hpp" +#include "dwave-optimization/state.hpp" + +namespace dwave::optimization { + +////////////////////// MatMulNode + +// Valid shapes to multiply, and the resulting shape +// (-1, 2, 5, 3) and (-1, 2, 3, 7) -> (-1, 2, 5, 7) +// (-1, 3) and (3) -> (-1) +// (-1, 3) and (3, 7) -> (-1, 7) +// (-1) and (-1) -> () +// (-1) and (-1, 5) -> (5) + +ssize_t size_from_shape(std::span shape) { + return std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); +} + +ssize_t get_axis_size(std::span shape, ssize_t index, bool vector_as_row) { + // If vector_as_row is true, treat vector as shape (1, size), else as shape (size, 1) + assert(index < 0); + if (shape.size() == 0) return 1; + if (shape.size() == 1 and index == -2 and vector_as_row) return 1; + if (shape.size() == 1 and index == -1 and not vector_as_row) return 1; + if (shape.size() == 1) return shape.back(); + return shape[shape.size() + index]; +} + +std::vector output_shape(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { + if (x_ptr->ndim() == 0 or y_ptr->ndim() == 0) { + throw std::invalid_argument("operands cannot be scalar"); + } + + // Check that last dimension of x matches the second last dimension of y + ssize_t x_last_axis_size = get_axis_size(x_ptr->shape(), -1, true); + ssize_t y_penultimate_axis_size = get_axis_size(y_ptr->shape(), -2, false); + if (x_last_axis_size != y_penultimate_axis_size) { + throw std::invalid_argument( + "the last dimension of `x` is not the same size as the second to last dimension of " + "`y`"); + } else if (x_last_axis_size == -1) { + assert(x_ptr->dynamic() && y_ptr->dynamic()); + // Both are dynamic. We need to check that the dynamic dimension is + // always the same size. + ssize_t x_subspace_size = -1 * size_from_shape(x_ptr->shape()); + ssize_t y_subspace_size = -1 * size_from_shape(y_ptr->shape()); + if (x_ptr->sizeinfo() / x_subspace_size != y_ptr->sizeinfo() / y_subspace_size) { + throw std::invalid_argument( + "the last dimension of `x` is not the same size as the second to last " + "dimension of `y`"); + } + } + + // Now check that the leading subspace shape is identical (no broadcasting for now) + if (x_ptr->ndim() > 2 && y_ptr->ndim() > 2) { + if (x_ptr->ndim() != y_ptr->ndim()) { + throw std::invalid_argument( + "operands have different dimensions (use BroadcastNode if you wish to " + "broadcast missing dimensions)"); + } + for (ssize_t i = 0, stop = x_ptr->ndim() - 2; i < stop; i++) { + if (x_ptr->shape()[i] != y_ptr->shape()[i]) { + throw std::invalid_argument( + "operands must have matching leading shape (up to the last two " + "dimensions)"); + } + } + } + + // Now we now the leading axes match, we can construct the output shape + std::vector shape; + for (ssize_t d : x_ptr->shape() | std::views::take(x_ptr->ndim() - 1)) { + shape.push_back(d); + } + if (y_ptr->ndim() >= 2) { + shape.push_back(y_ptr->shape().back()); + } + + return shape; +} + +SizeInfo get_sizeinfo(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { + if (y_ptr->dynamic() and y_ptr->ndim() <= 2) { + // x must also be dynamic, and we must be contracting along the dynamic + // dimension, so the output is fixed size. + std::vector shape = output_shape(x_ptr, y_ptr); + ssize_t size = size_from_shape(shape); + assert(size >= 1); + return SizeInfo(size); + } + assert(x_ptr->shape().back() != -1); + SizeInfo sizeinfo = x_ptr->sizeinfo() / x_ptr->shape().back(); + if (y_ptr->ndim() == 2 && y_ptr->dynamic()) { + assert(x_ptr->dynamic() && x_ptr->ndim() == 1); + } else if (y_ptr->ndim() >= 2) { + assert(y_ptr->shape().back() != -1); + sizeinfo *= y_ptr->shape().back(); + } + return sizeinfo; +} + +ValuesInfo get_values_info(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { + // get all possible combinations of values + std::array combos{x_ptr->min() * y_ptr->min(), x_ptr->min() * y_ptr->max(), + x_ptr->max() * y_ptr->min(), x_ptr->max() * y_ptr->max()}; + + double min_val = std::ranges::min(combos); + double max_val = std::ranges::max(combos); + + ssize_t x_subspace_size = std::reduce(x_ptr->shape().begin(), x_ptr->shape().end() - 1, 1, + std::multiplies()); + SizeInfo contracted_axis_size = x_ptr->sizeinfo() / x_subspace_size; + + if (contracted_axis_size.max.has_value() and *contracted_axis_size.max == 0) { + // Output will always be empty, so we can return early + return ValuesInfo(0.0, 0.0, true); + } + + // Use default constructor to get default min/max + ValuesInfo values_info; + values_info.integral = x_ptr->integral() && y_ptr->integral(); + + if (contracted_axis_size.max.has_value()) { + if (max_val >= 0) values_info.max = max_val * contracted_axis_size.max.value(); + if (min_val <= 0) values_info.min = min_val * contracted_axis_size.max.value(); + } + + ssize_t min_size = contracted_axis_size.min.value_or(0); + if (max_val < 0) values_info.max = max_val * min_size; + if (min_val > 0) values_info.min = min_val * min_size; + + return values_info; +} + +std::vector atleast_2d_shape(std::span shape, bool as_row) { + if (shape.size() == 0) return {1, 1}; + if (shape.size() == 1 and as_row) return {1, shape[0]}; + if (shape.size() == 1 and not as_row) return {shape[0], 1}; + return {shape.begin(), shape.end()}; +} + +MatMulNode::MatMulNode(ArrayNode* x_ptr, ArrayNode* y_ptr) + : ArrayOutputMixin(output_shape(x_ptr, y_ptr)), + x_ptr_(x_ptr), + y_ptr_(y_ptr), + sizeinfo_(get_sizeinfo(x_ptr, y_ptr)), + values_info_(get_values_info(x_ptr, y_ptr)) {} + +class MatMulNodeData : public ArrayNodeStateData { + public: + explicit MatMulNodeData(std::vector&& values, std::span shape) + : ArrayNodeStateData(std::move(values)), shape(shape.begin(), shape.end()) {} + + std::vector output; + std::vector shape; +}; + +ssize_t get_leading_stride(std::span shape) { + ssize_t stride = 1; + if (shape.size() >= 1) stride *= shape.back(); + if (shape.size() >= 2) stride *= shape[shape.size() - 2]; + return stride; +} + +void MatMulNode::matmul(State& state, std::span out, + std::span out_shape) const { + auto x_data = x_ptr_->view(state); + auto y_data = y_ptr_->view(state); + + ssize_t x_penultimate_axis_size = get_axis_size(x_ptr_->shape(state), -2, true); + ssize_t x_penultimate_axis = std::max(x_ptr_->ndim() - 2, 0); + ssize_t leading_subspace_size = std::reduce(x_ptr_->shape(state).begin(), + x_ptr_->shape(state).begin() + x_penultimate_axis, + 1, std::multiplies()); + + ssize_t x_leading_stride = get_leading_stride(x_ptr_->shape(state)); + ssize_t y_leading_stride = get_leading_stride(y_ptr_->shape(state)); + ssize_t out_leading_stride = get_leading_stride(out_shape); + + ssize_t y_last_axis_size = get_axis_size(y_ptr_->shape(state), -1, false); + ssize_t y_penultimate_axis_size = get_axis_size(y_ptr_->shape(state), -2, false); + + // TODO: consider using the parent arrays' strides directly + ssize_t x_penultimate_stride = get_axis_size(x_ptr_->shape(state), -1, true); + ssize_t x_last_stride = 1; + + ssize_t y_penultimate_stride = y_last_axis_size; + ssize_t y_last_stride = 1; + + ssize_t out_following_stride = get_axis_size(out_shape, -1, false); + + for (ssize_t w = 0; w < leading_subspace_size; w++) { + for (ssize_t i = 0; i < x_penultimate_axis_size; i++) { + for (ssize_t j = 0; j < y_last_axis_size; j++) { + auto x = x_data.begin() + w * x_leading_stride + i * x_penultimate_stride; + auto y = y_data.begin() + w * y_leading_stride + j * y_last_stride; + double& out_val = out[w * out_leading_stride + i * out_following_stride + j]; + out_val = 0.0; + for (ssize_t k = 0; k < y_penultimate_axis_size; k++) { + out_val += *x * *y; + x += x_last_stride; + y += y_penultimate_stride; + } + } + } + } +} + +void MatMulNode::initialize_state(State& state) const { + ssize_t start_size = this->size(); + std::vector shape(this->shape().begin(), this->shape().end()); + if (this->dynamic()) { + shape[0] = x_ptr_->shape(state)[0]; + start_size = size_from_shape(shape); + } + + std::vector data(start_size); + matmul(state, data, shape); + emplace_data_ptr(state, std::move(data), shape); +} + +double const* MatMulNode::buff(const State& state) const { + return data_ptr(state)->buff(); +} + +void MatMulNode::commit(State& state) const { return data_ptr(state)->commit(); } + +std::span MatMulNode::diff(const State& state) const { + return data_ptr(state)->diff(); +} + +bool MatMulNode::integral() const { return values_info_.integral; } + +double MatMulNode::max() const { return values_info_.max; } + +double MatMulNode::min() const { return values_info_.min; } + +void MatMulNode::update_shape(State& state) const { + if (this->dynamic()) { + data_ptr(state)->shape[0] = x_ptr_->shape(state)[0]; + } +} + +void MatMulNode::propagate(State& state) const { + auto data = data_ptr(state); + this->update_shape(state); + ssize_t new_size = size_from_shape(data->shape); + data->output.resize(new_size); + this->matmul(state, data->output, data->shape); + data->assign(data->output); +} + +void MatMulNode::revert(State& state) const { + auto data = data_ptr(state); + data->revert(); + this->update_shape(state); +} + +std::span MatMulNode::shape(const State& state) const { + if (not this->dynamic()) return this->shape(); + return data_ptr(state)->shape; +} + +ssize_t MatMulNode::size(const State& state) const { + if (not this->dynamic()) return this->size(); + return data_ptr(state)->size(); +} + +ssize_t MatMulNode::size_diff(const State& state) const { + if (not this->dynamic()) return 0; + return data_ptr(state)->size_diff(); +} + +SizeInfo MatMulNode::sizeinfo() const { return sizeinfo_; } + +} // namespace dwave::optimization diff --git a/meson.build b/meson.build index 1fba2658..9c601685 100644 --- a/meson.build +++ b/meson.build @@ -35,6 +35,7 @@ dwave_optimization_src = [ 'dwave/optimization/src/nodes/inputs.cpp', 'dwave/optimization/src/nodes/interpolation.cpp', 'dwave/optimization/src/nodes/lambda.cpp', + 'dwave/optimization/src/nodes/linear_algebra.cpp', 'dwave/optimization/src/nodes/lp.cpp', 'dwave/optimization/src/nodes/manipulation.cpp', 'dwave/optimization/src/nodes/naryop.cpp', diff --git a/tests/cpp/meson.build b/tests/cpp/meson.build index 7e840d26..ef23f764 100644 --- a/tests/cpp/meson.build +++ b/tests/cpp/meson.build @@ -16,8 +16,9 @@ tests_all = executable( 'nodes/test_flow.cpp', 'nodes/test_inputs.cpp', 'nodes/test_interpolation.cpp', - 'nodes/test_lp.cpp', 'nodes/test_lambda.cpp', + 'nodes/test_linear_algebra.cpp', + 'nodes/test_lp.cpp', 'nodes/test_manipulation.cpp', 'nodes/test_naryop.cpp', 'nodes/test_numbers.cpp', diff --git a/tests/cpp/nodes/test_linear_algebra.cpp b/tests/cpp/nodes/test_linear_algebra.cpp new file mode 100644 index 00000000..44854b6c --- /dev/null +++ b/tests/cpp/nodes/test_linear_algebra.cpp @@ -0,0 +1,440 @@ +// Copyright 2025 D-Wave +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "dwave-optimization/nodes/binaryop.hpp" +#include "dwave-optimization/nodes/collections.hpp" +#include "dwave-optimization/nodes/constants.hpp" +#include "dwave-optimization/nodes/indexing.hpp" +#include "dwave-optimization/nodes/linear_algebra.hpp" +#include "dwave-optimization/nodes/manipulation.hpp" +#include "dwave-optimization/nodes/testing.hpp" + +using Catch::Matchers::RangeEquals; + +namespace dwave::optimization { + +TEST_CASE("MatMulNode") { + auto graph = Graph(); + + GIVEN("A dynamic testing array node and matmul on it") { + auto arr = DynamicArrayTestingNode(std::initializer_list{-1}, -10.0, -5.0, false); + auto constant = ConstantNode(15.0); + auto add = AddNode(&arr, &constant); + REQUIRE(add.min() == 5.0); + REQUIRE(add.max() == 10.0); + auto matmul = MatMulNode(&arr, &add); + THEN("MatMulNode reports correct min and max") { + CHECK(matmul.min() == ValuesInfo().min); + CHECK(matmul.max() == 0.0); + } + } + + GIVEN("A dynamic testing array node with minimum size and matmul on it") { + auto arr = DynamicArrayTestingNode(std::initializer_list{-1}, -10.0, -5.0, false, + 3, std::nullopt); + auto constant = ConstantNode(15.0); + auto add = AddNode(&arr, &constant); + REQUIRE(add.min() == 5.0); + REQUIRE(add.max() == 10.0); + auto matmul = MatMulNode(&arr, &add); + THEN("MatMulNode reports correct min and max") { + CHECK(matmul.min() == ValuesInfo().min); + CHECK(matmul.max() == -5.0 * 5.0 * 3); + } + } + + GIVEN("A dynamic testing array node with minimum and maximum size and matmul on it") { + auto arr = DynamicArrayTestingNode(std::initializer_list{-1}, -10.0, -5.0, false, + 3, 7); + auto constant = ConstantNode(15.0); + auto add = AddNode(&arr, &constant); + REQUIRE(add.min() == 5.0); + REQUIRE(add.max() == 10.0); + auto matmul = MatMulNode(&arr, &add); + THEN("MatMulNode reports correct min and max") { + CHECK(matmul.min() == -10.0 * 10.0 * 7); + CHECK(matmul.max() == -5.0 * 5.0 * 3); + } + } + + GIVEN("Two constant 1d nodes and a MatMulNode") { + auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0}); + auto c2_ptr = graph.emplace_node(std::vector{4.0, 5.0, 6.0}); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 1); + CHECK(matmul_ptr->ndim() == 0); + + CHECK(matmul_ptr->min() == 1.0 * 4.0 * 3); + CHECK(matmul_ptr->max() == 3.0 * 6.0 * 3); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatMulNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({32})); + } + } + } + + GIVEN("Two constant 2d nodes and a MatMulNode") { + auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + std::vector{2, 3}); + auto c2_ptr = graph.emplace_node(std::vector{7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + std::vector{3, 2}); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 4); + CHECK(matmul_ptr->ndim() == 2); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({2, 2})); + + CHECK(matmul_ptr->min() == 1.0 * 7.0 * 3); + CHECK(matmul_ptr->max() == 6.0 * 12.0 * 3); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatMulNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({58, 64, 139, 154})); + } + } + } + + GIVEN("Two constant 2d nodes and a MatMulNode") { + auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + std::vector{2, 3}); + auto c2_ptr = graph.emplace_node(std::vector{7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + std::vector{3, 2}); + + auto matmul_ptr = graph.emplace_node(c2_ptr, c1_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 9); + CHECK(matmul_ptr->ndim() == 2); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({3, 3})); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatMulNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), + RangeEquals({39, 54, 69, 49, 68, 87, 59, 82, 105})); + } + } + } + + GIVEN("One constant 1d node and one constant 2d node") { + auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0}, + std::vector{3}); + auto c2_ptr = graph.emplace_node( + std::vector{4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + std::vector{3, 3}); + + AND_GIVEN("The 1d node @ the 2d node") { + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 3); + CHECK(matmul_ptr->ndim() == 1); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatMulNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({48, 54, 60})); + } + } + } + + AND_GIVEN("The 2d node @ the 1d node") { + auto matmul_ptr = graph.emplace_node(c2_ptr, c1_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 3); + CHECK(matmul_ptr->ndim() == 1); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatMulNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({32, 50, 68})); + } + } + } + } + + GIVEN("Two 4d constant nodes and a MatMulNode") { + std::vector c1_shape{5, 3, 1, 7}; + ssize_t c1_size = 5 * 3 * 1 * 7; + std::vector c1_data(c1_size); + std::iota(c1_data.begin(), c1_data.end(), -20.0); + auto c1_ptr = graph.emplace_node(c1_data, c1_shape); + REQUIRE(c1_ptr->min() == -20.0); + REQUIRE(c1_ptr->max() == 84); + + std::vector c2_shape{5, 3, 7, 2}; + ssize_t c2_size = 5 * 3 * 7 * 2; + std::vector c2_data(c2_size); + std::iota(c2_data.begin(), c2_data.end(), -10.0); + auto c2_ptr = graph.emplace_node(c2_data, c2_shape); + REQUIRE(c2_ptr->min() == -10.0); + REQUIRE(c2_ptr->max() == 199.0); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 5 * 3 * 1 * 2); + CHECK(matmul_ptr->ndim() == 4); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({5, 3, 1, 2})); + + CHECK(matmul_ptr->min() == -20.0 * 199.0 * 7); + CHECK(matmul_ptr->max() == 84.0 * 199.0 * 7); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatMulNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), + RangeEquals({532, 413, -644, -714, -448, -469, 1120, 1148, + 4060, 4137, 8372, 8498, 14056, 14231, 21112, 21336, + 29540, 29813, 39340, 39662, 50512, 50883, 63056, 63476, + 76972, 77441, 92260, 92778, 108920, 109487})); + } + } + } + + GIVEN("A set node and MatMulNode representing self dot product") { + auto set_ptr = graph.emplace_node(10); + auto matmul_ptr = graph.emplace_node(set_ptr, set_ptr); + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->ndim() == 0); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + REQUIRE(set_ptr->size(state) == 0); + + THEN("The initial MatMulNode state is correct") { + CHECK(matmul_ptr->size(state) == 1); + CHECK(matmul_ptr->shape(state).size() == 0); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({0.0})); + } + + AND_WHEN("We grow the set and propagate") { + set_ptr->assign(state, {5, 7, 1}); + graph.propagate(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({5 * 5 + 7 * 7 + 1 * 1})); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 1); + CHECK(matmul_ptr->shape(state).size() == 0); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({0.0})); + } + } + + AND_WHEN("We commit") { + graph.commit(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 1); + CHECK(matmul_ptr->shape(state).size() == 0); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({5 * 5 + 7 * 7 + 1 * 1})); + } + } + } + } + } + + GIVEN("A 2d dynamic testing node and a 1d constant") { + auto arr_ptr = + graph.emplace_node(std::initializer_list{-1, 3}); + auto c_ptr = graph.emplace_node(std::vector{1, 2, 3}); + + CHECK_THROWS_AS(MatMulNode(c_ptr, arr_ptr), std::invalid_argument); + + auto matmul_ptr = graph.emplace_node(arr_ptr, c_ptr); + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->dynamic()); + CHECK(matmul_ptr->ndim() == 1); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + REQUIRE(arr_ptr->size(state) == 0); + + THEN("The initial MatMulNode state is correct") { + CHECK(matmul_ptr->size(state) == 0); + CHECK_THAT(matmul_ptr->shape(state), RangeEquals({0})); + CHECK(matmul_ptr->view(state).size() == 0); + } + + AND_WHEN("We grow the set and propagate") { + arr_ptr->grow(state, {5.0, 7.0, 9.0, 6.0, 8.0, 10.0}); + graph.propagate(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 2); + CHECK_THAT(matmul_ptr->shape(state), RangeEquals({2})); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({46, 52})); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 0); + CHECK_THAT(matmul_ptr->shape(state), RangeEquals({0})); + CHECK(matmul_ptr->view(state).size() == 0); + } + } + + AND_WHEN("We commit") { + graph.commit(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 2); + CHECK_THAT(matmul_ptr->shape(state), RangeEquals({2})); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({46, 52})); + } + } + } + } + } + + GIVEN("A 2d dynamic testing node and a 1d slice") { + auto arr_ptr = + graph.emplace_node(std::initializer_list{-1, 3}); + auto vec_ptr = graph.emplace_node(arr_ptr, Slice(), 0); + + CHECK_THROWS_AS(MatMulNode(arr_ptr, vec_ptr), std::invalid_argument); + + auto matmul_ptr = graph.emplace_node(vec_ptr, arr_ptr); + graph.emplace_node(matmul_ptr); + + CHECK(not matmul_ptr->dynamic()); + CHECK(matmul_ptr->ndim() == 1); + CHECK(matmul_ptr->size() == 3); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({3})); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + REQUIRE(arr_ptr->size(state) == 0); + + THEN("The initial MatMulNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({0, 0, 0})); + } + + AND_WHEN("We grow the set and propagate") { + arr_ptr->grow(state, {5.0, 7.0, 9.0, 6.0, 8.0, 10.0}); + graph.propagate(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({61, 83, 105})); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({0, 0, 0})); + } + } + + AND_WHEN("We commit") { + graph.commit(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({61, 83, 105})); + } + } + } + } + } + + GIVEN("A 4d dynamic testing node and a matmulable reshape") { + auto arr_ptr = graph.emplace_node( + std::initializer_list{-1, 3, 2, 7}); + auto reshape_ptr = + graph.emplace_node(arr_ptr, std::initializer_list{-1, 3, 7, 2}); + + auto matmul_ptr = graph.emplace_node(arr_ptr, reshape_ptr); + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->dynamic()); + CHECK(matmul_ptr->ndim() == 4); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({-1, 3, 2, 2})); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + REQUIRE(arr_ptr->size(state) == 0); + + THEN("The initial MatMulNode state is correct") { + CHECK(matmul_ptr->view(state).size() == 0); + } + + AND_WHEN("We grow the set and propagate") { + std::vector arr_data(2 * 3 * 2 * 7); + std::iota(arr_data.begin(), arr_data.end(), 0.0); + arr_ptr->grow(state, arr_data); + graph.propagate(state); + + std::vector expected = {182, 203, 476, 546, 2436, 2555, + 3416, 3584, 7434, 7651, 9100, 9366, + 15176, 15491, 17528, 17892, 25662, 26075, + 28700, 29162, 38892, 39403, 42616, 43176}; + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals(expected)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The state is correct") { CHECK(matmul_ptr->view(state).size() == 0); } + } + + AND_WHEN("We commit") { + graph.commit(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals(expected)); + } + } + } + } + } +} + +} // namespace dwave::optimization From f8f092f84fef447eea95c86325893e7d1095100f Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Fri, 21 Nov 2025 15:36:05 -0800 Subject: [PATCH 02/10] Use different ReshapeNode constructor to satisfy macos --- tests/cpp/nodes/test_linear_algebra.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/nodes/test_linear_algebra.cpp b/tests/cpp/nodes/test_linear_algebra.cpp index 44854b6c..95427264 100644 --- a/tests/cpp/nodes/test_linear_algebra.cpp +++ b/tests/cpp/nodes/test_linear_algebra.cpp @@ -387,7 +387,7 @@ TEST_CASE("MatMulNode") { auto arr_ptr = graph.emplace_node( std::initializer_list{-1, 3, 2, 7}); auto reshape_ptr = - graph.emplace_node(arr_ptr, std::initializer_list{-1, 3, 7, 2}); + graph.emplace_node(arr_ptr, std::vector{-1, 3, 7, 2}); auto matmul_ptr = graph.emplace_node(arr_ptr, reshape_ptr); graph.emplace_node(matmul_ptr); From ad064db908cac5387a26d978af6678179e675a06 Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Mon, 24 Nov 2025 15:30:19 -0800 Subject: [PATCH 03/10] Fully support broadcasting for matmul and address review comments --- .../nodes/linear_algebra.hpp | 12 +- .../libcpp/nodes/linear_algebra.pxd | 20 +++ dwave/optimization/mathematical.py | 67 +++++++ .../optimization/src/nodes/linear_algebra.cpp | 151 ++++++++++------ dwave/optimization/symbols/__init__.py | 2 + dwave/optimization/symbols/linear_algebra.pyi | 18 ++ dwave/optimization/symbols/linear_algebra.pyx | 43 +++++ meson.build | 1 + tests/cpp/nodes/test_linear_algebra.cpp | 168 ++++++++++++++---- tests/test_symbols.py | 106 +++++++++++ 10 files changed, 491 insertions(+), 97 deletions(-) create mode 100644 dwave/optimization/libcpp/nodes/linear_algebra.pxd create mode 100644 dwave/optimization/symbols/linear_algebra.pyi create mode 100644 dwave/optimization/symbols/linear_algebra.pyx diff --git a/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp b/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp index e79e8305..6dca38a7 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp @@ -21,9 +21,9 @@ namespace dwave::optimization { -class MatMulNode : public ArrayOutputMixin { +class MatrixMultiplyNode : public ArrayOutputMixin { public: - MatMulNode(ArrayNode* x_ptr, ArrayNode* y_ptr); + MatrixMultiplyNode(ArrayNode* x_ptr, ArrayNode* y_ptr); /// @copydoc Array::buff() double const* buff(const State& state) const override; @@ -52,14 +52,18 @@ class MatMulNode : public ArrayOutputMixin { /// @copydoc Node::revert() void revert(State& state) const override; + /// @copydoc Array::shape() using Array::shape; - std::span shape(const State& state) const override; + /// @copydoc Array::size() using Array::size; - ssize_t size(const State& state) const override; + + /// @copydoc Array::size_diff() ssize_t size_diff(const State& state) const override; + + /// @copydoc Array::size_info() SizeInfo sizeinfo() const override; private: diff --git a/dwave/optimization/libcpp/nodes/linear_algebra.pxd b/dwave/optimization/libcpp/nodes/linear_algebra.pxd new file mode 100644 index 00000000..5599ccae --- /dev/null +++ b/dwave/optimization/libcpp/nodes/linear_algebra.pxd @@ -0,0 +1,20 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dwave.optimization.libcpp.graph cimport ArrayNode + + +cdef extern from "dwave-optimization/nodes/linear_algebra.hpp" namespace "dwave::optimization" nogil: + cdef cppclass MatrixMultiplyNode(ArrayNode): + pass diff --git a/dwave/optimization/mathematical.py b/dwave/optimization/mathematical.py index 74c3703f..77727cb0 100644 --- a/dwave/optimization/mathematical.py +++ b/dwave/optimization/mathematical.py @@ -39,6 +39,7 @@ LinearProgramSolution, Log, Logical, + MatrixMultiply, Maximum, Mean, Minimum, @@ -88,6 +89,7 @@ "logical_not", "logical_or", "logical_xor", + "matmul", "maximum", "mean", "minimum", @@ -1060,6 +1062,71 @@ def logical_xor(x1: ArraySymbol, x2: ArraySymbol) -> Xor: return Xor(x1, x2) +def matmul(x: ArraySymbol, y: ArraySymbol) -> MatrixMultiply: + r"""Compute the matrix product of two array symbols. + + Args: + x, y: Operand array symbols. The size of the last axis of `x` must be + equal to the size of the second to last axis of `y`. If `x` or `y` + are 1-d, they will treated as if they a row or column vector + respectively. If both are 1-d, this will produce a scalar (and the + operation is equivalent to the dot product of two vectors). + + Returns: + A MatrixMultiply symbol representing the matrix product. If `x` and + `y` have shapes `(..., n, k)` and `(..., k, m)`, then the output will + have shape `(..., n, m)`. + + Examples: + This example computes the dot product of two integer arrays. + + >>> from dwave.optimization import Model + >>> from dwave.optimization.mathematical import matmul + ... + >>> model = Model() + >>> i = model.integer(3) + >>> j = model.integer(3) + >>> m = matmul(i, j) + >>> with model.lock(): + ... model.states.resize(1) + ... i.set_state(0, [1, 2, 3]) + ... j.set_state(0, [4, 5, 6]) + ... print(m.state(0)) + 32.0 + + See Also: + :class:`~dwave.optimization.symbols.MatrixMultiply`: equivalent symbol. + + .. versionadded:: 0.6.9 + """ + + def broadcast_missing_axes(a, b): + a_shape = [1,] * (b.ndim() - a.ndim()) + list(a.shape()) + b_shape = [1,] * (a.ndim() - b.ndim()) + list(b.shape()) + + for i in range(len(a_shape) - 2): + if a_shape[i] == 1: + a_shape[i] = b_shape[i] + elif b_shape[i] == 1: + b_shape[i] = a_shape[i] + elif a_shape[i] != b_shape[i]: + raise ValueError("Could not broadcast operands") + + if tuple(a_shape) != a.shape(): + a = broadcast_to(a, a_shape) + if tuple(b_shape) != b.shape(): + b = broadcast_to(b, b_shape) + return a, b + + if x.ndim() == 0 or y.ndim() == 0: + raise ValueError("Operands must not be scalar") + + if not (x.ndim() == 1 or y.ndim() == 1) and x.shape()[:-2] != y.shape()[:-2]: + return MatrixMultiply(*broadcast_missing_axes(x, y)) + + return MatrixMultiply(x, y) + + @_op(Maximum, NaryMaximum, "max") def maximum(x1: ArraySymbol, x2: ArraySymbol, *xi: ArraySymbol, ) -> typing.Union[Maximum, NaryMaximum]: diff --git a/dwave/optimization/src/nodes/linear_algebra.cpp b/dwave/optimization/src/nodes/linear_algebra.cpp index eb046c71..6d925ea1 100644 --- a/dwave/optimization/src/nodes/linear_algebra.cpp +++ b/dwave/optimization/src/nodes/linear_algebra.cpp @@ -21,7 +21,7 @@ namespace dwave::optimization { -////////////////////// MatMulNode +////////////////////// MatrixMultiplyNode // Valid shapes to multiply, and the resulting shape // (-1, 2, 5, 3) and (-1, 2, 3, 7) -> (-1, 2, 5, 7) @@ -70,7 +70,7 @@ std::vector output_shape(const ArrayNode* x_ptr, const ArrayNode* y_ptr } // Now check that the leading subspace shape is identical (no broadcasting for now) - if (x_ptr->ndim() > 2 && y_ptr->ndim() > 2) { + if (x_ptr->ndim() >= 2 && y_ptr->ndim() >= 2) { if (x_ptr->ndim() != y_ptr->ndim()) { throw std::invalid_argument( "operands have different dimensions (use BroadcastNode if you wish to " @@ -87,6 +87,14 @@ std::vector output_shape(const ArrayNode* x_ptr, const ArrayNode* y_ptr // Now we now the leading axes match, we can construct the output shape std::vector shape; + // If x is being broadcast, we need to add the axes from the start of y + if (y_ptr->ndim() > 2 && y_ptr->ndim() > x_ptr->ndim()) { + const ssize_t num_x_leading = std::max(0l, x_ptr->ndim() - 2); + const ssize_t num_y_leading = std::max(0l, y_ptr->ndim() - 2); + for (ssize_t d : y_ptr->shape() | std::views::take(num_y_leading - num_x_leading)) { + shape.push_back(d); + } + } for (ssize_t d : x_ptr->shape() | std::views::take(x_ptr->ndim() - 1)) { shape.push_back(d); } @@ -118,7 +126,7 @@ SizeInfo get_sizeinfo(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { } ValuesInfo get_values_info(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { - // get all possible combinations of values + // Get all possible combinations of values std::array combos{x_ptr->min() * y_ptr->min(), x_ptr->min() * y_ptr->max(), x_ptr->max() * y_ptr->min(), x_ptr->max() * y_ptr->max()}; @@ -157,62 +165,84 @@ std::vector atleast_2d_shape(std::span shape, bool as_ro return {shape.begin(), shape.end()}; } -MatMulNode::MatMulNode(ArrayNode* x_ptr, ArrayNode* y_ptr) - : ArrayOutputMixin(output_shape(x_ptr, y_ptr)), - x_ptr_(x_ptr), - y_ptr_(y_ptr), - sizeinfo_(get_sizeinfo(x_ptr, y_ptr)), - values_info_(get_values_info(x_ptr, y_ptr)) {} - -class MatMulNodeData : public ArrayNodeStateData { +class MatrixMultiplyNodeData : public ArrayNodeStateData { public: - explicit MatMulNodeData(std::vector&& values, std::span shape) + explicit MatrixMultiplyNodeData(std::vector&& values, std::span shape) : ArrayNodeStateData(std::move(values)), shape(shape.begin(), shape.end()) {} std::vector output; std::vector shape; }; +MatrixMultiplyNode::MatrixMultiplyNode(ArrayNode* x_ptr, ArrayNode* y_ptr) + : ArrayOutputMixin(output_shape(x_ptr, y_ptr)), + x_ptr_(x_ptr), + y_ptr_(y_ptr), + sizeinfo_(get_sizeinfo(x_ptr, y_ptr)), + values_info_(get_values_info(x_ptr, y_ptr)) { + add_predecessor(x_ptr); + add_predecessor(y_ptr); +} + ssize_t get_leading_stride(std::span shape) { - ssize_t stride = 1; - if (shape.size() >= 1) stride *= shape.back(); - if (shape.size() >= 2) stride *= shape[shape.size() - 2]; - return stride; + if (shape.size() < 2) return 0; // handles broadcasting for the vector case + return shape.back() * shape[shape.size() - 2]; +} + +ssize_t get_stride(std::span shape, ssize_t index, bool as_row) { + assert(index < 0 && index >= -2); + if (get_axis_size(shape, index, as_row) == 1) return 0; + if (index + 1 == 0) return 1; + return get_axis_size(shape, index + 1, as_row); +} + +ssize_t get_leading_subspace_size(std::span x_shape, + std::span y_shape) { + auto shape = x_shape.size() > y_shape.size() ? x_shape : y_shape; + const ssize_t penultimate_axis = std::max(0, static_cast(shape.size()) - 2); + return std::reduce(shape.begin(), shape.begin() + penultimate_axis, 1, + std::multiplies()); } -void MatMulNode::matmul(State& state, std::span out, - std::span out_shape) const { +void MatrixMultiplyNode::matmul(State& state, std::span out, + std::span out_shape) const { auto x_data = x_ptr_->view(state); auto y_data = y_ptr_->view(state); - ssize_t x_penultimate_axis_size = get_axis_size(x_ptr_->shape(state), -2, true); - ssize_t x_penultimate_axis = std::max(x_ptr_->ndim() - 2, 0); - ssize_t leading_subspace_size = std::reduce(x_ptr_->shape(state).begin(), - x_ptr_->shape(state).begin() + x_penultimate_axis, - 1, std::multiplies()); + const ssize_t x_penultimate_axis_size = get_axis_size(x_ptr_->shape(state), -2, true); + const ssize_t leading_subspace_size = + get_leading_subspace_size(x_ptr_->shape(state), y_ptr_->shape(state)); - ssize_t x_leading_stride = get_leading_stride(x_ptr_->shape(state)); - ssize_t y_leading_stride = get_leading_stride(y_ptr_->shape(state)); - ssize_t out_leading_stride = get_leading_stride(out_shape); + const ssize_t x_leading_stride = get_leading_stride(x_ptr_->shape(state)); + const ssize_t y_leading_stride = get_leading_stride(y_ptr_->shape(state)); + const ssize_t out_leading_stride = [&]() -> ssize_t { + if (x_ptr_->ndim() >= 2 and y_ptr_->ndim() >= 2) return get_leading_stride(out_shape); + if (x_ptr_->ndim() == 1 and y_ptr_->ndim() == 1) return 0; + return out_shape.back(); + }(); - ssize_t y_last_axis_size = get_axis_size(y_ptr_->shape(state), -1, false); - ssize_t y_penultimate_axis_size = get_axis_size(y_ptr_->shape(state), -2, false); + const ssize_t y_last_axis_size = get_axis_size(y_ptr_->shape(state), -1, false); + const ssize_t y_penultimate_axis_size = get_axis_size(y_ptr_->shape(state), -2, false); // TODO: consider using the parent arrays' strides directly - ssize_t x_penultimate_stride = get_axis_size(x_ptr_->shape(state), -1, true); - ssize_t x_last_stride = 1; + // const ssize_t x_penultimate_stride = get_axis_size(x_ptr_->shape(state), -1, true); + const ssize_t x_penultimate_stride = get_stride(x_ptr_->shape(state), -2, true); + const ssize_t x_last_stride = 1; - ssize_t y_penultimate_stride = y_last_axis_size; - ssize_t y_last_stride = 1; + const ssize_t y_penultimate_stride = y_last_axis_size; + const ssize_t y_last_stride = y_ptr_->ndim() >= 2 ? 1 : 0; - ssize_t out_following_stride = get_axis_size(out_shape, -1, false); + const ssize_t out_penultimate_stride = [&]() -> ssize_t { + if (y_ptr_->ndim() == 1) return 1; + return get_axis_size(out_shape, -1, false); + }(); for (ssize_t w = 0; w < leading_subspace_size; w++) { for (ssize_t i = 0; i < x_penultimate_axis_size; i++) { for (ssize_t j = 0; j < y_last_axis_size; j++) { auto x = x_data.begin() + w * x_leading_stride + i * x_penultimate_stride; auto y = y_data.begin() + w * y_leading_stride + j * y_last_stride; - double& out_val = out[w * out_leading_stride + i * out_following_stride + j]; + double& out_val = out[w * out_leading_stride + i * out_penultimate_stride + j]; out_val = 0.0; for (ssize_t k = 0; k < y_penultimate_axis_size; k++) { out_val += *x * *y; @@ -224,7 +254,7 @@ void MatMulNode::matmul(State& state, std::span out, } } -void MatMulNode::initialize_state(State& state) const { +void MatrixMultiplyNode::initialize_state(State& state) const { ssize_t start_size = this->size(); std::vector shape(this->shape().begin(), this->shape().end()); if (this->dynamic()) { @@ -234,61 +264,68 @@ void MatMulNode::initialize_state(State& state) const { std::vector data(start_size); matmul(state, data, shape); - emplace_data_ptr(state, std::move(data), shape); + emplace_data_ptr(state, std::move(data), shape); } -double const* MatMulNode::buff(const State& state) const { - return data_ptr(state)->buff(); +double const* MatrixMultiplyNode::buff(const State& state) const { + return data_ptr(state)->buff(); } -void MatMulNode::commit(State& state) const { return data_ptr(state)->commit(); } +void MatrixMultiplyNode::commit(State& state) const { + return data_ptr(state)->commit(); +} -std::span MatMulNode::diff(const State& state) const { - return data_ptr(state)->diff(); +std::span MatrixMultiplyNode::diff(const State& state) const { + return data_ptr(state)->diff(); } -bool MatMulNode::integral() const { return values_info_.integral; } +bool MatrixMultiplyNode::integral() const { return values_info_.integral; } -double MatMulNode::max() const { return values_info_.max; } +double MatrixMultiplyNode::max() const { return values_info_.max; } -double MatMulNode::min() const { return values_info_.min; } +double MatrixMultiplyNode::min() const { return values_info_.min; } -void MatMulNode::update_shape(State& state) const { +void MatrixMultiplyNode::update_shape(State& state) const { if (this->dynamic()) { - data_ptr(state)->shape[0] = x_ptr_->shape(state)[0]; + data_ptr(state)->shape[0] = x_ptr_->shape(state)[0]; } } -void MatMulNode::propagate(State& state) const { - auto data = data_ptr(state); +void MatrixMultiplyNode::propagate(State& state) const { + if (x_ptr_->diff(state).size() == 0 and y_ptr_->diff(state).size() == 0) return; + + auto data = data_ptr(state); + this->update_shape(state); ssize_t new_size = size_from_shape(data->shape); + data->output.resize(new_size); + this->matmul(state, data->output, data->shape); data->assign(data->output); } -void MatMulNode::revert(State& state) const { - auto data = data_ptr(state); +void MatrixMultiplyNode::revert(State& state) const { + auto data = data_ptr(state); data->revert(); this->update_shape(state); } -std::span MatMulNode::shape(const State& state) const { +std::span MatrixMultiplyNode::shape(const State& state) const { if (not this->dynamic()) return this->shape(); - return data_ptr(state)->shape; + return data_ptr(state)->shape; } -ssize_t MatMulNode::size(const State& state) const { +ssize_t MatrixMultiplyNode::size(const State& state) const { if (not this->dynamic()) return this->size(); - return data_ptr(state)->size(); + return data_ptr(state)->size(); } -ssize_t MatMulNode::size_diff(const State& state) const { +ssize_t MatrixMultiplyNode::size_diff(const State& state) const { if (not this->dynamic()) return 0; - return data_ptr(state)->size_diff(); + return data_ptr(state)->size_diff(); } -SizeInfo MatMulNode::sizeinfo() const { return sizeinfo_; } +SizeInfo MatrixMultiplyNode::sizeinfo() const { return sizeinfo_; } } // namespace dwave::optimization diff --git a/dwave/optimization/symbols/__init__.py b/dwave/optimization/symbols/__init__.py index 6afdc3e5..2421e7f1 100644 --- a/dwave/optimization/symbols/__init__.py +++ b/dwave/optimization/symbols/__init__.py @@ -52,6 +52,7 @@ ) from dwave.optimization.symbols.inputs import Input from dwave.optimization.symbols.interpolation import BSpline +from dwave.optimization.symbols.linear_algebra import MatrixMultiply from dwave.optimization.symbols.lp import ( LinearProgram, LinearProgramFeasible, @@ -149,6 +150,7 @@ "ListVariable", "Log", "Logical", + "MatrixMultiply", "Max", "Maximum", "Mean", diff --git a/dwave/optimization/symbols/linear_algebra.pyi b/dwave/optimization/symbols/linear_algebra.pyi new file mode 100644 index 00000000..9e99d0bf --- /dev/null +++ b/dwave/optimization/symbols/linear_algebra.pyi @@ -0,0 +1,18 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dwave.optimization.model import ArraySymbol as _ArraySymbol + + +class MatrixMultiply(_ArraySymbol): ... diff --git a/dwave/optimization/symbols/linear_algebra.pyx b/dwave/optimization/symbols/linear_algebra.pyx new file mode 100644 index 00000000..29c551e8 --- /dev/null +++ b/dwave/optimization/symbols/linear_algebra.pyx @@ -0,0 +1,43 @@ +# cython: auto_pickle=False + +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cython.operator cimport typeid + +from dwave.optimization._model cimport _Graph, _register, ArraySymbol +from dwave.optimization.libcpp.nodes.linear_algebra cimport ( + MatrixMultiplyNode, +) + + +cdef class MatrixMultiply(ArraySymbol): + """MatrixMultiply symbol. + + See Also: + :func:`~dwave.optimization.mathematical.matmul`: equivalent function. + + .. versionadded:: 0.6.9 + """ + def __init__(self, ArraySymbol x, ArraySymbol y): + cdef _Graph model = x.model + if y.model is not model: + raise ValueError("operands must be from the same model") + + cdef MatrixMultiplyNode* ptr = model._graph.emplace_node[MatrixMultiplyNode]( + x.array_ptr, y.array_ptr, + ) + self.initialize_arraynode(model, ptr) + +_register(MatrixMultiply, typeid(MatrixMultiplyNode)) diff --git a/meson.build b/meson.build index 9c601685..8785c288 100644 --- a/meson.build +++ b/meson.build @@ -103,6 +103,7 @@ foreach name : [ 'indexing', 'inputs', 'interpolation', + 'linear_algebra', 'lp', 'manipulation', 'naryop', diff --git a/tests/cpp/nodes/test_linear_algebra.cpp b/tests/cpp/nodes/test_linear_algebra.cpp index 95427264..cb525148 100644 --- a/tests/cpp/nodes/test_linear_algebra.cpp +++ b/tests/cpp/nodes/test_linear_algebra.cpp @@ -23,23 +23,24 @@ #include "dwave-optimization/nodes/indexing.hpp" #include "dwave-optimization/nodes/linear_algebra.hpp" #include "dwave-optimization/nodes/manipulation.hpp" +#include "dwave-optimization/nodes/numbers.hpp" #include "dwave-optimization/nodes/testing.hpp" using Catch::Matchers::RangeEquals; namespace dwave::optimization { -TEST_CASE("MatMulNode") { +TEST_CASE("MatrixMultiplyNode") { auto graph = Graph(); - GIVEN("A dynamic testing array node and matmul on it") { + GIVEN("A dynamic testing array node and MatrixMultiply on it") { auto arr = DynamicArrayTestingNode(std::initializer_list{-1}, -10.0, -5.0, false); auto constant = ConstantNode(15.0); auto add = AddNode(&arr, &constant); REQUIRE(add.min() == 5.0); REQUIRE(add.max() == 10.0); - auto matmul = MatMulNode(&arr, &add); - THEN("MatMulNode reports correct min and max") { + auto matmul = MatrixMultiplyNode(&arr, &add); + THEN("MatrixMultiplyNode reports correct min and max") { CHECK(matmul.min() == ValuesInfo().min); CHECK(matmul.max() == 0.0); } @@ -52,8 +53,8 @@ TEST_CASE("MatMulNode") { auto add = AddNode(&arr, &constant); REQUIRE(add.min() == 5.0); REQUIRE(add.max() == 10.0); - auto matmul = MatMulNode(&arr, &add); - THEN("MatMulNode reports correct min and max") { + auto matmul = MatrixMultiplyNode(&arr, &add); + THEN("MatrixMultiplyNode reports correct min and max") { CHECK(matmul.min() == ValuesInfo().min); CHECK(matmul.max() == -5.0 * 5.0 * 3); } @@ -66,18 +67,31 @@ TEST_CASE("MatMulNode") { auto add = AddNode(&arr, &constant); REQUIRE(add.min() == 5.0); REQUIRE(add.max() == 10.0); - auto matmul = MatMulNode(&arr, &add); - THEN("MatMulNode reports correct min and max") { + auto matmul = MatrixMultiplyNode(&arr, &add); + THEN("MatrixMultiplyNode reports correct min and max") { CHECK(matmul.min() == -10.0 * 10.0 * 7); CHECK(matmul.max() == -5.0 * 5.0 * 3); } } - GIVEN("Two constant 1d nodes and a MatMulNode") { + SECTION("Higher order broadcasting") { + auto a = IntegerNode({5, 4, 3, 2}); + + auto b = IntegerNode({2, 7}); + CHECK_THROWS_AS(MatrixMultiplyNode(&a, &b), std::invalid_argument); + + auto c = IntegerNode({4, 2, 1}); + CHECK_THROWS_AS(MatrixMultiplyNode(&a, &c), std::invalid_argument); + + auto d = IntegerNode({5, 1, 2, 1}); + CHECK_THROWS_AS(MatrixMultiplyNode(&a, &d), std::invalid_argument); + } + + GIVEN("Two constant 1d nodes and a MatrixMultiplyNode") { auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0}); auto c2_ptr = graph.emplace_node(std::vector{4.0, 5.0, 6.0}); - auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); graph.emplace_node(matmul_ptr); @@ -90,19 +104,19 @@ TEST_CASE("MatMulNode") { WHEN("We initialize a state") { auto state = graph.initialize_state(); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK_THAT(matmul_ptr->view(state), RangeEquals({32})); } } } - GIVEN("Two constant 2d nodes and a MatMulNode") { + GIVEN("Two constant 2d nodes and a MatrixMultiplyNode") { auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, std::vector{2, 3}); auto c2_ptr = graph.emplace_node(std::vector{7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, std::vector{3, 2}); - auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); graph.emplace_node(matmul_ptr); @@ -116,19 +130,19 @@ TEST_CASE("MatMulNode") { WHEN("We initialize a state") { auto state = graph.initialize_state(); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK_THAT(matmul_ptr->view(state), RangeEquals({58, 64, 139, 154})); } } } - GIVEN("Two constant 2d nodes and a MatMulNode") { + GIVEN("Two constant 2d nodes and a MatrixMultiplyNode") { auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, std::vector{2, 3}); auto c2_ptr = graph.emplace_node(std::vector{7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, std::vector{3, 2}); - auto matmul_ptr = graph.emplace_node(c2_ptr, c1_ptr); + auto matmul_ptr = graph.emplace_node(c2_ptr, c1_ptr); graph.emplace_node(matmul_ptr); @@ -139,7 +153,7 @@ TEST_CASE("MatMulNode") { WHEN("We initialize a state") { auto state = graph.initialize_state(); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK_THAT(matmul_ptr->view(state), RangeEquals({39, 54, 69, 49, 68, 87, 59, 82, 105})); } @@ -154,7 +168,7 @@ TEST_CASE("MatMulNode") { std::vector{3, 3}); AND_GIVEN("The 1d node @ the 2d node") { - auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); graph.emplace_node(matmul_ptr); @@ -164,14 +178,14 @@ TEST_CASE("MatMulNode") { WHEN("We initialize a state") { auto state = graph.initialize_state(); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK_THAT(matmul_ptr->view(state), RangeEquals({48, 54, 60})); } } } AND_GIVEN("The 2d node @ the 1d node") { - auto matmul_ptr = graph.emplace_node(c2_ptr, c1_ptr); + auto matmul_ptr = graph.emplace_node(c2_ptr, c1_ptr); graph.emplace_node(matmul_ptr); @@ -181,16 +195,98 @@ TEST_CASE("MatMulNode") { WHEN("We initialize a state") { auto state = graph.initialize_state(); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK_THAT(matmul_ptr->view(state), RangeEquals({32, 50, 68})); } } } } - GIVEN("Two 4d constant nodes and a MatMulNode") { + GIVEN("One 1d and one 4d constant nodes and a MatrixMultiplyNode") { + std::vector c1_shape{7}; + const ssize_t c1_size = 7; + std::vector c1_data(c1_size); + std::iota(c1_data.begin(), c1_data.end(), 2.0); + auto c1_ptr = graph.emplace_node(c1_data, c1_shape); + REQUIRE(c1_ptr->min() == 2.0); + REQUIRE(c1_ptr->max() == 8.0); + + std::vector c2_shape{5, 3, 7, 2}; + ssize_t c2_size = 5 * 3 * 7 * 2; + std::vector c2_data(c2_size); + std::iota(c2_data.begin(), c2_data.end(), -10.0); + auto c2_ptr = graph.emplace_node(c2_data, c2_shape); + REQUIRE(c2_ptr->min() == -10.0); + REQUIRE(c2_ptr->max() == 199.0); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 5 * 3 * 2); + CHECK(matmul_ptr->ndim() == 3); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({5, 3, 2})); + + CHECK(matmul_ptr->min() == 8.0 * -10.0 * 7); + CHECK(matmul_ptr->max() == 8.0 * 199.0 * 7); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT( + matmul_ptr->view(state), + RangeEquals({-84, -49, 406, 441, 896, 931, 1386, 1421, 1876, 1911, + 2366, 2401, 2856, 2891, 3346, 3381, 3836, 3871, 4326, 4361, + 4816, 4851, 5306, 5341, 5796, 5831, 6286, 6321, 6776, 6811})); + } + } + } + + GIVEN("One 4d and one 1d constant nodes and a MatrixMultiplyNode") { + std::vector c1_shape{5, 3, 2, 7}; + ssize_t c1_size = 5 * 3 * 2 * 7; + std::vector c1_data(c1_size); + std::iota(c1_data.begin(), c1_data.end(), -10.0); + auto c1_ptr = graph.emplace_node(c1_data, c1_shape); + REQUIRE(c1_ptr->min() == -10.0); + REQUIRE(c1_ptr->max() == 199.0); + + std::vector c2_shape{7}; + const ssize_t c2_size = 7; + std::vector c2_data(c2_size); + std::iota(c2_data.begin(), c2_data.end(), 2.0); + auto c2_ptr = graph.emplace_node(c2_data, c2_shape); + REQUIRE(c2_ptr->min() == 2.0); + REQUIRE(c2_ptr->max() == 8.0); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 5 * 3 * 2); + CHECK(matmul_ptr->ndim() == 3); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({5, 3, 2})); + + CHECK(matmul_ptr->min() == 8.0 * -10.0 * 7); + CHECK(matmul_ptr->max() == 8.0 * 199.0 * 7); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT( + matmul_ptr->view(state), + RangeEquals({-217, 28, 273, 518, 763, 1008, 1253, 1498, 1743, 1988, + 2233, 2478, 2723, 2968, 3213, 3458, 3703, 3948, 4193, 4438, + 4683, 4928, 5173, 5418, 5663, 5908, 6153, 6398, 6643, 6888})); + } + } + } + + GIVEN("Two 4d constant nodes and a MatrixMultiplyNode") { std::vector c1_shape{5, 3, 1, 7}; - ssize_t c1_size = 5 * 3 * 1 * 7; + const ssize_t c1_size = 5 * 3 * 1 * 7; std::vector c1_data(c1_size); std::iota(c1_data.begin(), c1_data.end(), -20.0); auto c1_ptr = graph.emplace_node(c1_data, c1_shape); @@ -205,7 +301,7 @@ TEST_CASE("MatMulNode") { REQUIRE(c2_ptr->min() == -10.0); REQUIRE(c2_ptr->max() == 199.0); - auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); graph.emplace_node(matmul_ptr); @@ -219,7 +315,7 @@ TEST_CASE("MatMulNode") { WHEN("We initialize a state") { auto state = graph.initialize_state(); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK_THAT(matmul_ptr->view(state), RangeEquals({532, 413, -644, -714, -448, -469, 1120, 1148, 4060, 4137, 8372, 8498, 14056, 14231, 21112, 21336, @@ -229,9 +325,9 @@ TEST_CASE("MatMulNode") { } } - GIVEN("A set node and MatMulNode representing self dot product") { + GIVEN("A set node and MatrixMultiplyNode representing self dot product") { auto set_ptr = graph.emplace_node(10); - auto matmul_ptr = graph.emplace_node(set_ptr, set_ptr); + auto matmul_ptr = graph.emplace_node(set_ptr, set_ptr); graph.emplace_node(matmul_ptr); CHECK(matmul_ptr->ndim() == 0); @@ -240,7 +336,7 @@ TEST_CASE("MatMulNode") { auto state = graph.initialize_state(); REQUIRE(set_ptr->size(state) == 0); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK(matmul_ptr->size(state) == 1); CHECK(matmul_ptr->shape(state).size() == 0); CHECK_THAT(matmul_ptr->view(state), RangeEquals({0.0})); @@ -282,9 +378,9 @@ TEST_CASE("MatMulNode") { graph.emplace_node(std::initializer_list{-1, 3}); auto c_ptr = graph.emplace_node(std::vector{1, 2, 3}); - CHECK_THROWS_AS(MatMulNode(c_ptr, arr_ptr), std::invalid_argument); + CHECK_THROWS_AS(MatrixMultiplyNode(c_ptr, arr_ptr), std::invalid_argument); - auto matmul_ptr = graph.emplace_node(arr_ptr, c_ptr); + auto matmul_ptr = graph.emplace_node(arr_ptr, c_ptr); graph.emplace_node(matmul_ptr); CHECK(matmul_ptr->dynamic()); @@ -294,7 +390,7 @@ TEST_CASE("MatMulNode") { auto state = graph.initialize_state(); REQUIRE(arr_ptr->size(state) == 0); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK(matmul_ptr->size(state) == 0); CHECK_THAT(matmul_ptr->shape(state), RangeEquals({0})); CHECK(matmul_ptr->view(state).size() == 0); @@ -338,9 +434,9 @@ TEST_CASE("MatMulNode") { graph.emplace_node(std::initializer_list{-1, 3}); auto vec_ptr = graph.emplace_node(arr_ptr, Slice(), 0); - CHECK_THROWS_AS(MatMulNode(arr_ptr, vec_ptr), std::invalid_argument); + CHECK_THROWS_AS(MatrixMultiplyNode(arr_ptr, vec_ptr), std::invalid_argument); - auto matmul_ptr = graph.emplace_node(vec_ptr, arr_ptr); + auto matmul_ptr = graph.emplace_node(vec_ptr, arr_ptr); graph.emplace_node(matmul_ptr); CHECK(not matmul_ptr->dynamic()); @@ -352,7 +448,7 @@ TEST_CASE("MatMulNode") { auto state = graph.initialize_state(); REQUIRE(arr_ptr->size(state) == 0); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK_THAT(matmul_ptr->view(state), RangeEquals({0, 0, 0})); } @@ -389,7 +485,7 @@ TEST_CASE("MatMulNode") { auto reshape_ptr = graph.emplace_node(arr_ptr, std::vector{-1, 3, 7, 2}); - auto matmul_ptr = graph.emplace_node(arr_ptr, reshape_ptr); + auto matmul_ptr = graph.emplace_node(arr_ptr, reshape_ptr); graph.emplace_node(matmul_ptr); CHECK(matmul_ptr->dynamic()); @@ -400,7 +496,7 @@ TEST_CASE("MatMulNode") { auto state = graph.initialize_state(); REQUIRE(arr_ptr->size(state) == 0); - THEN("The initial MatMulNode state is correct") { + THEN("The initial MatrixMultiplyNode state is correct") { CHECK(matmul_ptr->view(state).size() == 0); } diff --git a/tests/test_symbols.py b/tests/test_symbols.py index abb74cb2..bed3bc69 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -38,6 +38,7 @@ logical_or, logical_not, logical_xor, + matmul, mod, put, rint, @@ -1156,6 +1157,7 @@ def test_interning(self): self.assertEqual(z.shape(), (4, 2)) self.assertNotEqual(x.id(), z.id()) + class TestCopy(utils.SymbolTests): def generate_symbols(self): model = Model() @@ -2279,6 +2281,110 @@ def test_serialization_with_states(self): np.testing.assert_array_equal(lp.state(3), [1, 1]) +class TestMatrixMultiply(utils.SymbolTests): + def generate_symbols(self): + model = Model() + c = model.constant(self._shaped_range(3, 4)) + c_reshape = c.reshape((4, 3)) + mm = dwave.optimization.symbols.MatrixMultiply(c, c_reshape) + + with model.lock(): + yield mm + + def test_matmul(self): + model = Model() + c = model.constant(self._shaped_range(3, 4)) + c_reshape = c.reshape((4, 3)) + mm = matmul(c, c_reshape) + self.assertIsInstance(mm, dwave.optimization.symbols.MatrixMultiply) + + def test_matmul_scalar(self): + model = Model() + with self.assertRaises(ValueError): + matmul(model.constant(np.arange(4)), model.constant(2)) + with self.assertRaises(ValueError): + matmul(model.constant(2), model.constant(np.arange(4))) + with self.assertRaises(ValueError): + matmul(model.constant(self._shaped_range(2, 3)), model.constant(2)) + with self.assertRaises(ValueError): + matmul(model.constant(2), model.constant(self._shaped_range(2, 3))) + + def test_matmul_broadcast_x(self): + model = Model() + x_data = self._shaped_range(5, 2, 3, 4) + x = model.constant(x_data) + + for shape in [ + (4,), + (4, 6), + (1, 4, 3), + (5, 2, 4, 3), + (5, 1, 4, 3), + (2, 1, 1, 4, 3), + ]: + y_data = self._shaped_range(*shape) + np_res = np.matmul(x_data, y_data) + y = model.constant(y_data) + mm = matmul(x, y) + with model.lock(): + model.states.resize(1) + self.assertTrue(np.array_equal(mm.state(0), np_res)) + + with self.assertRaises(ValueError): + matmul(x, model.constant(np.ones((5, 7, 4, 3)))) + + def test_matmul_broadcast_y(self): + model = Model() + y_data = self._shaped_range(5, 2, 3, 4) + y = model.constant(y_data) + + for shape in [ + (3,), + (6, 3), + (1, 2, 3), + (5, 2, 4, 3), + (5, 1, 4, 3), + (2, 1, 1, 4, 3), + ]: + x_data = self._shaped_range(*shape) + np_res = np.matmul(x_data, y_data) + x = model.constant(x_data) + mm = matmul(x, y) + with model.lock(): + model.states.resize(1) + self.assertTrue(np.array_equal(mm.state(0), np_res)) + + with self.assertRaises(ValueError): + matmul(y, model.constant(np.ones((5, 7, 4, 3)))) + + def test_matmul_broadcast_both_operands(self): + model = Model() + + for x_shape, y_shape in [ + [(3,), (3,)], + [(7, 3, 2), (1, 2, 5)], + [(5, 7, 3, 2), (1, 1, 2, 5)], + [(1, 7, 3, 2), (5, 1, 2, 5)], + [(1, 7, 3, 2), (4, 5, 1, 2, 5)], + ]: + x_data = self._shaped_range(*x_shape) + x = model.constant(x_data) + + y_data = self._shaped_range(*y_shape) + y = model.constant(y_data) + + np_res = np.matmul(x_data, y_data) + + mm = matmul(x, y) + with model.lock(): + model.states.resize(1) + self.assertTrue(np.array_equal(mm.state(0), np_res)) + + @staticmethod + def _shaped_range(*shape): + return np.arange(np.prod(shape)).reshape(shape) + + class TestMax(utils.ReduceTests): empty_requires_initial = True From 3a75f3e8017249af02e00b0f3b51a83942a93b9c Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Mon, 24 Nov 2025 15:44:05 -0800 Subject: [PATCH 04/10] Fix std::max call for windows and add asserts in SizeInfo::operator/ --- .../include/dwave-optimization/array.hpp | 12 ++++++++++-- dwave/optimization/src/nodes/linear_algebra.cpp | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/array.hpp b/dwave/optimization/include/dwave-optimization/array.hpp index af6b6631..34c35af4 100644 --- a/dwave/optimization/include/dwave-optimization/array.hpp +++ b/dwave/optimization/include/dwave-optimization/array.hpp @@ -92,8 +92,16 @@ struct SizeInfo { if (!n) throw std::invalid_argument("cannot divide by 0"); multiplier /= n; offset /= n; - if (min.has_value()) min.value() /= n; - if (max.has_value()) max.value() /= n; + if (min.has_value()) { + assert(min.value() % n == 0 and + "dividing SizeInfo with a divisor that does not evenly divide the minimum"); + min.value() /= n; + } + if (max.has_value()) { + assert(max.value() % n == 0 and + "dividing SizeInfo with a divisor that does not evenly divide the maximum"); + max.value() /= n; + } return *this; } friend SizeInfo operator/(SizeInfo lhs, const std::integral auto rhs) { diff --git a/dwave/optimization/src/nodes/linear_algebra.cpp b/dwave/optimization/src/nodes/linear_algebra.cpp index 6d925ea1..bf96433d 100644 --- a/dwave/optimization/src/nodes/linear_algebra.cpp +++ b/dwave/optimization/src/nodes/linear_algebra.cpp @@ -89,8 +89,8 @@ std::vector output_shape(const ArrayNode* x_ptr, const ArrayNode* y_ptr std::vector shape; // If x is being broadcast, we need to add the axes from the start of y if (y_ptr->ndim() > 2 && y_ptr->ndim() > x_ptr->ndim()) { - const ssize_t num_x_leading = std::max(0l, x_ptr->ndim() - 2); - const ssize_t num_y_leading = std::max(0l, y_ptr->ndim() - 2); + const ssize_t num_x_leading = std::max(0, x_ptr->ndim() - 2); + const ssize_t num_y_leading = std::max(0, y_ptr->ndim() - 2); for (ssize_t d : y_ptr->shape() | std::views::take(num_y_leading - num_x_leading)) { shape.push_back(d); } From d47f35a1612e8e6f01a61ccae9980b84a7f51f59 Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Mon, 24 Nov 2025 15:51:19 -0800 Subject: [PATCH 05/10] Add missing includes --- dwave/optimization/src/nodes/linear_algebra.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dwave/optimization/src/nodes/linear_algebra.cpp b/dwave/optimization/src/nodes/linear_algebra.cpp index bf96433d..e27a8718 100644 --- a/dwave/optimization/src/nodes/linear_algebra.cpp +++ b/dwave/optimization/src/nodes/linear_algebra.cpp @@ -14,6 +14,10 @@ #include "dwave-optimization/nodes/linear_algebra.hpp" +#include +#include +#include + #include "../functional_.hpp" #include "_state.hpp" #include "dwave-optimization/array.hpp" From 0a80d2dde4b169e2206c89a49af4ecc5a1afa8fe Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Mon, 24 Nov 2025 16:43:01 -0800 Subject: [PATCH 06/10] Add release notes and some clarifying comments --- .../optimization/src/nodes/linear_algebra.cpp | 22 +++++++++++-------- ...ultiplication-symbol-ebf608660b82adfa.yaml | 7 ++++++ 2 files changed, 20 insertions(+), 9 deletions(-) create mode 100644 releasenotes/notes/add-matrix-multiplication-symbol-ebf608660b82adfa.yaml diff --git a/dwave/optimization/src/nodes/linear_algebra.cpp b/dwave/optimization/src/nodes/linear_algebra.cpp index e27a8718..f646c817 100644 --- a/dwave/optimization/src/nodes/linear_algebra.cpp +++ b/dwave/optimization/src/nodes/linear_algebra.cpp @@ -118,14 +118,16 @@ SizeInfo get_sizeinfo(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { assert(size >= 1); return SizeInfo(size); } + + // The size should be x's size, divided by the size of x's last dimension, + // multiplied by the size of y's last dimension (if matrix or higher dim). assert(x_ptr->shape().back() != -1); SizeInfo sizeinfo = x_ptr->sizeinfo() / x_ptr->shape().back(); - if (y_ptr->ndim() == 2 && y_ptr->dynamic()) { - assert(x_ptr->dynamic() && x_ptr->ndim() == 1); - } else if (y_ptr->ndim() >= 2) { + if (y_ptr->ndim() >= 2) { assert(y_ptr->shape().back() != -1); sizeinfo *= y_ptr->shape().back(); } + return sizeinfo; } @@ -162,10 +164,11 @@ ValuesInfo get_values_info(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { return values_info; } -std::vector atleast_2d_shape(std::span shape, bool as_row) { +std::vector atleast_2d_shape(std::span shape, bool vector_as_row) { + // If vector_as_row is true, treat vector as shape (1, size), else as shape (size, 1) if (shape.size() == 0) return {1, 1}; - if (shape.size() == 1 and as_row) return {1, shape[0]}; - if (shape.size() == 1 and not as_row) return {shape[0], 1}; + if (shape.size() == 1 and vector_as_row) return {1, shape[0]}; + if (shape.size() == 1 and not vector_as_row) return {shape[0], 1}; return {shape.begin(), shape.end()}; } @@ -193,11 +196,12 @@ ssize_t get_leading_stride(std::span shape) { return shape.back() * shape[shape.size() - 2]; } -ssize_t get_stride(std::span shape, ssize_t index, bool as_row) { +ssize_t get_stride(std::span shape, ssize_t index, bool vector_as_row) { + // If vector_as_row is true, treat vector as shape (1, size), else as shape (size, 1) assert(index < 0 && index >= -2); - if (get_axis_size(shape, index, as_row) == 1) return 0; + if (get_axis_size(shape, index, vector_as_row) == 1) return 0; if (index + 1 == 0) return 1; - return get_axis_size(shape, index + 1, as_row); + return get_axis_size(shape, index + 1, vector_as_row); } ssize_t get_leading_subspace_size(std::span x_shape, diff --git a/releasenotes/notes/add-matrix-multiplication-symbol-ebf608660b82adfa.yaml b/releasenotes/notes/add-matrix-multiplication-symbol-ebf608660b82adfa.yaml new file mode 100644 index 00000000..bcf31629 --- /dev/null +++ b/releasenotes/notes/add-matrix-multiplication-symbol-ebf608660b82adfa.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + Add the ``MatrixMultiplication`` symbol and corresponding method + ``matmul``. The ``matmul`` method follows the behavior of NumPy's + ``matmul``, meaning that it works with matrices, vectors, and higher order + arrays. From f037cf07e2c213dac10f730cd2f4bebb015121b1 Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Tue, 25 Nov 2025 14:08:10 -0800 Subject: [PATCH 07/10] Make Array::shape_to_size public and use it in MatrixMultiplyNode, along with other small code improvements. Also fixed a bug in the ValuesInfo for MatrixMultiplyNode as it was not always calculating the contracted axis size correctly. --- .../include/dwave-optimization/array.hpp | 26 ++++---- .../nodes/linear_algebra.hpp | 4 +- .../optimization/src/nodes/linear_algebra.cpp | 61 ++++++++++--------- tests/cpp/nodes/test_linear_algebra.cpp | 9 ++- 4 files changed, 52 insertions(+), 48 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/array.hpp b/dwave/optimization/include/dwave-optimization/array.hpp index 34c35af4..2a9d8c42 100644 --- a/dwave/optimization/include/dwave-optimization/array.hpp +++ b/dwave/optimization/include/dwave-optimization/array.hpp @@ -483,6 +483,19 @@ class Array { return 0; } + // Determine the size by the shape. For a node with a fixed size, it is simply + // the product of the shape. + // Expects the shape to be stored in a C-style array of length ndim. + static ssize_t shape_to_size(const ssize_t ndim, const ssize_t* const shape) noexcept { + if (ndim <= 0) return 1; + if (shape[0] < 0) return DYNAMIC_SIZE; + return std::reduce(shape, shape + ndim, 1, std::multiplies()); + } + + static ssize_t shape_to_size(const std::span shape) noexcept { + return shape_to_size(shape.size(), shape.data()); + } + protected: // Some utility methods that might be useful to subclasses @@ -509,19 +522,6 @@ class Array { return true; } - // Determine the size by the shape. For a node with a fixed size, it is simply - // the product of the shape. - // Expects the shape to be stored in a C-style array of length ndim. - static ssize_t shape_to_size(const ssize_t ndim, const ssize_t* shape) noexcept { - if (ndim <= 0) return 1; - if (shape[0] < 0) return DYNAMIC_SIZE; - return std::reduce(shape, shape + ndim, 1, std::multiplies()); - } - - static ssize_t shape_to_size(const std::span shape) noexcept { - return shape_to_size(shape.size(), shape.data()); - } - // Determine the strides from the shape. // Assumes itemsize = sizeof(double). // Expects the shape to be stored in a C-style array of length ndim. diff --git a/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp b/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp index 6dca38a7..da42fc21 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp @@ -67,8 +67,8 @@ class MatrixMultiplyNode : public ArrayOutputMixin { SizeInfo sizeinfo() const override; private: - void matmul(State& state, std::span out, std::span out_shape) const; - void update_shape(State& state) const; + void matmul_(State& state, std::span out, std::span out_shape) const; + void update_shape_(State& state) const; const ArrayNode* x_ptr_; const ArrayNode* y_ptr_; diff --git a/dwave/optimization/src/nodes/linear_algebra.cpp b/dwave/optimization/src/nodes/linear_algebra.cpp index f646c817..8ee8f95d 100644 --- a/dwave/optimization/src/nodes/linear_algebra.cpp +++ b/dwave/optimization/src/nodes/linear_algebra.cpp @@ -34,10 +34,6 @@ namespace dwave::optimization { // (-1) and (-1) -> () // (-1) and (-1, 5) -> (5) -ssize_t size_from_shape(std::span shape) { - return std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); -} - ssize_t get_axis_size(std::span shape, ssize_t index, bool vector_as_row) { // If vector_as_row is true, treat vector as shape (1, size), else as shape (size, 1) assert(index < 0); @@ -54,8 +50,8 @@ std::vector output_shape(const ArrayNode* x_ptr, const ArrayNode* y_ptr } // Check that last dimension of x matches the second last dimension of y - ssize_t x_last_axis_size = get_axis_size(x_ptr->shape(), -1, true); - ssize_t y_penultimate_axis_size = get_axis_size(y_ptr->shape(), -2, false); + const ssize_t x_last_axis_size = get_axis_size(x_ptr->shape(), -1, true); + const ssize_t y_penultimate_axis_size = get_axis_size(y_ptr->shape(), -2, false); if (x_last_axis_size != y_penultimate_axis_size) { throw std::invalid_argument( "the last dimension of `x` is not the same size as the second to last dimension of " @@ -64,8 +60,10 @@ std::vector output_shape(const ArrayNode* x_ptr, const ArrayNode* y_ptr assert(x_ptr->dynamic() && y_ptr->dynamic()); // Both are dynamic. We need to check that the dynamic dimension is // always the same size. - ssize_t x_subspace_size = -1 * size_from_shape(x_ptr->shape()); - ssize_t y_subspace_size = -1 * size_from_shape(y_ptr->shape()); + const ssize_t x_subspace_size = Array::shape_to_size(x_ptr->shape().subspan(1)); + const ssize_t y_subspace_size = Array::shape_to_size(y_ptr->shape().subspan(1)); + assert(x_subspace_size != Array::DYNAMIC_SIZE); + assert(y_subspace_size != Array::DYNAMIC_SIZE); if (x_ptr->sizeinfo() / x_subspace_size != y_ptr->sizeinfo() / y_subspace_size) { throw std::invalid_argument( "the last dimension of `x` is not the same size as the second to last " @@ -114,7 +112,7 @@ SizeInfo get_sizeinfo(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { // x must also be dynamic, and we must be contracting along the dynamic // dimension, so the output is fixed size. std::vector shape = output_shape(x_ptr, y_ptr); - ssize_t size = size_from_shape(shape); + ssize_t size = Array::shape_to_size(shape); assert(size >= 1); return SizeInfo(size); } @@ -133,17 +131,20 @@ SizeInfo get_sizeinfo(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { ValuesInfo get_values_info(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { // Get all possible combinations of values - std::array combos{x_ptr->min() * y_ptr->min(), x_ptr->min() * y_ptr->max(), - x_ptr->max() * y_ptr->min(), x_ptr->max() * y_ptr->max()}; + const std::array combos{x_ptr->min() * y_ptr->min(), x_ptr->min() * y_ptr->max(), + x_ptr->max() * y_ptr->min(), x_ptr->max() * y_ptr->max()}; - double min_val = std::ranges::min(combos); - double max_val = std::ranges::max(combos); + const double min_val = std::ranges::min(combos); + const double max_val = std::ranges::max(combos); - ssize_t x_subspace_size = std::reduce(x_ptr->shape().begin(), x_ptr->shape().end() - 1, 1, - std::multiplies()); - SizeInfo contracted_axis_size = x_ptr->sizeinfo() / x_subspace_size; + const SizeInfo contracted_axis_size = [&]() { + // If x is 1d, then the contracted axis size is equal to x's size + if (x_ptr->ndim() == 1) return x_ptr->sizeinfo(); + // Otherwise it's always the last axis of x (which definitionally is not dynamic) + return SizeInfo(x_ptr->shape().back()); + }(); - if (contracted_axis_size.max.has_value() and *contracted_axis_size.max == 0) { + if (contracted_axis_size.max.has_value() and contracted_axis_size.max.value() == 0) { // Output will always be empty, so we can return early return ValuesInfo(0.0, 0.0, true); } @@ -157,7 +158,7 @@ ValuesInfo get_values_info(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { if (min_val <= 0) values_info.min = min_val * contracted_axis_size.max.value(); } - ssize_t min_size = contracted_axis_size.min.value_or(0); + const ssize_t min_size = contracted_axis_size.min.value_or(0); if (max_val < 0) values_info.max = max_val * min_size; if (min_val > 0) values_info.min = min_val * min_size; @@ -206,16 +207,16 @@ ssize_t get_stride(std::span shape, ssize_t index, bool vector_as ssize_t get_leading_subspace_size(std::span x_shape, std::span y_shape) { - auto shape = x_shape.size() > y_shape.size() ? x_shape : y_shape; + const auto shape = x_shape.size() > y_shape.size() ? x_shape : y_shape; const ssize_t penultimate_axis = std::max(0, static_cast(shape.size()) - 2); return std::reduce(shape.begin(), shape.begin() + penultimate_axis, 1, std::multiplies()); } -void MatrixMultiplyNode::matmul(State& state, std::span out, - std::span out_shape) const { - auto x_data = x_ptr_->view(state); - auto y_data = y_ptr_->view(state); +void MatrixMultiplyNode::matmul_(State& state, std::span out, + std::span out_shape) const { + const auto x_data = x_ptr_->view(state); + const auto y_data = y_ptr_->view(state); const ssize_t x_penultimate_axis_size = get_axis_size(x_ptr_->shape(state), -2, true); const ssize_t leading_subspace_size = @@ -267,11 +268,11 @@ void MatrixMultiplyNode::initialize_state(State& state) const { std::vector shape(this->shape().begin(), this->shape().end()); if (this->dynamic()) { shape[0] = x_ptr_->shape(state)[0]; - start_size = size_from_shape(shape); + start_size = Array::shape_to_size(shape); } std::vector data(start_size); - matmul(state, data, shape); + matmul_(state, data, shape); emplace_data_ptr(state, std::move(data), shape); } @@ -293,7 +294,7 @@ double MatrixMultiplyNode::max() const { return values_info_.max; } double MatrixMultiplyNode::min() const { return values_info_.min; } -void MatrixMultiplyNode::update_shape(State& state) const { +void MatrixMultiplyNode::update_shape_(State& state) const { if (this->dynamic()) { data_ptr(state)->shape[0] = x_ptr_->shape(state)[0]; } @@ -304,19 +305,19 @@ void MatrixMultiplyNode::propagate(State& state) const { auto data = data_ptr(state); - this->update_shape(state); - ssize_t new_size = size_from_shape(data->shape); + this->update_shape_(state); + const ssize_t new_size = Array::shape_to_size(data->shape); data->output.resize(new_size); - this->matmul(state, data->output, data->shape); + this->matmul_(state, data->output, data->shape); data->assign(data->output); } void MatrixMultiplyNode::revert(State& state) const { auto data = data_ptr(state); data->revert(); - this->update_shape(state); + this->update_shape_(state); } std::span MatrixMultiplyNode::shape(const State& state) const { diff --git a/tests/cpp/nodes/test_linear_algebra.cpp b/tests/cpp/nodes/test_linear_algebra.cpp index cb525148..c9836863 100644 --- a/tests/cpp/nodes/test_linear_algebra.cpp +++ b/tests/cpp/nodes/test_linear_algebra.cpp @@ -374,8 +374,8 @@ TEST_CASE("MatrixMultiplyNode") { } GIVEN("A 2d dynamic testing node and a 1d constant") { - auto arr_ptr = - graph.emplace_node(std::initializer_list{-1, 3}); + auto arr_ptr = graph.emplace_node( + std::initializer_list{-1, 3}, -3.0, 10.0, false); auto c_ptr = graph.emplace_node(std::vector{1, 2, 3}); CHECK_THROWS_AS(MatrixMultiplyNode(c_ptr, arr_ptr), std::invalid_argument); @@ -386,6 +386,9 @@ TEST_CASE("MatrixMultiplyNode") { CHECK(matmul_ptr->dynamic()); CHECK(matmul_ptr->ndim() == 1); + CHECK(matmul_ptr->min() == 3.0 * -3.0 * 3); + CHECK(matmul_ptr->max() == 3.0 * 10.0 * 3); + WHEN("We initialize a state") { auto state = graph.initialize_state(); REQUIRE(arr_ptr->size(state) == 0); @@ -429,7 +432,7 @@ TEST_CASE("MatrixMultiplyNode") { } } - GIVEN("A 2d dynamic testing node and a 1d slice") { + GIVEN("A 2d dynamic testing node and a 1d column slice") { auto arr_ptr = graph.emplace_node(std::initializer_list{-1, 3}); auto vec_ptr = graph.emplace_node(arr_ptr, Slice(), 0); From ee16139e143e830dd6c909e0dee2839cd4b33887 Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Wed, 26 Nov 2025 11:33:59 -0800 Subject: [PATCH 08/10] Use raw pointers/strides of predecessors in matmul --- .../optimization/src/nodes/linear_algebra.cpp | 53 ++++++++++++------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/dwave/optimization/src/nodes/linear_algebra.cpp b/dwave/optimization/src/nodes/linear_algebra.cpp index 8ee8f95d..3e59168b 100644 --- a/dwave/optimization/src/nodes/linear_algebra.cpp +++ b/dwave/optimization/src/nodes/linear_algebra.cpp @@ -192,17 +192,22 @@ MatrixMultiplyNode::MatrixMultiplyNode(ArrayNode* x_ptr, ArrayNode* y_ptr) add_predecessor(y_ptr); } -ssize_t get_leading_stride(std::span shape) { - if (shape.size() < 2) return 0; // handles broadcasting for the vector case - return shape.back() * shape[shape.size() - 2]; +ssize_t get_leading_stride(std::span shape, std::span strides) { + assert(shape.size() >= 1); + assert(shape.size() == strides.size()); + if (shape.size() == 1) return 0; // handles broadcasting for the vector case + if (shape.size() == 2) return 1; + return strides[shape.size() - 3] / sizeof(double); } -ssize_t get_stride(std::span shape, ssize_t index, bool vector_as_row) { +ssize_t get_stride(std::span shape, std::span strides, ssize_t index, + bool vector_as_row) { // If vector_as_row is true, treat vector as shape (1, size), else as shape (size, 1) assert(index < 0 && index >= -2); if (get_axis_size(shape, index, vector_as_row) == 1) return 0; - if (index + 1 == 0) return 1; - return get_axis_size(shape, index + 1, vector_as_row); + if (index + 1 == 0) return strides.back() / sizeof(double); + assert(shape.size() > 1); + return strides[static_cast(strides.size()) + index] / sizeof(double); } ssize_t get_leading_subspace_size(std::span x_shape, @@ -215,17 +220,20 @@ ssize_t get_leading_subspace_size(std::span x_shape, void MatrixMultiplyNode::matmul_(State& state, std::span out, std::span out_shape) const { - const auto x_data = x_ptr_->view(state); - const auto y_data = y_ptr_->view(state); + assert(static_cast(out.size()) == Array::shape_to_size(out_shape)); + + // If out is empty (possible when predecessors have 0 size) there is nothing to do + if (out.size() == 0) return; const ssize_t x_penultimate_axis_size = get_axis_size(x_ptr_->shape(state), -2, true); const ssize_t leading_subspace_size = get_leading_subspace_size(x_ptr_->shape(state), y_ptr_->shape(state)); - const ssize_t x_leading_stride = get_leading_stride(x_ptr_->shape(state)); - const ssize_t y_leading_stride = get_leading_stride(y_ptr_->shape(state)); + const ssize_t x_leading_stride = get_leading_stride(x_ptr_->shape(state), x_ptr_->strides()); + const ssize_t y_leading_stride = get_leading_stride(y_ptr_->shape(state), y_ptr_->strides()); const ssize_t out_leading_stride = [&]() -> ssize_t { - if (x_ptr_->ndim() >= 2 and y_ptr_->ndim() >= 2) return get_leading_stride(out_shape); + if (x_ptr_->ndim() >= 2 and y_ptr_->ndim() >= 2) + return get_leading_stride(out_shape, this->strides()); if (x_ptr_->ndim() == 1 and y_ptr_->ndim() == 1) return 0; return out_shape.back(); }(); @@ -235,26 +243,33 @@ void MatrixMultiplyNode::matmul_(State& state, std::span out, // TODO: consider using the parent arrays' strides directly // const ssize_t x_penultimate_stride = get_axis_size(x_ptr_->shape(state), -1, true); - const ssize_t x_penultimate_stride = get_stride(x_ptr_->shape(state), -2, true); - const ssize_t x_last_stride = 1; + const ssize_t x_penultimate_stride = + get_stride(x_ptr_->shape(state), x_ptr_->strides(), -2, true); + const ssize_t x_last_stride = x_ptr_->strides().back() / sizeof(double); const ssize_t y_penultimate_stride = y_last_axis_size; - const ssize_t y_last_stride = y_ptr_->ndim() >= 2 ? 1 : 0; + const ssize_t y_last_stride = + y_ptr_->ndim() >= 2 ? y_ptr_->strides().back() / sizeof(double) : 0; const ssize_t out_penultimate_stride = [&]() -> ssize_t { if (y_ptr_->ndim() == 1) return 1; return get_axis_size(out_shape, -1, false); }(); + const double* const x_data = x_ptr_->buff(state); + const double* const y_data = y_ptr_->buff(state); + double* __restrict const out_data = out.data(); + for (ssize_t w = 0; w < leading_subspace_size; w++) { for (ssize_t i = 0; i < x_penultimate_axis_size; i++) { for (ssize_t j = 0; j < y_last_axis_size; j++) { - auto x = x_data.begin() + w * x_leading_stride + i * x_penultimate_stride; - auto y = y_data.begin() + w * y_leading_stride + j * y_last_stride; - double& out_val = out[w * out_leading_stride + i * out_penultimate_stride + j]; - out_val = 0.0; + const double* x = x_data + w * x_leading_stride + i * x_penultimate_stride; + const double* y = y_data + w * y_leading_stride + j * y_last_stride; + double* out_val = + out_data + w * out_leading_stride + i * out_penultimate_stride + j; + *out_val = 0.0; for (ssize_t k = 0; k < y_penultimate_axis_size; k++) { - out_val += *x * *y; + *out_val += *x * *y; x += x_last_stride; y += y_penultimate_stride; } From 13b61a8854b077e6d3b3209093bb7cc3975c0f75 Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Wed, 26 Nov 2025 13:33:42 -0800 Subject: [PATCH 09/10] Fix bug in matmul with leading subspace iteration --- .../optimization/src/nodes/linear_algebra.cpp | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/dwave/optimization/src/nodes/linear_algebra.cpp b/dwave/optimization/src/nodes/linear_algebra.cpp index 3e59168b..16bbc3ac 100644 --- a/dwave/optimization/src/nodes/linear_algebra.cpp +++ b/dwave/optimization/src/nodes/linear_algebra.cpp @@ -192,12 +192,11 @@ MatrixMultiplyNode::MatrixMultiplyNode(ArrayNode* x_ptr, ArrayNode* y_ptr) add_predecessor(y_ptr); } -ssize_t get_leading_stride(std::span shape, std::span strides) { +ssize_t get_leading_stride(std::span shape) { assert(shape.size() >= 1); - assert(shape.size() == strides.size()); if (shape.size() == 1) return 0; // handles broadcasting for the vector case if (shape.size() == 2) return 1; - return strides[shape.size() - 3] / sizeof(double); + return Array::shape_to_size(shape.subspan(shape.size() - 2)); } ssize_t get_stride(std::span shape, std::span strides, ssize_t index, @@ -229,11 +228,10 @@ void MatrixMultiplyNode::matmul_(State& state, std::span out, const ssize_t leading_subspace_size = get_leading_subspace_size(x_ptr_->shape(state), y_ptr_->shape(state)); - const ssize_t x_leading_stride = get_leading_stride(x_ptr_->shape(state), x_ptr_->strides()); - const ssize_t y_leading_stride = get_leading_stride(y_ptr_->shape(state), y_ptr_->strides()); + const ssize_t x_leading_stride = get_leading_stride(x_ptr_->shape(state)); + const ssize_t y_leading_stride = get_leading_stride(y_ptr_->shape(state)); const ssize_t out_leading_stride = [&]() -> ssize_t { - if (x_ptr_->ndim() >= 2 and y_ptr_->ndim() >= 2) - return get_leading_stride(out_shape, this->strides()); + if (x_ptr_->ndim() >= 2 and y_ptr_->ndim() >= 2) return get_leading_stride(out_shape); if (x_ptr_->ndim() == 1 and y_ptr_->ndim() == 1) return 0; return out_shape.back(); }(); @@ -241,8 +239,6 @@ void MatrixMultiplyNode::matmul_(State& state, std::span out, const ssize_t y_last_axis_size = get_axis_size(y_ptr_->shape(state), -1, false); const ssize_t y_penultimate_axis_size = get_axis_size(y_ptr_->shape(state), -2, false); - // TODO: consider using the parent arrays' strides directly - // const ssize_t x_penultimate_stride = get_axis_size(x_ptr_->shape(state), -1, true); const ssize_t x_penultimate_stride = get_stride(x_ptr_->shape(state), x_ptr_->strides(), -2, true); const ssize_t x_last_stride = x_ptr_->strides().back() / sizeof(double); @@ -256,15 +252,20 @@ void MatrixMultiplyNode::matmul_(State& state, std::span out, return get_axis_size(out_shape, -1, false); }(); - const double* const x_data = x_ptr_->buff(state); - const double* const y_data = y_ptr_->buff(state); double* __restrict const out_data = out.data(); for (ssize_t w = 0; w < leading_subspace_size; w++) { + // In order to avoid having to iterate over all leading dimensions + // and checking the strides of both x/y, we use ArrayIterators to + // get us to the correct subspace, and then use the strides of the + // predecessors to iterate through the last one or two dimensions. + const double* const x_data = &x_ptr_->view(state).begin()[w * x_leading_stride]; + const double* const y_data = &y_ptr_->view(state).begin()[w * y_leading_stride]; + // Now we do standard 2D matrix multiply for (ssize_t i = 0; i < x_penultimate_axis_size; i++) { for (ssize_t j = 0; j < y_last_axis_size; j++) { - const double* x = x_data + w * x_leading_stride + i * x_penultimate_stride; - const double* y = y_data + w * y_leading_stride + j * y_last_stride; + const double* x = x_data + i * x_penultimate_stride; + const double* y = y_data + j * y_last_stride; double* out_val = out_data + w * out_leading_stride + i * out_penultimate_stride + j; *out_val = 0.0; From 142cdd1954f2825d30c204bd7515c8ef51dc3b61 Mon Sep 17 00:00:00 2001 From: William Bernoudy Date: Wed, 26 Nov 2025 13:53:39 -0800 Subject: [PATCH 10/10] Clarify Python `matmul` method --- dwave/optimization/mathematical.py | 47 ++++++++++--------- dwave/optimization/symbols/linear_algebra.pyx | 2 +- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/dwave/optimization/mathematical.py b/dwave/optimization/mathematical.py index 77727cb0..9ab379a3 100644 --- a/dwave/optimization/mathematical.py +++ b/dwave/optimization/mathematical.py @@ -1072,6 +1072,13 @@ def matmul(x: ArraySymbol, y: ArraySymbol) -> MatrixMultiply: respectively. If both are 1-d, this will produce a scalar (and the operation is equivalent to the dot product of two vectors). + Otherwise, if it is possible that the shapes can be broadcast + together, either by broadcasting missing leading axes (e.g. + `(2, 5, 3, 4, 2)` and `(2, 6)` -> `(2, 5, 3, 4, 6)`) or by + broadcasting axes for which one of the operands has size 1 (e.g. + `(3, 1, 4, 2)` and `(1, 7, 2, 3)` -> `(3, 7, 4, 3)`), then this + will return the result after broadcasting. + Returns: A MatrixMultiply symbol representing the matrix product. If `x` and `y` have shapes `(..., n, k)` and `(..., k, m)`, then the output will @@ -1097,32 +1104,30 @@ def matmul(x: ArraySymbol, y: ArraySymbol) -> MatrixMultiply: See Also: :class:`~dwave.optimization.symbols.MatrixMultiply`: equivalent symbol. - .. versionadded:: 0.6.9 + .. versionadded:: 0.6.10 """ - def broadcast_missing_axes(a, b): - a_shape = [1,] * (b.ndim() - a.ndim()) + list(a.shape()) - b_shape = [1,] * (a.ndim() - b.ndim()) + list(b.shape()) - - for i in range(len(a_shape) - 2): - if a_shape[i] == 1: - a_shape[i] = b_shape[i] - elif b_shape[i] == 1: - b_shape[i] = a_shape[i] - elif a_shape[i] != b_shape[i]: + if not (x.ndim() == 1 or y.ndim() == 1) and x.shape()[:-2] != y.shape()[:-2]: + # The shapes don't match, but it may be possible to do a broadcast. The + # vector broadcast case is handled by MatrixMultiplyNode, so we need to + # handle all the other cases by adding one or two BroadcastNodes. + x_shape = [1,] * (y.ndim() - x.ndim()) + list(x.shape()) + y_shape = [1,] * (x.ndim() - y.ndim()) + list(y.shape()) + + for i in range(len(x_shape) - 2): + if x_shape[i] == 1: + x_shape[i] = y_shape[i] + elif y_shape[i] == 1: + y_shape[i] = x_shape[i] + elif x_shape[i] != y_shape[i]: raise ValueError("Could not broadcast operands") - if tuple(a_shape) != a.shape(): - a = broadcast_to(a, a_shape) - if tuple(b_shape) != b.shape(): - b = broadcast_to(b, b_shape) - return a, b - - if x.ndim() == 0 or y.ndim() == 0: - raise ValueError("Operands must not be scalar") + if tuple(x_shape) != x.shape(): + x = broadcast_to(x, tuple(x_shape)) + if tuple(y_shape) != y.shape(): + y = broadcast_to(y, tuple(y_shape)) - if not (x.ndim() == 1 or y.ndim() == 1) and x.shape()[:-2] != y.shape()[:-2]: - return MatrixMultiply(*broadcast_missing_axes(x, y)) + return MatrixMultiply(x, y) return MatrixMultiply(x, y) diff --git a/dwave/optimization/symbols/linear_algebra.pyx b/dwave/optimization/symbols/linear_algebra.pyx index 29c551e8..49eca707 100644 --- a/dwave/optimization/symbols/linear_algebra.pyx +++ b/dwave/optimization/symbols/linear_algebra.pyx @@ -28,7 +28,7 @@ cdef class MatrixMultiply(ArraySymbol): See Also: :func:`~dwave.optimization.mathematical.matmul`: equivalent function. - .. versionadded:: 0.6.9 + .. versionadded:: 0.6.10 """ def __init__(self, ArraySymbol x, ArraySymbol y): cdef _Graph model = x.model