Skip to content

Commit 9c05cd4

Browse files
Update pull/push rule API for call_hi_primitive_p
PiperOrigin-RevId: 831036071
1 parent 99a5e0d commit 9c05cd4

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

jax/_src/pallas/fuser/block_spec.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,16 +1987,13 @@ def read_usage_env(_: core.Var):
19871987
def _custom_call_hi_primitive_pull_block_spec_rule(
19881988
ctx: PullRuleContext, out_block_specs, *, prim
19891989
):
1990-
del ctx
1991-
return prim.pull_block_spec_rule(out_block_specs)
1992-
1990+
return prim.pull_block_spec_rule(ctx, out_block_specs)
19931991

19941992
@register_eval_rule(hijax.call_hi_primitive_p)
19951993
def _custom_call_hi_primitive_eval_rule(
1996-
ctx: KernelEvalContext, x, *, prim
1994+
ctx: KernelEvalContext, *args, prim
19971995
):
1998-
del ctx
1999-
return prim.expand(x)
1996+
return jax.tree.leaves(prim.block_eval_rule(ctx, *args))
20001997

20011998

20021999
@functools.partial(api_boundary, repro_api_name="fuser.push_block_spec")
@@ -2229,6 +2226,13 @@ def _custom_vjp_call_push_rule(
22292226
del ctx, num_consts, fwd_jaxpr_thunk, bwd, out_trees, symbolic_zeros
22302227
return _push_block_spec_jaxpr(call_jaxpr.jaxpr, *block_specs)
22312228

2229+
@register_push_block_spec_rule(hijax.call_hi_primitive_p)
2230+
def _custom_call_hi_primitive_push_block_spec_rule(
2231+
ctx: PullRuleContext, *block_specs, prim
2232+
):
2233+
return prim.push_block_spec_rule(ctx, block_specs)
2234+
2235+
22322236

22332237
@register_push_block_spec_rule(pjit.jit_p)
22342238
def _pjit_push_rule(ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_):

0 commit comments

Comments
 (0)