-
Notifications
You must be signed in to change notification settings - Fork 23
Description
JAX works in single precision by default, and won't even let you create double precision arrays unless an environment variable is set (JAX_ENABLE_X64=True) or a special command is run when jax is imported (config.update("jax_enable_x64", True)). In order to test double precision, these commands are used various places in the tests.
Unfortunately, enabling double precision also makes it the default for new arrays, creating situations where tests have different behavior when run on their own versus in the whole suite (because the config is "sticky" and setting it in one test affects others).
All of this may change in a future JAX release (jax-ml/jax#8178), but for now, I propose running all tests with JAX_ENABLE_X64=True JAX_DEFAULT_DTYPE_BITS=32 and removing any config.updates in test files.