Skip to content

Commit 4a926fa

Browse files
Use dispatch for gufunc signature and (partially) implement Subtensor gufunc
1 parent 8ac7562 commit 4a926fa

File tree

5 files changed

+88
-34
lines changed

5 files changed

+88
-34
lines changed

aesara/tensor/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,20 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
9898
return len(var.data)
9999

100100

101+
def get_gufunc_signature(op, blocked_inputs):
102+
sig = getattr(op, "gufunc_sig", None)
103+
104+
if sig is None:
105+
return _get_gufunc_signature(op, blocked_inputs)
106+
107+
return sig
108+
109+
110+
@singledispatch
111+
def _get_gufunc_signature(op, blocked_inputs):
112+
raise ValueError(f"'{op}' object has no attribute 'gufunc_sig'")
113+
114+
101115
import aesara.tensor.exceptions # noqa
102116
from aesara.gradient import consider_constant, grad, hessian, jacobian # noqa
103117

aesara/tensor/basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from aesara.scalar.basic import ScalarConstant, ScalarVariable
3939
from aesara.tensor import (
4040
_as_tensor_variable,
41+
_get_gufunc_signature,
4142
_get_vector_length,
4243
as_tensor_variable,
4344
get_vector_length,
@@ -3469,6 +3470,12 @@ def __setstate__(self, state):
34693470
self.axis2 = 1
34703471

34713472

3473+
@_get_gufunc_signature.register(ExtractDiag)
3474+
def _get_gufunc_signature_ExtractDiag(op, blocked_inputs):
3475+
# TODO:
3476+
raise NotImplementedError()
3477+
3478+
34723479
extract_diag = ExtractDiag()
34733480
# TODO: optimization to insert ExtractDiag with view=True
34743481

aesara/tensor/blockwise.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from aesara.graph.null_type import NullType
99
from aesara.graph.op import Op
1010
from aesara.scalar.basic import constant as scalar_constant
11-
from aesara.tensor import get_scalar_constant_value
12-
from aesara.tensor.basic import ExtractDiag, atleast_Nd
13-
from aesara.tensor.elemwise import DimShuffle, Elemwise
11+
from aesara.scalar.basic import int64
12+
from aesara.tensor import get_gufunc_signature, get_scalar_constant_value
13+
from aesara.tensor.basic import atleast_Nd
14+
from aesara.tensor.elemwise import DimShuffle
1415
from aesara.tensor.exceptions import NotScalarConstantError
1516
from aesara.tensor.extra_ops import broadcast_shape
1617
from aesara.tensor.math import sum as at_sum
1718
from aesara.tensor.shape import shape_tuple
18-
from aesara.tensor.subtensor import Subtensor
1919
from aesara.tensor.type import TensorType
2020

2121

@@ -110,7 +110,15 @@ def get_dim_size(x):
110110
res = dim_sizes.get(x)
111111

112112
if res is None:
113-
return scalar_constant(int(x))
113+
try:
114+
return scalar_constant(int(x))
115+
except (ValueError, TypeError):
116+
# Return a symbolic placeholder for new dimension references
117+
# For example, a signature like `("m", "n") -> ("p",)` means
118+
# that there will be no `"p"` label to reference in `dim_sizes`
119+
# (i.e. pre-existing dimension labels that already have values
120+
# assigned to them).
121+
return int64(name=x)
114122

115123
return res
116124

@@ -292,35 +300,12 @@ def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable:
292300

293301
return atleast_Nd(res, n=nd)
294302

295-
if isinstance(node.op, (Subtensor, ExtractDiag)):
296-
return var
297-
298303
blocked_inputs = [transform(ipt, node) for ipt in node.inputs]
299-
grad_signature = getattr(node.op, "gufunc_sig", None)
300-
op = node.op
301-
302-
if grad_signature is None:
303-
if isinstance(op, DimShuffle):
304-
# remove the extra dimensions that
305-
# we have added during op creation
306-
new_order = [i for i in op.new_order if i != "x"]
307-
308-
# derive gufunc signature for DimShuffle
309-
input_signature: Tuple[str, ...] = tuple(
310-
f"a{i}" for i in range(len(new_order))
311-
)
312-
output_signature: Tuple[str, ...] = tuple(
313-
f"a{i}" if i != "x" else "1" for i in op.new_order
314-
)
315-
grad_signature = ((input_signature,), (output_signature,))
316-
elif isinstance(op, Elemwise):
317-
op = op.scalar_op
318-
grad_signature = (((),) * len(blocked_inputs), ((),))
319-
else:
320-
raise ValueError(f"'{op}' object has no attribute 'gufunc_sig'")
321-
322-
new_r = Blockwise(op, signature=grad_signature)(*blocked_inputs)
304+
grad_signature = get_gufunc_signature(node.op, blocked_inputs)
305+
new_r = Blockwise(node.op, signature=grad_signature)(*blocked_inputs)
306+
323307
assert isinstance(new_r, Variable)
308+
324309
return new_r
325310

326311
ret = []

aesara/tensor/elemwise.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from aesara.scalar.basic import bool as scalar_bool
2020
from aesara.scalar.basic import identity as scalar_identity
2121
from aesara.scalar.basic import transfer_type, upcast
22-
from aesara.tensor import _get_vector_length, as_tensor_variable
22+
from aesara.tensor import _get_gufunc_signature, _get_vector_length, as_tensor_variable
2323
from aesara.tensor import elemwise_cgen as cgen
2424
from aesara.tensor import get_vector_length
2525
from aesara.tensor.type import (
@@ -275,6 +275,20 @@ def grad(self, inp, grads):
275275
]
276276

277277

278+
@_get_gufunc_signature.register(DimShuffle)
279+
def _get_gufunc_signature_DimShuffle(op, blocked_inputs):
280+
# remove the extra dimensions that
281+
# we have added during op creation
282+
new_order = [i for i in op.new_order if i != "x"]
283+
284+
# derive gufunc signature for DimShuffle
285+
input_signature: Tuple[str, ...] = tuple(f"a{i}" for i in range(len(new_order)))
286+
output_signature: Tuple[str, ...] = tuple(
287+
f"a{i}" if i != "x" else "1" for i in op.new_order
288+
)
289+
return ((input_signature,), (output_signature,))
290+
291+
278292
class DimShufflePrinter(Printer):
279293
def __p(self, new_order, pstate, r):
280294
if new_order != () and new_order[0] == "x":
@@ -1222,6 +1236,12 @@ def c_code_cache_version_apply(self, node):
12221236
return ()
12231237

12241238

1239+
@_get_gufunc_signature.register(Elemwise)
1240+
def _get_gufunc_signature_Elemwise(op, blocked_inputs):
1241+
op = op.scalar_op
1242+
return (((),) * len(blocked_inputs), ((),))
1243+
1244+
12251245
class CAReduce(COp):
12261246
"""Reduces a scalar operation along specified axes.
12271247

aesara/tensor/subtensor.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919
from aesara.misc.safe_asarray import _asarray
2020
from aesara.printing import Printer, pprint, set_precedence
2121
from aesara.scalar.basic import ScalarConstant
22-
from aesara.tensor import _get_vector_length, as_tensor_variable, get_vector_length
22+
from aesara.tensor import (
23+
_get_gufunc_signature,
24+
_get_vector_length,
25+
as_tensor_variable,
26+
get_vector_length,
27+
)
2328
from aesara.tensor.basic import alloc, get_scalar_constant_value
2429
from aesara.tensor.elemwise import DimShuffle
2530
from aesara.tensor.exceptions import (
@@ -1200,6 +1205,29 @@ def R_op(self, inputs, eval_points):
12001205
return self(eval_points[0], *inputs[1:], return_list=True)
12011206

12021207

1208+
@_get_gufunc_signature.register(Subtensor)
1209+
def _get_gufunc_signature_Subtensor(op, blocked_inputs):
1210+
min_base_dims = len(op.idx_list)
1211+
index_input_types = get_slice_elements(
1212+
op.idx_list, lambda entry: isinstance(entry, Type)
1213+
)
1214+
1215+
indexed_input_sig = tuple(f"a{i}" for i in range(min_base_dims))
1216+
index_input_sig = tuple(
1217+
("1",) if typ.ndim == 0 else tuple(f"b{i}{j}" for j in range(typ.ndim))
1218+
for i, typ in enumerate(index_input_types)
1219+
)
1220+
1221+
# TODO: Compute the number of output dimensions
1222+
out_ndim = 1
1223+
output_sig = tuple(f"d{i}" for i in range(out_ndim))
1224+
1225+
input_signature: Tuple[str, ...] = (indexed_input_sig,) + index_input_sig
1226+
output_signature: Tuple[str, ...] = (output_sig,)
1227+
1228+
return (input_signature, output_signature)
1229+
1230+
12031231
class SubtensorPrinter(Printer):
12041232
def process(self, r, pstate):
12051233
return self._process(r.owner.op.idx_list, r.owner.inputs, pstate)

0 commit comments

Comments
 (0)