Skip to content

Conversation

@carlosgmartin
Copy link
Contributor

Edit tree_min and tree_max to correctly handle pytrees that contain leafs that are zero-size arrays. For example:

Before:

$ py -c "import optax, jax; print(optax.tree.min([3, jax.numpy.zeros(0)]))"
ValueError: zero-size array to reduction operation min which has no identity

After:

$ py -c "import optax, jax; print(optax.tree.min([3, jax.numpy.zeros(0)]))"
3.0

@carlosgmartin carlosgmartin force-pushed the tree_min_zero_size_arrays branch 3 times, most recently from da7b694 to cea2e24 Compare July 11, 2025 19:12
@carlosgmartin carlosgmartin force-pushed the tree_min_zero_size_arrays branch 2 times, most recently from 49024a8 to 0086b75 Compare July 15, 2025 17:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants