Skip to content

Commit c2470b5

Browse files
Fix a core inputs computation bug and do some refactoring
1 parent f3f256d commit c2470b5

File tree

1 file changed

+112
-62
lines changed

1 file changed

+112
-62
lines changed

aesara/tensor/blockwise.py

Lines changed: 112 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -189,65 +189,117 @@ def infer_shape_to_gufunc_sig(node: Apply, fgraph: Optional["FunctionGraph"] = N
189189
return (gufunc_inputs_sig, gufunc_outputs_sig)
190190

191191

192+
def safe_const_val(x):
193+
try:
194+
return get_scalar_constant_value(x)
195+
except NotScalarConstantError:
196+
return None
197+
198+
192199
class Blockwise(Op):
193200
__props__ = ("op", "signature")
194201

195202
def __init__(self, op, signature=None):
196203
self.op = op
197204
self.signature = signature or self.op.gufunc_sig
198205

199-
def get_output_info(self, *inputs):
200-
"""Return the outputs dtype and broadcastable pattern and the
201-
dimshuffled inputs.
206+
def get_core_inputs_outputs(self, inputs: Sequence["TensorVariable"]):
207+
"""Get the core inputs and outputs for a given set of `inputs`.
208+
209+
Parameters
210+
==========
211+
inputs
212+
The normalized, blocked inputs (i.e. "broadcasted" inputs with all
213+
the necessary dimensions added). They're needed for their dtype
214+
and static shape information.
202215
203216
"""
204-
# ensure that all inputs have the code dimensions
217+
205218
core_inputs = []
206-
for input, signature in zip(inputs, self.signature[0]):
207-
core_dimension = len(signature)
208-
if core_dimension > input.type.ndim:
209-
difference = core_dimension - input.type.ndim
210-
core_inputs.append(
211-
DimShuffle(
212-
input.type.broadcastable,
213-
list(range(input.type.ndim)) + ["x"] * difference,
214-
)(input)
215-
)
219+
for _inp, _inp_sig in zip(inputs, self.signature[0]):
220+
curr_dtype = _inp.type.dtype
221+
# Extract the static shape values of the core dimensions in the
222+
# signature. Doing so will produce a much more precise
223+
# `TensorType`.
224+
curr_static_shape = _inp.type.shape[_inp.type.ndim - len(_inp_sig) :]
225+
core_inputs.append(TensorType(curr_dtype, curr_static_shape)())
226+
227+
# TODO: This shouldn't be necessary; `Op.make_node` doesn't call
228+
# `compute_test_value`, only `Op.__call__` does.
229+
with aesara.config.change_flags(compute_test_value="off"):
230+
core_outputs: Sequence[Variable] = self.op.make_node(*core_inputs).outputs
231+
232+
return core_inputs, core_outputs
233+
234+
def get_output_info(self, *inputs):
235+
r"""Return the outputs dtype and broadcastable pattern and the `DimShuffle`\d inputs.
236+
237+
Parameters
238+
==========
239+
inputs
240+
The blocked inputs (i.e. "broadcasted" inputs).
241+
"""
242+
243+
# Ensure that all blocked inputs have the same number of core
244+
# dimensions
245+
blocked_inputs = []
246+
for inp, signature in zip(inputs, self.signature[0]):
247+
core_ndim = len(signature)
248+
difference = core_ndim - inp.type.ndim
249+
250+
# Do we need to _add_ core dimensions?
251+
if difference > 0:
252+
core_inp = DimShuffle(
253+
inp.type.broadcastable,
254+
list(range(inp.type.ndim)) + ["x"] * difference,
255+
)(inp)
216256
else:
217-
core_inputs.append(input)
257+
core_inp = inp
218258

219-
# remove the core dimension first the then broadcast the rest of the dimension
259+
blocked_inputs.append(core_inp)
260+
261+
# Remove the core dimension first, then broadcast the rest of the
262+
# dimensions
220263
max_loop_dimension = max(
221-
core_inputs[i].type.ndim - len(self.signature[0][i])
222-
for i in range(len(core_inputs))
264+
blocked_inputs[i].type.ndim - len(self.signature[0][i])
265+
for i in range(len(blocked_inputs))
223266
)
224267

268+
# Normalize the inputs by adding missing broadcast dimensions
225269
broadcasted_inputs = []
226-
for input, signature in zip(core_inputs, self.signature[0]):
227-
core_dimension = len(signature)
228-
loop_dimension = input.type.ndim - core_dimension
270+
for inp, signature in zip(blocked_inputs, self.signature[0]):
271+
core_ndim = len(signature)
272+
loop_dimension = inp.type.ndim - core_ndim
229273
difference = max_loop_dimension - loop_dimension
274+
assert difference >= 0
230275

231-
if difference == 0:
232-
broadcasted_inputs.append(input)
276+
if difference > 0:
277+
bcast_inp = DimShuffle(
278+
inp.type.broadcastable,
279+
["x"] * difference + list(range(inp.type.ndim)),
280+
)(inp)
233281
else:
234-
broadcasted_inputs.append(
235-
DimShuffle(
236-
input.type.broadcastable,
237-
["x"] * difference + list(range(input.type.ndim)),
238-
)(input)
239-
)
240-
inputs = broadcasted_inputs
241-
242-
shadow = self.op.make_node(*inputs)
243-
out_dtypes = [o.type.dtype for o in shadow.outputs]
244-
245-
bcast_shape, dim_sizes = _parse_input_dimensions(inputs, self.signature[0])
282+
bcast_inp = inp
283+
284+
broadcasted_inputs.append(bcast_inp)
285+
286+
_, core_outputs = self.get_core_inputs_outputs(broadcasted_inputs)
287+
out_dtypes = [o.type.dtype for o in core_outputs]
288+
289+
bcast_shape, dim_sizes = _parse_input_dimensions(
290+
broadcasted_inputs, self.signature[0]
291+
)
246292
output_shapes = _calculate_shapes(bcast_shape, dim_sizes, self.signature[1])
247293

248-
return out_dtypes, output_shapes, inputs
294+
return out_dtypes, output_shapes, broadcasted_inputs
249295

250296
def make_node(self, *inputs):
297+
"""
298+
Parameters
299+
==========
300+
inputs
301+
The blocked inputs (i.e. "broadcasted" inputs).
302+
"""
251303
num_expected_inps = len(self.signature[0])
252304
if len(inputs) != num_expected_inps:
253305
raise ValueError(
@@ -256,14 +308,10 @@ def make_node(self, *inputs):
256308

257309
out_dtypes, output_shapes, inputs = self.get_output_info(*inputs)
258310

259-
def safe_const_val(x):
260-
try:
261-
return get_scalar_constant_value(x)
262-
except NotScalarConstantError:
263-
return None
264-
265311
outputs = [
266-
TensorType(out_dtypes[i], shape=tuple(safe_const_val(s) for s in output_shapes[i]))()
312+
TensorType(
313+
out_dtypes[i], shape=tuple(safe_const_val(s) for s in output_shapes[i])
314+
)()
267315
for i in range(len(output_shapes))
268316
]
269317
return Apply(self, list(inputs), outputs)
@@ -280,7 +328,8 @@ def L_op(self, inputs, outs, ograds):
280328
# Compute grad with respect to broadcasted input
281329
rval = self._bgrad(inputs, outs, ograds)
282330

283-
# sum out the broadcasted dimensions
331+
# TODO: This is very broken. See #1089.
332+
# Sum out the broadcasted dimensions
284333
for i, ipt in enumerate(inputs):
285334
if isinstance(rval[i].type, (NullType, DisconnectedType)):
286335
continue
@@ -298,22 +347,19 @@ def L_op(self, inputs, outs, ograds):
298347
sr = at_sum(rval[i], axis=to_sum, keepdims=True)
299348
rval[i] = sr
300349

350+
for inp, grad in zip(inputs, rval):
351+
assert inp.ndim == grad.ndim
352+
301353
return rval
302354

303355
def _bgrad(
304356
self,
305-
inputs: Sequence[Variable],
306-
outputs: Sequence[Variable],
307-
ograds: Sequence[Variable],
357+
inputs: Sequence["TensorVariable"],
358+
outputs: Sequence["TensorVariable"],
359+
ograds: Sequence["TensorVariable"],
308360
):
309-
310361
with aesara.config.change_flags(compute_test_value="off"):
311-
core_inputs = []
312-
for _inp, _inp_sig in zip(inputs, self.signature[0]):
313-
curr_dtype = _inp.type.dtype
314-
# extract the core dimensions
315-
curr_static_shape = _inp.type.shape[-len(_inp_sig) :]
316-
core_inputs.append(TensorType(curr_dtype, curr_static_shape)())
362+
core_inputs, core_outputs = self.get_core_inputs_outputs(inputs)
317363

318364
core_out_grads = []
319365
for _out_grad, _out_sig in zip(ograds, self.signature[1]):
@@ -322,24 +368,28 @@ def _bgrad(
322368
curr_static_shape = _out_grad.type.shape[start_idx:]
323369
core_out_grads.append(TensorType(curr_dtype, curr_static_shape)())
324370

325-
core_outputs: Sequence[Variable] = self.op.make_node(*core_inputs).outputs
326371
core_inp_grads = self.op.L_op(core_inputs, core_outputs, core_out_grads)
327372

328373
for igrad in core_inp_grads:
329374
assert igrad is not None, self.op
330375

331-
def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable:
376+
def transform(
377+
var: "TensorVariable", client_node: Optional[Apply]
378+
) -> "TensorVariable":
332379
"""Walk a graph and expand single gradient \"block\"s into their block-wise equivalents."""
333380

334381
if isinstance(var.type, (NullType, DisconnectedType)):
335382
return var
336383

337384
if var in core_inputs:
338-
return inputs[core_inputs.index(var)]
339-
if var in core_outputs:
340-
return outputs[core_outputs.index(var)]
341-
if var in core_out_grads:
342-
return ograds[core_out_grads.index(var)]
385+
idx: int = core_inputs.index(var)
386+
return inputs[idx]
387+
elif var in core_outputs:
388+
idx = core_outputs.index(var)
389+
return outputs[idx]
390+
elif var in core_out_grads:
391+
idx = core_out_grads.index(var)
392+
return ograds[idx]
343393

344394
node = var.owner
345395
if node is None:
@@ -362,7 +412,7 @@ def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable:
362412

363413
assert isinstance(new_r, Variable)
364414

365-
return new_r
415+
return cast("TensorVariable", new_r)
366416

367417
ret = []
368418
for core_inp_grad, ipt in zip(core_inp_grads, inputs):

0 commit comments

Comments
 (0)