-
In theory, EDIT: I went with sprinkling |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 6 replies
-
Right now, nested |
Beta Was this translation helpful? Give feedback.
-
For anyone coming back to this, For example, >>> f = lambda x: sum(i * x for i in range(256)) # common function
>>> jf = jax.jit(f)
>>> %timeit jax.jit(lambda x: sum(f(x) for _ in range(16)))(3.14)
2.79 s ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit jax.jit(lambda x: sum(jf(x) for _ in range(16)))(3.14)
36.5 ms ± 9.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) |
Beta Was this translation helpful? Give feedback.
-
Revisiting this again in 2025, and I don't think this is the case anymore with the recent pulling of pjit.cc across into jaxlib. Notably, if I def f(in_vars):
@jax.jit
def sub_fun(vars):
return jnp.some_op(vars)
x = jnp.some_op(in_vars)
y = sub_fun(x)
return y And turn on verbose logging ( You end up with many (function name e.g. annotation mine):
Notably, you get these errors on JAX's built-in functions as well - so even if you Which, looking at Lines 641 to 657 in 8ff5d33 You can produce this behaviour with this: import jax
import jax.numpy as jnp
import numpy as np
jax.config.update("jax_platform_name", "gpu")
@jax.jit
def simple_stuff(x, y):
z = jnp.matmul(x, y)
z = jnp.abs(z)
z = z**2
return jnp.sum(z)
many_x = jnp.array(np.random.random((batch_dim, 1000, 1000)))
many_y = jnp.array(np.random.random((batch_dim, 1000, 1000)))
def do_many(x, y):
res = simple_stuff(x,y) - 1
return res
v_matmul = jax.jit(jax.vmap(do_many, in_axes=(0,0)))
res = v_matmul(many_x, many_y) Which seems to indicate to me that true nested jit should be avoided - unless of course you use some of the nested functions independently elsewhere where they won't be passed Tracers as input. From https://docs.jax.dev/en/latest/faq.html#benchmarking-jax-code :
If my understanding of all of this is correct, I'd love if someone with authority can confirm/correct! |
Beta Was this translation helpful? Give feedback.
-
Revisiting in end of 2025, I'd like to point out these realizations:
My test for this was to have f(x)= exp(x), g(x) = log(x).
Here is a minimal example: from functools import partial
import jax
import jax.numpy as jnp
# Enable logging to see when compilations happen
jax.config.update("jax_log_compiles", True)
# 1. Define the inner function that does not change
@partial(jax.jit)
def inner_function(x):
print("--- Tracing and Compiling inner_function ---")
return jnp.tanh(x)
# 2. Define the outer function whose input dimensions might change
@partial(jax.jit)
def outer_function(z):
# Assume z is (N, 5) where N can vary
print("--- Tracing and Compiling outer_function ---")
# Z reduced is always (5,) regardless of N
z_reduced = jnp.sum(z, axis=0)
# Call the inner function with a fixed-size input
inner_result = inner_function(z_reduced)
# Combine with the dynamic-shaped input
return z_reduced + inner_result
# --- First Call ---
# Input shape is (20,5). Trigger both outer_function and inner_function to compile.
print("First call with shape (20,5)")
key = jax.random.PRNGKey(0)
arr1 = jnp.arange(100.0).reshape(20, 5)
result1 = outer_function(arr1)
result1.block_until_ready()
# On this first run, BOTH outer_function and inner_function will be compiled.
print("\n" + "=" * 40 + "\n")
# --- Second Call ---
# Input shape is (30,5). Only outer_function should recompile.
print("Second call with shape (30,5)")
arr2 = jnp.arange(150.0).reshape(30, 5)
result2 = outer_function(arr2)
result2.block_until_ready()
# --- Third Call ---
# Lets try (30, 6) to see if both recompile
print("\n" + "=" * 40 + "\n")
print("Third call with shape (30,6)")
arr3 = jnp.arange(180.0).reshape(30, 6)
result3 = outer_function(arr3)
result3.block_until_ready() Results in this:
Note in the second call "inner_function" is not re-traced! |
Beta Was this translation helpful? Give feedback.
Right now, nested
jit
calls will be preserved as function calls in the IR that JAX generates, but will be flattened by XLA. In the future, this may not be true any more! At some point, XLA (or another compiler) may inline less aggressively.