Skip to content

Improve jax.config.update experience in interactive environments #30474

@ryan112358

Description

@ryan112358

In colab I might write the following code in the header:

import jax

jax.config.update('jax_num_cpu_devices', 8)

Now later after writing some code I add a new import and re-run the cell (without modifying the "8")

import jax
import numpy as np

jax.config.update('jax_num_cpu_devices', 8)

This provides an error requiring me to restart my colab notebook
RuntimeError: jax_num_cpu_devices config should be updated before backends are initialized i.e. before any JAX operation is executed. You should initialize this config immediately after import jax.

Since jax_num_cpu_devices didn't change, it would be nice if this pattern worked. Failure if the value actually changes seems totally reasonable to me. I'm not sure what other config options this might apply to, but this is one I use pretty frequently.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions