Replies: 2 comments 3 replies
-
Hi - thanks for the report. One debugging tool that's useful in cases like this is the Here's a quick example: import jax
import jax.numpy as jnp
jax.config.update('jax_explain_cache_misses', True)
@jax.jit
def f(x):
return x
x1 = f(jnp.arange(10))
x2 = f(jnp.arange(100)) This is the log output:
You can see that the first cache miss (i.e. jit compilation) happens because the function has not yet been seen by the compiler, and the second cache miss happens because Hopefully you can enable that config in your code and find the answer in the logs! |
Beta Was this translation helpful? Give feedback.
-
Is there a way to enable this for internal functions as well, as I don't see this for instance for things like To maybe clarify, as I'm doing something very unorthodox, I'm trying to write a custom lowering, and atm just as a test I'm hacking and replacing To clarify, the same Lowering object (as in id(self) is the same), returned from |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm seeing a strange behaviour of
jax.jit
recompiling functions in places I don't expect. Unfortunately, I don't have a simple repro to post, but roughly what I'm seeing:Beta Was this translation helpful? Give feedback.
All reactions