1717from jax ._src import core
1818from jax ._src import dispatch
1919from jax ._src import linear_util as lu
20- from jax ._src .api_util import debug_info , flatten_fun_nokwargs
20+ from jax ._src .api_util import debug_info , flatten_fun
2121from jax ._src .util import (safe_map , safe_zip , weakref_lru_cache , unzip2 ,
2222 split_list )
2323from jax ._src .tree_util import tree_flatten , tree_unflatten
@@ -38,9 +38,9 @@ def xla_metadata_call(f=None, **meta):
3838
3939# TODO(yashkatariya): Figure out a way to reuse code with compute_on2_p, fused_p
4040def _xla_metadata_call (f , ** meta ):
41- def wrapped (* args ):
42- dbg = debug_info ('xla_metadata_call' , f , args , {} )
43- args_flat , in_tree = tree_flatten (args )
41+ def wrapped (* args , ** kwargs ):
42+ dbg = debug_info ('xla_metadata_call' , f , args , kwargs )
43+ args_flat , in_tree = tree_flatten (( args , kwargs ) )
4444 in_avals = tuple (core .shaped_abstractify (x ) for x in args_flat )
4545 jaxpr , out_tree = _trace_to_jaxpr (f , in_avals , in_tree , dbg )
4646 outs_flat = xla_metadata_call_p .bind (* args_flat , jaxpr = jaxpr , ** meta )
@@ -50,7 +50,7 @@ def wrapped(*args):
5050@weakref_lru_cache
5151def _trace_to_jaxpr (fun , in_avals , in_tree , dbg ):
5252 f = lu .wrap_init (fun , debug_info = dbg )
53- f , out_tree = flatten_fun_nokwargs (f , in_tree )
53+ f , out_tree = flatten_fun (f , in_tree )
5454 jaxpr , _ , consts = pe .trace_to_jaxpr_dynamic (f , in_avals )
5555 return core .ClosedJaxpr (jaxpr , consts ), out_tree ()
5656
0 commit comments