diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 944dbaa9..7b85f6e7 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -19,9 +19,9 @@ import transformers from datasets import load_dataset from omegaconf import DictConfig, OmegaConf +from packaging import version from torch import nn from torch.utils.data import DataLoader, Dataset, IterableDataset -from torch_xla._internal.jax_workarounds import jax_env_context from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear from transformers import ( @@ -44,6 +44,11 @@ from torchprime.torch_xla_models import offloading, remat_all, scan_layers from torchprime.torch_xla_models.topology import get_mesh, is_1d_sharding +if version.parse(torch_xla.__version__.split("+")[0]) >= version.parse("2.8.0"): + from torch_xla._internal.jax_workarounds import jax_env_context +else: + from torch_xla.experimental.custom_kernel import _jax_env_context as jax_env_context + check_min_version("4.39.3") logger = logging.getLogger(__name__)