|
8 | 8 | from aesara.graph.null_type import NullType
|
9 | 9 | from aesara.graph.op import Op
|
10 | 10 | from aesara.scalar.basic import constant as scalar_constant
|
11 |
| -from aesara.tensor import get_scalar_constant_value |
12 |
| -from aesara.tensor.basic import ExtractDiag, atleast_Nd |
13 |
| -from aesara.tensor.elemwise import DimShuffle, Elemwise |
| 11 | +from aesara.scalar.basic import int64 |
| 12 | +from aesara.tensor import get_gufunc_signature, get_scalar_constant_value |
| 13 | +from aesara.tensor.basic import atleast_Nd |
| 14 | +from aesara.tensor.elemwise import DimShuffle |
14 | 15 | from aesara.tensor.exceptions import NotScalarConstantError
|
15 | 16 | from aesara.tensor.extra_ops import broadcast_shape
|
16 | 17 | from aesara.tensor.math import sum as at_sum
|
17 | 18 | from aesara.tensor.shape import shape_tuple
|
18 |
| -from aesara.tensor.subtensor import Subtensor |
19 | 19 | from aesara.tensor.type import TensorType
|
20 | 20 |
|
21 | 21 |
|
@@ -110,7 +110,15 @@ def get_dim_size(x):
|
110 | 110 | res = dim_sizes.get(x)
|
111 | 111 |
|
112 | 112 | if res is None:
|
113 |
| - return scalar_constant(int(x)) |
| 113 | + try: |
| 114 | + return scalar_constant(int(x)) |
| 115 | + except (ValueError, TypeError): |
| 116 | + # Return a symbolic placeholder for new dimension references |
| 117 | + # For example, a signature like `("m", "n") -> ("p",)` means |
| 118 | + # that there will be no `"p"` label to reference in `dim_sizes` |
| 119 | + # (i.e. pre-existing dimension labels that already have values |
| 120 | + # assigned to them). |
| 121 | + return int64(name=x) |
114 | 122 |
|
115 | 123 | return res
|
116 | 124 |
|
@@ -292,35 +300,12 @@ def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable:
|
292 | 300 |
|
293 | 301 | return atleast_Nd(res, n=nd)
|
294 | 302 |
|
295 |
| - if isinstance(node.op, (Subtensor, ExtractDiag)): |
296 |
| - return var |
297 |
| - |
298 | 303 | blocked_inputs = [transform(ipt, node) for ipt in node.inputs]
|
299 |
| - grad_signature = getattr(node.op, "gufunc_sig", None) |
300 |
| - op = node.op |
301 |
| - |
302 |
| - if grad_signature is None: |
303 |
| - if isinstance(op, DimShuffle): |
304 |
| - # remove the extra dimensions that |
305 |
| - # we have added during op creation |
306 |
| - new_order = [i for i in op.new_order if i != "x"] |
307 |
| - |
308 |
| - # derive gufunc signature for DimShuffle |
309 |
| - input_signature: Tuple[str, ...] = tuple( |
310 |
| - f"a{i}" for i in range(len(new_order)) |
311 |
| - ) |
312 |
| - output_signature: Tuple[str, ...] = tuple( |
313 |
| - f"a{i}" if i != "x" else "1" for i in op.new_order |
314 |
| - ) |
315 |
| - grad_signature = ((input_signature,), (output_signature,)) |
316 |
| - elif isinstance(op, Elemwise): |
317 |
| - op = op.scalar_op |
318 |
| - grad_signature = (((),) * len(blocked_inputs), ((),)) |
319 |
| - else: |
320 |
| - raise ValueError(f"'{op}' object has no attribute 'gufunc_sig'") |
321 |
| - |
322 |
| - new_r = Blockwise(op, signature=grad_signature)(*blocked_inputs) |
| 304 | + grad_signature = get_gufunc_signature(node.op, blocked_inputs) |
| 305 | + new_r = Blockwise(node.op, signature=grad_signature)(*blocked_inputs) |
| 306 | + |
323 | 307 | assert isinstance(new_r, Variable)
|
| 308 | + |
324 | 309 | return new_r
|
325 | 310 |
|
326 | 311 | ret = []
|
|
0 commit comments