Skip to content

Conversation

@wbernoudy
Copy link
Member

@wbernoudy wbernoudy commented Nov 21, 2025

(still working on the symbol part)

MatrixMultiply symbol now handles vector-vector, matrix-vector, and ndarray-vector multiplication. The matmul method will also add in Broadcast symbols to one or both operands to handle the rest of the implicit broadcasting behavior that np.matmul supports.

@wbernoudy wbernoudy requested a review from arcondello November 21, 2025 23:29

namespace dwave::optimization {

class MatMulNode : public ArrayOutputMixin<ArrayNode> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MatrixMultiplyNode

if (!n) throw std::invalid_argument("cannot divide by 0");
multiplier /= n;
offset /= n;
if (min.has_value()) min.value() /= n;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, should these also be fractions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The min/max should always be multiplies of any valid divisor here. Though I should assert that this is the case

sizeinfo_(get_sizeinfo(x_ptr, y_ptr)),
values_info_(get_values_info(x_ptr, y_ptr)) {}

class MatMulNodeData : public ArrayNodeStateData {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO clearer to put this above the MatMulNode constructor

auto x_data = x_ptr_->view(state);
auto y_data = y_ptr_->view(state);

ssize_t x_penultimate_axis_size = get_axis_size(x_ptr_->shape(state), -2, true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bunch of these could/should be const.

Copy link
Member

@arcondello arcondello left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic all seems good to me with the possible exception of the size_from_shape() method and dynamic shapes.

Otherwise just a few nits.

// (-1) and (-1) -> ()
// (-1) and (-1, 5) -> (5)

ssize_t size_from_shape(std::span<const ssize_t> shape) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static ssize_t shape_to_size(const std::span<const ssize_t> shape) noexcept {

I also keep forgetting it exists.

Though I think maybe you're using the fact that it will return a negative number in the case of dynamic arrays?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am. And I'm using it in non-class methods, so not sure how I can call it given it's a protected method on Array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could easily avoid the negative number trick though. Should I make Array::shape_to_size() public?

SizeInfo sizeinfo() const override;

private:
void matmul(State& state, std::span<double> out, std::span<const ssize_t> out_shape) const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are adopting this convention: #398 so these need trailing underscores.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, forgot about that

std::array<double, 4> combos{x_ptr->min() * y_ptr->min(), x_ptr->min() * y_ptr->max(),
x_ptr->max() * y_ptr->min(), x_ptr->max() * y_ptr->max()};

double min_val = std::ranges::min(combos);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and others could be marked const

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, how nice it would be to have const-by-default mode.

@arcondello arcondello added the enhancement New feature or request label Nov 26, 2025
@arcondello arcondello linked an issue Nov 26, 2025 that may be closed by this pull request
and use it in MatrixMultiplyNode, along with other small code
improvements. Also fixed a bug in the ValuesInfo for MatrixMultiplyNode
as it was not always calculating the contracted axis size correctly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Consider adding matmul symbol or similar

2 participants