Indexing a Jax Tracer? #32408
-
Hi All! I am trying to index a jax tracer. Initially I looked into converting a tracer to a normal jax array and then I stumbled upon this: https://docs.jax.dev/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array . The main problem I am running into is I need this to happen: And when I checked with The function works without any qualms when not in grad, but has this tracer thing in grad mode which is understandable. Way more context, I am trying to implement a graph edge kernel in GPJax and I am running into indexing issues. Any help is much appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
From the error message, it sounds like The fix here would be to not attempt to index a 1D array or tracer with two indices.
This probably indicates that you're doing some sort of batching transformation along with the gradient. We'll need more info in order to help out – could you include a minimal reproducible example of code that shows this behavior? |
Beta Was this translation helpful? Give feedback.
-
Thank you for the explanation. This is making tonnes of sense now. |
Beta Was this translation helpful? Give feedback.
Thanks for the repro! This is behaving as expected.
When you wrap a function in
vmap
, it means that the function effectively operates over a single batch of the input. Your inputx
is two-dimensional, and so each batch is logically a one-dimensional vector. When you writex[:, 0]
within your function, you are attempting a two-dimensional indexing operation on a one-dimensional array, which leads to the error you're seeing.To fix this, you'll either need to (1) not wrap your function in
vmap
, or (2) rewrite your function so that it accepts one-dimensional batches.You can read more about
vmap
and automatic vectorization at https://docs.jax.dev/en/latest/automatic-vectorization.html.