Skip to content
Discussion options

You must be logged in to vote

Thanks for the repro! This is behaving as expected.

When you wrap a function in vmap, it means that the function effectively operates over a single batch of the input. Your input x is two-dimensional, and so each batch is logically a one-dimensional vector. When you write x[:, 0] within your function, you are attempting a two-dimensional indexing operation on a one-dimensional array, which leads to the error you're seeing.

To fix this, you'll either need to (1) not wrap your function in vmap, or (2) rewrite your function so that it accepts one-dimensional batches.

You can read more about vmap and automatic vectorization at https://docs.jax.dev/en/latest/automatic-vectorization.html.

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
3 replies
@syedzayyan
Comment options

@syedzayyan
Comment options

@jakevdp
Comment options

Answer selected by syedzayyan
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants