Replies: 4 comments 4 replies
-
Okay I think I have managed to create a reproducible example of code that breaks. It appears that are still issues with jax.lax.scan, because the following code: import jax
import jax.numpy as jnp
@jax.jit
def bar(x, w):
def scan_fn(x, _):
c = jnp.array([])
x = w[...] @ x
x = jnp.concatenate([x, c], axis=-1)
return x, None
x, _ = jax.lax.scan(scan_fn, x, None, length=10)
return x
@jax.jit
def foo(w):
return bar(
jnp.zeros((1,)),
w,
)
foo(jax.array_ref(jnp.eye(1))) Gives the error: ValueError: Curiously, using jnp.zeros((0,)) instead of jnp.array([]) fixes the issue. I will open a bug report. |
Beta Was this translation helpful? Give feedback.
-
I have created #32399 |
Beta Was this translation helpful? Give feedback.
-
Thanks so much for flagging this! |
Beta Was this translation helpful? Give feedback.
-
Thanks for patching this so quickly! I can confirm that I no longer have issues with the example I posted above, but am still running into issues with the full model (the same error and the jaxpr involves consts so I assume a similar bug elsewhere). I will track down the issue and open a new bug when I can reproduce with a simple example. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
When using nnx with the new array_ref API, it appears there are some bugs with some jax transformations as I get
ValueError:
JaxprInputEffect
Read<5> is invalid.followed by an extremely large jaxpr. It doesn't happen with all nnx models that I use and I am struggling to produce a reproducible minimal example I can post here--I don't have any issues when using the old-style nnx API.
I also sometimes get a similar-looking error that suggests not enough inputs to a jaxpr for a particular Read<> effect.
I suspect this is an internal bug with some jax transformation that I am doing and how it interacts with array_refs (i.e. something similar to #16370 potentially, although that specific issue has been marked as closed)
Does anyone have suggestions as to how I can debug this in a manageable manner? Is there some flag I can enable for super-stringent jaxpr validation that will catch which transformation is causing the issue?
Beta Was this translation helpful? Give feedback.
All reactions