@@ -1959,16 +1959,13 @@ def read_usage_env(_: core.Var):
19591959def _custom_call_hi_primitive_pull_block_spec_rule (
19601960 ctx : PullRuleContext , out_block_specs , * , prim
19611961):
1962- del ctx
1963- return prim .pull_block_spec_rule (out_block_specs )
1964-
1962+ return prim .pull_block_spec_rule (ctx , out_block_specs )
19651963
19661964@register_eval_rule (hijax .call_hi_primitive_p )
19671965def _custom_call_hi_primitive_eval_rule (
1968- ctx : KernelEvalContext , x , * , prim
1966+ ctx : KernelEvalContext , * args , prim
19691967):
1970- del ctx
1971- return prim .expand (x )
1968+ return jax .tree .leaves (prim .block_eval_rule (ctx , * args ))
19721969
19731970
19741971@functools .partial (api_boundary , repro_api_name = "fuser.push_block_spec" )
@@ -2201,6 +2198,13 @@ def _custom_vjp_call_push_rule(
22012198 del ctx , num_consts , fwd_jaxpr_thunk , bwd , out_trees , symbolic_zeros
22022199 return _push_block_spec_jaxpr (call_jaxpr .jaxpr , * block_specs )
22032200
2201+ @register_push_block_spec_rule (hijax .call_hi_primitive_p )
2202+ def _custom_call_hi_primitive_push_block_spec_rule (
2203+ ctx : PullRuleContext , * block_specs , prim
2204+ ):
2205+ return prim .push_block_spec_rule (ctx , block_specs )
2206+
2207+
22042208
22052209@register_push_block_spec_rule (pjit .jit_p )
22062210def _pjit_push_rule (ctx , * block_specs , jaxpr : core .ClosedJaxpr , ** _ ):
0 commit comments