Skip to content

Commit 0792e8a

Browse files
brandonwillardBrandon T. WillardSayam753kc611
authored andcommitted
Add a Blockwise Op
Co-authored-by: Brandon T. Willard <[email protected]> Co-authored-by: Sayam Kumar <[email protected]> Co-authored-by: Kaustubh <[email protected]>
1 parent a6a5602 commit 0792e8a

File tree

5 files changed

+575
-5
lines changed

5 files changed

+575
-5
lines changed

aesara/tensor/blockwise.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, cast
2+
3+
import numpy as np
4+
5+
import aesara
6+
from aesara.gradient import DisconnectedType
7+
from aesara.graph.basic import Apply, Variable
8+
from aesara.graph.null_type import NullType
9+
from aesara.graph.op import Op
10+
from aesara.tensor import get_scalar_constant_value
11+
from aesara.tensor.basic import atleast_Nd
12+
from aesara.tensor.elemwise import DimShuffle, Elemwise
13+
from aesara.tensor.exceptions import NotScalarConstantError
14+
from aesara.tensor.extra_ops import broadcast_shape
15+
from aesara.tensor.math import sum as at_sum
16+
from aesara.tensor.shape import shape_tuple
17+
from aesara.tensor.type import TensorType
18+
19+
20+
if TYPE_CHECKING:
21+
from aesara.tensor.var import TensorVariable
22+
23+
24+
def _update_dim_sizes(
25+
dim_sizes: Dict[str, "TensorVariable"],
26+
arg: "TensorVariable",
27+
core_dims: Tuple[str, ...],
28+
):
29+
"""Incrementally check and update core dimension sizes for a single argument.
30+
31+
From `numpy.lib.function_base`.
32+
33+
Parameters
34+
----------
35+
dim_sizes
36+
Sizes of existing core dimensions. Will be updated in-place.
37+
arg
38+
Argument to examine.
39+
core_dims
40+
Core dimensions for this argument.
41+
"""
42+
if not core_dims:
43+
return
44+
45+
num_core_dims = len(core_dims)
46+
if arg.type.ndim < num_core_dims:
47+
raise ValueError(
48+
f"{arg.type.ndim}-dimensional argument does not have enough "
49+
f"dimensions for all core dimensions: {core_dims}"
50+
)
51+
52+
core_shape = shape_tuple(arg)[-num_core_dims:]
53+
for dim, size in zip(core_dims, core_shape):
54+
if dim not in dim_sizes:
55+
dim_sizes[dim] = cast("TensorVariable", size)
56+
# else:
57+
# # This check can't be done (sufficiently) at compile-time
58+
# if size != dim_sizes[dim]:
59+
# raise ValueError(
60+
# f"Inconsistent size for core dimension {dim}: {size} vs {dim_sizes[dim]}"
61+
# )
62+
63+
64+
def _parse_input_dimensions(
65+
args: Tuple["TensorVariable", ...], input_core_dims: List[Tuple[str, ...]]
66+
) -> Tuple[Tuple[Variable, ...], Dict[str, "TensorVariable"]]:
67+
"""Parse broadcast and core dimensions for vectorize with a signature.
68+
69+
From `numpy.lib.function_base`.
70+
71+
Parameters
72+
----------
73+
args
74+
Tuple of input arguments to examine.
75+
input_core_dims
76+
List of core dimensions corresponding to each input.
77+
78+
Returns
79+
-------
80+
broadcast_shape
81+
Common shape to broadcast all non-core dimensions to.
82+
dim_sizes
83+
Common sizes for named core dimensions.
84+
"""
85+
broadcast_args = []
86+
dim_sizes: Dict[str, "TensorVariable"] = {}
87+
for arg, core_dims in zip(args, input_core_dims):
88+
_update_dim_sizes(dim_sizes, arg, core_dims)
89+
ndim = arg.ndim - len(core_dims)
90+
arg_shape = shape_tuple(arg)
91+
broadcast_args.append(arg_shape[:ndim])
92+
bcast_shape = broadcast_shape(*broadcast_args, arrays_are_shapes=True)
93+
return bcast_shape, dim_sizes
94+
95+
96+
def _calculate_shapes(
97+
broadcast_shape: Tuple[Variable, ...],
98+
dim_sizes: Dict[str, Variable],
99+
list_of_core_dims: List[Tuple[str, ...]],
100+
) -> List[Tuple[Variable, ...]]:
101+
"""Helper for calculating broadcast shapes with core dimensions.
102+
103+
From `numpy.lib.function_base`.
104+
105+
"""
106+
return [
107+
broadcast_shape + tuple(dim_sizes[dim] for dim in core_dims)
108+
for core_dims in list_of_core_dims
109+
]
110+
111+
112+
def gufunc_sign_to_str(sign):
113+
in_sign = [f"({','.join(_sign)})" for _sign in sign[0]]
114+
out_sign = [f"({','.join(_sign)})" for _sign in sign[1]]
115+
return f"{','.join(in_sign)}->{','.join(out_sign)}"
116+
117+
118+
class Blockwise(Op):
119+
__props__ = ("op", "signature")
120+
121+
def __init__(self, op, signature=None):
122+
self.op = op
123+
self.signature = signature or self.op.gufunc_sig
124+
125+
def get_output_info(self, *inputs):
126+
"""Return the outputs dtype and broadcastable pattern and the
127+
dimshuffled inputs.
128+
129+
"""
130+
# ensure that all inputs have the code dimensions
131+
core_inputs = []
132+
for input, signature in zip(inputs, self.signature[0]):
133+
core_dimension = len(signature)
134+
if core_dimension > input.type.ndim:
135+
difference = core_dimension - input.type.ndim
136+
core_inputs.append(
137+
DimShuffle(
138+
input.type.broadcastable,
139+
list(range(input.type.ndim)) + ["x"] * difference,
140+
)(input)
141+
)
142+
else:
143+
core_inputs.append(input)
144+
145+
# remore the core dimension first the then broadcast the rest of the dimension
146+
max_loop_dimension = max(
147+
core_inputs[i].type.ndim - len(self.signature[0][i])
148+
for i in range(len(core_inputs))
149+
)
150+
151+
broadcasted_inputs = []
152+
for input, signature in zip(core_inputs, self.signature[0]):
153+
core_dimension = len(signature)
154+
loop_dimension = input.type.ndim - core_dimension
155+
difference = max_loop_dimension - loop_dimension
156+
157+
if difference == 0:
158+
broadcasted_inputs.append(input)
159+
else:
160+
broadcasted_inputs.append(
161+
DimShuffle(
162+
input.type.broadcastable,
163+
["x"] * difference + list(range(input.type.ndim)),
164+
)(input)
165+
)
166+
inputs = broadcasted_inputs
167+
168+
# TODO: Correct this
169+
out_dtype = inputs[0].dtype
170+
171+
bcast_shape, dim_sizes = _parse_input_dimensions(inputs, self.signature[0])
172+
output_shapes = _calculate_shapes(bcast_shape, dim_sizes, self.signature[1])
173+
174+
return out_dtype, output_shapes, inputs
175+
176+
def make_node(self, *inputs):
177+
num_expected_inps = len(self.signature[0])
178+
if len(inputs) != num_expected_inps:
179+
raise ValueError(
180+
f"Expected {int(num_expected_inps)} inputs, got {len(inputs)}"
181+
)
182+
183+
out_dtype, output_shapes, inputs = self.get_output_info(*inputs)
184+
185+
def safe_const_val(x):
186+
try:
187+
return get_scalar_constant_value(x)
188+
except NotScalarConstantError:
189+
return None
190+
191+
outputs = [
192+
TensorType(out_dtype, shape=tuple(safe_const_val(s) for s in shp))()
193+
for shp in output_shapes
194+
]
195+
return Apply(self, list(inputs), outputs)
196+
197+
def infer_shape(self, fgraph, node, shapes):
198+
bcast_shape, dim_sizes = _parse_input_dimensions(node.inputs, self.signature[0])
199+
output_shapes = _calculate_shapes(bcast_shape, dim_sizes, self.signature[1])
200+
return output_shapes
201+
202+
def L_op(self, inputs, outs, ograds):
203+
# Compute grad with respect to broadcasted input
204+
rval = self._bgrad(inputs, outs, ograds)
205+
206+
# sum out the broadcasted dimensions
207+
for i, ipt in enumerate(inputs):
208+
if isinstance(rval[i].type, (NullType, DisconnectedType)):
209+
continue
210+
211+
# List of all the dimensions that are broadcastable for input[i] so
212+
# we can sum over them
213+
# TODO: only count dimensions that were effectively broadcasted
214+
to_sum = [
215+
j
216+
for j, bcast in enumerate(ipt.type.broadcastable)
217+
if bcast and not outs[0].broadcastable[j]
218+
]
219+
220+
if to_sum:
221+
sr = at_sum(rval[i], axis=to_sum, keepdims=True)
222+
rval[i] = sr
223+
224+
return rval
225+
226+
def _bgrad(
227+
self,
228+
inputs: Sequence[Variable],
229+
outputs: Sequence[Variable],
230+
ograds: Sequence[Variable],
231+
):
232+
233+
with aesara.config.change_flags(compute_test_value="off"):
234+
core_inputs = []
235+
for _inp, _inp_sig in zip(inputs, self.signature[0]):
236+
curr_dtype = _inp.type.dtype
237+
# extract the core dimensions
238+
curr_static_shape = _inp.type.shape[-len(_inp_sig) :]
239+
core_inputs.append(TensorType(curr_dtype, curr_static_shape)())
240+
241+
core_out_grads = []
242+
for _out_grad, _out_sig in zip(ograds, self.signature[1]):
243+
curr_dtype = _out_grad.type.dtype
244+
curr_static_shape = _out_grad.type.shape[-len(_out_sig) :]
245+
core_out_grads.append(TensorType(curr_dtype, curr_static_shape)())
246+
247+
core_outputs: Sequence[Variable] = self.op.make_node(*core_inputs).outputs
248+
core_inp_grads = self.op.L_op(core_inputs, core_outputs, core_out_grads)
249+
250+
for igrad in core_inp_grads:
251+
assert igrad is not None, self.op
252+
253+
def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable:
254+
"""Walk a graph and expand single gradient \"block\"s into their block-wise equivalents."""
255+
256+
if isinstance(var.type, (NullType, DisconnectedType)):
257+
return var
258+
259+
if var in core_inputs:
260+
return inputs[core_inputs.index(var)]
261+
if var in core_outputs:
262+
return outputs[core_outputs.index(var)]
263+
if var in core_out_grads:
264+
return ograds[core_out_grads.index(var)]
265+
266+
node = var.owner
267+
if node is None:
268+
# The gradient contains a constant
269+
# res = aesara.tensor.basic.constant(
270+
# np.asarray(var.data), dtype=var.type.dtype
271+
# )
272+
res = var
273+
274+
# TODO FIXME: Use dimensions of relevant/appropriate inputs.
275+
# What exactly are those in this case?
276+
nd = inputs[0].type.ndim
277+
278+
return atleast_Nd(res, nd)
279+
280+
blocked_inputs = [transform(ipt, node) for ipt in node.inputs]
281+
282+
grad_signature = getattr(node.op, "gufunc_sig", None)
283+
284+
if grad_signature is None:
285+
if isinstance(node.op, DimShuffle):
286+
# remove the extra dimensions that
287+
# we have added during op creation
288+
new_order = [i for i in node.op.new_order if i != "x"]
289+
290+
# derive gufunc signature for DimShuffle
291+
input_signature = tuple([f"a{i}" for i in range(len(new_order))])
292+
output_signature = tuple([f"a{i}" for i in new_order])
293+
grad_signature = ((input_signature,), (output_signature,))
294+
elif isinstance(node.op, Elemwise):
295+
input_len = len(blocked_inputs)
296+
input_signature = ((),) * input_len
297+
output_signature = ()
298+
grad_signature = (input_signature, (output_signature,))
299+
else:
300+
raise ValueError(
301+
f"'{node.op}' object has no attribute 'gufunc_sig'"
302+
)
303+
304+
new_r = Blockwise(node.op, signature=grad_signature)(*blocked_inputs)
305+
assert isinstance(new_r, Variable)
306+
return new_r
307+
308+
ret = []
309+
for core_inp_grad, ipt in zip(core_inp_grads, inputs):
310+
ret.append(transform(core_inp_grad, None))
311+
312+
return ret
313+
314+
def perform(self, node, inputs, outputs):
315+
def py_func(*inner_inputs):
316+
res = [[None]] * len(outputs)
317+
# TODO:This can be avoided by making a single dummy node
318+
# But will that cover all cases?
319+
inner_node = self.op.make_node(*inner_inputs)
320+
if isinstance(self.op, DimShuffle):
321+
self.op.perform(inner_node, inner_inputs, res, params=None)
322+
else:
323+
self.op.perform(inner_node, inner_inputs, res)
324+
325+
# Numpy always expects outputs to be Numpy arrays
326+
# And since we have a variable number of outputs
327+
if len(res) == 1:
328+
return res[0][0]
329+
else:
330+
return tuple(_res[0] for _res in res)
331+
332+
numpy_vec_func = np.vectorize(
333+
py_func, signature=gufunc_sign_to_str(self.signature)
334+
)
335+
res_variables = numpy_vec_func(*inputs)
336+
337+
if isinstance(res_variables, tuple):
338+
for i, out in enumerate(outputs):
339+
outputs[i][0] = res_variables[i]
340+
else:
341+
outputs[0][0] = res_variables

aesara/tensor/math.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1888,7 +1888,9 @@ class Dot(Op):
18881888
18891889
"""
18901890

1891-
__props__ = ()
1891+
gufunc_sig = ((("m", "n"), ("n", "p")), (("m", "p"),))
1892+
1893+
__props__ = ("gufunc_sig",)
18921894

18931895
# the rationale for Dot22 is related to getting GEMM Ops into the
18941896
# graph. See Dot22 in tensor.blas for details.

aesara/tensor/nlinalg.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ class MatrixInverse(Op):
110110
of ``solve``.
111111
112112
"""
113-
114-
__props__ = ()
113+
gufunc_sig = (
114+
(("m", "m"),),
115+
(("m", "m"),),
116+
)
117+
__props__ = ("gufunc_sig",)
115118

116119
def __init__(self):
117120
pass
@@ -199,7 +202,8 @@ class Det(Op):
199202
200203
"""
201204

202-
__props__ = ()
205+
gufunc_sig = ((("m", "m"),), ((),))
206+
__props__ = ("gufunc_sig",)
203207

204208
def make_node(self, x):
205209
x = as_tensor_variable(x)

aesara/tensor/slinalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ class Cholesky(Op):
3838
# TODO: inplace
3939
# TODO: for specific dtypes
4040
# TODO: LAPACK wrapper with in-place behavior, for solve also
41-
4241
__props__ = ("lower", "destructive", "on_error")
4342

4443
def __init__(self, lower=True, on_error="raise"):
@@ -430,6 +429,13 @@ class Solve(SolveBase):
430429
Solve a system of linear equations.
431430
"""
432431

432+
gufunc_sig = (
433+
(
434+
("m", "m"),
435+
("m", "k"),
436+
),
437+
(("m", "k"),),
438+
)
433439
__props__ = (
434440
"assume_a",
435441
"lower",

0 commit comments

Comments
 (0)