@@ -174,11 +174,15 @@ def local_dimshuffle_alloc(fgraph, node):
174174def local_dimshuffle_subtensor (fgraph , node ):
175175 """If a subtensor is inside a dimshuffle which only drop
176176 broadcastable dimensions, scrap the dimshuffle and index the
177- subtensor with 0
177+ subtensor in a way that avoids the degenerate dimension
178178
179179 x[i:j, :, k:l].dimshuffle(0, 2) =>
180180 x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
181181
182+ x[i:j, k:l, :].dimshuffle(0, 2) => x[i:j, k, :]
183+ x[i:j, k:, :].dimshuffle(0, 2) => x[i:j, k, :]
184+ x[i:j, :l, :].dimshuffle(0, 2) => x[i:j, 0, :]
185+
182186 """
183187 if isinstance (node .op , DimShuffle ) and node .inputs [0 ].owner :
184188 # the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
@@ -217,24 +221,40 @@ def local_dimshuffle_subtensor(fgraph, node):
217221 new_idx_list = list (input_ .owner .op .idx_list )
218222 new_inputs = [input_ .owner .inputs [0 ]]
219223 zero = constant (0 )
220- slice_attr_list = ["start" , "stop" , "step" ]
221224 j = 0
222225 slice_i = - 1
223226 subtensor_removed_dims = 0
224227 for i , idx in enumerate (input_ .owner .op .idx_list ):
225228 if isinstance (idx , slice ):
226- past_j = j
227229 slice_i += 1
228- for slice_attr in slice_attr_list :
229- if getattr (idx , slice_attr ) is not None :
230- new_inputs += [input_ .owner .inputs [1 + j ]]
231- j += 1
232- # if past_j == j indicates a slice(None, None, None),
233- # that's where we want to index with 0 if it is also at
234- # the same spot of a missing dim
235- if past_j == j and slice_i in missing_dims :
236- new_idx_list [i ] = zero
237- new_inputs += [zero ]
230+ if slice_i in missing_dims :
231+ # Missing dim is a slice(None), remove by indexing by 0
232+ if idx == slice (None ):
233+ new_idx_list [i ] = zero
234+ new_inputs += [zero ]
235+ # Missing dim is an ordinary slice with known output dim length of 1
236+ # Remove by indexing by start
237+ else :
238+ if idx .start is None :
239+ start = zero
240+ else :
241+ start = input_ .owner .inputs [1 + j ]
242+ j += 1
243+ new_idx_list [i ] = start
244+ new_inputs += [start ]
245+
246+ # Ignore useless stop and step input if there is one
247+ for slice_attr in ("stop" , "step" ):
248+ if getattr (idx , slice_attr ) is not None :
249+ j += 1
250+
251+ # Keep non-dropped slice inputs
252+ else :
253+ for slice_attr in ("start" , "stop" , "step" ):
254+ if getattr (idx , slice_attr ) is not None :
255+ new_inputs += [input_ .owner .inputs [1 + j ]]
256+ j += 1
257+ # Keep non-dropped non-slice inputs
238258 else :
239259 new_inputs += [input_ .owner .inputs [1 + j ]]
240260 j += 1
0 commit comments