Skip to content

Commit 7774f56

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas/Fuser] Relax block spec shape check by doing a better comparison
Fixes false positive fusion failures PiperOrigin-RevId: 831907364
1 parent 79f5078 commit 7774f56

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

jax/_src/pallas/fuser/block_spec.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,34 @@ def wrapped(*args, **kwargs):
296296
return wrapped
297297

298298

299+
def _block_dim_equal(
300+
b1: int | pallas_core.BlockDim | None, b2: int | pallas_core.BlockDim | None
301+
) -> bool:
302+
block_size1 = pallas_core.get_block_size(b1)
303+
block_size2 = pallas_core.get_block_size(b2)
304+
match (b1, b2):
305+
case (None, _) | (_, None):
306+
return b1 == b2
307+
case (
308+
(pallas_core.Blocked(), int())
309+
| (int(), pallas_core.Blocked())
310+
| (pallas_core.Blocked(), pallas_core.Blocked())
311+
| (int(), int())
312+
):
313+
return block_size1 == block_size2
314+
case _:
315+
return type(b1) == type(b2) and (block_size1 == block_size2)
316+
317+
318+
def _block_shapes_equal(
319+
bs1: tuple[int | pallas_core.BlockDim | None] | None,
320+
bs2: tuple[int | pallas_core.BlockDim | None] | None,
321+
) -> bool:
322+
if bs1 is None or bs2 is None:
323+
return bs1 == bs2
324+
return all(_block_dim_equal(b1, b2) for b1, b2 in zip(bs1, bs2))
325+
326+
299327
def _pull_block_spec(
300328
jaxpr: core.Jaxpr,
301329
out_block_specs: tuple[pallas_core.BlockSpec, ...],
@@ -393,7 +421,8 @@ def _scalar_prefetch_fn(jaxpr):
393421
if (
394422
not isinstance(v, core.Literal)
395423
and v in env
396-
and env[v].block_shape != in_block_spec.block_shape
424+
and not _block_shapes_equal(env[v].block_shape,
425+
in_block_spec.block_shape)
397426
):
398427
in_block_spec = pallas_core.BlockSpec(_illegal, _illegal) # pytype: disable=wrong-arg-types
399428
_write_block_spec(v, in_block_spec)
@@ -436,7 +465,7 @@ def make_kernel_function(
436465
def _remove_nones(
437466
shape: tuple[pallas_core.BlockDim | int | None, ...] | None,
438467
) -> tuple[int, ...]:
439-
assert shape is not None
468+
assert isinstance(shape, tuple)
440469
new_shape = tuple(_block_size(s) for s in shape)
441470
return tuple(s for s in new_shape if s is not None)
442471

0 commit comments

Comments
 (0)