Skip to content

Commit 31af95f

Browse files
Update pull/push rule API for call_hi_primitive_p
PiperOrigin-RevId: 831036071
1 parent 79f5078 commit 31af95f

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

jax/_src/pallas/fuser/block_spec.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,16 +1959,13 @@ def read_usage_env(_: core.Var):
19591959
def _custom_call_hi_primitive_pull_block_spec_rule(
19601960
ctx: PullRuleContext, out_block_specs, *, prim
19611961
):
1962-
del ctx
1963-
return prim.pull_block_spec_rule(out_block_specs)
1964-
1962+
return prim.pull_block_spec_rule(ctx, out_block_specs)
19651963

19661964
@register_eval_rule(hijax.call_hi_primitive_p)
19671965
def _custom_call_hi_primitive_eval_rule(
1968-
ctx: KernelEvalContext, x, *, prim
1966+
ctx: KernelEvalContext, *args, prim
19691967
):
1970-
del ctx
1971-
return prim.expand(x)
1968+
return jax.tree.leaves(prim.block_eval_rule(ctx, *args))
19721969

19731970

19741971
@functools.partial(api_boundary, repro_api_name="fuser.push_block_spec")
@@ -2201,6 +2198,13 @@ def _custom_vjp_call_push_rule(
22012198
del ctx, num_consts, fwd_jaxpr_thunk, bwd, out_trees, symbolic_zeros
22022199
return _push_block_spec_jaxpr(call_jaxpr.jaxpr, *block_specs)
22032200

2201+
@register_push_block_spec_rule(hijax.call_hi_primitive_p)
2202+
def _custom_call_hi_primitive_push_block_spec_rule(
2203+
ctx: PullRuleContext, *block_specs, prim
2204+
):
2205+
return prim.push_block_spec_rule(ctx, block_specs)
2206+
2207+
22042208

22052209
@register_push_block_spec_rule(pjit.jit_p)
22062210
def _pjit_push_rule(ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_):

0 commit comments

Comments
 (0)