Skip to content

Commit 02911cf

Browse files
Update pull/push rule API for call_hi_primitive_p
PiperOrigin-RevId: 831036071
1 parent d21998e commit 02911cf

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

jax/_src/pallas/fuser/block_spec.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,16 +1988,13 @@ def read_usage_env(_: core.Var):
19881988
def _custom_call_hi_primitive_pull_block_spec_rule(
19891989
ctx: PullRuleContext, out_block_specs, *, prim
19901990
):
1991-
del ctx
1992-
return prim.pull_block_spec_rule(out_block_specs)
1993-
1991+
return prim.pull_block_spec_rule(ctx, out_block_specs)
19941992

19951993
@register_eval_rule(hijax.call_hi_primitive_p)
19961994
def _custom_call_hi_primitive_eval_rule(
1997-
ctx: KernelEvalContext, x, *, prim
1995+
ctx: KernelEvalContext, *args, prim
19981996
):
1999-
del ctx
2000-
return prim.expand(x)
1997+
return jax.tree.leaves(prim.block_eval_rule(ctx, *args))
20011998

20021999

20032000
@functools.partial(api_boundary, repro_api_name="fuser.push_block_spec")
@@ -2230,6 +2227,13 @@ def _custom_vjp_call_push_rule(
22302227
del ctx, num_consts, fwd_jaxpr_thunk, bwd, out_trees, symbolic_zeros
22312228
return _push_block_spec_jaxpr(call_jaxpr.jaxpr, *block_specs)
22322229

2230+
@register_push_block_spec_rule(hijax.call_hi_primitive_p)
2231+
def _custom_call_hi_primitive_push_block_spec_rule(
2232+
ctx: PullRuleContext, *block_specs, prim
2233+
):
2234+
return prim.push_block_spec_rule(ctx, block_specs)
2235+
2236+
22332237

22342238
@register_push_block_spec_rule(pjit.jit_p)
22352239
def _pjit_push_rule(ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_):

0 commit comments

Comments
 (0)