diff --git a/dwave/optimization/include/dwave-optimization/array.hpp b/dwave/optimization/include/dwave-optimization/array.hpp index d9330426..2a9d8c42 100644 --- a/dwave/optimization/include/dwave-optimization/array.hpp +++ b/dwave/optimization/include/dwave-optimization/array.hpp @@ -76,6 +76,39 @@ struct SizeInfo { } bool operator==(const SizeInfo& other) const; + constexpr SizeInfo& operator*=(const std::integral auto n) { + multiplier *= n; + offset *= n; + if (min.has_value()) min.value() *= n; + if (max.has_value()) max.value() *= n; + return *this; + } + friend SizeInfo operator*(SizeInfo lhs, const std::integral auto rhs) { + lhs *= rhs; + return lhs; + } + + constexpr SizeInfo& operator/=(const std::integral auto n) { + if (!n) throw std::invalid_argument("cannot divide by 0"); + multiplier /= n; + offset /= n; + if (min.has_value()) { + assert(min.value() % n == 0 and + "dividing SizeInfo with a divisor that does not evenly divide the minimum"); + min.value() /= n; + } + if (max.has_value()) { + assert(max.value() % n == 0 and + "dividing SizeInfo with a divisor that does not evenly divide the maximum"); + max.value() /= n; + } + return *this; + } + friend SizeInfo operator/(SizeInfo lhs, const std::integral auto rhs) { + lhs /= rhs; + return lhs; + } + // SizeInfos are printable friend std::ostream& operator<<(std::ostream& os, const SizeInfo& sizeinfo); @@ -450,6 +483,19 @@ class Array { return 0; } + // Determine the size by the shape. For a node with a fixed size, it is simply + // the product of the shape. + // Expects the shape to be stored in a C-style array of length ndim. + static ssize_t shape_to_size(const ssize_t ndim, const ssize_t* const shape) noexcept { + if (ndim <= 0) return 1; + if (shape[0] < 0) return DYNAMIC_SIZE; + return std::reduce(shape, shape + ndim, 1, std::multiplies()); + } + + static ssize_t shape_to_size(const std::span shape) noexcept { + return shape_to_size(shape.size(), shape.data()); + } + protected: // Some utility methods that might be useful to subclasses @@ -476,19 +522,6 @@ class Array { return true; } - // Determine the size by the shape. For a node with a fixed size, it is simply - // the product of the shape. - // Expects the shape to be stored in a C-style array of length ndim. - static ssize_t shape_to_size(const ssize_t ndim, const ssize_t* shape) noexcept { - if (ndim <= 0) return 1; - if (shape[0] < 0) return DYNAMIC_SIZE; - return std::reduce(shape, shape + ndim, 1, std::multiplies()); - } - - static ssize_t shape_to_size(const std::span shape) noexcept { - return shape_to_size(shape.size(), shape.data()); - } - // Determine the strides from the shape. // Assumes itemsize = sizeof(double). // Expects the shape to be stored in a C-style array of length ndim. diff --git a/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp b/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp new file mode 100644 index 00000000..da42fc21 --- /dev/null +++ b/dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp @@ -0,0 +1,80 @@ +// Copyright 2025 D-Wave Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "dwave-optimization/array.hpp" +#include "dwave-optimization/graph.hpp" + +namespace dwave::optimization { + +class MatrixMultiplyNode : public ArrayOutputMixin { + public: + MatrixMultiplyNode(ArrayNode* x_ptr, ArrayNode* y_ptr); + + /// @copydoc Array::buff() + double const* buff(const State& state) const override; + + /// @copydoc Node::commit() + void commit(State& state) const override; + + /// @copydoc Array::diff() + std::span diff(const State& state) const override; + + /// @copydoc Node::initialize_state() + void initialize_state(State& state) const override; + + /// @copydoc Array::integral() + bool integral() const override; + + /// @copydoc Array::max() + double max() const override; + + /// @copydoc Array::min() + double min() const override; + + /// @copydoc Node::propagate() + void propagate(State& state) const override; + + /// @copydoc Node::revert() + void revert(State& state) const override; + + /// @copydoc Array::shape() + using Array::shape; + std::span shape(const State& state) const override; + + /// @copydoc Array::size() + using Array::size; + ssize_t size(const State& state) const override; + + /// @copydoc Array::size_diff() + ssize_t size_diff(const State& state) const override; + + /// @copydoc Array::size_info() + SizeInfo sizeinfo() const override; + + private: + void matmul_(State& state, std::span out, std::span out_shape) const; + void update_shape_(State& state) const; + + const ArrayNode* x_ptr_; + const ArrayNode* y_ptr_; + + const SizeInfo sizeinfo_; + const ValuesInfo values_info_; +}; + +} // namespace dwave::optimization diff --git a/dwave/optimization/libcpp/nodes/linear_algebra.pxd b/dwave/optimization/libcpp/nodes/linear_algebra.pxd new file mode 100644 index 00000000..5599ccae --- /dev/null +++ b/dwave/optimization/libcpp/nodes/linear_algebra.pxd @@ -0,0 +1,20 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dwave.optimization.libcpp.graph cimport ArrayNode + + +cdef extern from "dwave-optimization/nodes/linear_algebra.hpp" namespace "dwave::optimization" nogil: + cdef cppclass MatrixMultiplyNode(ArrayNode): + pass diff --git a/dwave/optimization/mathematical.py b/dwave/optimization/mathematical.py index 74c3703f..9ab379a3 100644 --- a/dwave/optimization/mathematical.py +++ b/dwave/optimization/mathematical.py @@ -39,6 +39,7 @@ LinearProgramSolution, Log, Logical, + MatrixMultiply, Maximum, Mean, Minimum, @@ -88,6 +89,7 @@ "logical_not", "logical_or", "logical_xor", + "matmul", "maximum", "mean", "minimum", @@ -1060,6 +1062,76 @@ def logical_xor(x1: ArraySymbol, x2: ArraySymbol) -> Xor: return Xor(x1, x2) +def matmul(x: ArraySymbol, y: ArraySymbol) -> MatrixMultiply: + r"""Compute the matrix product of two array symbols. + + Args: + x, y: Operand array symbols. The size of the last axis of `x` must be + equal to the size of the second to last axis of `y`. If `x` or `y` + are 1-d, they will treated as if they a row or column vector + respectively. If both are 1-d, this will produce a scalar (and the + operation is equivalent to the dot product of two vectors). + + Otherwise, if it is possible that the shapes can be broadcast + together, either by broadcasting missing leading axes (e.g. + `(2, 5, 3, 4, 2)` and `(2, 6)` -> `(2, 5, 3, 4, 6)`) or by + broadcasting axes for which one of the operands has size 1 (e.g. + `(3, 1, 4, 2)` and `(1, 7, 2, 3)` -> `(3, 7, 4, 3)`), then this + will return the result after broadcasting. + + Returns: + A MatrixMultiply symbol representing the matrix product. If `x` and + `y` have shapes `(..., n, k)` and `(..., k, m)`, then the output will + have shape `(..., n, m)`. + + Examples: + This example computes the dot product of two integer arrays. + + >>> from dwave.optimization import Model + >>> from dwave.optimization.mathematical import matmul + ... + >>> model = Model() + >>> i = model.integer(3) + >>> j = model.integer(3) + >>> m = matmul(i, j) + >>> with model.lock(): + ... model.states.resize(1) + ... i.set_state(0, [1, 2, 3]) + ... j.set_state(0, [4, 5, 6]) + ... print(m.state(0)) + 32.0 + + See Also: + :class:`~dwave.optimization.symbols.MatrixMultiply`: equivalent symbol. + + .. versionadded:: 0.6.10 + """ + + if not (x.ndim() == 1 or y.ndim() == 1) and x.shape()[:-2] != y.shape()[:-2]: + # The shapes don't match, but it may be possible to do a broadcast. The + # vector broadcast case is handled by MatrixMultiplyNode, so we need to + # handle all the other cases by adding one or two BroadcastNodes. + x_shape = [1,] * (y.ndim() - x.ndim()) + list(x.shape()) + y_shape = [1,] * (x.ndim() - y.ndim()) + list(y.shape()) + + for i in range(len(x_shape) - 2): + if x_shape[i] == 1: + x_shape[i] = y_shape[i] + elif y_shape[i] == 1: + y_shape[i] = x_shape[i] + elif x_shape[i] != y_shape[i]: + raise ValueError("Could not broadcast operands") + + if tuple(x_shape) != x.shape(): + x = broadcast_to(x, tuple(x_shape)) + if tuple(y_shape) != y.shape(): + y = broadcast_to(y, tuple(y_shape)) + + return MatrixMultiply(x, y) + + return MatrixMultiply(x, y) + + @_op(Maximum, NaryMaximum, "max") def maximum(x1: ArraySymbol, x2: ArraySymbol, *xi: ArraySymbol, ) -> typing.Union[Maximum, NaryMaximum]: diff --git a/dwave/optimization/src/nodes/linear_algebra.cpp b/dwave/optimization/src/nodes/linear_algebra.cpp new file mode 100644 index 00000000..16bbc3ac --- /dev/null +++ b/dwave/optimization/src/nodes/linear_algebra.cpp @@ -0,0 +1,356 @@ +// Copyright 2025 D-Wave Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dwave-optimization/nodes/linear_algebra.hpp" + +#include +#include +#include + +#include "../functional_.hpp" +#include "_state.hpp" +#include "dwave-optimization/array.hpp" +#include "dwave-optimization/state.hpp" + +namespace dwave::optimization { + +////////////////////// MatrixMultiplyNode + +// Valid shapes to multiply, and the resulting shape +// (-1, 2, 5, 3) and (-1, 2, 3, 7) -> (-1, 2, 5, 7) +// (-1, 3) and (3) -> (-1) +// (-1, 3) and (3, 7) -> (-1, 7) +// (-1) and (-1) -> () +// (-1) and (-1, 5) -> (5) + +ssize_t get_axis_size(std::span shape, ssize_t index, bool vector_as_row) { + // If vector_as_row is true, treat vector as shape (1, size), else as shape (size, 1) + assert(index < 0); + if (shape.size() == 0) return 1; + if (shape.size() == 1 and index == -2 and vector_as_row) return 1; + if (shape.size() == 1 and index == -1 and not vector_as_row) return 1; + if (shape.size() == 1) return shape.back(); + return shape[shape.size() + index]; +} + +std::vector output_shape(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { + if (x_ptr->ndim() == 0 or y_ptr->ndim() == 0) { + throw std::invalid_argument("operands cannot be scalar"); + } + + // Check that last dimension of x matches the second last dimension of y + const ssize_t x_last_axis_size = get_axis_size(x_ptr->shape(), -1, true); + const ssize_t y_penultimate_axis_size = get_axis_size(y_ptr->shape(), -2, false); + if (x_last_axis_size != y_penultimate_axis_size) { + throw std::invalid_argument( + "the last dimension of `x` is not the same size as the second to last dimension of " + "`y`"); + } else if (x_last_axis_size == -1) { + assert(x_ptr->dynamic() && y_ptr->dynamic()); + // Both are dynamic. We need to check that the dynamic dimension is + // always the same size. + const ssize_t x_subspace_size = Array::shape_to_size(x_ptr->shape().subspan(1)); + const ssize_t y_subspace_size = Array::shape_to_size(y_ptr->shape().subspan(1)); + assert(x_subspace_size != Array::DYNAMIC_SIZE); + assert(y_subspace_size != Array::DYNAMIC_SIZE); + if (x_ptr->sizeinfo() / x_subspace_size != y_ptr->sizeinfo() / y_subspace_size) { + throw std::invalid_argument( + "the last dimension of `x` is not the same size as the second to last " + "dimension of `y`"); + } + } + + // Now check that the leading subspace shape is identical (no broadcasting for now) + if (x_ptr->ndim() >= 2 && y_ptr->ndim() >= 2) { + if (x_ptr->ndim() != y_ptr->ndim()) { + throw std::invalid_argument( + "operands have different dimensions (use BroadcastNode if you wish to " + "broadcast missing dimensions)"); + } + for (ssize_t i = 0, stop = x_ptr->ndim() - 2; i < stop; i++) { + if (x_ptr->shape()[i] != y_ptr->shape()[i]) { + throw std::invalid_argument( + "operands must have matching leading shape (up to the last two " + "dimensions)"); + } + } + } + + // Now we now the leading axes match, we can construct the output shape + std::vector shape; + // If x is being broadcast, we need to add the axes from the start of y + if (y_ptr->ndim() > 2 && y_ptr->ndim() > x_ptr->ndim()) { + const ssize_t num_x_leading = std::max(0, x_ptr->ndim() - 2); + const ssize_t num_y_leading = std::max(0, y_ptr->ndim() - 2); + for (ssize_t d : y_ptr->shape() | std::views::take(num_y_leading - num_x_leading)) { + shape.push_back(d); + } + } + for (ssize_t d : x_ptr->shape() | std::views::take(x_ptr->ndim() - 1)) { + shape.push_back(d); + } + if (y_ptr->ndim() >= 2) { + shape.push_back(y_ptr->shape().back()); + } + + return shape; +} + +SizeInfo get_sizeinfo(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { + if (y_ptr->dynamic() and y_ptr->ndim() <= 2) { + // x must also be dynamic, and we must be contracting along the dynamic + // dimension, so the output is fixed size. + std::vector shape = output_shape(x_ptr, y_ptr); + ssize_t size = Array::shape_to_size(shape); + assert(size >= 1); + return SizeInfo(size); + } + + // The size should be x's size, divided by the size of x's last dimension, + // multiplied by the size of y's last dimension (if matrix or higher dim). + assert(x_ptr->shape().back() != -1); + SizeInfo sizeinfo = x_ptr->sizeinfo() / x_ptr->shape().back(); + if (y_ptr->ndim() >= 2) { + assert(y_ptr->shape().back() != -1); + sizeinfo *= y_ptr->shape().back(); + } + + return sizeinfo; +} + +ValuesInfo get_values_info(const ArrayNode* x_ptr, const ArrayNode* y_ptr) { + // Get all possible combinations of values + const std::array combos{x_ptr->min() * y_ptr->min(), x_ptr->min() * y_ptr->max(), + x_ptr->max() * y_ptr->min(), x_ptr->max() * y_ptr->max()}; + + const double min_val = std::ranges::min(combos); + const double max_val = std::ranges::max(combos); + + const SizeInfo contracted_axis_size = [&]() { + // If x is 1d, then the contracted axis size is equal to x's size + if (x_ptr->ndim() == 1) return x_ptr->sizeinfo(); + // Otherwise it's always the last axis of x (which definitionally is not dynamic) + return SizeInfo(x_ptr->shape().back()); + }(); + + if (contracted_axis_size.max.has_value() and contracted_axis_size.max.value() == 0) { + // Output will always be empty, so we can return early + return ValuesInfo(0.0, 0.0, true); + } + + // Use default constructor to get default min/max + ValuesInfo values_info; + values_info.integral = x_ptr->integral() && y_ptr->integral(); + + if (contracted_axis_size.max.has_value()) { + if (max_val >= 0) values_info.max = max_val * contracted_axis_size.max.value(); + if (min_val <= 0) values_info.min = min_val * contracted_axis_size.max.value(); + } + + const ssize_t min_size = contracted_axis_size.min.value_or(0); + if (max_val < 0) values_info.max = max_val * min_size; + if (min_val > 0) values_info.min = min_val * min_size; + + return values_info; +} + +std::vector atleast_2d_shape(std::span shape, bool vector_as_row) { + // If vector_as_row is true, treat vector as shape (1, size), else as shape (size, 1) + if (shape.size() == 0) return {1, 1}; + if (shape.size() == 1 and vector_as_row) return {1, shape[0]}; + if (shape.size() == 1 and not vector_as_row) return {shape[0], 1}; + return {shape.begin(), shape.end()}; +} + +class MatrixMultiplyNodeData : public ArrayNodeStateData { + public: + explicit MatrixMultiplyNodeData(std::vector&& values, std::span shape) + : ArrayNodeStateData(std::move(values)), shape(shape.begin(), shape.end()) {} + + std::vector output; + std::vector shape; +}; + +MatrixMultiplyNode::MatrixMultiplyNode(ArrayNode* x_ptr, ArrayNode* y_ptr) + : ArrayOutputMixin(output_shape(x_ptr, y_ptr)), + x_ptr_(x_ptr), + y_ptr_(y_ptr), + sizeinfo_(get_sizeinfo(x_ptr, y_ptr)), + values_info_(get_values_info(x_ptr, y_ptr)) { + add_predecessor(x_ptr); + add_predecessor(y_ptr); +} + +ssize_t get_leading_stride(std::span shape) { + assert(shape.size() >= 1); + if (shape.size() == 1) return 0; // handles broadcasting for the vector case + if (shape.size() == 2) return 1; + return Array::shape_to_size(shape.subspan(shape.size() - 2)); +} + +ssize_t get_stride(std::span shape, std::span strides, ssize_t index, + bool vector_as_row) { + // If vector_as_row is true, treat vector as shape (1, size), else as shape (size, 1) + assert(index < 0 && index >= -2); + if (get_axis_size(shape, index, vector_as_row) == 1) return 0; + if (index + 1 == 0) return strides.back() / sizeof(double); + assert(shape.size() > 1); + return strides[static_cast(strides.size()) + index] / sizeof(double); +} + +ssize_t get_leading_subspace_size(std::span x_shape, + std::span y_shape) { + const auto shape = x_shape.size() > y_shape.size() ? x_shape : y_shape; + const ssize_t penultimate_axis = std::max(0, static_cast(shape.size()) - 2); + return std::reduce(shape.begin(), shape.begin() + penultimate_axis, 1, + std::multiplies()); +} + +void MatrixMultiplyNode::matmul_(State& state, std::span out, + std::span out_shape) const { + assert(static_cast(out.size()) == Array::shape_to_size(out_shape)); + + // If out is empty (possible when predecessors have 0 size) there is nothing to do + if (out.size() == 0) return; + + const ssize_t x_penultimate_axis_size = get_axis_size(x_ptr_->shape(state), -2, true); + const ssize_t leading_subspace_size = + get_leading_subspace_size(x_ptr_->shape(state), y_ptr_->shape(state)); + + const ssize_t x_leading_stride = get_leading_stride(x_ptr_->shape(state)); + const ssize_t y_leading_stride = get_leading_stride(y_ptr_->shape(state)); + const ssize_t out_leading_stride = [&]() -> ssize_t { + if (x_ptr_->ndim() >= 2 and y_ptr_->ndim() >= 2) return get_leading_stride(out_shape); + if (x_ptr_->ndim() == 1 and y_ptr_->ndim() == 1) return 0; + return out_shape.back(); + }(); + + const ssize_t y_last_axis_size = get_axis_size(y_ptr_->shape(state), -1, false); + const ssize_t y_penultimate_axis_size = get_axis_size(y_ptr_->shape(state), -2, false); + + const ssize_t x_penultimate_stride = + get_stride(x_ptr_->shape(state), x_ptr_->strides(), -2, true); + const ssize_t x_last_stride = x_ptr_->strides().back() / sizeof(double); + + const ssize_t y_penultimate_stride = y_last_axis_size; + const ssize_t y_last_stride = + y_ptr_->ndim() >= 2 ? y_ptr_->strides().back() / sizeof(double) : 0; + + const ssize_t out_penultimate_stride = [&]() -> ssize_t { + if (y_ptr_->ndim() == 1) return 1; + return get_axis_size(out_shape, -1, false); + }(); + + double* __restrict const out_data = out.data(); + + for (ssize_t w = 0; w < leading_subspace_size; w++) { + // In order to avoid having to iterate over all leading dimensions + // and checking the strides of both x/y, we use ArrayIterators to + // get us to the correct subspace, and then use the strides of the + // predecessors to iterate through the last one or two dimensions. + const double* const x_data = &x_ptr_->view(state).begin()[w * x_leading_stride]; + const double* const y_data = &y_ptr_->view(state).begin()[w * y_leading_stride]; + // Now we do standard 2D matrix multiply + for (ssize_t i = 0; i < x_penultimate_axis_size; i++) { + for (ssize_t j = 0; j < y_last_axis_size; j++) { + const double* x = x_data + i * x_penultimate_stride; + const double* y = y_data + j * y_last_stride; + double* out_val = + out_data + w * out_leading_stride + i * out_penultimate_stride + j; + *out_val = 0.0; + for (ssize_t k = 0; k < y_penultimate_axis_size; k++) { + *out_val += *x * *y; + x += x_last_stride; + y += y_penultimate_stride; + } + } + } + } +} + +void MatrixMultiplyNode::initialize_state(State& state) const { + ssize_t start_size = this->size(); + std::vector shape(this->shape().begin(), this->shape().end()); + if (this->dynamic()) { + shape[0] = x_ptr_->shape(state)[0]; + start_size = Array::shape_to_size(shape); + } + + std::vector data(start_size); + matmul_(state, data, shape); + emplace_data_ptr(state, std::move(data), shape); +} + +double const* MatrixMultiplyNode::buff(const State& state) const { + return data_ptr(state)->buff(); +} + +void MatrixMultiplyNode::commit(State& state) const { + return data_ptr(state)->commit(); +} + +std::span MatrixMultiplyNode::diff(const State& state) const { + return data_ptr(state)->diff(); +} + +bool MatrixMultiplyNode::integral() const { return values_info_.integral; } + +double MatrixMultiplyNode::max() const { return values_info_.max; } + +double MatrixMultiplyNode::min() const { return values_info_.min; } + +void MatrixMultiplyNode::update_shape_(State& state) const { + if (this->dynamic()) { + data_ptr(state)->shape[0] = x_ptr_->shape(state)[0]; + } +} + +void MatrixMultiplyNode::propagate(State& state) const { + if (x_ptr_->diff(state).size() == 0 and y_ptr_->diff(state).size() == 0) return; + + auto data = data_ptr(state); + + this->update_shape_(state); + const ssize_t new_size = Array::shape_to_size(data->shape); + + data->output.resize(new_size); + + this->matmul_(state, data->output, data->shape); + data->assign(data->output); +} + +void MatrixMultiplyNode::revert(State& state) const { + auto data = data_ptr(state); + data->revert(); + this->update_shape_(state); +} + +std::span MatrixMultiplyNode::shape(const State& state) const { + if (not this->dynamic()) return this->shape(); + return data_ptr(state)->shape; +} + +ssize_t MatrixMultiplyNode::size(const State& state) const { + if (not this->dynamic()) return this->size(); + return data_ptr(state)->size(); +} + +ssize_t MatrixMultiplyNode::size_diff(const State& state) const { + if (not this->dynamic()) return 0; + return data_ptr(state)->size_diff(); +} + +SizeInfo MatrixMultiplyNode::sizeinfo() const { return sizeinfo_; } + +} // namespace dwave::optimization diff --git a/dwave/optimization/symbols/__init__.py b/dwave/optimization/symbols/__init__.py index 6afdc3e5..2421e7f1 100644 --- a/dwave/optimization/symbols/__init__.py +++ b/dwave/optimization/symbols/__init__.py @@ -52,6 +52,7 @@ ) from dwave.optimization.symbols.inputs import Input from dwave.optimization.symbols.interpolation import BSpline +from dwave.optimization.symbols.linear_algebra import MatrixMultiply from dwave.optimization.symbols.lp import ( LinearProgram, LinearProgramFeasible, @@ -149,6 +150,7 @@ "ListVariable", "Log", "Logical", + "MatrixMultiply", "Max", "Maximum", "Mean", diff --git a/dwave/optimization/symbols/linear_algebra.pyi b/dwave/optimization/symbols/linear_algebra.pyi new file mode 100644 index 00000000..9e99d0bf --- /dev/null +++ b/dwave/optimization/symbols/linear_algebra.pyi @@ -0,0 +1,18 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dwave.optimization.model import ArraySymbol as _ArraySymbol + + +class MatrixMultiply(_ArraySymbol): ... diff --git a/dwave/optimization/symbols/linear_algebra.pyx b/dwave/optimization/symbols/linear_algebra.pyx new file mode 100644 index 00000000..49eca707 --- /dev/null +++ b/dwave/optimization/symbols/linear_algebra.pyx @@ -0,0 +1,43 @@ +# cython: auto_pickle=False + +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cython.operator cimport typeid + +from dwave.optimization._model cimport _Graph, _register, ArraySymbol +from dwave.optimization.libcpp.nodes.linear_algebra cimport ( + MatrixMultiplyNode, +) + + +cdef class MatrixMultiply(ArraySymbol): + """MatrixMultiply symbol. + + See Also: + :func:`~dwave.optimization.mathematical.matmul`: equivalent function. + + .. versionadded:: 0.6.10 + """ + def __init__(self, ArraySymbol x, ArraySymbol y): + cdef _Graph model = x.model + if y.model is not model: + raise ValueError("operands must be from the same model") + + cdef MatrixMultiplyNode* ptr = model._graph.emplace_node[MatrixMultiplyNode]( + x.array_ptr, y.array_ptr, + ) + self.initialize_arraynode(model, ptr) + +_register(MatrixMultiply, typeid(MatrixMultiplyNode)) diff --git a/meson.build b/meson.build index 1fba2658..8785c288 100644 --- a/meson.build +++ b/meson.build @@ -35,6 +35,7 @@ dwave_optimization_src = [ 'dwave/optimization/src/nodes/inputs.cpp', 'dwave/optimization/src/nodes/interpolation.cpp', 'dwave/optimization/src/nodes/lambda.cpp', + 'dwave/optimization/src/nodes/linear_algebra.cpp', 'dwave/optimization/src/nodes/lp.cpp', 'dwave/optimization/src/nodes/manipulation.cpp', 'dwave/optimization/src/nodes/naryop.cpp', @@ -102,6 +103,7 @@ foreach name : [ 'indexing', 'inputs', 'interpolation', + 'linear_algebra', 'lp', 'manipulation', 'naryop', diff --git a/releasenotes/notes/add-matrix-multiplication-symbol-ebf608660b82adfa.yaml b/releasenotes/notes/add-matrix-multiplication-symbol-ebf608660b82adfa.yaml new file mode 100644 index 00000000..bcf31629 --- /dev/null +++ b/releasenotes/notes/add-matrix-multiplication-symbol-ebf608660b82adfa.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + Add the ``MatrixMultiplication`` symbol and corresponding method + ``matmul``. The ``matmul`` method follows the behavior of NumPy's + ``matmul``, meaning that it works with matrices, vectors, and higher order + arrays. diff --git a/tests/cpp/meson.build b/tests/cpp/meson.build index 7e840d26..ef23f764 100644 --- a/tests/cpp/meson.build +++ b/tests/cpp/meson.build @@ -16,8 +16,9 @@ tests_all = executable( 'nodes/test_flow.cpp', 'nodes/test_inputs.cpp', 'nodes/test_interpolation.cpp', - 'nodes/test_lp.cpp', 'nodes/test_lambda.cpp', + 'nodes/test_linear_algebra.cpp', + 'nodes/test_lp.cpp', 'nodes/test_manipulation.cpp', 'nodes/test_naryop.cpp', 'nodes/test_numbers.cpp', diff --git a/tests/cpp/nodes/test_linear_algebra.cpp b/tests/cpp/nodes/test_linear_algebra.cpp new file mode 100644 index 00000000..c9836863 --- /dev/null +++ b/tests/cpp/nodes/test_linear_algebra.cpp @@ -0,0 +1,539 @@ +// Copyright 2025 D-Wave +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "dwave-optimization/nodes/binaryop.hpp" +#include "dwave-optimization/nodes/collections.hpp" +#include "dwave-optimization/nodes/constants.hpp" +#include "dwave-optimization/nodes/indexing.hpp" +#include "dwave-optimization/nodes/linear_algebra.hpp" +#include "dwave-optimization/nodes/manipulation.hpp" +#include "dwave-optimization/nodes/numbers.hpp" +#include "dwave-optimization/nodes/testing.hpp" + +using Catch::Matchers::RangeEquals; + +namespace dwave::optimization { + +TEST_CASE("MatrixMultiplyNode") { + auto graph = Graph(); + + GIVEN("A dynamic testing array node and MatrixMultiply on it") { + auto arr = DynamicArrayTestingNode(std::initializer_list{-1}, -10.0, -5.0, false); + auto constant = ConstantNode(15.0); + auto add = AddNode(&arr, &constant); + REQUIRE(add.min() == 5.0); + REQUIRE(add.max() == 10.0); + auto matmul = MatrixMultiplyNode(&arr, &add); + THEN("MatrixMultiplyNode reports correct min and max") { + CHECK(matmul.min() == ValuesInfo().min); + CHECK(matmul.max() == 0.0); + } + } + + GIVEN("A dynamic testing array node with minimum size and matmul on it") { + auto arr = DynamicArrayTestingNode(std::initializer_list{-1}, -10.0, -5.0, false, + 3, std::nullopt); + auto constant = ConstantNode(15.0); + auto add = AddNode(&arr, &constant); + REQUIRE(add.min() == 5.0); + REQUIRE(add.max() == 10.0); + auto matmul = MatrixMultiplyNode(&arr, &add); + THEN("MatrixMultiplyNode reports correct min and max") { + CHECK(matmul.min() == ValuesInfo().min); + CHECK(matmul.max() == -5.0 * 5.0 * 3); + } + } + + GIVEN("A dynamic testing array node with minimum and maximum size and matmul on it") { + auto arr = DynamicArrayTestingNode(std::initializer_list{-1}, -10.0, -5.0, false, + 3, 7); + auto constant = ConstantNode(15.0); + auto add = AddNode(&arr, &constant); + REQUIRE(add.min() == 5.0); + REQUIRE(add.max() == 10.0); + auto matmul = MatrixMultiplyNode(&arr, &add); + THEN("MatrixMultiplyNode reports correct min and max") { + CHECK(matmul.min() == -10.0 * 10.0 * 7); + CHECK(matmul.max() == -5.0 * 5.0 * 3); + } + } + + SECTION("Higher order broadcasting") { + auto a = IntegerNode({5, 4, 3, 2}); + + auto b = IntegerNode({2, 7}); + CHECK_THROWS_AS(MatrixMultiplyNode(&a, &b), std::invalid_argument); + + auto c = IntegerNode({4, 2, 1}); + CHECK_THROWS_AS(MatrixMultiplyNode(&a, &c), std::invalid_argument); + + auto d = IntegerNode({5, 1, 2, 1}); + CHECK_THROWS_AS(MatrixMultiplyNode(&a, &d), std::invalid_argument); + } + + GIVEN("Two constant 1d nodes and a MatrixMultiplyNode") { + auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0}); + auto c2_ptr = graph.emplace_node(std::vector{4.0, 5.0, 6.0}); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 1); + CHECK(matmul_ptr->ndim() == 0); + + CHECK(matmul_ptr->min() == 1.0 * 4.0 * 3); + CHECK(matmul_ptr->max() == 3.0 * 6.0 * 3); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({32})); + } + } + } + + GIVEN("Two constant 2d nodes and a MatrixMultiplyNode") { + auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + std::vector{2, 3}); + auto c2_ptr = graph.emplace_node(std::vector{7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + std::vector{3, 2}); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 4); + CHECK(matmul_ptr->ndim() == 2); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({2, 2})); + + CHECK(matmul_ptr->min() == 1.0 * 7.0 * 3); + CHECK(matmul_ptr->max() == 6.0 * 12.0 * 3); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({58, 64, 139, 154})); + } + } + } + + GIVEN("Two constant 2d nodes and a MatrixMultiplyNode") { + auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + std::vector{2, 3}); + auto c2_ptr = graph.emplace_node(std::vector{7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + std::vector{3, 2}); + + auto matmul_ptr = graph.emplace_node(c2_ptr, c1_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 9); + CHECK(matmul_ptr->ndim() == 2); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({3, 3})); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), + RangeEquals({39, 54, 69, 49, 68, 87, 59, 82, 105})); + } + } + } + + GIVEN("One constant 1d node and one constant 2d node") { + auto c1_ptr = graph.emplace_node(std::vector{1.0, 2.0, 3.0}, + std::vector{3}); + auto c2_ptr = graph.emplace_node( + std::vector{4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}, + std::vector{3, 3}); + + AND_GIVEN("The 1d node @ the 2d node") { + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 3); + CHECK(matmul_ptr->ndim() == 1); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({48, 54, 60})); + } + } + } + + AND_GIVEN("The 2d node @ the 1d node") { + auto matmul_ptr = graph.emplace_node(c2_ptr, c1_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 3); + CHECK(matmul_ptr->ndim() == 1); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({32, 50, 68})); + } + } + } + } + + GIVEN("One 1d and one 4d constant nodes and a MatrixMultiplyNode") { + std::vector c1_shape{7}; + const ssize_t c1_size = 7; + std::vector c1_data(c1_size); + std::iota(c1_data.begin(), c1_data.end(), 2.0); + auto c1_ptr = graph.emplace_node(c1_data, c1_shape); + REQUIRE(c1_ptr->min() == 2.0); + REQUIRE(c1_ptr->max() == 8.0); + + std::vector c2_shape{5, 3, 7, 2}; + ssize_t c2_size = 5 * 3 * 7 * 2; + std::vector c2_data(c2_size); + std::iota(c2_data.begin(), c2_data.end(), -10.0); + auto c2_ptr = graph.emplace_node(c2_data, c2_shape); + REQUIRE(c2_ptr->min() == -10.0); + REQUIRE(c2_ptr->max() == 199.0); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 5 * 3 * 2); + CHECK(matmul_ptr->ndim() == 3); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({5, 3, 2})); + + CHECK(matmul_ptr->min() == 8.0 * -10.0 * 7); + CHECK(matmul_ptr->max() == 8.0 * 199.0 * 7); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT( + matmul_ptr->view(state), + RangeEquals({-84, -49, 406, 441, 896, 931, 1386, 1421, 1876, 1911, + 2366, 2401, 2856, 2891, 3346, 3381, 3836, 3871, 4326, 4361, + 4816, 4851, 5306, 5341, 5796, 5831, 6286, 6321, 6776, 6811})); + } + } + } + + GIVEN("One 4d and one 1d constant nodes and a MatrixMultiplyNode") { + std::vector c1_shape{5, 3, 2, 7}; + ssize_t c1_size = 5 * 3 * 2 * 7; + std::vector c1_data(c1_size); + std::iota(c1_data.begin(), c1_data.end(), -10.0); + auto c1_ptr = graph.emplace_node(c1_data, c1_shape); + REQUIRE(c1_ptr->min() == -10.0); + REQUIRE(c1_ptr->max() == 199.0); + + std::vector c2_shape{7}; + const ssize_t c2_size = 7; + std::vector c2_data(c2_size); + std::iota(c2_data.begin(), c2_data.end(), 2.0); + auto c2_ptr = graph.emplace_node(c2_data, c2_shape); + REQUIRE(c2_ptr->min() == 2.0); + REQUIRE(c2_ptr->max() == 8.0); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 5 * 3 * 2); + CHECK(matmul_ptr->ndim() == 3); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({5, 3, 2})); + + CHECK(matmul_ptr->min() == 8.0 * -10.0 * 7); + CHECK(matmul_ptr->max() == 8.0 * 199.0 * 7); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT( + matmul_ptr->view(state), + RangeEquals({-217, 28, 273, 518, 763, 1008, 1253, 1498, 1743, 1988, + 2233, 2478, 2723, 2968, 3213, 3458, 3703, 3948, 4193, 4438, + 4683, 4928, 5173, 5418, 5663, 5908, 6153, 6398, 6643, 6888})); + } + } + } + + GIVEN("Two 4d constant nodes and a MatrixMultiplyNode") { + std::vector c1_shape{5, 3, 1, 7}; + const ssize_t c1_size = 5 * 3 * 1 * 7; + std::vector c1_data(c1_size); + std::iota(c1_data.begin(), c1_data.end(), -20.0); + auto c1_ptr = graph.emplace_node(c1_data, c1_shape); + REQUIRE(c1_ptr->min() == -20.0); + REQUIRE(c1_ptr->max() == 84); + + std::vector c2_shape{5, 3, 7, 2}; + ssize_t c2_size = 5 * 3 * 7 * 2; + std::vector c2_data(c2_size); + std::iota(c2_data.begin(), c2_data.end(), -10.0); + auto c2_ptr = graph.emplace_node(c2_data, c2_shape); + REQUIRE(c2_ptr->min() == -10.0); + REQUIRE(c2_ptr->max() == 199.0); + + auto matmul_ptr = graph.emplace_node(c1_ptr, c2_ptr); + + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->size() == 5 * 3 * 1 * 2); + CHECK(matmul_ptr->ndim() == 4); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({5, 3, 1, 2})); + + CHECK(matmul_ptr->min() == -20.0 * 199.0 * 7); + CHECK(matmul_ptr->max() == 84.0 * 199.0 * 7); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), + RangeEquals({532, 413, -644, -714, -448, -469, 1120, 1148, + 4060, 4137, 8372, 8498, 14056, 14231, 21112, 21336, + 29540, 29813, 39340, 39662, 50512, 50883, 63056, 63476, + 76972, 77441, 92260, 92778, 108920, 109487})); + } + } + } + + GIVEN("A set node and MatrixMultiplyNode representing self dot product") { + auto set_ptr = graph.emplace_node(10); + auto matmul_ptr = graph.emplace_node(set_ptr, set_ptr); + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->ndim() == 0); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + REQUIRE(set_ptr->size(state) == 0); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK(matmul_ptr->size(state) == 1); + CHECK(matmul_ptr->shape(state).size() == 0); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({0.0})); + } + + AND_WHEN("We grow the set and propagate") { + set_ptr->assign(state, {5, 7, 1}); + graph.propagate(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({5 * 5 + 7 * 7 + 1 * 1})); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 1); + CHECK(matmul_ptr->shape(state).size() == 0); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({0.0})); + } + } + + AND_WHEN("We commit") { + graph.commit(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 1); + CHECK(matmul_ptr->shape(state).size() == 0); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({5 * 5 + 7 * 7 + 1 * 1})); + } + } + } + } + } + + GIVEN("A 2d dynamic testing node and a 1d constant") { + auto arr_ptr = graph.emplace_node( + std::initializer_list{-1, 3}, -3.0, 10.0, false); + auto c_ptr = graph.emplace_node(std::vector{1, 2, 3}); + + CHECK_THROWS_AS(MatrixMultiplyNode(c_ptr, arr_ptr), std::invalid_argument); + + auto matmul_ptr = graph.emplace_node(arr_ptr, c_ptr); + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->dynamic()); + CHECK(matmul_ptr->ndim() == 1); + + CHECK(matmul_ptr->min() == 3.0 * -3.0 * 3); + CHECK(matmul_ptr->max() == 3.0 * 10.0 * 3); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + REQUIRE(arr_ptr->size(state) == 0); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK(matmul_ptr->size(state) == 0); + CHECK_THAT(matmul_ptr->shape(state), RangeEquals({0})); + CHECK(matmul_ptr->view(state).size() == 0); + } + + AND_WHEN("We grow the set and propagate") { + arr_ptr->grow(state, {5.0, 7.0, 9.0, 6.0, 8.0, 10.0}); + graph.propagate(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 2); + CHECK_THAT(matmul_ptr->shape(state), RangeEquals({2})); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({46, 52})); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 0); + CHECK_THAT(matmul_ptr->shape(state), RangeEquals({0})); + CHECK(matmul_ptr->view(state).size() == 0); + } + } + + AND_WHEN("We commit") { + graph.commit(state); + + THEN("The state is correct") { + CHECK(matmul_ptr->size(state) == 2); + CHECK_THAT(matmul_ptr->shape(state), RangeEquals({2})); + CHECK_THAT(matmul_ptr->view(state), RangeEquals({46, 52})); + } + } + } + } + } + + GIVEN("A 2d dynamic testing node and a 1d column slice") { + auto arr_ptr = + graph.emplace_node(std::initializer_list{-1, 3}); + auto vec_ptr = graph.emplace_node(arr_ptr, Slice(), 0); + + CHECK_THROWS_AS(MatrixMultiplyNode(arr_ptr, vec_ptr), std::invalid_argument); + + auto matmul_ptr = graph.emplace_node(vec_ptr, arr_ptr); + graph.emplace_node(matmul_ptr); + + CHECK(not matmul_ptr->dynamic()); + CHECK(matmul_ptr->ndim() == 1); + CHECK(matmul_ptr->size() == 3); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({3})); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + REQUIRE(arr_ptr->size(state) == 0); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({0, 0, 0})); + } + + AND_WHEN("We grow the set and propagate") { + arr_ptr->grow(state, {5.0, 7.0, 9.0, 6.0, 8.0, 10.0}); + graph.propagate(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({61, 83, 105})); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({0, 0, 0})); + } + } + + AND_WHEN("We commit") { + graph.commit(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals({61, 83, 105})); + } + } + } + } + } + + GIVEN("A 4d dynamic testing node and a matmulable reshape") { + auto arr_ptr = graph.emplace_node( + std::initializer_list{-1, 3, 2, 7}); + auto reshape_ptr = + graph.emplace_node(arr_ptr, std::vector{-1, 3, 7, 2}); + + auto matmul_ptr = graph.emplace_node(arr_ptr, reshape_ptr); + graph.emplace_node(matmul_ptr); + + CHECK(matmul_ptr->dynamic()); + CHECK(matmul_ptr->ndim() == 4); + CHECK_THAT(matmul_ptr->shape(), RangeEquals({-1, 3, 2, 2})); + + WHEN("We initialize a state") { + auto state = graph.initialize_state(); + REQUIRE(arr_ptr->size(state) == 0); + + THEN("The initial MatrixMultiplyNode state is correct") { + CHECK(matmul_ptr->view(state).size() == 0); + } + + AND_WHEN("We grow the set and propagate") { + std::vector arr_data(2 * 3 * 2 * 7); + std::iota(arr_data.begin(), arr_data.end(), 0.0); + arr_ptr->grow(state, arr_data); + graph.propagate(state); + + std::vector expected = {182, 203, 476, 546, 2436, 2555, + 3416, 3584, 7434, 7651, 9100, 9366, + 15176, 15491, 17528, 17892, 25662, 26075, + 28700, 29162, 38892, 39403, 42616, 43176}; + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals(expected)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The state is correct") { CHECK(matmul_ptr->view(state).size() == 0); } + } + + AND_WHEN("We commit") { + graph.commit(state); + + THEN("The state is correct") { + CHECK_THAT(matmul_ptr->view(state), RangeEquals(expected)); + } + } + } + } + } +} + +} // namespace dwave::optimization diff --git a/tests/test_symbols.py b/tests/test_symbols.py index abb74cb2..bed3bc69 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -38,6 +38,7 @@ logical_or, logical_not, logical_xor, + matmul, mod, put, rint, @@ -1156,6 +1157,7 @@ def test_interning(self): self.assertEqual(z.shape(), (4, 2)) self.assertNotEqual(x.id(), z.id()) + class TestCopy(utils.SymbolTests): def generate_symbols(self): model = Model() @@ -2279,6 +2281,110 @@ def test_serialization_with_states(self): np.testing.assert_array_equal(lp.state(3), [1, 1]) +class TestMatrixMultiply(utils.SymbolTests): + def generate_symbols(self): + model = Model() + c = model.constant(self._shaped_range(3, 4)) + c_reshape = c.reshape((4, 3)) + mm = dwave.optimization.symbols.MatrixMultiply(c, c_reshape) + + with model.lock(): + yield mm + + def test_matmul(self): + model = Model() + c = model.constant(self._shaped_range(3, 4)) + c_reshape = c.reshape((4, 3)) + mm = matmul(c, c_reshape) + self.assertIsInstance(mm, dwave.optimization.symbols.MatrixMultiply) + + def test_matmul_scalar(self): + model = Model() + with self.assertRaises(ValueError): + matmul(model.constant(np.arange(4)), model.constant(2)) + with self.assertRaises(ValueError): + matmul(model.constant(2), model.constant(np.arange(4))) + with self.assertRaises(ValueError): + matmul(model.constant(self._shaped_range(2, 3)), model.constant(2)) + with self.assertRaises(ValueError): + matmul(model.constant(2), model.constant(self._shaped_range(2, 3))) + + def test_matmul_broadcast_x(self): + model = Model() + x_data = self._shaped_range(5, 2, 3, 4) + x = model.constant(x_data) + + for shape in [ + (4,), + (4, 6), + (1, 4, 3), + (5, 2, 4, 3), + (5, 1, 4, 3), + (2, 1, 1, 4, 3), + ]: + y_data = self._shaped_range(*shape) + np_res = np.matmul(x_data, y_data) + y = model.constant(y_data) + mm = matmul(x, y) + with model.lock(): + model.states.resize(1) + self.assertTrue(np.array_equal(mm.state(0), np_res)) + + with self.assertRaises(ValueError): + matmul(x, model.constant(np.ones((5, 7, 4, 3)))) + + def test_matmul_broadcast_y(self): + model = Model() + y_data = self._shaped_range(5, 2, 3, 4) + y = model.constant(y_data) + + for shape in [ + (3,), + (6, 3), + (1, 2, 3), + (5, 2, 4, 3), + (5, 1, 4, 3), + (2, 1, 1, 4, 3), + ]: + x_data = self._shaped_range(*shape) + np_res = np.matmul(x_data, y_data) + x = model.constant(x_data) + mm = matmul(x, y) + with model.lock(): + model.states.resize(1) + self.assertTrue(np.array_equal(mm.state(0), np_res)) + + with self.assertRaises(ValueError): + matmul(y, model.constant(np.ones((5, 7, 4, 3)))) + + def test_matmul_broadcast_both_operands(self): + model = Model() + + for x_shape, y_shape in [ + [(3,), (3,)], + [(7, 3, 2), (1, 2, 5)], + [(5, 7, 3, 2), (1, 1, 2, 5)], + [(1, 7, 3, 2), (5, 1, 2, 5)], + [(1, 7, 3, 2), (4, 5, 1, 2, 5)], + ]: + x_data = self._shaped_range(*x_shape) + x = model.constant(x_data) + + y_data = self._shaped_range(*y_shape) + y = model.constant(y_data) + + np_res = np.matmul(x_data, y_data) + + mm = matmul(x, y) + with model.lock(): + model.states.resize(1) + self.assertTrue(np.array_equal(mm.state(0), np_res)) + + @staticmethod + def _shaped_range(*shape): + return np.arange(np.prod(shape)).reshape(shape) + + class TestMax(utils.ReduceTests): empty_requires_initial = True