Skip to content

Commit 24e80c4

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add kwargs support to xla_metadata_call
PiperOrigin-RevId: 831033450
1 parent e9b7720 commit 24e80c4

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

jax/experimental/scheduling_groups.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from jax._src import core
1818
from jax._src import dispatch
1919
from jax._src import linear_util as lu
20-
from jax._src.api_util import debug_info, flatten_fun_nokwargs
20+
from jax._src.api_util import debug_info, flatten_fun
2121
from jax._src.util import (safe_map, safe_zip, weakref_lru_cache, unzip2,
2222
split_list)
2323
from jax._src.tree_util import tree_flatten, tree_unflatten
@@ -38,9 +38,9 @@ def xla_metadata_call(f=None, **meta):
3838

3939
# TODO(yashkatariya): Figure out a way to reuse code with compute_on2_p, fused_p
4040
def _xla_metadata_call(f, **meta):
41-
def wrapped(*args):
42-
dbg = debug_info('xla_metadata_call', f, args, {})
43-
args_flat, in_tree = tree_flatten(args)
41+
def wrapped(*args, **kwargs):
42+
dbg = debug_info('xla_metadata_call', f, args, kwargs)
43+
args_flat, in_tree = tree_flatten((args, kwargs))
4444
in_avals = tuple(core.shaped_abstractify(x) for x in args_flat)
4545
jaxpr, out_tree = _trace_to_jaxpr(f, in_avals, in_tree, dbg)
4646
outs_flat = xla_metadata_call_p.bind(*args_flat, jaxpr=jaxpr, **meta)
@@ -50,7 +50,7 @@ def wrapped(*args):
5050
@weakref_lru_cache
5151
def _trace_to_jaxpr(fun, in_avals, in_tree, dbg):
5252
f = lu.wrap_init(fun, debug_info=dbg)
53-
f, out_tree = flatten_fun_nokwargs(f, in_tree)
53+
f, out_tree = flatten_fun(f, in_tree)
5454
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(f, in_avals)
5555
return core.ClosedJaxpr(jaxpr, consts), out_tree()
5656

tests/scheduling_groups_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,22 @@ def f(x):
158158
compiled = lowered.compile()
159159
compiled(inp) # doesn't crash
160160

161+
@jtu.run_on_devices('cpu')
162+
def test_xla_metadata_call_deduplication_kwargs(self):
163+
inp = jnp.arange(8.)
164+
165+
@xla_metadata_call(inlineable='false')
166+
@jax.jit
167+
def g(x):
168+
return x * 2
169+
170+
def f(x):
171+
y = g(x=x)
172+
z = g(x=y)
173+
return z.sum()
174+
175+
f(inp) # doesn't crash
176+
161177

162178
if __name__ == '__main__':
163179
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)