@@ -1987,16 +1987,13 @@ def read_usage_env(_: core.Var):
19871987def _custom_call_hi_primitive_pull_block_spec_rule (
19881988 ctx : PullRuleContext , out_block_specs , * , prim
19891989):
1990- del ctx
1991- return prim .pull_block_spec_rule (out_block_specs )
1992-
1990+ return prim .pull_block_spec_rule (ctx , out_block_specs )
19931991
19941992@register_eval_rule (hijax .call_hi_primitive_p )
19951993def _custom_call_hi_primitive_eval_rule (
1996- ctx : KernelEvalContext , x , * , prim
1994+ ctx : KernelEvalContext , * args , prim
19971995):
1998- del ctx
1999- return prim .expand (x )
1996+ return jax .tree .leaves (prim .block_eval_rule (ctx , * args ))
20001997
20011998
20021999@functools .partial (api_boundary , repro_api_name = "fuser.push_block_spec" )
@@ -2229,6 +2226,13 @@ def _custom_vjp_call_push_rule(
22292226 del ctx , num_consts , fwd_jaxpr_thunk , bwd , out_trees , symbolic_zeros
22302227 return _push_block_spec_jaxpr (call_jaxpr .jaxpr , * block_specs )
22312228
2229+ @register_push_block_spec_rule (hijax .call_hi_primitive_p )
2230+ def _custom_call_hi_primitive_push_block_spec_rule (
2231+ ctx : PullRuleContext , * block_specs , prim
2232+ ):
2233+ return prim .push_block_spec_rule (ctx , block_specs )
2234+
2235+
22322236
22332237@register_push_block_spec_rule (pjit .jit_p )
22342238def _pjit_push_rule (ctx , * block_specs , jaxpr : core .ClosedJaxpr , ** _ ):
0 commit comments