@@ -1988,16 +1988,13 @@ def read_usage_env(_: core.Var):
19881988def _custom_call_hi_primitive_pull_block_spec_rule (
19891989 ctx : PullRuleContext , out_block_specs , * , prim
19901990):
1991- del ctx
1992- return prim .pull_block_spec_rule (out_block_specs )
1993-
1991+ return prim .pull_block_spec_rule (ctx , out_block_specs )
19941992
19951993@register_eval_rule (hijax .call_hi_primitive_p )
19961994def _custom_call_hi_primitive_eval_rule (
1997- ctx : KernelEvalContext , x , * , prim
1995+ ctx : KernelEvalContext , * args , prim
19981996):
1999- del ctx
2000- return prim .expand (x )
1997+ return jax .tree .leaves (prim .block_eval_rule (ctx , * args ))
20011998
20021999
20032000@functools .partial (api_boundary , repro_api_name = "fuser.push_block_spec" )
@@ -2230,6 +2227,13 @@ def _custom_vjp_call_push_rule(
22302227 del ctx , num_consts , fwd_jaxpr_thunk , bwd , out_trees , symbolic_zeros
22312228 return _push_block_spec_jaxpr (call_jaxpr .jaxpr , * block_specs )
22322229
2230+ @register_push_block_spec_rule (hijax .call_hi_primitive_p )
2231+ def _custom_call_hi_primitive_push_block_spec_rule (
2232+ ctx : PullRuleContext , * block_specs , prim
2233+ ):
2234+ return prim .push_block_spec_rule (ctx , block_specs )
2235+
2236+
22332237
22342238@register_push_block_spec_rule (pjit .jit_p )
22352239def _pjit_push_rule (ctx , * block_specs , jaxpr : core .ClosedJaxpr , ** _ ):
0 commit comments