Skip to content

Commit 8ac7562

Browse files
purna135brandonwillard
authored andcommitted
manage the Ops which support nd inputs
1 parent 71cf816 commit 8ac7562

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

aesara/tensor/blockwise.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
from aesara.graph.op import Op
1010
from aesara.scalar.basic import constant as scalar_constant
1111
from aesara.tensor import get_scalar_constant_value
12-
from aesara.tensor.basic import atleast_Nd
12+
from aesara.tensor.basic import ExtractDiag, atleast_Nd
1313
from aesara.tensor.elemwise import DimShuffle, Elemwise
1414
from aesara.tensor.exceptions import NotScalarConstantError
1515
from aesara.tensor.extra_ops import broadcast_shape
1616
from aesara.tensor.math import sum as at_sum
1717
from aesara.tensor.shape import shape_tuple
18+
from aesara.tensor.subtensor import Subtensor
1819
from aesara.tensor.type import TensorType
1920

2021

@@ -291,6 +292,9 @@ def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable:
291292

292293
return atleast_Nd(res, n=nd)
293294

295+
if isinstance(node.op, (Subtensor, ExtractDiag)):
296+
return var
297+
294298
blocked_inputs = [transform(ipt, node) for ipt in node.inputs]
295299
grad_signature = getattr(node.op, "gufunc_sig", None)
296300
op = node.op

0 commit comments

Comments
 (0)