@@ -296,6 +296,34 @@ def wrapped(*args, **kwargs):
296296 return wrapped
297297
298298
299+ def _block_dim_equal (
300+ b1 : int | pallas_core .BlockDim | None , b2 : int | pallas_core .BlockDim | None
301+ ) -> bool :
302+ block_size1 = pallas_core .get_block_size (b1 )
303+ block_size2 = pallas_core .get_block_size (b2 )
304+ match (b1 , b2 ):
305+ case (None , _) | (_, None ):
306+ return b1 == b2
307+ case (
308+ (pallas_core .Blocked (), int ())
309+ | (int (), pallas_core .Blocked ())
310+ | (pallas_core .Blocked (), pallas_core .Blocked ())
311+ | (int (), int ())
312+ ):
313+ return block_size1 == block_size2
314+ case _:
315+ return type (b1 ) == type (b2 ) and (block_size1 == block_size2 )
316+
317+
318+ def _block_shapes_equal (
319+ bs1 : tuple [int | pallas_core .BlockDim | None ] | None ,
320+ bs2 : tuple [int | pallas_core .BlockDim | None ] | None ,
321+ ) -> bool :
322+ if bs1 is None or bs2 is None :
323+ return bs1 == bs2
324+ return all (_block_dim_equal (b1 , b2 ) for b1 , b2 in zip (bs1 , bs2 ))
325+
326+
299327def _pull_block_spec (
300328 jaxpr : core .Jaxpr ,
301329 out_block_specs : tuple [pallas_core .BlockSpec , ...],
@@ -393,7 +421,8 @@ def _scalar_prefetch_fn(jaxpr):
393421 if (
394422 not isinstance (v , core .Literal )
395423 and v in env
396- and env [v ].block_shape != in_block_spec .block_shape
424+ and not _block_shapes_equal (env [v ].block_shape ,
425+ in_block_spec .block_shape )
397426 ):
398427 in_block_spec = pallas_core .BlockSpec (_illegal , _illegal ) # pytype: disable=wrong-arg-types
399428 _write_block_spec (v , in_block_spec )
@@ -436,7 +465,7 @@ def make_kernel_function(
436465 def _remove_nones (
437466 shape : tuple [pallas_core .BlockDim | int | None , ...] | None ,
438467 ) -> tuple [int , ...]:
439- assert shape is not None
468+ assert isinstance ( shape , tuple )
440469 new_shape = tuple (_block_size (s ) for s in shape )
441470 return tuple (s for s in new_shape if s is not None )
442471
0 commit comments