From 3f79f1c4b857ac86942b07f459d5c9b6ab9c52e6 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Wed, 23 Apr 2025 23:02:49 +0000 Subject: [PATCH 1/2] Compatibility for 2.7 release --- torchprime/torch_xla_models/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 944dbaa9..61c6a693 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -19,9 +19,9 @@ import transformers from datasets import load_dataset from omegaconf import DictConfig, OmegaConf +from packaging import version from torch import nn from torch.utils.data import DataLoader, Dataset, IterableDataset -from torch_xla._internal.jax_workarounds import jax_env_context from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear from transformers import ( @@ -44,6 +44,11 @@ from torchprime.torch_xla_models import offloading, remat_all, scan_layers from torchprime.torch_xla_models.topology import get_mesh, is_1d_sharding +if version.parse(torch_xla.__version__.split("+")[0]) >= version.parse("2.8.0"): + from torch_xla._internal.jax_workarounds import jax_env_context +else: + from torch_xla.experimental.custom_kernel import jax_env_context + check_min_version("4.39.3") logger = logging.getLogger(__name__) From e5f97da6d7e860ce5d8bb1dd92f12bbb074d2bc5 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Thu, 24 Apr 2025 00:26:01 +0000 Subject: [PATCH 2/2] fix --- torchprime/torch_xla_models/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchprime/torch_xla_models/train.py b/torchprime/torch_xla_models/train.py index 61c6a693..7b85f6e7 100644 --- a/torchprime/torch_xla_models/train.py +++ b/torchprime/torch_xla_models/train.py @@ -47,7 +47,7 @@ if version.parse(torch_xla.__version__.split("+")[0]) >= version.parse("2.8.0"): from torch_xla._internal.jax_workarounds import jax_env_context else: - from torch_xla.experimental.custom_kernel import jax_env_context + from torch_xla.experimental.custom_kernel import _jax_env_context as jax_env_context check_min_version("4.39.3") logger = logging.getLogger(__name__)