-
Notifications
You must be signed in to change notification settings - Fork 30
Add MatMul symbol #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add MatMul symbol #434
Conversation
|
|
||
| namespace dwave::optimization { | ||
|
|
||
| class MatMulNode : public ArrayOutputMixin<ArrayNode> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MatrixMultiplyNode
| if (!n) throw std::invalid_argument("cannot divide by 0"); | ||
| multiplier /= n; | ||
| offset /= n; | ||
| if (min.has_value()) min.value() /= n; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, should these also be fractions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. The min/max should always be multiplies of any valid divisor here. Though I should assert that this is the case
| sizeinfo_(get_sizeinfo(x_ptr, y_ptr)), | ||
| values_info_(get_values_info(x_ptr, y_ptr)) {} | ||
|
|
||
| class MatMulNodeData : public ArrayNodeStateData { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO clearer to put this above the MatMulNode constructor
| auto x_data = x_ptr_->view(state); | ||
| auto y_data = y_ptr_->view(state); | ||
|
|
||
| ssize_t x_penultimate_axis_size = get_axis_size(x_ptr_->shape(state), -2, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bunch of these could/should be const.
dwave/optimization/include/dwave-optimization/nodes/linear_algebra.hpp
Outdated
Show resolved
Hide resolved
and address review comments
and add asserts in SizeInfo::operator/
arcondello
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic all seems good to me with the possible exception of the size_from_shape() method and dynamic shapes.
Otherwise just a few nits.
| // (-1) and (-1) -> () | ||
| // (-1) and (-1, 5) -> (5) | ||
|
|
||
| ssize_t size_from_shape(std::span<const ssize_t> shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| static ssize_t shape_to_size(const std::span<const ssize_t> shape) noexcept { |
I also keep forgetting it exists.
Though I think maybe you're using the fact that it will return a negative number in the case of dynamic arrays?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am. And I'm using it in non-class methods, so not sure how I can call it given it's a protected method on Array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could easily avoid the negative number trick though. Should I make Array::shape_to_size() public?
| SizeInfo sizeinfo() const override; | ||
|
|
||
| private: | ||
| void matmul(State& state, std::span<double> out, std::span<const ssize_t> out_shape) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are adopting this convention: #398 so these need trailing underscores.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, forgot about that
| std::array<double, 4> combos{x_ptr->min() * y_ptr->min(), x_ptr->min() * y_ptr->max(), | ||
| x_ptr->max() * y_ptr->min(), x_ptr->max() * y_ptr->max()}; | ||
|
|
||
| double min_val = std::ranges::min(combos); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and others could be marked const
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, how nice it would be to have const-by-default mode.
and use it in MatrixMultiplyNode, along with other small code improvements. Also fixed a bug in the ValuesInfo for MatrixMultiplyNode as it was not always calculating the contracted axis size correctly.
(still working on the symbol part)MatrixMultiplysymbol now handles vector-vector, matrix-vector, and ndarray-vector multiplication. Thematmulmethod will also add inBroadcastsymbols to one or both operands to handle the rest of the implicit broadcasting behavior thatnp.matmulsupports.