@@ -189,65 +189,117 @@ def infer_shape_to_gufunc_sig(node: Apply, fgraph: Optional["FunctionGraph"] = N
189
189
return (gufunc_inputs_sig , gufunc_outputs_sig )
190
190
191
191
192
+ def safe_const_val (x ):
193
+ try :
194
+ return get_scalar_constant_value (x )
195
+ except NotScalarConstantError :
196
+ return None
197
+
198
+
192
199
class Blockwise (Op ):
193
200
__props__ = ("op" , "signature" )
194
201
195
202
def __init__ (self , op , signature = None ):
196
203
self .op = op
197
204
self .signature = signature or self .op .gufunc_sig
198
205
199
- def get_output_info (self , * inputs ):
200
- """Return the outputs dtype and broadcastable pattern and the
201
- dimshuffled inputs.
206
+ def get_core_inputs_outputs (self , inputs : Sequence ["TensorVariable" ]):
207
+ """Get the core inputs and outputs for a given set of `inputs`.
208
+
209
+ Parameters
210
+ ==========
211
+ inputs
212
+ The normalized, blocked inputs (i.e. "broadcasted" inputs with all
213
+ the necessary dimensions added). They're needed for their dtype
214
+ and static shape information.
202
215
203
216
"""
204
- # ensure that all inputs have the code dimensions
217
+
205
218
core_inputs = []
206
- for input , signature in zip (inputs , self .signature [0 ]):
207
- core_dimension = len (signature )
208
- if core_dimension > input .type .ndim :
209
- difference = core_dimension - input .type .ndim
210
- core_inputs .append (
211
- DimShuffle (
212
- input .type .broadcastable ,
213
- list (range (input .type .ndim )) + ["x" ] * difference ,
214
- )(input )
215
- )
219
+ for _inp , _inp_sig in zip (inputs , self .signature [0 ]):
220
+ curr_dtype = _inp .type .dtype
221
+ # Extract the static shape values of the core dimensions in the
222
+ # signature. Doing so will produce a much more precise
223
+ # `TensorType`.
224
+ curr_static_shape = _inp .type .shape [_inp .type .ndim - len (_inp_sig ) :]
225
+ core_inputs .append (TensorType (curr_dtype , curr_static_shape )())
226
+
227
+ # TODO: This shouldn't be necessary; `Op.make_node` doesn't call
228
+ # `compute_test_value`, only `Op.__call__` does.
229
+ with aesara .config .change_flags (compute_test_value = "off" ):
230
+ core_outputs : Sequence [Variable ] = self .op .make_node (* core_inputs ).outputs
231
+
232
+ return core_inputs , core_outputs
233
+
234
+ def get_output_info (self , * inputs ):
235
+ r"""Return the outputs dtype and broadcastable pattern and the `DimShuffle`\d inputs.
236
+
237
+ Parameters
238
+ ==========
239
+ inputs
240
+ The blocked inputs (i.e. "broadcasted" inputs).
241
+ """
242
+
243
+ # Ensure that all blocked inputs have the same number of core
244
+ # dimensions
245
+ blocked_inputs = []
246
+ for inp , signature in zip (inputs , self .signature [0 ]):
247
+ core_ndim = len (signature )
248
+ difference = core_ndim - inp .type .ndim
249
+
250
+ # Do we need to _add_ core dimensions?
251
+ if difference > 0 :
252
+ core_inp = DimShuffle (
253
+ inp .type .broadcastable ,
254
+ list (range (inp .type .ndim )) + ["x" ] * difference ,
255
+ )(inp )
216
256
else :
217
- core_inputs . append ( input )
257
+ core_inp = inp
218
258
219
- # remove the core dimension first the then broadcast the rest of the dimension
259
+ blocked_inputs .append (core_inp )
260
+
261
+ # Remove the core dimension first, then broadcast the rest of the
262
+ # dimensions
220
263
max_loop_dimension = max (
221
- core_inputs [i ].type .ndim - len (self .signature [0 ][i ])
222
- for i in range (len (core_inputs ))
264
+ blocked_inputs [i ].type .ndim - len (self .signature [0 ][i ])
265
+ for i in range (len (blocked_inputs ))
223
266
)
224
267
268
+ # Normalize the inputs by adding missing broadcast dimensions
225
269
broadcasted_inputs = []
226
- for input , signature in zip (core_inputs , self .signature [0 ]):
227
- core_dimension = len (signature )
228
- loop_dimension = input .type .ndim - core_dimension
270
+ for inp , signature in zip (blocked_inputs , self .signature [0 ]):
271
+ core_ndim = len (signature )
272
+ loop_dimension = inp .type .ndim - core_ndim
229
273
difference = max_loop_dimension - loop_dimension
274
+ assert difference >= 0
230
275
231
- if difference == 0 :
232
- broadcasted_inputs .append (input )
276
+ if difference > 0 :
277
+ bcast_inp = DimShuffle (
278
+ inp .type .broadcastable ,
279
+ ["x" ] * difference + list (range (inp .type .ndim )),
280
+ )(inp )
233
281
else :
234
- broadcasted_inputs .append (
235
- DimShuffle (
236
- input .type .broadcastable ,
237
- ["x" ] * difference + list (range (input .type .ndim )),
238
- )(input )
239
- )
240
- inputs = broadcasted_inputs
241
-
242
- shadow = self .op .make_node (* inputs )
243
- out_dtypes = [o .type .dtype for o in shadow .outputs ]
244
-
245
- bcast_shape , dim_sizes = _parse_input_dimensions (inputs , self .signature [0 ])
282
+ bcast_inp = inp
283
+
284
+ broadcasted_inputs .append (bcast_inp )
285
+
286
+ _ , core_outputs = self .get_core_inputs_outputs (broadcasted_inputs )
287
+ out_dtypes = [o .type .dtype for o in core_outputs ]
288
+
289
+ bcast_shape , dim_sizes = _parse_input_dimensions (
290
+ broadcasted_inputs , self .signature [0 ]
291
+ )
246
292
output_shapes = _calculate_shapes (bcast_shape , dim_sizes , self .signature [1 ])
247
293
248
- return out_dtypes , output_shapes , inputs
294
+ return out_dtypes , output_shapes , broadcasted_inputs
249
295
250
296
def make_node (self , * inputs ):
297
+ """
298
+ Parameters
299
+ ==========
300
+ inputs
301
+ The blocked inputs (i.e. "broadcasted" inputs).
302
+ """
251
303
num_expected_inps = len (self .signature [0 ])
252
304
if len (inputs ) != num_expected_inps :
253
305
raise ValueError (
@@ -256,14 +308,10 @@ def make_node(self, *inputs):
256
308
257
309
out_dtypes , output_shapes , inputs = self .get_output_info (* inputs )
258
310
259
- def safe_const_val (x ):
260
- try :
261
- return get_scalar_constant_value (x )
262
- except NotScalarConstantError :
263
- return None
264
-
265
311
outputs = [
266
- TensorType (out_dtypes [i ], shape = tuple (safe_const_val (s ) for s in output_shapes [i ]))()
312
+ TensorType (
313
+ out_dtypes [i ], shape = tuple (safe_const_val (s ) for s in output_shapes [i ])
314
+ )()
267
315
for i in range (len (output_shapes ))
268
316
]
269
317
return Apply (self , list (inputs ), outputs )
@@ -280,7 +328,8 @@ def L_op(self, inputs, outs, ograds):
280
328
# Compute grad with respect to broadcasted input
281
329
rval = self ._bgrad (inputs , outs , ograds )
282
330
283
- # sum out the broadcasted dimensions
331
+ # TODO: This is very broken. See #1089.
332
+ # Sum out the broadcasted dimensions
284
333
for i , ipt in enumerate (inputs ):
285
334
if isinstance (rval [i ].type , (NullType , DisconnectedType )):
286
335
continue
@@ -298,22 +347,19 @@ def L_op(self, inputs, outs, ograds):
298
347
sr = at_sum (rval [i ], axis = to_sum , keepdims = True )
299
348
rval [i ] = sr
300
349
350
+ for inp , grad in zip (inputs , rval ):
351
+ assert inp .ndim == grad .ndim
352
+
301
353
return rval
302
354
303
355
def _bgrad (
304
356
self ,
305
- inputs : Sequence [Variable ],
306
- outputs : Sequence [Variable ],
307
- ograds : Sequence [Variable ],
357
+ inputs : Sequence ["TensorVariable" ],
358
+ outputs : Sequence ["TensorVariable" ],
359
+ ograds : Sequence ["TensorVariable" ],
308
360
):
309
-
310
361
with aesara .config .change_flags (compute_test_value = "off" ):
311
- core_inputs = []
312
- for _inp , _inp_sig in zip (inputs , self .signature [0 ]):
313
- curr_dtype = _inp .type .dtype
314
- # extract the core dimensions
315
- curr_static_shape = _inp .type .shape [- len (_inp_sig ) :]
316
- core_inputs .append (TensorType (curr_dtype , curr_static_shape )())
362
+ core_inputs , core_outputs = self .get_core_inputs_outputs (inputs )
317
363
318
364
core_out_grads = []
319
365
for _out_grad , _out_sig in zip (ograds , self .signature [1 ]):
@@ -322,24 +368,28 @@ def _bgrad(
322
368
curr_static_shape = _out_grad .type .shape [start_idx :]
323
369
core_out_grads .append (TensorType (curr_dtype , curr_static_shape )())
324
370
325
- core_outputs : Sequence [Variable ] = self .op .make_node (* core_inputs ).outputs
326
371
core_inp_grads = self .op .L_op (core_inputs , core_outputs , core_out_grads )
327
372
328
373
for igrad in core_inp_grads :
329
374
assert igrad is not None , self .op
330
375
331
- def transform (var : "TensorVariable" , client_node : Optional [Apply ]) -> Variable :
376
+ def transform (
377
+ var : "TensorVariable" , client_node : Optional [Apply ]
378
+ ) -> "TensorVariable" :
332
379
"""Walk a graph and expand single gradient \" block\" s into their block-wise equivalents."""
333
380
334
381
if isinstance (var .type , (NullType , DisconnectedType )):
335
382
return var
336
383
337
384
if var in core_inputs :
338
- return inputs [core_inputs .index (var )]
339
- if var in core_outputs :
340
- return outputs [core_outputs .index (var )]
341
- if var in core_out_grads :
342
- return ograds [core_out_grads .index (var )]
385
+ idx : int = core_inputs .index (var )
386
+ return inputs [idx ]
387
+ elif var in core_outputs :
388
+ idx = core_outputs .index (var )
389
+ return outputs [idx ]
390
+ elif var in core_out_grads :
391
+ idx = core_out_grads .index (var )
392
+ return ograds [idx ]
343
393
344
394
node = var .owner
345
395
if node is None :
@@ -362,7 +412,7 @@ def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable:
362
412
363
413
assert isinstance (new_r , Variable )
364
414
365
- return new_r
415
+ return cast ( "TensorVariable" , new_r )
366
416
367
417
ret = []
368
418
for core_inp_grad , ipt in zip (core_inp_grads , inputs ):
0 commit comments