-
-
Notifications
You must be signed in to change notification settings - Fork 151
Add Blockwise
Op
#1215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Blockwise
Op
#1215
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1215 +/- ##
==========================================
+ Coverage 75.02% 79.16% +4.14%
==========================================
Files 194 174 -20
Lines 50099 48677 -1422
Branches 12096 10359 -1737
==========================================
+ Hits 37586 38536 +950
+ Misses 10189 7640 -2549
- Partials 2324 2501 +177
|
Don't forget to rebase onto |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! The next step involves extending the number of gufunc_sig
s we specify and adding the associated tests.
The big, open question is whether or not we can replace Elemwise
with this new Op
. When we demonstrate that this Op
can at least handle all the standard Elemwise
cases, then we'll start exploring this question further, though. In other words, we don't want to start considering all the other changes (e.g. Blockwise.c_code
, Numba/JAX transpilations, etc.) until we've demonstrated good test coverage (both Elemwise
/scalar broadcasting cases and otherwise).
x = Blockwise(op)(*args) | ||
x_fn = aesara.function(args, x) | ||
|
||
x_fn(*arg_vals) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're going to need to assert
something about this output.
gufunc_sig = ((("m", "n"), ("n", "p")), (("m", "p"),)) | ||
|
||
__props__ = ("gufunc_sig",) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: We'll need to create these kinds of signatures for every applicable Op
.
a57528c
to
f770ada
Compare
f770ada
to
6cda5c3
Compare
What should be the signature for Subtensor Op and Shape Op ? |
If you're talking about constructing symbolic graphs, the signatures are ultimately determined by their |
Yes, got it now |
Hello, @brandonwillard. You can reproduce the error using the following command. |
It looks like I'm guessing Regardless, we shouldn't need new |
0792e8a
to
fdb3045
Compare
877d04d
to
c9ad602
Compare
9f973b9
to
4a926fa
Compare
f045067
to
a5de97c
Compare
b2b2604
to
c2470b5
Compare
aesara/tensor/blockwise.py
Outdated
""" | ||
op = node.op | ||
in_shapes = tuple( | ||
tuple(lscalar(f"i{s}") for s in range(inp.type.ndim)) for inp in node.inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we need to change the labels to reflect the inputs. For example:
tuple(lscalar(f"i{s}") for s in range(inp.type.ndim)) for inp in node.inputs | |
tuple(lscalar(f"i{n}{s}") for s in range(inp.type.ndim)) for n, inp in enumerate(node.inputs) |
Co-authored-by: Brandon T. Willard <[email protected]> Co-authored-by: Sayam Kumar <[email protected]> Co-authored-by: Kaustubh <[email protected]>
3ed3497
to
c7b0d10
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added comments for some of the changes we made locally during the meeting.
), | ||
(("n", "m"),), | ||
) | ||
__props__ = ("dtype", "gufunc_sig") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__props__ = ("dtype", "gufunc_sig") | |
__props__ = ("dtype",) |
|
||
__props__ = ("offset", "axis1", "axis2") | ||
gufunc_sig = (((),), (("m", "m"),)) | ||
__props__ = ("offset", "axis1", "axis2", "gufunc_sig") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__props__ = ("offset", "axis1", "axis2", "gufunc_sig") | |
__props__ = ("offset", "axis1", "axis2",) |
return Apply(self, list(inputs), outputs) | ||
|
||
def __str__(self): | ||
return f"{type(self).__name__}{{op={self.op}}}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return f"{type(self).__name__}{{op={self.op}}}" | |
return f"{type(self).__name__}{{{self.op}, {self.signature}}}" |
# The gradient contains a constant | ||
# res = aesara.tensor.basic.constant( | ||
# np.asarray(var.data), dtype=var.type.dtype | ||
# ) | ||
res = var | ||
|
||
# TODO FIXME: Use dimensions of relevant/appropriate inputs. | ||
# What exactly are those in this case? | ||
nd = inputs[0].type.ndim | ||
|
||
return atleast_Nd(res, n=nd) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# The gradient contains a constant | |
# res = aesara.tensor.basic.constant( | |
# np.asarray(var.data), dtype=var.type.dtype | |
# ) | |
res = var | |
# TODO FIXME: Use dimensions of relevant/appropriate inputs. | |
# What exactly are those in this case? | |
nd = inputs[0].type.ndim | |
return atleast_Nd(res, n=nd) | |
return var |
|
||
__props__ = ("lower", "destructive", "on_error") | ||
gufunc_sig = ((("m", "m"),), (("m", "m"),)) | ||
__props__ = ("lower", "destructive", "on_error", "gufunc_sig") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__props__ = ("lower", "destructive", "on_error", "gufunc_sig") | |
__props__ = ("lower", "destructive", "on_error",) |
from aesara.tensor.basic import Tri | ||
|
||
blk_op = Blockwise(op=Tri(dtype="float64"), signature=(((), (), ()), (("n", "m"),))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from aesara.tensor.basic import Tri | |
blk_op = Blockwise(op=Tri(dtype="float64"), signature=(((), (), ()), (("n", "m"),))) | |
blk_op = Blockwise(op=Tri(dtype="float64")) |
blk_op = Blockwise(op=Tri(dtype="float64"), signature=(((), (), ()), (("n", "m"),))) | ||
out_dtype, output_shapes, inputs = blk_op.get_output_info(a, b, c) | ||
|
||
assert out_dtype == ["float64"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to assert
something about output_shapes
(i.e. make sure they're correct in some way).
This PR builds off of #757 and closes #695.
To #757 it adds:
get_output_info()
, which is the same asElemwise
get_output_info(), to make all inputs of the same dimension.grad
is computedDifferences with #757:
curr_static_shape
ofcore_inp_grads
use the dimensions from the end.perform()
ofDimShuffle
(which can be removed later)