Skip to content

Handle double precision in tests more carefully #288

@Michael-T-McCann

Description

@Michael-T-McCann

JAX works in single precision by default, and won't even let you create double precision arrays unless an environment variable is set (JAX_ENABLE_X64=True) or a special command is run when jax is imported (config.update("jax_enable_x64", True)). In order to test double precision, these commands are used various places in the tests.

Unfortunately, enabling double precision also makes it the default for new arrays, creating situations where tests have different behavior when run on their own versus in the whole suite (because the config is "sticky" and setting it in one test affects others).

All of this may change in a future JAX release (jax-ml/jax#8178), but for now, I propose running all tests with JAX_ENABLE_X64=True JAX_DEFAULT_DTYPE_BITS=32 and removing any config.updates in test files.

Metadata

Metadata

Assignees

No one assigned

    Labels

    testsPertaining to SCICO tests

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions