Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions dwave/optimization/include/dwave-optimization/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,39 @@ struct SizeInfo {
}
bool operator==(const SizeInfo& other) const;

constexpr SizeInfo& operator*=(const std::integral auto n) {
multiplier *= n;
offset *= n;
if (min.has_value()) min.value() *= n;
if (max.has_value()) max.value() *= n;
return *this;
}
friend SizeInfo operator*(SizeInfo lhs, const std::integral auto rhs) {
lhs *= rhs;
return lhs;
}

constexpr SizeInfo& operator/=(const std::integral auto n) {
if (!n) throw std::invalid_argument("cannot divide by 0");
multiplier /= n;
offset /= n;
if (min.has_value()) {
assert(min.value() % n == 0 and
"dividing SizeInfo with a divisor that does not evenly divide the minimum");
min.value() /= n;
}
if (max.has_value()) {
assert(max.value() % n == 0 and
"dividing SizeInfo with a divisor that does not evenly divide the maximum");
max.value() /= n;
}
return *this;
}
friend SizeInfo operator/(SizeInfo lhs, const std::integral auto rhs) {
lhs /= rhs;
return lhs;
}

// SizeInfos are printable
friend std::ostream& operator<<(std::ostream& os, const SizeInfo& sizeinfo);

Expand Down Expand Up @@ -450,6 +483,19 @@ class Array {
return 0;
}

// Determine the size by the shape. For a node with a fixed size, it is simply
// the product of the shape.
// Expects the shape to be stored in a C-style array of length ndim.
static ssize_t shape_to_size(const ssize_t ndim, const ssize_t* const shape) noexcept {
if (ndim <= 0) return 1;
if (shape[0] < 0) return DYNAMIC_SIZE;
return std::reduce(shape, shape + ndim, 1, std::multiplies<ssize_t>());
}

static ssize_t shape_to_size(const std::span<const ssize_t> shape) noexcept {
return shape_to_size(shape.size(), shape.data());
}

protected:
// Some utility methods that might be useful to subclasses

Expand All @@ -476,19 +522,6 @@ class Array {
return true;
}

// Determine the size by the shape. For a node with a fixed size, it is simply
// the product of the shape.
// Expects the shape to be stored in a C-style array of length ndim.
static ssize_t shape_to_size(const ssize_t ndim, const ssize_t* shape) noexcept {
if (ndim <= 0) return 1;
if (shape[0] < 0) return DYNAMIC_SIZE;
return std::reduce(shape, shape + ndim, 1, std::multiplies<ssize_t>());
}

static ssize_t shape_to_size(const std::span<const ssize_t> shape) noexcept {
return shape_to_size(shape.size(), shape.data());
}

// Determine the strides from the shape.
// Assumes itemsize = sizeof(double).
// Expects the shape to be stored in a C-style array of length ndim.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2025 D-Wave Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>

#include "dwave-optimization/array.hpp"
#include "dwave-optimization/graph.hpp"

namespace dwave::optimization {

class MatrixMultiplyNode : public ArrayOutputMixin<ArrayNode> {
public:
MatrixMultiplyNode(ArrayNode* x_ptr, ArrayNode* y_ptr);

/// @copydoc Array::buff()
double const* buff(const State& state) const override;

/// @copydoc Node::commit()
void commit(State& state) const override;

/// @copydoc Array::diff()
std::span<const Update> diff(const State& state) const override;

/// @copydoc Node::initialize_state()
void initialize_state(State& state) const override;

/// @copydoc Array::integral()
bool integral() const override;

/// @copydoc Array::max()
double max() const override;

/// @copydoc Array::min()
double min() const override;

/// @copydoc Node::propagate()
void propagate(State& state) const override;

/// @copydoc Node::revert()
void revert(State& state) const override;

/// @copydoc Array::shape()
using Array::shape;
std::span<const ssize_t> shape(const State& state) const override;

/// @copydoc Array::size()
using Array::size;
ssize_t size(const State& state) const override;

/// @copydoc Array::size_diff()
ssize_t size_diff(const State& state) const override;

/// @copydoc Array::size_info()
SizeInfo sizeinfo() const override;

private:
void matmul_(State& state, std::span<double> out, std::span<const ssize_t> out_shape) const;
void update_shape_(State& state) const;

const ArrayNode* x_ptr_;
const ArrayNode* y_ptr_;

const SizeInfo sizeinfo_;
const ValuesInfo values_info_;
};

} // namespace dwave::optimization
20 changes: 20 additions & 0 deletions dwave/optimization/libcpp/nodes/linear_algebra.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dwave.optimization.libcpp.graph cimport ArrayNode


cdef extern from "dwave-optimization/nodes/linear_algebra.hpp" namespace "dwave::optimization" nogil:
cdef cppclass MatrixMultiplyNode(ArrayNode):
pass
72 changes: 72 additions & 0 deletions dwave/optimization/mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
LinearProgramSolution,
Log,
Logical,
MatrixMultiply,
Maximum,
Mean,
Minimum,
Expand Down Expand Up @@ -88,6 +89,7 @@
"logical_not",
"logical_or",
"logical_xor",
"matmul",
"maximum",
"mean",
"minimum",
Expand Down Expand Up @@ -1060,6 +1062,76 @@ def logical_xor(x1: ArraySymbol, x2: ArraySymbol) -> Xor:
return Xor(x1, x2)


def matmul(x: ArraySymbol, y: ArraySymbol) -> MatrixMultiply:
r"""Compute the matrix product of two array symbols.

Args:
x, y: Operand array symbols. The size of the last axis of `x` must be
equal to the size of the second to last axis of `y`. If `x` or `y`
are 1-d, they will treated as if they a row or column vector
respectively. If both are 1-d, this will produce a scalar (and the
operation is equivalent to the dot product of two vectors).

Otherwise, if it is possible that the shapes can be broadcast
together, either by broadcasting missing leading axes (e.g.
`(2, 5, 3, 4, 2)` and `(2, 6)` -> `(2, 5, 3, 4, 6)`) or by
broadcasting axes for which one of the operands has size 1 (e.g.
`(3, 1, 4, 2)` and `(1, 7, 2, 3)` -> `(3, 7, 4, 3)`), then this
will return the result after broadcasting.

Returns:
A MatrixMultiply symbol representing the matrix product. If `x` and
`y` have shapes `(..., n, k)` and `(..., k, m)`, then the output will
have shape `(..., n, m)`.

Examples:
This example computes the dot product of two integer arrays.

>>> from dwave.optimization import Model
>>> from dwave.optimization.mathematical import matmul
...
>>> model = Model()
>>> i = model.integer(3)
>>> j = model.integer(3)
>>> m = matmul(i, j)
>>> with model.lock():
... model.states.resize(1)
... i.set_state(0, [1, 2, 3])
... j.set_state(0, [4, 5, 6])
... print(m.state(0))
32.0

See Also:
:class:`~dwave.optimization.symbols.MatrixMultiply`: equivalent symbol.

.. versionadded:: 0.6.10
"""

if not (x.ndim() == 1 or y.ndim() == 1) and x.shape()[:-2] != y.shape()[:-2]:
# The shapes don't match, but it may be possible to do a broadcast. The
# vector broadcast case is handled by MatrixMultiplyNode, so we need to
# handle all the other cases by adding one or two BroadcastNodes.
x_shape = [1,] * (y.ndim() - x.ndim()) + list(x.shape())
y_shape = [1,] * (x.ndim() - y.ndim()) + list(y.shape())

for i in range(len(x_shape) - 2):
if x_shape[i] == 1:
x_shape[i] = y_shape[i]
elif y_shape[i] == 1:
y_shape[i] = x_shape[i]
elif x_shape[i] != y_shape[i]:
raise ValueError("Could not broadcast operands")

if tuple(x_shape) != x.shape():
x = broadcast_to(x, tuple(x_shape))
if tuple(y_shape) != y.shape():
y = broadcast_to(y, tuple(y_shape))

return MatrixMultiply(x, y)

return MatrixMultiply(x, y)


@_op(Maximum, NaryMaximum, "max")
def maximum(x1: ArraySymbol, x2: ArraySymbol, *xi: ArraySymbol,
) -> typing.Union[Maximum, NaryMaximum]:
Expand Down
Loading