From 7bf1a6cec568751bd01f4cd17b5e9f261ca5e377 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 19 May 2025 21:19:40 +0000 Subject: [PATCH 1/9] Deprecate `runtime.xla_device` in favor of `xla_model.xla_device` --- test/spmd/test_xla_sharding.py | 4 ++-- torch_xla/core/xla_model.py | 15 ++++++++++++--- torch_xla/experimental/scan.py | 4 ++-- torch_xla/experimental/scan_layers.py | 4 ++-- torch_xla/runtime.py | 26 +------------------------- 5 files changed, 19 insertions(+), 34 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 7fa438e2e420..f0a75dff4d52 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1484,10 +1484,10 @@ def test_xla_patched_linear(self): """ from torch_xla.distributed.spmd.xla_sharding import XLAPatchedLinear - import torch_xla.runtime + import torch_xla.core.xla_model as xm import torch.nn.functional as F - with torch_xla.runtime.xla_device(): + with xm.xla_device(): torch_xla.manual_seed(42) x0 = torch.randn(2, 3, requires_grad=True) w0 = torch.randn(4, 3, requires_grad=True) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 536b9c4115b6..9a1d0edc40bd 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -142,12 +142,12 @@ def xla_device(n: Optional[int] = None, Args: n (int, optional): The specific instance (ordinal) to be returned. If specified, the specific XLA device instance will be returned. Otherwise - the first device of `devkind` will be returned. + the first device (default 0) will be returned. devkind (string..., optional): If specified, device type such as `TPU`, `CUDA`, `CPU`, or custom PJRT device. Deprecated. Returns: - A `torch.device` with the requested instance. + A `torch.device` with the requested instance of an XLA device. """ # When SPMD is enabled, we always return `xla:0` to the user, and # under the hood we use virtual device logic for every xla tensor @@ -156,7 +156,16 @@ def xla_device(n: Optional[int] = None, torch_xla._XLAC._xla_set_default_device(device) return torch.device(device) - return runtime.xla_device(n, devkind) + if n is None: + return torch.device(torch_xla._XLAC._xla_get_default_device()) + + devices = xm.get_xla_supported_devices(devkind=devkind) + if n > len(devices): + raise IndexError('Device index {} out of range in {}'.format(n, devices)) + + device = devices[n] + torch_xla._XLAC._xla_set_default_device(device) + return torch.device(device) def _xla_real_device(device: torch.device) -> Any: diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 894ed1baa92d..d11141db58db 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -120,7 +120,7 @@ def scan(fn, init, xs): Example: >>> # Example of using `scan` to implement `torch.cumsum`. - >>> import torch_xla.runtime + >>> import torch_xla.core.xla_model as xm >>> import torch >>> from torch_xla.experimental.scan import scan >>> @@ -129,7 +129,7 @@ def scan(fn, init, xs): >>> y = new_carry >>> return new_carry, y >>> - >>> with torch_xla.runtime.xla_device(): + >>> with xm.xla_device(): >>> init = torch.tensor([0.0, 0.0], requires_grad=True) >>> xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], >>> requires_grad=True) diff --git a/torch_xla/experimental/scan_layers.py b/torch_xla/experimental/scan_layers.py index 140a95312c03..8ddafed82024 100644 --- a/torch_xla/experimental/scan_layers.py +++ b/torch_xla/experimental/scan_layers.py @@ -47,11 +47,11 @@ def scan_layers(layers: Iterable[torch.nn.Module], Example: - >>> import torch_xla.runtime + >>> import torch_xla.core.xla_model as xm >>> import torch >>> import torch.nn as nn >>> from torch_xla.experimental.scan_layers import scan_layers - >>> with torch_xla.runtime.xla_device(): + >>> with xm.xla_device(): >>> layers = [nn.Linear(16, 16) for i in range(10)] >>> input = torch.randn(16) >>> output = scan_layers(layers, input) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index b1285c268e82..e9dbb9d48241 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -103,30 +103,6 @@ def is_bf16_supported(): return False -def xla_device(n: Optional[int] = None, - devkind: Optional[str] = None) -> torch.device: - """Returns an XLA device. - - Args: - n: Index of XLA device within visibible devices. If not set, use local - ordinal (default 0) to select an addressable device. - devkind: Type of device to return. Should match `device_type()`. - - Returns: - A `torch.device` representing an XLA device. - """ - if n is None: - return torch.device(torch_xla._XLAC._xla_get_default_device()) - - devices = xm.get_xla_supported_devices(devkind=devkind) - if n > len(devices): - raise IndexError('Device index {} out of range in {}'.format(n, devices)) - - device = devices[n] - torch_xla._XLAC._xla_set_default_device(device) - return torch.device(device) - - def local_process_count() -> int: """Returns the number of processes running on this host.""" return xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_COUNT, int, defval=1) @@ -180,7 +156,7 @@ def local_ordinal() -> int: Local ordinal is in range [0, local_device_count).""" local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0) devices_per_process = addressable_device_count() - return local_rank * devices_per_process + xla_device().index + return local_rank * devices_per_process + xm.xla_device().index def process_index() -> int: From 3a9a776efb3c642d3bdffec5406f93dff9b9b6a2 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 21 May 2025 21:49:22 +0000 Subject: [PATCH 2/9] Migrate to use `torch_xla.device()` from `xm.xla_devices()` --- torch_xla/core/xla_model.py | 20 +++----------------- torch_xla/torch_xla.py | 23 +++++++++++++++++++---- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 9a1d0edc40bd..6dcc5afe7bc7 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -135,6 +135,7 @@ def master_print(*args: Any, print(*args, file=fd, flush=flush) +@deprecated("Use torch_xla.device instead") def xla_device(n: Optional[int] = None, devkind: Optional[str] = None) -> torch.device: """Returns a given instance of an XLA device. @@ -149,23 +150,8 @@ def xla_device(n: Optional[int] = None, Returns: A `torch.device` with the requested instance of an XLA device. """ - # When SPMD is enabled, we always return `xla:0` to the user, and - # under the hood we use virtual device logic for every xla tensor - if xu.check_env_flag('XLA_USE_SPMD'): - device = 'xla:0' - torch_xla._XLAC._xla_set_default_device(device) - return torch.device(device) - - if n is None: - return torch.device(torch_xla._XLAC._xla_get_default_device()) - - devices = xm.get_xla_supported_devices(devkind=devkind) - if n > len(devices): - raise IndexError('Device index {} out of range in {}'.format(n, devices)) - - device = devices[n] - torch_xla._XLAC._xla_set_default_device(device) - return torch.device(device) + del devkind + return torch_xla.device(n) def _xla_real_device(device: torch.device) -> Any: diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index e4486a8cd0b5..3b2b327ff5c9 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -19,18 +19,33 @@ def device(index: int = None) -> torch.device: """Returns a given instance of an XLA device. - If SPMD enables, returns a virtual device that wraps all devices available + If SPMD is enabled, returns a virtual device that wraps all devices available to this process. Args: index: index of the XLA device to be returned. Corresponds to index in - `torch_xla.devices()`. + `torch_xla.devices()`. By default, get the first device. Returns: An XLA `torch.device`. """ - - return xm.xla_device(index) + # When SPMD is enabled, we always return `xla:0` to the user, and + # under the hood we use virtual device logic for every xla tensor + if xu.check_env_flag('XLA_USE_SPMD'): + device = 'xla:0' + torch_xla._XLAC._xla_set_default_device(device) + return torch.device(device) + + if n is None: + return torch.device(torch_xla._XLAC._xla_get_default_device()) + + devices = xm.get_xla_supported_devices() + if n > len(devices): + raise IndexError('Device index {} out of range in {}'.format(n, devices)) + + device = devices[n] + torch_xla._XLAC._xla_set_default_device(device) + return torch.device(device) def devices() -> List[torch.device]: From 079a78c3dbcffae31c35eb5de3a068bb49c5adeb Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 21 May 2025 21:49:54 +0000 Subject: [PATCH 3/9] Replace existing uses of `xm.xla_devices()` --- API_GUIDE.md | 24 +- README.md | 2 +- benchmarks/benchmark_experiment.py | 2 +- benchmarks/experiment_runner.py | 2 +- benchmarks/matmul_bench.py | 4 +- ...ributed-pytorch-xla-basics-with-pjrt.ipynb | 10 +- .../kaggle/pytorch-xla-2-0-on-kaggle.ipynb | 2 +- docs/source/learn/_pjrt.md | 4 +- docs/source/learn/pytorch-on-xla-devices.md | 22 +- docs/source/learn/troubleshoot.md | 4 +- docs/source/learn/xla-overview.md | 8 +- docs/source/perf/amp.md | 14 +- docs/source/perf/ddp.md | 2 +- docs/source/perf/dynamo.md | 8 +- docs/source/perf/fori_loop.md | 4 +- docs/source/perf/quantized_ops.md | 2 +- docs/source/perf/spmd_basic.md | 2 +- examples/train_resnet_amp.py | 2 +- plugins/cpu/README.md | 2 +- plugins/cuda/README.md | 2 +- test/bench.py | 2 +- test/debug_tool/test_mp_pt_xla_debug.py | 2 +- test/debug_tool/test_pt_xla_debug.py | 18 +- test/distributed_util.py | 2 +- test/ds/test_dynamic_shape_models.py | 2 +- test/ds/test_dynamic_shapes.py | 10 +- test/dynamo/test_bridge.py | 14 +- test/dynamo/test_dynamo.py | 30 +- test/dynamo/test_dynamo_aliasing.py | 16 +- test/dynamo/test_dynamo_graph_dump.py | 2 +- test/dynamo/test_dynamo_integrations_util.py | 12 +- test/dynamo/test_graph_input_matcher.py | 2 +- test/dynamo/test_num_output.py | 2 +- test/dynamo/test_traceable_collectives.py | 2 +- test/metrics_compare_utils_test.py | 2 +- test/neuron/test_neuron_data_types.py | 4 +- test/pjrt/test_collective_ops_tpu.py | 20 +- test/pjrt/test_ddp.py | 2 +- test/pjrt/test_dtypes.py | 6 +- test/pjrt/test_metrics.py | 2 +- test/pjrt/test_profiler.py | 4 +- test/pjrt/test_runtime.py | 2 +- test/pjrt/test_runtime_multi_cpu.py | 8 +- test/pjrt/test_runtime_multi_gpu.py | 12 +- test/pjrt/test_runtime_tpu.py | 12 +- test/pjrt/test_torchrun.py | 9 +- test/pjrt/test_train_hf_transformer.py | 2 +- test/pytorch_test_base.py | 4 +- test/quantized_ops/test_quantized_matmul.py | 2 +- test/scan/test_scan_pallas.py | 2 +- test/spmd/test_dtensor_integration.py | 14 +- test/spmd/test_dtensor_integration2.py | 4 +- test/spmd/test_dynamo_spmd.py | 16 +- test/spmd/test_fsdp_v2.py | 26 +- test/spmd/test_mp_input_sharding.py | 12 +- test/spmd/test_sharding_strategies.py | 2 +- test/spmd/test_spmd_debugging.py | 6 +- test/spmd/test_spmd_graph_dump.py | 2 +- test/spmd/test_spmd_lowering_context.py | 4 +- test/spmd/test_train_spmd_imagenet.py | 10 +- test/spmd/test_xla_auto_sharding.py | 14 +- test/spmd/test_xla_distributed_checkpoint.py | 12 +- test/spmd/test_xla_sharding.py | 170 ++++++------ test/spmd/test_xla_sharding_hlo.py | 4 +- .../test_xla_spmd_python_api_interaction.py | 12 +- test/spmd/test_xla_virtual_device.py | 30 +- test/stablehlo/test_composite.py | 2 +- test/stablehlo/test_implicit_broadcasting.py | 2 +- test/stablehlo/test_pt2e_qdq.py | 4 +- test/stablehlo/test_stablehlo_compile.py | 2 +- test/stablehlo/test_stablehlo_custom_call.py | 6 +- test/stablehlo/test_stablehlo_inference.py | 4 +- test/stablehlo/test_stablehlo_save_load.py | 6 +- test/stablehlo/test_unbounded_dynamism.py | 2 +- test/stablehlo/test_xla_export_interpreter.py | 2 +- test/test_as_stride_use_slice.py | 6 +- test/test_autocast.py | 16 +- test/test_autocast_xla.py | 2 +- test/test_compilation_cache_utils.py | 2 +- test/test_core_aten_ops.py | 4 +- test/test_data_type.py | 4 +- test/test_env_var_mapper.py | 2 +- test/test_fsdp_auto_wrap.py | 4 +- test/test_grad_checkpoint.py | 2 +- test/test_gradient_accumulation.py | 2 +- test/test_hlo_metadata.py | 6 +- test/test_inplace_update.py | 10 +- test/test_input_output_aliases.py | 34 +-- test/test_jax_interop.py | 22 +- test/test_metrics.py | 34 +-- test/test_mp_all_gather.py | 2 +- test/test_mp_all_to_all.py | 2 +- test/test_mp_collective_matmul.py | 2 +- test/test_mp_collective_permute.py | 2 +- test/test_mp_distributed_mm.py | 2 +- test/test_mp_early_exit.py | 2 +- test/test_mp_reduce_scatter.py | 2 +- test/test_mp_replication.py | 2 +- test/test_mp_save.py | 2 +- test/test_mp_sync_batch_norm.py | 8 +- test/test_operations.py | 261 +++++++++--------- test/test_operations_hlo.py | 16 +- test/test_persistent_cache.py | 10 +- test/test_profile_mp_mnist.py | 2 +- test/test_python_ops.py | 6 +- test/test_syncfree_optimizers.py | 2 +- ...st_torch_distributed_fsdp_frozen_weight.py | 4 +- test/test_torch_distributed_xla_backend.py | 26 +- test/test_train_mp_imagenet.py | 2 +- test/test_train_mp_imagenet_amp.py | 4 +- test/test_train_mp_imagenet_fsdp.py | 2 +- test/test_train_mp_mnist.py | 2 +- test/test_train_mp_mnist_amp.py | 2 +- test/test_train_mp_mnist_fsdp_with_ckpt.py | 2 +- test/test_train_mp_mnist_zero1.py | 2 +- test/test_user_computation_debug_cache.py | 2 +- test/test_utils.py | 2 +- test/test_while_loop.py | 8 +- test/test_xla_graph_execution.py | 4 +- test/test_zero1.py | 6 +- test/torch_distributed/test_ddp.py | 2 +- ...orch_distributed_all_gather_xla_backend.py | 2 +- ...orch_distributed_all_reduce_xla_backend.py | 2 +- ...ributed_bucketed_all_reduce_xla_backend.py | 2 +- .../test_torch_distributed_fsdp_meta.py | 12 +- ...istributed_multi_all_reduce_xla_backend.py | 2 +- ..._distributed_reduce_scatter_xla_backend.py | 2 +- test/utils/train_spmd_linear_model.py | 4 +- .../utils/train_spmd_linear_model_grad_acc.py | 4 +- torch_xla/_dynamo/dynamo_bridge.py | 9 +- torch_xla/_internal/pjrt.py | 4 +- torch_xla/_internal/tpu.py | 2 +- torch_xla/amp/syncfree/adam.py | 2 +- torch_xla/amp/syncfree/adamw.py | 2 +- torch_xla/core/xla_op_registry.py | 2 +- torch_xla/csrc/init_python_bindings.cpp | 2 +- .../fsdp/xla_fully_sharded_data_parallel.py | 8 +- torch_xla/distributed/spmd/api.py | 4 +- torch_xla/distributed/spmd/xla_sharding.py | 8 +- torch_xla/distributed/xla_multiprocessing.py | 2 +- torch_xla/experimental/fori_loop.py | 2 +- torch_xla/experimental/scan.py | 2 +- torch_xla/experimental/scan_layers.py | 2 +- .../spmd_fully_sharded_data_parallel.py | 2 +- torch_xla/runtime.py | 4 +- torch_xla/stablehlo.py | 4 +- 146 files changed, 668 insertions(+), 665 deletions(-) diff --git a/API_GUIDE.md b/API_GUIDE.md index 47f0e674b798..bb8895c1774b 100644 --- a/API_GUIDE.md +++ b/API_GUIDE.md @@ -15,14 +15,14 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -t = torch.randn(2, 2, device=xm.xla_device()) +t = torch.randn(2, 2, device=torch_xla.device()) print(t.device) print(t) ``` This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and -`xm.xla_device()` returns the current XLA device. This may be a CPU or TPU +`torch_xla.device()` returns the current XLA device. This may be a CPU or TPU depending on your environment. ## XLA Tensors are PyTorch Tensors @@ -32,8 +32,8 @@ PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors For example, XLA tensors can be added together: ```python -t0 = torch.randn(2, 2, device=xm.xla_device()) -t1 = torch.randn(2, 2, device=xm.xla_device()) +t0 = torch.randn(2, 2, device=torch_xla.device()) +t1 = torch.randn(2, 2, device=torch_xla.device()) print(t0 + t1) ``` @@ -46,8 +46,8 @@ print(t0.mm(t1)) Or used with neural network modules: ```python -l_in = torch.randn(10, device=xm.xla_device()) -linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +l_in = torch.randn(10, device=torch_xla.device()) +linear = torch.nn.Linear(10, 20).to(torch_xla.device()) l_out = linear(l_in) print(l_out) ``` @@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like ```python -l_in = torch.randn(10, device=xm.xla_device()) +l_in = torch.randn(10, device=torch_xla.device()) linear = torch.nn.Linear(10, 20) l_out = linear(l_in) print(l_out) @@ -109,10 +109,10 @@ class MNIST(nn.Module): batch_size = 128 train_loader = xu.SampleGenerator( data=(torch.zeros(batch_size, 1, 28, 28), - torch.zeros(batch_size, dtype=torch.int64)), + torch.zeros(batch_size, dtype=torch.int64)), sample_count=60000 // batch_size // xr.world_size()) -device = xm.xla_device() # Get the XLA device (TPU). +device = torch_xla.device() # Get the XLA device (TPU). model = MNIST().train().to(device) # Create a model and move it to the device. loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) @@ -169,7 +169,7 @@ def _mp_fn(index): index: Index of the process. """ - device = xm.xla_device() # Get the device assigned to this process. + device = torch_xla.device() # Get the device assigned to this process. # Wrap the loader for multi-device. mp_device_loader = pl.MpDeviceLoader(train_loader, device) @@ -197,7 +197,7 @@ single device snippet. Let's go over then one by one. - `torch_xla.launch()` - Creates the processes that each run an XLA device. - This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device. - - Note that if you print the `xm.xla_device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details). + - Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details). - `MpDeviceLoader` - Loads the training data onto each device. - `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance. @@ -290,7 +290,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -device = xm.xla_device() +device = torch_xla.device() t0 = torch.randn(2, 2, device=device) t1 = torch.randn(2, 2, device=device) diff --git a/README.md b/README.md index d43eac651a5c..a991cfb30342 100644 --- a/README.md +++ b/README.md @@ -196,7 +196,7 @@ If you're using `DistributedDataParallel`, make the following changes: + # Rank and world size are inferred from the XLA device runtime + dist.init_process_group("xla", init_method='xla://') + -+ model.to(xm.xla_device()) ++ model.to(torch_xla.device()) + ddp_model = DDP(model, gradient_as_bucket_view=True) - model = model.to(rank) diff --git a/benchmarks/benchmark_experiment.py b/benchmarks/benchmark_experiment.py index a82490d373b1..e1fab48334a8 100644 --- a/benchmarks/benchmark_experiment.py +++ b/benchmarks/benchmark_experiment.py @@ -208,7 +208,7 @@ def update_process_env(self, process_env: Dict[str, str]): def get_device(self): if self.torch_xla2: # Initiate the model in CPU first for xla2. We will move the model to jax device later. - # This is because we don't have xm.xla_device() function in torch_xla2. + # This is because we don't have torch_xla.device() function in torch_xla2. return torch.device("cpu") if self.xla: return xm.xla_device(devkind=self.accelerator.upper()) diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index a3b09c7cd7e0..b784af68e47b 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -255,7 +255,7 @@ def _default_iter_fn(self, benchmark_experiment: BenchmarkExperiment, def _pure_wall_time_iter_fn(self, benchmark_experiment: BenchmarkExperiment, benchmark_model: BenchmarkModel, input_tensor): - device = xm.xla_device() if benchmark_experiment.xla else 'cuda' + device = torch_xla.device() if benchmark_experiment.xla else 'cuda' sync_fn = xm.wait_device_ops if benchmark_experiment.xla else torch.cuda.synchronize timing, output = bench.do_bench( lambda: benchmark_model.model_iter_fn( diff --git a/benchmarks/matmul_bench.py b/benchmarks/matmul_bench.py index af518f355ca2..661963bc358c 100644 --- a/benchmarks/matmul_bench.py +++ b/benchmarks/matmul_bench.py @@ -42,7 +42,7 @@ def main(): fn, return_mode='min', sync_fn=lambda: xm.wait_device_ops(), - device=xm.xla_device()) + device=torch_xla.device()) ind_bench_fn = lambda fn: do_bench( fn, return_mode='min', @@ -53,7 +53,7 @@ def main(): for dtype in dtypes: for inductor_matmul, xla_matmul in zip( get_matmuls(device='cuda', dtype=dtype, backend='inductor'), - get_matmuls(device=xm.xla_device(), dtype=dtype, backend='openxla')): + get_matmuls(device=torch_xla.device(), dtype=dtype, backend='openxla')): ind_lhs_shape, ind_rhs_shape, ind_fn = inductor_matmul xla_lhs_shape, xla_rhs_shape, xla_fn = xla_matmul assert ind_lhs_shape == xla_lhs_shape, f"Expect matmul shapes to match for benchmarking. Mismatch lhs: {ind_lhs_shape}, rhs: {xla_rhs_shape}" diff --git a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb index d4d676f745e5..6f0f06d7c146 100644 --- a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb +++ b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb @@ -188,7 +188,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To get the current process/thread's default XLA device, use `xm.xla_device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`." + "To get the current process/thread's default XLA device, use `torch_xla.device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`." ] }, { @@ -210,7 +210,7 @@ "lock = mp.Manager().Lock()\n", "\n", "def print_device(i, lock):\n", - " device = xm.xla_device()\n", + " device = torch_xla.device()\n", " with lock:\n", " print('process', i, device)" ] @@ -318,7 +318,7 @@ ], "source": [ "def add_ones(i, lock):\n", - " x = torch.ones((3, 3), device=xm.xla_device())\n", + " x = torch.ones((3, 3), device=torch_xla.device())\n", " y = x + x\n", " \n", " # Run graph to compute `y` before printing\n", @@ -378,7 +378,7 @@ "source": [ "def gather_ids(i, lock):\n", " # Create a tensor on each device with the device ID\n", - " t = torch.tensor([i], device=xm.xla_device())\n", + " t = torch.tensor([i], device=torch_xla.device())\n", " with lock:\n", " print(i, t)\n", " \n", @@ -454,7 +454,7 @@ "import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n", "\n", "def toy_model(index, lock):\n", - " device = xm.xla_device()\n", + " device = torch_xla.device()\n", " dist.init_process_group('xla', init_method='xla://')\n", "\n", " # Initialize a basic toy model\n", diff --git a/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb b/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb index a2c2f0d099e2..a0c5d6d3d769 100644 --- a/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb +++ b/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb @@ -172,7 +172,7 @@ "\n", "pipeline = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\")\n", "# Move the model to the first TPU core\n", - "pipeline = pipeline.to(xm.xla_device())" + "pipeline = pipeline.to(torch_xla.device())" ] }, { diff --git a/docs/source/learn/_pjrt.md b/docs/source/learn/_pjrt.md index 5531ce8824a0..91917b115a5d 100644 --- a/docs/source/learn/_pjrt.md +++ b/docs/source/learn/_pjrt.md @@ -73,7 +73,7 @@ import torch_xla.distributed.xla_backend def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() - dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size()) + dist.init_process_group('xla', init_method='xla://') @@ -377,7 +377,7 @@ def _all_gather(index: int): # No need to pass in `rank` or `world_size` dist.init_process_group('xla', init_method='xla://') - t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device()) + t = torch.tensor([index], dtype=torch.int32, device=torch_xla.device()) output = [torch.zeros_like(t) for _ in range(dist.get_world_size())] dist.all_gather(output, t) diff --git a/docs/source/learn/pytorch-on-xla-devices.md b/docs/source/learn/pytorch-on-xla-devices.md index de3cd3c69409..25328f0ab0a0 100644 --- a/docs/source/learn/pytorch-on-xla-devices.md +++ b/docs/source/learn/pytorch-on-xla-devices.md @@ -14,14 +14,14 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -t = torch.randn(2, 2, device=xm.xla_device()) +t = torch.randn(2, 2, device=torch_xla.device()) print(t.device) print(t) ``` This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing `torch_xla` initializes -PyTorch/XLA, and `xm.xla_device()` returns the current XLA device. This +PyTorch/XLA, and `torch_xla.device()` returns the current XLA device. This may be a CPU or TPU depending on your environment. ## XLA Tensors are PyTorch Tensors @@ -32,8 +32,8 @@ tensors. For example, XLA tensors can be added together: ``` python -t0 = torch.randn(2, 2, device=xm.xla_device()) -t1 = torch.randn(2, 2, device=xm.xla_device()) +t0 = torch.randn(2, 2, device=torch_xla.device()) +t1 = torch.randn(2, 2, device=torch_xla.device()) print(t0 + t1) ``` @@ -46,8 +46,8 @@ print(t0.mm(t1)) Or used with neural network modules: ``` python -l_in = torch.randn(10, device=xm.xla_device()) -linear = torch.nn.Linear(10, 20).to(xm.xla_device()) +l_in = torch.randn(10, device=torch_xla.device()) +linear = torch.nn.Linear(10, 20).to(torch_xla.device()) l_out = linear(l_in) print(l_out) ``` @@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like ``` python -l_in = torch.randn(10, device=xm.xla_device()) +l_in = torch.randn(10, device=torch_xla.device()) linear = torch.nn.Linear(10, 20) l_out = linear(l_in) print(l_out) @@ -81,7 +81,7 @@ The following snippet shows a network training on a single XLA device: ``` python import torch_xla.core.xla_model as xm -device = xm.xla_device() +device = torch_xla.device() model = MNIST().train().to(device) loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) @@ -120,7 +120,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() mp_device_loader = pl.MpDeviceLoader(train_loader, device) model = MNIST().train().to(device) @@ -148,7 +148,7 @@ previous single device snippet. Let's go over then one by one. will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device. - - Note that if you print the `xm.xla_device()` on each process you + - Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 @@ -283,7 +283,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -device = xm.xla_device() +device = torch_xla.device() t0 = torch.randn(2, 2, device=device) t1 = torch.randn(2, 2, device=device) diff --git a/docs/source/learn/troubleshoot.md b/docs/source/learn/troubleshoot.md index fab620a22110..67497bfa5f09 100644 --- a/docs/source/learn/troubleshoot.md +++ b/docs/source/learn/troubleshoot.md @@ -32,8 +32,8 @@ vm:~$ export PJRT_DEVICE=TPU vm:~$ python3 >>> import torch >>> import torch_xla.core.xla_model as xm ->>> t1 = torch.tensor(100, device=xm.xla_device()) ->>> t2 = torch.tensor(200, device=xm.xla_device()) +>>> t1 = torch.tensor(100, device=torch_xla.device()) +>>> t2 = torch.tensor(200, device=torch_xla.device()) >>> print(t1 + t2) tensor(300, device='xla:0') ``` diff --git a/docs/source/learn/xla-overview.md b/docs/source/learn/xla-overview.md index 987d7b9629f2..f6b0761fd69a 100644 --- a/docs/source/learn/xla-overview.md +++ b/docs/source/learn/xla-overview.md @@ -184,7 +184,7 @@ repo. contains examples for training and serving many LLM and diffusion models. General guidelines to modify your code: -- Replace `cuda` with `xm.xla_device()` +- Replace `cuda` with `torch_xla.device()` - Remove progress bar, printing that would access the XLA tensor values - Reduce logging and callbacks that would access the XLA tensor values @@ -227,7 +227,7 @@ tutorial, but you can pass the `device` value to the function as well. ``` python import torch_xla.core.xla_model as xm - self.device = xm.xla_device() + self.device = torch_xla.device() ``` Another place in the code that has cuda specific code is DDIM scheduler. @@ -244,7 +244,7 @@ if attr.device != torch.device("cuda"): with ``` python -device = xm.xla_device() +device = torch_xla.device() attr = attr.to(torch.device(device)) ``` @@ -339,7 +339,7 @@ with the following lines: ``` python import torch_xla.core.xla_model as xm -device = xm.xla_device() +device = torch_xla.device() pipe.to(device) ``` diff --git a/docs/source/perf/amp.md b/docs/source/perf/amp.md index 0b2ef1fc428d..0832c14d3fc2 100644 --- a/docs/source/perf/amp.md +++ b/docs/source/perf/amp.md @@ -19,7 +19,7 @@ from torch_xla.amp import syncfree import torch_xla.core.xla_model as xm # Creates model and optimizer in default precision -model = Net().to(xm.xla_device()) +model = Net().to(torch_xla.device()) # Pytorch/XLA provides sync-free optimizers for improved performance optimizer = syncfree.SGD(model.parameters(), ...) @@ -27,7 +27,7 @@ for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(xm.xla_device()): + with autocast(torch_xla.device()): output = model(input) loss = loss_fn(output, target) @@ -36,7 +36,7 @@ for input, target in data: xm.optimizer_step.(optimizer) ``` -`autocast(xm.xla_device())` aliases `torch.autocast('xla')` when the XLA +`autocast(torch_xla.device())` aliases `torch.autocast('xla')` when the XLA Device is a TPU. Alternatively, if a script is only used with TPUs, then `torch.autocast('xla', dtype=torch.bfloat16)` can be directly used. @@ -106,7 +106,7 @@ from torch_xla.amp import syncfree import torch_xla.core.xla_model as xm # Creates model and optimizer in default precision -model = Net().to(xm.xla_device()) +model = Net().to(torch_xla.device()) # Pytorch/XLA provides sync-free optimizers for improved performance optimizer = syncfree.SGD(model.parameters(), ...) scaler = GradScaler() @@ -115,7 +115,7 @@ for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(xm.xla_device()): + with autocast(torch_xla.device()): output = model(input) loss = loss_fn(output, target) @@ -127,12 +127,12 @@ for input, target in data: scaler.update() ``` -`autocast(xm.xla_device())` aliases `torch.cuda.amp.autocast()` when the +`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is only used with CUDA devices, then `torch.cuda.amp.autocast` can be directly used, but requires `torch` is compiled with `cuda` support for datatype of `torch.bfloat16`. We recommend using -`autocast(xm.xla_device())` on XLA:GPU as it does not require +`autocast(torch_xla.device())` on XLA:GPU as it does not require `torch.cuda` support for any datatypes, including `torch.bfloat16`. ### AMP for XLA:GPU Best Practices diff --git a/docs/source/perf/ddp.md b/docs/source/perf/ddp.md index eade410cdb58..efc4071d648d 100644 --- a/docs/source/perf/ddp.md +++ b/docs/source/perf/ddp.md @@ -105,7 +105,7 @@ def demo_basic(rank): setup(rank, world_size) # create model and move it to XLA device - device = xm.xla_device() + device = torch_xla.device() model = ToyModel().to(device) ddp_model = DDP(model, gradient_as_bucket_view=True) diff --git a/docs/source/perf/dynamo.md b/docs/source/perf/dynamo.md index a34a162d89bc..2ef86a2a5f9d 100644 --- a/docs/source/perf/dynamo.md +++ b/docs/source/perf/dynamo.md @@ -23,8 +23,8 @@ import torch import torch_xla.core.xla_model as xm def add(a, b): - a_xla = a.to(xm.xla_device()) - b_xla = b.to(xm.xla_device()) + a_xla = a.to(torch_xla.device()) + b_xla = b.to(torch_xla.device()) return a_xla + b_xla compiled_code = torch.compile(add, backend='openxla') @@ -41,7 +41,7 @@ import torchvision import torch_xla.core.xla_model as xm def eval_model(loader): - device = xm.xla_device() + device = torch_xla.device() xla_resnet18 = torchvision.models.resnet18().to(device) xla_resnet18.eval() dynamo_resnet18 = torch.compile( @@ -129,7 +129,7 @@ def train_model(model, data, target, optimizer): return pred def train_model_main(loader): - device = xm.xla_device() + device = torch_xla.device() xla_resnet18 = torchvision.models.resnet18().to(device) xla_resnet18.train() dynamo_train_model = torch.compile( diff --git a/docs/source/perf/fori_loop.md b/docs/source/perf/fori_loop.md index 1fb80be8a3f9..bfdd2bf318ab 100644 --- a/docs/source/perf/fori_loop.md +++ b/docs/source/perf/fori_loop.md @@ -30,7 +30,7 @@ result = while_loop(cond_fn, body_fn, init) >>> from torch._higher_order_ops.while_loop import while_loop >>> import torch_xla.core.xla_model as xm >>> ->>> device = xm.xla_device() +>>> device = torch_xla.device() >>> >>> def cond_fn(iteri, x): ... return iteri > 0 @@ -60,7 +60,7 @@ with similar logic: cumulative plus 1 for ten times: >>> import torch_xla >>> import torch_xla.core.xla_model as xm >>> ->>> device = xm.xla_device() +>>> device = torch_xla.device() >>> >>> init_val = torch.tensor(1, device=device) >>> iteri = torch.tensor(50, device=device) diff --git a/docs/source/perf/quantized_ops.md b/docs/source/perf/quantized_ops.md index 41e37f6a0708..6d44b05e433b 100644 --- a/docs/source/perf/quantized_ops.md +++ b/docs/source/perf/quantized_ops.md @@ -48,7 +48,7 @@ scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16) # Call with torch CPU tensor (For debugging purpose) matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler) -device = xm.xla_device() +device = torch_xla.device() x_xla = x.to(device) w_int_xla = w_int.to(device) scaler_xla = scaler.to(device) diff --git a/docs/source/perf/spmd_basic.md b/docs/source/perf/spmd_basic.md index 182533babcd3..4f343106529d 100644 --- a/docs/source/perf/spmd_basic.md +++ b/docs/source/perf/spmd_basic.md @@ -41,7 +41,7 @@ mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, ('data', 'model')) -t = torch.randn(8, 4).to(xm.xla_device()) +t = torch.randn(8, 4).to(torch_xla.device()) # Mesh partitioning, each device holds 1/8-th of the input partition_spec = ('data', 'model') diff --git a/examples/train_resnet_amp.py b/examples/train_resnet_amp.py index ae541705d717..8082d01524e9 100644 --- a/examples/train_resnet_amp.py +++ b/examples/train_resnet_amp.py @@ -17,7 +17,7 @@ def train_loop_fn(self, loader, epoch): for step, (data, target) in enumerate(loader): self.optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(xm.xla_device()): + with autocast(torch_xla.device()): output = self.model(data) loss = self.loss_fn(output, target) # TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU diff --git a/plugins/cpu/README.md b/plugins/cpu/README.md index c46a315e31c3..76c9d0b7c88e 100644 --- a/plugins/cpu/README.md +++ b/plugins/cpu/README.md @@ -38,5 +38,5 @@ plugins.use_dynamic_plugins() plugins.register_plugin('CPU', torch_xla_cpu_plugin.CpuPlugin()) xr.set_device_type('CPU') -print(xm.xla_device()) +print(torch_xla.device()) ``` diff --git a/plugins/cuda/README.md b/plugins/cuda/README.md index f86caf8e0d27..45a002e06f6c 100644 --- a/plugins/cuda/README.md +++ b/plugins/cuda/README.md @@ -35,5 +35,5 @@ plugins.use_dynamic_plugins() plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin()) xr.set_device_type('CUDA') -print(xm.xla_device()) +print(torch_xla.device()) ``` diff --git a/test/bench.py b/test/bench.py index 97fb24f0a5ae..e5eff86a34d5 100644 --- a/test/bench.py +++ b/test/bench.py @@ -29,7 +29,7 @@ class BaseBench(object): def __init__(self, args): self.args = args - self.device = xm.xla_device() + self.device = torch_xla.device() self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0) torch.manual_seed(42) diff --git a/test/debug_tool/test_mp_pt_xla_debug.py b/test/debug_tool/test_mp_pt_xla_debug.py index 45a9502af795..785554657b14 100644 --- a/test/debug_tool/test_mp_pt_xla_debug.py +++ b/test/debug_tool/test_mp_pt_xla_debug.py @@ -16,7 +16,7 @@ def _mp_fn(index): assert False, "This test should be run with PT_XLA_DEBUG_FILE" if index == 0: open(debug_file_name, 'w').close() - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(10, 10, device=device) t2 = t1 * 100 torch_xla.sync() diff --git a/test/debug_tool/test_pt_xla_debug.py b/test/debug_tool/test_pt_xla_debug.py index 57864cd74657..4ebcb2cd1bb9 100644 --- a/test/debug_tool/test_pt_xla_debug.py +++ b/test/debug_tool/test_pt_xla_debug.py @@ -31,7 +31,7 @@ def setUpClass(cls): def test_eager_sync(self): with torch_xla.experimental.eager_mode_context(True): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(5, 9, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: @@ -41,7 +41,7 @@ def test_eager_sync(self): open(self.debug_file_name, 'w').close() def test_user_sync(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(2, 2, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: @@ -79,7 +79,7 @@ def test_user_sync(self): open(self.debug_file_name, 'w').close() def test_step_trace(self): - device = xm.xla_device() + device = torch_xla.device() with xp.StepTrace('train_pt_xla_debug'): t1 = torch.randn(3, 3, device=device) with open(self.debug_file_name, 'rb') as f: @@ -111,7 +111,7 @@ def test_step_trace(self): open(self.debug_file_name, 'w').close() def test_dynamo(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 4, device=device) def toy_program(t1): @@ -161,7 +161,7 @@ def toy_program(t1): open(self.debug_file_name, 'w').close() def test_torch_xla_compile(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(12, 4, device=device) def toy_program(t1): @@ -209,7 +209,7 @@ def toy_program(t1): open(self.debug_file_name, 'w').close() def test_torch_xla_compile_custom_name(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(18, 4, device=device) def toy_program2(t1): @@ -239,7 +239,7 @@ def toy_program2(t1): open(self.debug_file_name, 'w').close() def test_parallel_loader(self): - device = xm.xla_device() + device = torch_xla.device() train_dataset_len = 100 batch_size = 10 @@ -287,7 +287,7 @@ def test_parallel_loader(self): open(self.debug_file_name, 'w').close() def test_print(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(5, 5, device=device) print(t1) with open(self.debug_file_name, 'rb') as f: @@ -315,7 +315,7 @@ def test_print(self): open(self.debug_file_name, 'w').close() def test_frame(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(6, 6, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: diff --git a/test/distributed_util.py b/test/distributed_util.py index 040ea451b4e2..85069aaabc82 100644 --- a/test/distributed_util.py +++ b/test/distributed_util.py @@ -101,7 +101,7 @@ def ddp_correctness(init_method: str = 'env://', dist.init_process_group("xla", init_method=init_method) rank, world_size = dist.get_rank(), dist.get_world_size() - device = xm.xla_device() + device = torch_xla.device() # Module initialization is not thread safe. Force threads to initialize one # at a time with the same seed diff --git a/test/ds/test_dynamic_shape_models.py b/test/ds/test_dynamic_shape_models.py index b6d06ea65f7f..2c1c827e7fac 100644 --- a/test/ds/test_dynamic_shape_models.py +++ b/test/ds/test_dynamic_shape_models.py @@ -17,7 +17,7 @@ # It enables us to run python implementations of CompositeAutogradImplicit ops. # CompositeAutogradImplicit means we don't have an explicit backward formula for an op instead an op is composed of a bunch of ops that do have backward formulas and combines this formulas is equivalent to differentiating the op explicitly. pd = torch._C._EnablePythonDispatcher() -xla_dev = xm.xla_device() +xla_dev = torch_xla.device() class Feedforward(torch.nn.Module): diff --git a/test/ds/test_dynamic_shapes.py b/test/ds/test_dynamic_shapes.py index 1c1a62e2724c..3d1f8bb28fbd 100644 --- a/test/ds/test_dynamic_shapes.py +++ b/test/ds/test_dynamic_shapes.py @@ -9,7 +9,7 @@ import test_utils pd = torch._C._EnablePythonDispatcher() -dev = xm.xla_device() +dev = torch_xla.device() class TestDynamicShapes(test_utils.XlaTestCase): @@ -163,7 +163,7 @@ def test_t_copy(self): self.assertEqual(t2_t.shape[1], 7) def test_nonzero_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device()) x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.nonzero(x, as_tuple=False), 0) self.assertEqual(x_dim0_shape.item(), 4) @@ -176,14 +176,14 @@ def test_nonzero_correctness(self): self.assertEqual(t2.cpu(), t2_aten) def test_masked_select_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device()) mask = x.ge(2) x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.masked_select(x, mask), 0) self.assertEqual(x_dim0_shape.item(), 3) def test_nonzero_cast(self): - t1 = torch.ones(5, 2, device=xm.xla_device()) + t1 = torch.ones(5, 2, device=torch_xla.device()) # Result of the nonzero should be the index type. Currently # index type is s64 on cpu and gpu, but s32 on TPU. We should be # able to cast it to any other type without error. @@ -191,7 +191,7 @@ def test_nonzero_cast(self): torch_xla.sync() def test_expand_symint_correctness(self): - dev = xm.xla_device() + dev = torch_xla.device() size1 = 5 size2 = 2 t1 = torch.ones([size1, size2]) diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 52c9cbf1053b..da36031759a2 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -116,7 +116,7 @@ def unwrap(cont): def make_reuse_graph_test(module_class, niter=100): def test_wrapper(self): - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() xla_module = module_class().to(device=xla_dev) inputs = tuple(x.to(device=xla_dev) for x in xla_module.get_random_inputs()) metrics.clear_counters() @@ -187,7 +187,7 @@ def make_training_test(model_cls): def test_wrapper(self): import torch_xla.core.xla_model as xm - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() model = model_cls() inputs = model.get_random_inputs() @@ -240,7 +240,7 @@ class Emb(torch.nn.Embedding): def __init__(self): super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0) - device = xm.xla_device() + device = torch_xla.device() module = Emb() module.to(device) @@ -255,7 +255,7 @@ def test_inputs_not_computed(self): def foo(x): return x * 2 - device = xm.xla_device() + device = torch_xla.device() x = torch.rand(5, device=device) x = x.unsqueeze(dim=-1) self._compile_and_check(foo, (x,)) @@ -265,7 +265,7 @@ def test_factory_copy(self): def foo(device): return torch.arange(5, device="cpu").to(device) - self._compile_and_check(foo, (xm.xla_device(),)) + self._compile_and_check(foo, (torch_xla.device(),)) def test_index_flag_unsupported(self): # The indices of the index operation are represented as @@ -277,7 +277,7 @@ def test_index_flag_unsupported(self): def foo(xt, t): return xt[t] - device = xm.xla_device() + device = torch_xla.device() xt = torch.rand(5, device=device) t = torch.randint(0, 5, (3,)) self._compile_and_check(foo, (xt, t)) @@ -299,7 +299,7 @@ def test_cpu_flag_unsupported(self): def foo(t): return t.cpu() - device = xm.xla_device() + device = torch_xla.device() t = torch.randint(0, 5, (3,), device=device) self._compile_and_check(foo, (t,)) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 0bf4b9110f2a..01b8085e02dd 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -49,7 +49,7 @@ def inplace_update(self, a): def test_inplace_update_correctness(self, backend): dynamo_inplace = torch.compile( self.inplace_update, backend=backend, fullgraph=True) - t = torch.tensor([0, 1, 2], device=xm.xla_device()) + t = torch.tensor([0, 1, 2], device=torch_xla.device()) for i in range(10): t = dynamo_inplace(t) self.assertTrue(torch.all(torch.eq(t.cpu(), torch.tensor([10, 11, 12])))) @@ -66,7 +66,7 @@ def test_random_op_different_result_each_run(self, backend): met.clear_all() dynamo_random_op = torch.compile( self.random_op, backend=backend, fullgraph=True) - t = torch.randn(5, 5).to(xm.xla_device()) + t = torch.randn(5, 5).to(torch_xla.device()) dynamo_res_1 = dynamo_random_op(t) dynamo_res_2 = dynamo_random_op(t) dynamo_res_3 = dynamo_random_op(t) @@ -89,7 +89,7 @@ def test_sync_after_dynamo(self): head_dim = 128 running = 16 - device = xm.xla_device() + device = torch_xla.device() cache = torch.rand((cache_len, kv_heads, head_dim)).to(device) update_indices = torch.randint( 0, cache_len, (running,), dtype=torch.long).to(device) @@ -131,7 +131,7 @@ def dummy_fn(self, a): def test_dynamo_with_trace(self): dynamo_dummy = torch.compile( self.dummy_fn, backend="openxla", fullgraph=True) - t = torch.randn(2, 3, 4, device=xm.xla_device()) + t = torch.randn(2, 3, 4, device=torch_xla.device()) for i in range(10): with xp.Trace('build_graph'): t = dynamo_dummy(t) @@ -150,7 +150,7 @@ def fn_simple(self, x, y): def _choose_proper_device(self, initialize_on_cuda): if not initialize_on_cuda: - return xm.xla_device() + return torch_xla.device() assert initialize_on_cuda if xr.device_type() != "CUDA" or not torch.cuda.is_available(): @@ -164,7 +164,7 @@ def _choose_proper_device(self, initialize_on_cuda): @skipOnNeuron def test_simple_model(self): - device = xm.xla_device() + device = torch_xla.device() x = torch.tensor(100.0) y = torch.tensor(200.0) xla_x = x.to(device) @@ -448,7 +448,7 @@ def fn_fallback(t): torch._dynamo.reset() met.clear_all() - device = xm.xla_device() + device = torch_xla.device() # Initial tracing dynamo_fn = torch.compile(fn_fallback, backend="openxla") @@ -488,7 +488,7 @@ def fn_fallback(t): torch._dynamo.reset() met.clear_all() - device = xm.xla_device() + device = torch_xla.device() # Initial tracing dynamo_fn = torch.compile(fn_fallback, backend="openxla") @@ -541,7 +541,7 @@ def train_model(self, model, data, target): def test_simple_model(self): torch._dynamo.reset() - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(3, 5, requires_grad=True) xla_input = input.detach().to(device) xla_input.requires_grad = True @@ -577,7 +577,7 @@ def test_simple_model(self): def test_resnet18(self): torch._dynamo.reset() met.clear_counters() - device = xm.xla_device() + device = torch_xla.device() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = xu.SampleGenerator( @@ -650,7 +650,7 @@ def train_model(self, model, data, target, optimizer): def test_simple_model(self): torch._dynamo.reset() - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(3, 5, requires_grad=True) saved_input = input.detach().to(device).cpu() xla_input = input.detach().to(device) @@ -673,7 +673,7 @@ def test_simple_model(self): def test_resnet18(self): torch._dynamo.reset() met.clear_counters() - device = xm.xla_device() + device = torch_xla.device() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = xu.SampleGenerator( @@ -732,7 +732,7 @@ def test_resnet18(self): class DynamoErrorMessageTest(parameterized.TestCase): def test_mixed_cpu_tensor(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(4, 3, 224, 224) input_xla = input.clone().to(device) resnet18 = torchvision.models.resnet18() @@ -783,7 +783,7 @@ def foo(x): optfoo = torch.compile(backend=backend)(foo) t = torch.arange(9) - Xt = t.to(xm.xla_device()) + Xt = t.to(torch_xla.device()) expected = foo(t) actual = optfoo(Xt).cpu() @@ -803,7 +803,7 @@ def foo(x): optfoo = torch.compile(backend=backend)(foo) t = torch.arange(10) - Xt = t.to(xm.xla_device()) + Xt = t.to(torch_xla.device()) expected = foo(t) actual = optfoo(Xt) diff --git a/test/dynamo/test_dynamo_aliasing.py b/test/dynamo/test_dynamo_aliasing.py index a28567f42c25..36bfb5744bd4 100644 --- a/test/dynamo/test_dynamo_aliasing.py +++ b/test/dynamo/test_dynamo_aliasing.py @@ -11,7 +11,7 @@ class TestBufferDonationUtil(unittest.TestCase): def test_hash_with_buffer_donor(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) res = torch.cos(input) hash_no_donor = torch_xla._XLAC._get_graph_hash([res]) @@ -40,7 +40,7 @@ def dummy_mul(self, input): return input * 1.1 def test_manual_buffer_donation(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_mul_compiled = torch.compile( @@ -55,7 +55,7 @@ def test_manual_buffer_donation(self): torch.allclose(input_cloned.cpu() * 1.1, input.cpu()) def test_manual_buffer_donation_for_non_inplce_op(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_mul_compiled = torch.compile(self.dummy_mul, backend='openxla') @@ -81,7 +81,7 @@ def dummy_inplace(input): torch.ops.xla.dynamo_set_buffer_donor_(input, True) input += (0.5 * torch.sin(input)) - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') @@ -109,7 +109,7 @@ def dummy_add(self, input): return input + 1 def test_manual_buffer_donation(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile( @@ -127,7 +127,7 @@ def test_manual_buffer_donation(self): self.assertFalse(torch_xla._XLAC._get_buffer_donation(input)) def test_manual_buffer_donation_for_non_inplce_op(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_add_compiled = torch.compile(self.dummy_add, backend='openxla') @@ -152,7 +152,7 @@ def test_manual_buffer_donation_for_inplce_op_repeat(self): def dummy_inplace(input): input += (0.3 * torch.cos(input)) - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') @@ -174,7 +174,7 @@ def dummy_inplace(input): self.assertEqual(met.metric_data('CompileTime')[0], 1) def test_buffer_donation_on_non_data_tensor(self): - device = xm.xla_device() + device = torch_xla.device() input = torch.randn(5, 5).to(device) res = input + 1 diff --git a/test/dynamo/test_dynamo_graph_dump.py b/test/dynamo/test_dynamo_graph_dump.py index 6ce95dcbeff3..ae0383a47963 100644 --- a/test/dynamo/test_dynamo_graph_dump.py +++ b/test/dynamo/test_dynamo_graph_dump.py @@ -27,7 +27,7 @@ def test_dump_graph_with_dynamo_execution(self): if not save_file: assert False, "This test should be run with XLA_SAVE_TENSORS_FILE" save_file += '.0' - device = xm.xla_device() + device = torch_xla.device() xla_x = torch.tensor(100.0).to(device) xla_y = torch.tensor(200.0).to(device) res_xla_dynamo = self.fn_simple_dynamo(xla_x, xla_y) diff --git a/test/dynamo/test_dynamo_integrations_util.py b/test/dynamo/test_dynamo_integrations_util.py index b18262fc113a..293bef17ec05 100644 --- a/test/dynamo/test_dynamo_integrations_util.py +++ b/test/dynamo/test_dynamo_integrations_util.py @@ -20,7 +20,7 @@ class PybindTest(unittest.TestCase): def test_get_tensors_xla_device_data_node(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.randn(20, 5).to(xla_device) t2 = torch.randn(20, 5).to(xla_device) t3 = t2 + t1 @@ -42,7 +42,7 @@ def test_get_tensors_xla_device_data_node(self): assert (expected_tensor_ids == sorted(res_pair[0])) def test_get_base_seed_as_tensor(self): - device = xm.xla_device() + device = torch_xla.device() xm.set_rng_state(23, str(device)) base_seed = torch_xla._XLAC._get_base_seed_as_tensor(str(device)).item() self.assertEqual(23, base_seed) @@ -51,7 +51,7 @@ def test_get_seed_info_id(self): self.assertEqual(torch_xla._XLAC._get_seed_info_id(), -127389) def test_check_tensor_need_materialization(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.randn(20, 5) assert (torch_xla._XLAC._check_tensor_need_materialization([t1]) == [False]) t1 = t1.to(xla_device) @@ -67,7 +67,7 @@ def test_check_tensor_need_materialization(self): assert (torch_xla._XLAC._check_tensor_need_materialization([t1]) == [True]) def test_get_graph_hash(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_input = torch.randn(64, 256, 14, 14).to(xla_device) xla_dummy_model = dummy_model.to(xla_device) xla_out = xla_dummy_model(xla_input) @@ -85,7 +85,7 @@ def test_get_graph_hash(self): assert (hash == torch_xla._XLAC._get_graph_hash([xla_out_2])) def test_clear_pending_irs(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() torch_xla.sync() t1 = torch.randn(20, 5).to(xla_device) t2 = torch.randn(20, 5).to(xla_device) @@ -104,7 +104,7 @@ def test_clear_pending_irs(self): self.assertEqual(met.metric_data('ExecuteTime')[0], 1) def test_run_cached_graph(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_input = torch.randn(64, 256, 14, 14).to(xla_device) xla_dummy_model = dummy_model.to(xla_device) xla_out = xla_dummy_model(xla_input) diff --git a/test/dynamo/test_graph_input_matcher.py b/test/dynamo/test_graph_input_matcher.py index 9060248d32de..70dd0be73f57 100644 --- a/test/dynamo/test_graph_input_matcher.py +++ b/test/dynamo/test_graph_input_matcher.py @@ -24,7 +24,7 @@ def get_example_inputs(self): class TestGraphInputMatcher(unittest.TestCase): def test_no_cache_fx_gragh_inputs(self): - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() model = M().to(device=xla_dev) inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs()) diff --git a/test/dynamo/test_num_output.py b/test/dynamo/test_num_output.py index ab86df8a6c50..b540e0691643 100644 --- a/test/dynamo/test_num_output.py +++ b/test/dynamo/test_num_output.py @@ -59,7 +59,7 @@ def get_example_inputs(self): class TestNumOutput(unittest.TestCase): def do_test(self, model_class, expected_num_output): - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() model = model_class().to(device=xla_dev) inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs()) diff --git a/test/dynamo/test_traceable_collectives.py b/test/dynamo/test_traceable_collectives.py index 948491bddcc4..45bd89266604 100644 --- a/test/dynamo/test_traceable_collectives.py +++ b/test/dynamo/test_traceable_collectives.py @@ -18,7 +18,7 @@ def collective_broadcast_and_cos(input, src): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): print(f'skip this test for hw {xm.xla_device_hw(device)}') diff --git a/test/metrics_compare_utils_test.py b/test/metrics_compare_utils_test.py index 69a942646a84..cf2ed9beae73 100644 --- a/test/metrics_compare_utils_test.py +++ b/test/metrics_compare_utils_test.py @@ -275,7 +275,7 @@ def test_compare_metrics_reports_new_counters(self): def test_parse_real_metrics(self): print( 'Testing against TPU. If this hangs, check that $XRT_TPU_CONFIG is set') - x = torch.rand(3, 5, device=xm.xla_device()) + x = torch.rand(3, 5, device=torch_xla.device()) x = torch.flatten(x, 1) x = torch.roll(x, 1, 0) x = torch.flip(x, [0, 1]) diff --git a/test/neuron/test_neuron_data_types.py b/test/neuron/test_neuron_data_types.py index 4b8fb76c001b..326b3857794e 100644 --- a/test/neuron/test_neuron_data_types.py +++ b/test/neuron/test_neuron_data_types.py @@ -9,8 +9,8 @@ class NeuronXlaDataTypeTest(unittest.TestCase): def _test_datatypes(self, dtype, op_xla_dtype, op): - t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) - t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) + t1 = torch.tensor([2, 3], dtype=dtype, device=torch_xla.device()) + t2 = torch.tensor([2, 3], dtype=dtype, device=torch_xla.device()) t3 = op(t1, t2) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index ccb020b8bd8a..614040a81dc7 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -17,7 +17,7 @@ class TestXMCollectiveOpsTpu(parameterized.TestCase): @staticmethod def _broadcast(sync): torch.manual_seed(xr.global_ordinal()) - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(5, 5).to(device) if sync: xm.broadcast_master_param(model) @@ -41,7 +41,7 @@ def test_broadcast_master_param(self, sync): @staticmethod def _all_reduce(pin_layout): - device = xm.xla_device() + device = torch_xla.device() # Prevent 0 and 1 from being converted to constants ordinal = xm.send_cpu_data_to_device( torch.tensor( @@ -63,7 +63,7 @@ def test_all_reduce(self, pin_layout): @staticmethod def _all_gather(pin_layout): - device = xm.xla_device() + device = torch_xla.device() ordinal = torch.tensor([xr.global_ordinal()], device=device) out = xm.all_gather(ordinal, pin_layout=pin_layout) torch_xla.sync() @@ -80,7 +80,7 @@ def test_all_gather(self, pin_layout): @staticmethod def _reduce_scatter(pin_layout): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() tensor = -torch.arange(world_size, dtype=torch.float32).to(device) @@ -105,7 +105,7 @@ def test_reduce_scatter(self, pin_layout): @staticmethod def _all_to_all(pin_layout): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() tensor = torch.cat( @@ -151,7 +151,7 @@ def callable(input): return input dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() input = torch.tensor([xr.global_ordinal()], dtype=torch.float, device=device) @@ -175,7 +175,7 @@ def callable(output, input): return output_tensor dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() input = torch.tensor([xr.global_ordinal()], dtype=torch.float, device=device) @@ -194,7 +194,7 @@ def callable(output, input): def _all_gather(use_dynamo: bool): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() def callable(input): output_tensor = [ @@ -223,7 +223,7 @@ def callable(input): def _reduce_scatter(use_dynamo: bool): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() def callable(output, input): dist.reduce_scatter_tensor(output, input) @@ -248,7 +248,7 @@ def callable(output, input): def _all_to_all_single(use_dynamo: bool, split_size: int = 1): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() def callable(output, input): dist.all_to_all_single(output, input) diff --git a/test/pjrt/test_ddp.py b/test/pjrt/test_ddp.py index 0be8835ddb36..d236b8e11ea1 100644 --- a/test/pjrt/test_ddp.py +++ b/test/pjrt/test_ddp.py @@ -25,7 +25,7 @@ class TestPjRtDistributedDataParallel(parameterized.TestCase): @staticmethod def _ddp_init(index: int = ...): dist.init_process_group('xla', init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(10, 10).to(device) ddp_model = DDP(model) diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py index ebac882efdf4..dd6a4344c94b 100644 --- a/test/pjrt/test_dtypes.py +++ b/test/pjrt/test_dtypes.py @@ -10,7 +10,7 @@ class TestDtypes(parameterized.TestCase): torch.bfloat16, torch.complex64) def test_float_round_trip(self, dtype: torch.dtype): t = torch.randn((3, 3), dtype=dtype) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) torch.testing.assert_close(xt.cpu(), t) @parameterized.parameters( @@ -22,12 +22,12 @@ def test_float_round_trip(self, dtype: torch.dtype): ) def test_int_round_trip(self, dtype: torch.dtype): t = torch.randint(0, 128, (3, 3), dtype=dtype) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) torch.testing.assert_close(xt.cpu(), t) def test_bool_round_trip(self): t = torch.randint(0, 2, (3, 3), dtype=torch.bool) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) torch.testing.assert_close(xt.cpu(), t) diff --git a/test/pjrt/test_metrics.py b/test/pjrt/test_metrics.py index 5cee1b7ea5da..3ff7563f7c15 100644 --- a/test/pjrt/test_metrics.py +++ b/test/pjrt/test_metrics.py @@ -27,7 +27,7 @@ def test_metrics_report(self): self.assertEmpty(met.metrics_report()) # Move a tensor to the XLA device and back - torch.rand(3, 3, device=xm.xla_device()).cpu() + torch.rand(3, 3, device=torch_xla.device()).cpu() metrics = met.metrics_report() self.assertNotEmpty(metrics) diff --git a/test/pjrt/test_profiler.py b/test/pjrt/test_profiler.py index 3be3d4a06c40..15e799473b3d 100644 --- a/test/pjrt/test_profiler.py +++ b/test/pjrt/test_profiler.py @@ -32,12 +32,12 @@ class TestPjRtProfiler(absltest.TestCase): def setUp(self): # HACK: ensure libtpu is loaded if using TPU - xm.xla_device() + torch_xla.device() def test_profiler_output(self): tempdir = self.create_tempdir().full_path - device = xm.xla_device() + device = torch_xla.device() ones = torch.ones([5]) with _profile(tempdir): xones = ones.to(device) diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index fcb44e2cb939..6529b5e826e1 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -59,7 +59,7 @@ def test_num_global_devices(self): def test_xla_device_error(self): with self.assertRaises(IndexError): - xm.xla_device(10) + torch_xla.device(10) @parameterized.named_parameters(('default', {}, True), ('no_default', { 'PJRT_SELECT_DEFAULT_DEVICE': '0' diff --git a/test/pjrt/test_runtime_multi_cpu.py b/test/pjrt/test_runtime_multi_cpu.py index 54da40346ff5..1b61f57d47bd 100644 --- a/test/pjrt/test_runtime_multi_cpu.py +++ b/test/pjrt/test_runtime_multi_cpu.py @@ -65,10 +65,10 @@ def forward(ctx, x): def backward(ctx, grad_output): results['forward_ordinal'] = ctx.forward_ordinal results['backward_ordinal'] = xr.global_ordinal() - results['device'] = str(xm.xla_device()) + results['device'] = str(torch_xla.device()) return grad_output - x = torch.ones(1, requires_grad=True, device=xm.xla_device()) + x = torch.ones(1, requires_grad=True, device=torch_xla.device()) y = _CustomBackwards.apply(x) y.backward() torch_xla.sync() @@ -110,7 +110,7 @@ def _hlo_dump(tmpdir: str): os.environ['XLA_SAVE_TENSORS_FMT'] = 'hlo' os.environ['XLA_SAVE_TENSORS_FILE'] = os.path.join(tmpdir, 'save.hlo') - x = torch.randn((3, 3), device=xm.xla_device()) + x = torch.randn((3, 3), device=torch_xla.device()) torch_xla.sync() x.cpu() @@ -124,7 +124,7 @@ def test_hlo_dump(self): @staticmethod def _all_reduce_hlo(): - ones = torch.ones((3, 3), device=xm.xla_device()) + ones = torch.ones((3, 3), device=torch_xla.device()) torch_xla.sync() reduced = xm.all_reduce(xm.REDUCE_SUM, ones) diff --git a/test/pjrt/test_runtime_multi_gpu.py b/test/pjrt/test_runtime_multi_gpu.py index 6609bc39d282..e48185af0c0d 100644 --- a/test/pjrt/test_runtime_multi_gpu.py +++ b/test/pjrt/test_runtime_multi_gpu.py @@ -122,10 +122,10 @@ def forward(ctx, x): def backward(ctx, grad_output): results['forward_ordinal'] = ctx.forward_ordinal results['backward_ordinal'] = xr.global_ordinal() - results['device'] = str(xm.xla_device()) + results['device'] = str(torch_xla.device()) return grad_output - x = torch.ones(1, requires_grad=True, device=xm.xla_device()) + x = torch.ones(1, requires_grad=True, device=torch_xla.device()) y = _CustomBackwards.apply(x) y.backward() torch_xla.sync() @@ -166,7 +166,7 @@ def test_spawn(self, spawn): @staticmethod def _broadcast(sync): torch.manual_seed(xr.global_ordinal()) - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(5, 5).to(device) if sync: xm.broadcast_master_param(model) @@ -188,7 +188,7 @@ def test_broadcast_master_param(self, sync): @staticmethod def _all_gather(pin_layout): - device = xm.xla_device() + device = torch_xla.device() ordinal = torch.tensor([xr.global_ordinal()], device=device) out = xm.all_gather(ordinal, pin_layout=pin_layout) torch_xla.sync() @@ -205,7 +205,7 @@ def test_all_gather(self, pin_layout): @staticmethod def _reduce_scatter(pin_layout): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() tensor = -torch.arange(world_size, dtype=torch.float32).to(device) @@ -231,7 +231,7 @@ def test_reduce_scatter(self, pin_layout): @staticmethod def _all_to_all(pin_layout): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() tensor = torch.cat( diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 021de719adb6..a19e2323c4de 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -172,7 +172,7 @@ def test_local_ordinal_with_discontiguous_global_ordinal_v4_threaded(self): @staticmethod def _spawn_threads() -> Dict[int, torch.device]: results = {} - pjrt.spawn_threads(lambda i: results.setdefault(i, xm.xla_device())) + pjrt.spawn_threads(lambda i: results.setdefault(i, torch_xla.device())) return results @@ -187,7 +187,7 @@ def test_spawn_threads(self): @staticmethod def _spawn_error(): # Initialize the client in the parent process - xm.xla_device() + torch_xla.device() torch_xla.launch(xm.xla_device) @@ -199,7 +199,7 @@ def test_spawn_error(self): @staticmethod def _runtime_device_attributes(): - return xr.runtime_device_attributes(str(xm.xla_device())) + return xr.runtime_device_attributes(str(torch_xla.device())) def test_runtime_device_attributes(self): result = pjrt.run_multiprocess(self._runtime_device_attributes) @@ -226,12 +226,12 @@ def test_global_runtime_device_attributes(self): @staticmethod def _execute_time_metric(): # Initialize the client before starting the timer. - xm.xla_device() + torch_xla.device() begin = time.perf_counter_ns() value = ( - torch.randn(10000, 10000, device=xm.xla_device()) * - torch.randn(10000, 10000, device=xm.xla_device())) + torch.randn(10000, 10000, device=torch_xla.device()) * + torch.randn(10000, 10000, device=torch_xla.device())) value_mean = value.mean() torch_xla.sync() cpu_value = value_mean.cpu() diff --git a/test/pjrt/test_torchrun.py b/test/pjrt/test_torchrun.py index 3939f7f6c582..0024c189aa75 100644 --- a/test/pjrt/test_torchrun.py +++ b/test/pjrt/test_torchrun.py @@ -28,7 +28,7 @@ def test_all_gather(self): rank = torch.tensor([dist.get_rank()], dtype=torch.float32, - device=xm.xla_device()) + device=torch_xla.device()) output = [rank.clone() for _ in range(expected_world_size)] dist.all_gather(output, rank) result = torch.concat(output) @@ -52,7 +52,8 @@ def test_all_reduce(self): expected = sum(tensors) xla_tensor = torch.arange( - 2, dtype=torch.int64, device=xm.xla_device()) + 1 + 2 * dist.get_rank() + 2, dtype=torch.int64, + device=torch_xla.device()) + 1 + 2 * dist.get_rank() dist.all_reduce(xla_tensor, op=dist.ReduceOp.SUM) torch_xla.sync() @@ -70,9 +71,9 @@ def test_reduce_scatter(self): expected = torch.split(tensor, world_size)[dist.get_rank()] tensor_out = torch.zeros( - world_size, dtype=torch.int64, device=xm.xla_device()) + world_size, dtype=torch.int64, device=torch_xla.device()) tensor_in = torch.arange( - world_size * world_size, dtype=torch.int64, device=xm.xla_device()) + world_size * world_size, dtype=torch.int64, device=torch_xla.device()) dist.reduce_scatter(tensor_out, [tensor_in], op=dist.ReduceOp.SUM) torch_xla.sync() diff --git a/test/pjrt/test_train_hf_transformer.py b/test/pjrt/test_train_hf_transformer.py index d2c113e9b5eb..d484edc0a6ce 100644 --- a/test/pjrt/test_train_hf_transformer.py +++ b/test/pjrt/test_train_hf_transformer.py @@ -55,7 +55,7 @@ def finetune(rank, train_dataset, test_dataset, tokenizer, flags): drop_last=True, generator=rng) - device = xm.xla_device() + device = torch_xla.device() model = AutoModelForSequenceClassification.from_pretrained( 'google-bert/bert-base-cased', num_labels=5) model.to(device) diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index b47ae3f3de6d..3355f8efba99 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -559,7 +559,7 @@ def _alt_lookup(d, keys, defval): def instantiate_test(cls, name, test, *, generic_cls): test_name = name + '_' + cls.device_type class_name = cls.__name__ - real_device_type = xm.xla_device_hw(str(xm.xla_device())) + real_device_type = xm.xla_device_hw(str(torch_xla.device())) assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type] @@ -632,7 +632,7 @@ def get_primary_device(cls): @classmethod def setUpClass(cls): # Sets the primary test device to the xla_device (CPU or TPU) - cls.primary_device = str(xm.xla_device()) + cls.primary_device = str(torch_xla.device()) torch_xla._XLAC._xla_set_mat_mul_precision('highest') def setUp(self): diff --git a/test/quantized_ops/test_quantized_matmul.py b/test/quantized_ops/test_quantized_matmul.py index b7f415a82b60..88a34c69a4ae 100644 --- a/test/quantized_ops/test_quantized_matmul.py +++ b/test/quantized_ops/test_quantized_matmul.py @@ -12,7 +12,7 @@ torch.manual_seed(123456) -device = xm.xla_device() +device = torch_xla.device() class M(torch.nn.Module): diff --git a/test/scan/test_scan_pallas.py b/test/scan/test_scan_pallas.py index 2613cc66f217..a267886cd3f7 100644 --- a/test/scan/test_scan_pallas.py +++ b/test/scan/test_scan_pallas.py @@ -72,7 +72,7 @@ def fake_fa_wrapper(self, has_model_weight, use_scan): torch.manual_seed(12) torch_xla.manual_seed(12) hidden_states = torch.randn((8, 4, 256, 256)).requires_grad_().to('xla') - with xm.xla_device(): + with torch_xla.device(): attention_layers = AttentionLayers( has_model_weight, num_layer=3, use_scan=use_scan) hidden_states.retain_grad() diff --git a/test/spmd/test_dtensor_integration.py b/test/spmd/test_dtensor_integration.py index d7d03a536899..402da525c96f 100644 --- a/test/spmd/test_dtensor_integration.py +++ b/test/spmd/test_dtensor_integration.py @@ -33,7 +33,7 @@ def test_xla_distribute_tensor(self): 3 * device_count, 3, requires_grad=requires_grad, - device=xm.xla_device()) + device=torch_xla.device()) dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor assert type(dist_tensor).__name__ == "XLAShardedTensor" @@ -49,7 +49,7 @@ def test_xla_distribute_tensor(self): def test_optimizer_step_with_sharding(self): # Use simple linear model to test model parameter sharding - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) # Running the same mark_sharding test with xla_distribute_tensor instead device_count = xr.global_runtime_device_count() @@ -60,8 +60,8 @@ def test_optimizer_step_with_sharding(self): model.train() optimizer = optim.SGD(model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for _ in range(3): optimizer.zero_grad() @@ -76,7 +76,7 @@ def test_optimizer_step_with_sharding(self): torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) def test_xla_distribute_module(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) device_count = xr.global_runtime_device_count() device_mesh = init_device_mesh("xla", mesh_shape=(device_count,)) @@ -96,8 +96,8 @@ def shard_params(mod_name, mod, mesh): sharded_model.train() optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for _ in range(3): optimizer.zero_grad() diff --git a/test/spmd/test_dtensor_integration2.py b/test/spmd/test_dtensor_integration2.py index 2d1329cdf4dc..0c729fbb91c7 100644 --- a/test/spmd/test_dtensor_integration2.py +++ b/test/spmd/test_dtensor_integration2.py @@ -38,8 +38,8 @@ def test_xla_distribute_module_auto(self): self.assertTrue(torch_xla._XLAC._xla_get_auto_sharding()) optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for _ in range(5): optimizer.zero_grad() diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index b4375e145c85..518e4203b459 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -42,7 +42,7 @@ def setUpClass(cls): super().setUpClass() def test_dynamo_spmd_basic(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -58,7 +58,7 @@ def test_dynamo_spmd_basic(self): # a ExecuteMetric. def test_dynamo_spmd_output_sharding_spec(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -74,7 +74,7 @@ def test_dynamo_spmd_output_sharding_spec(self): ) def test_dynamo_spmd_output_sharding_cache(self): met.clear_all() - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -90,7 +90,7 @@ def test_dynamo_spmd_output_sharding_cache(self): self.assertEqual(met.counter_value('UncachedOutputSharding'), 1) def test_dynamo_sharded_input(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -103,7 +103,7 @@ def test_dynamo_sharded_input(self): torch.allclose(xla_res.cpu(), dynamo_res.cpu()) def test_dynamo_input_sharding_changed(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -142,7 +142,7 @@ def test_dynamo_input_sharding_changed(self): @unittest.skipIf(xr.global_runtime_device_count() == 1, "Multiple devices needed to test the mesh change") def test_dynamo_input_sharding_threashold(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -183,7 +183,7 @@ def test_dynamo_input_sharding_threashold(self): del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self): - device = xm.xla_device() + device = torch_xla.device() linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -202,7 +202,7 @@ def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self): torch.allclose(xla_res.cpu(), dynamo_res.cpu()) def test_dynamo_spmd_activation_sharding_with_dynamo_mark_sharding(self): - device = xm.xla_device() + device = torch_xla.device() mesh = self._get_mesh((1, self.n_devices)) device_ids = mesh.device_ids.tolist() mesh_shape = list(mesh.mesh_shape) diff --git a/test/spmd/test_fsdp_v2.py b/test/spmd/test_fsdp_v2.py index d4a85f531a31..0f8a4d088ef2 100644 --- a/test/spmd/test_fsdp_v2.py +++ b/test/spmd/test_fsdp_v2.py @@ -24,7 +24,7 @@ def setUpClass(cls): super().setUpClass() def test_fsdp_v2_basic(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) model.fc1 = FSDPv2(model.fc1, mesh=mesh) model.fc2 = FSDPv2(model.fc2, mesh=mesh) @@ -39,7 +39,7 @@ def test_fsdp_v2_basic(self): self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) - x = torch.randn(16, 128).to(xm.xla_device()) + x = torch.randn(16, 128).to(torch_xla.device()) xs.mark_sharding(x, mesh, ('fsdp', None)) output = model(x) # Make sure output are sharded. @@ -63,7 +63,7 @@ def test_fsdp_v2_basic(self): xm.wait_device_ops() def test_fsdp_v2_output_correctness(self): - model_expected = self.SimpleLinear().to(xm.xla_device()) + model_expected = self.SimpleLinear().to(torch_xla.device()) model = copy.deepcopy(model_expected) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) @@ -71,7 +71,7 @@ def test_fsdp_v2_output_correctness(self): model.fc2 = FSDPv2(model.fc2, mesh=mesh) model = FSDPv2(model, mesh=mesh) - x_expected = torch.randn(16, 128).to(xm.xla_device()) + x_expected = torch.randn(16, 128).to(torch_xla.device()) x = copy.deepcopy(x_expected) xs.mark_sharding(x, mesh, ('fsdp', None)) @@ -81,7 +81,7 @@ def test_fsdp_v2_output_correctness(self): self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu())) def test_fsdp_v2_auto_wrap_basic(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, @@ -93,7 +93,7 @@ def test_fsdp_v2_auto_wrap_basic(self): self.assertTrue(isinstance(model.fc2, FSDPv2)) def test_fsdp_v2_auto_wrap_callable(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, @@ -115,7 +115,7 @@ def auto_wrapper_callable(m, *args, **kwargs): self.assertFalse(isinstance(model.fc2, FSDPv2)) def test_fsdp_v2_global_mesh(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1), None, ('fsdp', 'tensor')) xs.set_global_mesh(mesh) @@ -123,7 +123,7 @@ def test_fsdp_v2_global_mesh(self): self.assertEqual(id(model._mesh), id(mesh)) def test_fsdp_v2_global_mesh_error(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) xs.set_global_mesh(None) with self.assertRaises(ValueError): @@ -141,7 +141,7 @@ def test_fsdp_v2_cpu_model(self): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_fsdp_v2_multi_slice(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) model = FSDPv2(model, mesh=mesh, extra_data_axis="data") @@ -155,7 +155,7 @@ def test_fsdp_v2_multi_slice(self): self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(model.fc2.weight)) - x = torch.randn(16, 128).to(xm.xla_device()) + x = torch.randn(16, 128).to(torch_xla.device()) xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) output = model(x) # Make sure output are sharded. @@ -171,14 +171,14 @@ def test_fsdp_v2_multi_slice(self): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_fsdp_v2_multi_slice_output_correctness(self): - model_expected = self.SimpleLinear().to(xm.xla_device()) + model_expected = self.SimpleLinear().to(torch_xla.device()) model = copy.deepcopy(model_expected) mesh = self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor')) model = FSDPv2(model, mesh=mesh, extra_data_axis="data") - x_expected = torch.randn(16, 128).to(xm.xla_device()) + x_expected = torch.randn(16, 128).to(torch_xla.device()) x = copy.deepcopy(x_expected) xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) @@ -188,7 +188,7 @@ def test_fsdp_v2_multi_slice_output_correctness(self): self.assertTrue(torch.allclose(output_expected.cpu(), output.cpu())) def test_fsdp_v2_multi_slice_error(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) xs.set_global_mesh( self._get_mesh((2, self.n_devices // 2, 1), None, ('data', 'fsdp', 'tensor'))) diff --git a/test/spmd/test_mp_input_sharding.py b/test/spmd/test_mp_input_sharding.py index 6b78a3714e79..dc1e4aba12b0 100644 --- a/test/spmd/test_mp_input_sharding.py +++ b/test/spmd/test_mp_input_sharding.py @@ -34,7 +34,7 @@ def __next__(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_multiple_inputs(self): - device = xm.xla_device() + device = torch_xla.device() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -61,7 +61,7 @@ def test_multiple_inputs(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_single_tensor(self): - device = xm.xla_device() + device = torch_xla.device() batch = torch.randn((16, 128)) train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -78,7 +78,7 @@ def test_single_tensor(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_error_single_tensor_with_input_sharding_dict(self): - device = xm.xla_device() + device = torch_xla.device() batch = torch.randn((16, 128)) train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -95,7 +95,7 @@ def test_error_single_tensor_with_input_sharding_dict(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_input_sharding_none(self): - device = xm.xla_device() + device = torch_xla.device() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -112,7 +112,7 @@ def test_input_sharding_none(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_error_missing_keys(self): - device = xm.xla_device() + device = torch_xla.device() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) mesh = xs.get_1d_mesh('x') @@ -127,7 +127,7 @@ def test_error_missing_keys(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_input_sharding_not_dict(self): - device = xm.xla_device() + device = torch_xla.device() num_devices = xr.global_runtime_device_count() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))} train_loader = self.fake_dataloader(batch) diff --git a/test/spmd/test_sharding_strategies.py b/test/spmd/test_sharding_strategies.py index d6bc4221b811..2dd09580a5a6 100644 --- a/test/spmd/test_sharding_strategies.py +++ b/test/spmd/test_sharding_strategies.py @@ -146,7 +146,7 @@ def training_step(data): torch.manual_seed(42) tries = 5 -device = xm.xla_device() +device = torch_xla.device() if args.profile: print("Profiler server started at port 9012") server = xp.start_server(9012) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index a8113b2ae532..34221d375e9c 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -209,7 +209,7 @@ def test_single_host_replicated_tpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = xm.xla_device() + device = torch_xla.device() num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) @@ -252,7 +252,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = xm.xla_device() + device = torch_xla.device() num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) @@ -295,7 +295,7 @@ def test_single_host_partial_replication_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = xm.xla_device() + device = torch_xla.device() num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) diff --git a/test/spmd/test_spmd_graph_dump.py b/test/spmd/test_spmd_graph_dump.py index 2d1c2f84a4cf..45af3b154934 100644 --- a/test/spmd/test_spmd_graph_dump.py +++ b/test/spmd/test_spmd_graph_dump.py @@ -26,7 +26,7 @@ def test_dump_with_output_sharding(self): assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE" should_dump_output_sharding = (save_format == 'hlo') save_file += '.0' - device = xm.xla_device() + device = torch_xla.device() xla_x = torch.randn(8, 32).to(device) xla_y = torch.randn(8, 32).to(device) # shard one of the input tensor diff --git a/test/spmd/test_spmd_lowering_context.py b/test/spmd/test_spmd_lowering_context.py index 5cc0ac464bda..9bc80194318f 100644 --- a/test/spmd/test_spmd_lowering_context.py +++ b/test/spmd/test_spmd_lowering_context.py @@ -38,7 +38,7 @@ def test_basic(self): mesh_shape = (data_axis, model_axis) spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) - device = xm.xla_device() + device = torch_xla.device() a = torch.zeros(2048, device=device, requires_grad=True) xs.mark_sharding(a, spmd_mesh, ('x',)) b = torch.randn([32, 2048], device=device, requires_grad=True) @@ -108,7 +108,7 @@ def test_device_parameter_id_tensor_mapping(self): mesh_shape = (data_axis, model_axis) spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) - device = xm.xla_device() + device = torch_xla.device() a = torch.randn([32, 2048]).to(device) xs.mark_sharding(a, spmd_mesh, ('x', 'y')) b = torch.ones(2048).to(device) diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index 727103586d1e..f814810b2eb9 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -206,7 +206,7 @@ def train_imagenet(): torch.manual_seed(42) - device = xm.xla_device() + device = torch_xla.device() model = get_model_property('model_fn')().to(device) if FLAGS.use_gradient_checkpointing: @@ -313,8 +313,8 @@ def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): - x = data.to(xm.xla_device()) - y = target.to(xm.xla_device()) + x = data.to(torch_xla.device()) + y = target.to(torch_xla.device()) with xp.StepTrace('train_imagenet'): with xp.Trace('build_graph'): optimizer.zero_grad() @@ -344,8 +344,8 @@ def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): - data = data.to(xm.xla_device()) - target = target.to(xm.xla_device()) + data = data.to(torch_xla.device()) + target = target.to(torch_xla.device()) output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() diff --git a/test/spmd/test_xla_auto_sharding.py b/test/spmd/test_xla_auto_sharding.py index 40b0566f8b28..b30fc0c0e88e 100644 --- a/test/spmd/test_xla_auto_sharding.py +++ b/test/spmd/test_xla_auto_sharding.py @@ -39,11 +39,11 @@ def setUpClass(cls): xr.use_spmd(auto=True) def init_test_variables(cls): - xt_no_auto = torch.ones(2, 2).to(xm.xla_device()) + xt_no_auto = torch.ones(2, 2).to(torch_xla.device()) cls.hash_no_auto = torch_xla._XLAC._get_graph_hash([xt_no_auto + 0]) def test_auto_sharding_hashing(self): - xt = torch.ones(2, 2).to(xm.xla_device()) + xt = torch.ones(2, 2).to(torch_xla.device()) assert torch_xla._XLAC._xla_get_auto_sharding() hash_auto_spmd = torch_xla._XLAC._get_graph_hash([xt + 0]) self.assertNotEqual(hash_auto_spmd, self.hash_no_auto) @@ -60,8 +60,8 @@ def test_matmul(self): t2 = torch.ones(128, 256) t3 = (t1 @ t2).sum() - xt1 = t1.to(xm.xla_device()) - xt2 = t2.to(xm.xla_device()) + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) xt3 = (xt1 @ xt2).sum() torch_xla.sync() self.assertEqual(met.counter_value("CompileWithAutoSharding"), 1) @@ -72,11 +72,11 @@ def test_matmul(self): def test_simple_linear_training(self): met.clear_counters() - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) model.train() optimizer = optim.SGD(model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for i in range(5): optimizer.zero_grad() diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 3096d8b6d9dc..380470d5c8c1 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -50,7 +50,7 @@ def _get_sharded_model(self, mesh_shape=None): # Return a sharded SimpleLinear model with fc1.weight sharded and # fc2.weight explicitly replicated mesh_shape = mesh_shape or (1, self.n_devices) - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh(mesh_shape) xs.mark_sharding(model.fc1.weight, mesh, (0, 1)) xs.mark_sharding(model.fc2.weight, mesh, (None, None)) @@ -76,7 +76,7 @@ def _assert_same_state_dict(self, sd1, sd2, keypath=""): if isinstance(sd1, torch.Tensor): assert sd1.device == sd2.device, f"Tensors on different devices at {keypath}: {sd1} vs {sd2}" - if sd1.device == xm.xla_device(): + if sd1.device == torch_xla.device(): sharding1 = torch_xla._XLAC._get_xla_sharding_spec(sd1) sharding2 = torch_xla._XLAC._get_xla_sharding_spec(sd2) assert sharding1 == sharding2, f"Different sharding on tensors at {keypath}: {sharding1} vs {sharding2}" @@ -145,14 +145,14 @@ def _save_and_restore(self, def test_resharding_unsharded_to_sharded(self): # Save an unsharded model using the DefaultSavePlanner and load into a # sharded model using the SPMDLoadPlanner - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) sharded_model = self._get_sharded_model() self._save_and_restore(model, sharded_model, load_planner=SPMDLoadPlanner()) def test_resharding_sharded_to_unsharded(self): for chkpt_on_cpu in [True, False]: with self.subTest(chkpt_on_cpu): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) sharded_model = self._get_sharded_model() self._save_and_restore( sharded_model, @@ -338,7 +338,7 @@ def test_save_state_dict_with_cpu_shards(self): def test_cpu_state_dict_flattening(self): # In the case of a nested state_dict with fully sharded parameters, # _CpuShards should be treated as terminal nodes. - t = torch.randn(128, 128).to(xm.xla_device()) + t = torch.randn(128, 128).to(torch_xla.device()) mesh = self._get_mesh((self.n_devices, 1)) xs.mark_sharding(t, mesh, (0, 1)) state_dict = _sharded_cpu_state_dict({'model': {'weight': t}}) @@ -395,7 +395,7 @@ def test_resolve_shard_data(self): class DistributedCheckpointHelpersTest(DistributedCheckpointTestBase): def test_sharded_cpu_state_dict(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) state_dict = model.state_dict() sharded_cpu_state_dict = _sharded_cpu_state_dict(state_dict) self.assertCountEqual(sharded_cpu_state_dict, diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index f0a75dff4d52..81525faabed8 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -36,14 +36,14 @@ def test_xla_sharded_tensor(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) self.assertTrue(isinstance(xst1, XLAShardedTensor)) def test_xla_sharded_tensor_repr(self): - xt = torch.randn(128, 128).to(xm.xla_device()) - model = self.SimpleLinear().to(xm.xla_device()) + xt = torch.randn(128, 128).to(torch_xla.device()) + model = self.SimpleLinear().to(torch_xla.device()) mesh = self._get_mesh((1, self.n_devices)) partition_spec = (0, 1) @@ -59,7 +59,7 @@ def test_sharded_tensor_debug_info(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) @@ -73,7 +73,7 @@ def test_xla_shards(self): num_element = self.n_devices mesh = self._get_mesh((self.n_devices,)) t = torch.arange(num_element, dtype=torch.float32) - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (0,)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (0,)) shards = xt.local_shards self.assertEqual(len(shards), self.n_devices) @@ -97,7 +97,7 @@ def test_padded_xla_shards(self): num_element = self.n_devices + 1 # Ensure padding with two or more devices mesh = self._get_mesh((self.n_devices,)) t = torch.arange(num_element, dtype=torch.float32) - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (0,)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (0,)) shards = xt.local_shards self.assertEqual(len(shards), self.n_devices) shard_len = math.ceil(num_element / self.n_devices) @@ -127,7 +127,7 @@ def test_replicated_xla_shards(self): num_element = self.n_devices mesh = self._get_mesh((self.n_devices,)) t = torch.arange(num_element, dtype=torch.float32) - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None,)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (None,)) shards = xt.local_shards self.assertEqual(len(shards), self.n_devices) for i, shard in enumerate(shards): @@ -147,7 +147,7 @@ def test_partially_replicated_xla_shards(self): mesh = self._get_mesh((self.n_devices // 2, 2)) t = torch.arange(num_element, dtype=torch.float32).reshape((16, 16)) # Partial replication along the 0th tensor axis, shard 2-way on the 1st - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None, 1)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (None, 1)) shard_len = t.shape[1] // 2 shards = xt.local_shards @@ -172,7 +172,7 @@ def test_load_local_shards(self): num_element = self.n_devices mesh = self._get_mesh((self.n_devices,)) t = torch.arange(num_element, dtype=torch.float32) + 1 - xt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (0,)) + xt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (0,)) local_shards = xt.local_shards self.assertTrue(len(local_shards) == self.n_devices) @@ -197,13 +197,13 @@ def test_load_local_shards(self): xt.load_local_shards_(local_shards) # Replicated shards should fail - rt = xs.mark_sharding(t.to(xm.xla_device()), mesh, (None,)) + rt = xs.mark_sharding(t.to(torch_xla.device()), mesh, (None,)) local_shards = rt.local_shards with self.assertRaises(RuntimeError): rt.load_local_shards_(local_shards) def test_xla_sharding_type(self): - t = torch.randn(10, 20).to(xm.xla_device()) + t = torch.randn(10, 20).to(torch_xla.device()) self.assertEqual(torch_xla._XLAC._get_xla_sharding_type(t), None) x_dim = 2 if self.n_devices >= 2 else 1 @@ -229,7 +229,7 @@ def test_xla_sharding_type(self): self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED) def test_custom_tile_assignment(self): - xt = torch.randn(10, 20).to(device=xm.xla_device()) + xt = torch.randn(10, 20).to(device=torch_xla.device()) mesh_shape = (1, self.n_devices) device_ids = np.flip(self.device_ids) mesh = self._get_mesh(mesh_shape, device_ids) @@ -245,8 +245,8 @@ def test_mark_sharding_2d(self): t2 = torch.randn(1, 128, device='cpu') expected = t1 + t2 - xt1 = t1.to(xm.xla_device()) - xt2 = t2.to(xm.xla_device()) + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1)) if self.n_devices > 1: @@ -261,7 +261,7 @@ def test_mark_sharding_4d(self): t = torch.randn(2, 4, 8, 16, device='cpu') expected = t + t - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) # Shard along two axes if four or more devices are available z_dim = 2 if self.n_devices >= 4 else 1 xs.mark_sharding(xt, self._get_mesh((1, 1, z_dim, self.n_devices // z_dim)), @@ -277,7 +277,7 @@ def test_mark_sharding_4d(self): self.assertTrue(torch.allclose(expected, actual)) def test_mark_sharding_not_ordered_sharding_spec_2d(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(8, 16, device='cpu') expected = t1 + t1 @@ -290,7 +290,7 @@ def test_mark_sharding_not_ordered_sharding_spec_2d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_not_ordered_sharding_spec_3d(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 8, 16, device='cpu') expected = t1 + t1 @@ -307,7 +307,7 @@ def test_mark_sharding_not_ordered_sharding_spec_3d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_not_ordered_sharding_spec_4d(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(32, 4, 8, 16, device='cpu') expected = t1 + t1 @@ -326,7 +326,7 @@ def test_mark_sharding_not_ordered_sharding_spec_4d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_partial(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 4).to(device) t2 = torch.randn(4, 4).to(device) # Somehow the eager cpu result is different from the xla result. @@ -356,7 +356,7 @@ def test_mark_sharding_partial(self): self.assertTrue(torch.allclose(expected, actual)) def test_propagate_replicated_sharding(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 4).to(device) t2 = torch.randn(4, 4).to(device) t3 = t1 @ t2 @@ -368,7 +368,7 @@ def test_propagate_replicated_sharding(self): self.assertIn("replicated", torch_xla._XLAC._get_xla_sharding_spec(t3)) def test_mark_sharding_partial_unordered(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(4, 3, 4).to(device) t2 = torch.randn(4, 3, 4).to(device) expected = t1 + t2 @@ -401,7 +401,7 @@ def test_mark_sharding_partial_unordered(self): "Multiple devices required for tupled partition spec") def test_tupled_partition_spec(self): mesh = self._get_mesh((2, self.n_devices // 2)) - t = torch.randn(16).to(xm.xla_device()) + t = torch.randn(16).to(torch_xla.device()) xs.mark_sharding(t, mesh, ((0, 1),)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" % @@ -413,7 +413,7 @@ def test_named_partial_tupled_partition_spec(self): mesh = xs.Mesh( range(self.n_devices), (1, 2, self.n_devices // 2), ('r', 'b', 'm')) # Shard the first dimension on `r` and `b`, replicate the second dimension - t = torch.randn(16, 16).to(xm.xla_device()) + t = torch.randn(16, 16).to(torch_xla.device()) xs.mark_sharding(t, mesh, (('r', 'b'), None)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(t), @@ -421,14 +421,14 @@ def test_named_partial_tupled_partition_spec(self): (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) # Replicate the first dimension, shard the second on `b` and `m` - u = torch.randn(16, 16).to(xm.xla_device()) + u = torch.randn(16, 16).to(torch_xla.device()) xs.mark_sharding(u, mesh, (None, ('b', 'm'))) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" % (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) # Replicate the first dimension, shard the second on `r` and `m` - v = torch.randn(16, 16).to(xm.xla_device()) + v = torch.randn(16, 16).to(torch_xla.device()) xs.mark_sharding(v, mesh, (None, ('r', 'm'))) device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten() self.assertEqual( @@ -437,7 +437,7 @@ def test_named_partial_tupled_partition_spec(self): (self.n_devices // 2, ','.join(str(x) for x in device_order))) # Replicate the first dimension, shard the second on `m` and `b` - v = torch.randn(16, 16).to(xm.xla_device()) + v = torch.randn(16, 16).to(torch_xla.device()) xs.mark_sharding(v, mesh, (None, ('m', 'b'))) device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten() self.assertEqual( @@ -450,7 +450,7 @@ def test_multiple_tuples_in_spec(self): mesh = xs.Mesh( range(self.n_devices), (1, 2, self.n_devices // 2, 1), ('a', 'b', 'c', 'd')) - t = torch.randn(2, 2).to(xm.xla_device()) + t = torch.randn(2, 2).to(torch_xla.device()) xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd'))) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" % @@ -460,14 +460,14 @@ def test_multiple_tuples_in_spec(self): 'At least 2 devices needed for 2D mesh') def test_3d_tensor_2d_mesh(self): mesh = self._get_mesh((2, self.n_devices // 2)) - t = torch.randn(16, 16, 16).to(xm.xla_device()) + t = torch.randn(16, 16, 16).to(torch_xla.device()) xs.mark_sharding(t, mesh, (None, 0, 1)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) def test_partial_replication_addmm(self): - device = xm.xla_device() + device = torch_xla.device() z_dim = 2 if self.n_devices >= 4 else 1 mesh = self._get_mesh((z_dim, self.n_devices // z_dim)) @@ -495,7 +495,7 @@ def test_partial_replication_addmm(self): self.assertTrue(torch.allclose(expected, actual, atol=1e-5)) def test_clear_sharding(self): - xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) + xt = torch.randn(2, 4, 8, 16).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), (0, 1, 2, 3)) self.assertTrue(torch_xla._XLAC._get_xla_sharding_spec(xt)) @@ -503,7 +503,7 @@ def test_clear_sharding(self): self.assertFalse(torch_xla._XLAC._get_xla_sharding_spec(xt)) def test_replication_with_no_clear_sharding(self): - xt = torch.randn(2, 4).to(xm.xla_device()) + xt = torch.randn(2, 4).to(torch_xla.device()) # replication xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (None, None)) # sharding annotation over an existing replication sharding is permitted. @@ -513,7 +513,7 @@ def test_replication_with_no_clear_sharding(self): "replicated" in torch_xla._XLAC._get_xla_sharding_spec(xt)) def test_deep_copy(self): - xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) + xt = torch.randn(2, 4, 8, 16).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), (0, 1, 2, 3)) xt2 = copy.deepcopy(xt) @@ -522,7 +522,7 @@ def test_deep_copy(self): torch_xla._XLAC._get_xla_sharding_spec(xt2)) def test_clone(self): - xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) + xt = torch.randn(2, 4, 8, 16).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, 1, 1, self.n_devices)), (0, 1, 2, 3)) sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt) @@ -537,7 +537,7 @@ def test_clone(self): torch_xla._XLAC._get_xla_sharding_spec(xt2)) def test_sync_with_sharding(self): - xt = torch.ones(2, 2).to(xm.xla_device()) + xt = torch.ones(2, 2).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1)) sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt) torch_xla.sync() # `torch_xla.sync()` should preserve the sharding @@ -545,7 +545,7 @@ def test_sync_with_sharding(self): def test_execute_replicated_metrics(self): met.clear_all() - xt = torch.ones(2, 2).to(xm.xla_device()) + xt = torch.ones(2, 2).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1)) xt += 2 torch_xla.sync() @@ -554,15 +554,15 @@ def test_execute_replicated_metrics(self): def test_optimizer_step_with_sharding(self): # Use simple linear model to test model parameter sharding - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) xs.mark_sharding(model.fc1.weight, self._get_mesh((1, self.n_devices)), (0, 1)) sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight) model.train() optimizer = optim.SGD(model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for i in range(3): optimizer.zero_grad() @@ -581,7 +581,7 @@ def test_sharding_propagation(self): self.assertFalse(met.counter_value("ReplicateShardedData")) # Linear model with two linear layers and only one is annotated. - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) xs.mark_sharding(model.fc1.weight, self._get_mesh((1, self.n_devices)), (0, 1)) self.assertTrue(torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)) @@ -589,8 +589,8 @@ def test_sharding_propagation(self): model.train() optimizer = optim.SGD(model.parameters(), lr=0.1) - data = torch.randn(128, 128).to(xm.xla_device()) - target = torch.zeros(128).to(xm.xla_device()) + data = torch.randn(128, 128).to(torch_xla.device()) + target = torch.zeros(128).to(torch_xla.device()) loss_fn = nn.CrossEntropyLoss() for i in range(3): optimizer.zero_grad() @@ -606,7 +606,7 @@ def test_sharding_propagation(self): self.assertEqual(met.counter_value("ReplicateShardedData"), 2) def test_inplace_add_with_sharding(self): - xt = torch.ones(2, 2).to(xm.xla_device()) + xt = torch.ones(2, 2).to(torch_xla.device()) xs.mark_sharding(xt, self._get_mesh((1, self.n_devices)), (0, 1)) sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(xt) xt.add_(1) # inplace update should preserve the sharding @@ -622,8 +622,8 @@ def test_inplace_add_with_sharding(self): xr.device_type() == 'CPU', "sharding will be the same for both tensors on single device") def test_shard_hashing(self): - xt1 = torch.ones(2, 2).to(xm.xla_device()) - xt2 = torch.ones(2, 2).to(xm.xla_device()) + xt1 = torch.ones(2, 2).to(torch_xla.device()) + xt2 = torch.ones(2, 2).to(torch_xla.device()) # Add sharding to xt1, this should result in the hashes being different for # xt1 and xt2 @@ -639,7 +639,7 @@ def test_shard_hashing(self): self.assertNotEqual(hash1, hash2) def test_transfer_sharded_data_to_host(self): - xt1 = torch.ones(16, 16).to(xm.xla_device()) + xt1 = torch.ones(16, 16).to(torch_xla.device()) xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1)) t1 = xt1.cpu() self.assertTrue(torch.allclose(t1, torch.ones(16, 16))) @@ -657,7 +657,7 @@ def test_send_cpu_data_to_device_with_sharding(self): sharding_spec = xs.ShardingSpec(mesh, (0, 1)) self.assertTrue(sharding_spec.can_apply(tensor)) xtensors = xm.send_cpu_data_to_device([tensor], - xm.xla_device(), + torch_xla.device(), input_sharding=sharding_spec) self.assertEqual(len(xtensors), 1) outbound = met.metric_data("OutboundData")[1] @@ -666,7 +666,7 @@ def test_send_cpu_data_to_device_with_sharding(self): # Verify the resulting sharding annotation matches an explicit # `mark_sharding` call. xt = xtensors[0] - explicit_xt = tensor.to(xm.xla_device()) + explicit_xt = tensor.to(torch_xla.device()) xs.mark_sharding(explicit_xt, mesh, (0, 1)) self.assertEqual( torch_xla._XLAC._get_xla_sharding_spec(xt), @@ -676,8 +676,8 @@ def test_multiple_operations(self): t1 = torch.randn(2, 2) t2 = torch.randn(2, 2) expected_1 = t1 + t2 - xt1 = t1.to(xm.xla_device()) - xt2 = t2.to(xm.xla_device()) + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), (0, 1)) xt3 = xt1 + xt2 self.assertTrue(torch.allclose(expected_1, xt3.cpu())) @@ -685,8 +685,8 @@ def test_multiple_operations(self): t4 = torch.randn(2, 2) t5 = torch.randn(2, 2) expected_2 = t4 + t5 - xt4 = t4.to(xm.xla_device()) - xt5 = t5.to(xm.xla_device()) + xt4 = t4.to(torch_xla.device()) + xt5 = t5.to(torch_xla.device()) xs.mark_sharding(xt4, self._get_mesh((1, self.n_devices)), (0, 1)) xs.mark_sharding(xt5, self._get_mesh((1, self.n_devices)), (0, 1)) xt6 = xt4 + xt5 @@ -696,10 +696,10 @@ def test_no_sharding(self): partition_spec = (0, 1) t1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) t2 = torch.tensor([[8, 7, 6, 5, 4, 3, 2, 1]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) t3 = t1 + t2 t3_expected = [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0] self.assertEqual(t3.tolist()[0], t3_expected) @@ -708,7 +708,7 @@ def test_xla_sharded_hlo_dump(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) xst2 = xst1 + 5 @@ -724,8 +724,8 @@ def test_2d_tensor_3d_mesh(self): ct2 = torch.randn(16, 16, device='cpu') expected = ct1 + ct2 - t1 = ct1.to(xm.xla_device()) - t2 = ct2.to(xm.xla_device()) + t1 = ct1.to(torch_xla.device()) + t2 = ct2.to(torch_xla.device()) # Meaningful test for higher-order mesh with extra replication # requires multiple devices. Otherwise, this should defaults back to @@ -821,8 +821,8 @@ def test_mark_sharding_ir(self): t2 = torch.randn(1, 128, device='cpu') expected = t1 + t2 - xt1 = t1.to(xm.xla_device()) - xt2 = t2.to(xm.xla_device()) + xt1 = t1.to(torch_xla.device()) + xt2 = t2.to(torch_xla.device()) actual = xt1 + xt2 actual = xs.mark_sharding(actual, self._get_mesh((1, self.n_devices)), (0, 1)) @@ -912,7 +912,7 @@ def test_sharded_tensor_aliasing(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) xst1 += 1 @@ -921,7 +921,7 @@ def test_sharded_tensor_aliasing(self): def test_mark_sharding_ir_with_multiple_output(self): partition_spec = (0,) - xt1 = torch.randn(8, 8).to(xm.xla_device()) + xt1 = torch.randn(8, 8).to(torch_xla.device()) # max return 2 tensors `value` and `indices`. They are the output # of the same IR Node `MaxInDim` (xt_val, xt_index) = torch.max(xt1, 1) @@ -937,13 +937,13 @@ def test_mark_sharding_ir_with_multiple_output(self): def test_sharded_tensor_to_cpu_int_type(self): partition_spec = (0, 1) t1 = torch.arange(64).reshape(8, 8) - xt1 = t1.clone().to(xm.xla_device()) + xt1 = t1.clone().to(torch_xla.device()) xst1 = xs.mark_sharding(xt1, self._get_mesh((self.n_devices, 1)), partition_spec) self.assertTrue(torch.allclose(t1, xst1.cpu())) def test_named_partition_spec(self): - xt1 = torch.arange(64).reshape(8, 8).to(xm.xla_device()) + xt1 = torch.arange(64).reshape(8, 8).to(torch_xla.device()) mesh = xs.Mesh( list(range(self.n_devices)), (1, self.n_devices), ('data', 'model')) partition_spec = ('model', 'data') @@ -955,7 +955,7 @@ def test_named_partition_spec(self): self.assertTrue("replicated" in sharding_spec) def test_shard_device_data_ir(self): - device = xm.xla_device() + device = torch_xla.device() xla_x = torch.randn(8, 128, device=device) # xla_x now becomes a device data IR xla_y = xla_x * 5 @@ -967,7 +967,7 @@ def test_shard_device_data_ir(self): self.assertTrue(torch.allclose(xla_y.cpu(), xla_x.cpu() * 5)) def test_shard_device_data_ir_after_sync(self): - device = xm.xla_device() + device = torch_xla.device() xla_x = torch.randn(8, 128, device=device) x = xla_x.cpu() # xla_x now becomes a device data IR without XLAData @@ -981,18 +981,18 @@ def test_op_sharding_cache(self): met.clear_all() mesh = self._get_mesh((1, self.n_devices)) - t = torch.randn(1, self.n_devices).to(xm.xla_device()) + t = torch.randn(1, self.n_devices).to(torch_xla.device()) xs.mark_sharding(t, mesh, (0, 1)) self.assertIn("CreateOpSharding", met.counter_names()) self.assertEqual(met.counter_value("CreateOpSharding"), 1) # Sharding with the same partition spec should not result in another call - u = torch.randn(1, self.n_devices).to(xm.xla_device()) + u = torch.randn(1, self.n_devices).to(torch_xla.device()) xs.mark_sharding(u, mesh, (0, 1)) self.assertEqual(met.counter_value("CreateOpSharding"), 1) # Changing the partition spec will result in another CreateOpSharding - v = torch.randn(1, self.n_devices).to(xm.xla_device()) + v = torch.randn(1, self.n_devices).to(torch_xla.device()) xs.mark_sharding(v, mesh, (0, None)) self.assertEqual(met.counter_value("CreateOpSharding"), 2) @@ -1130,11 +1130,11 @@ def test_from_cpu_shards_global_shape(self): from_cpu_shards(shards, op_sharding, torch.Size((1,))) def test_backward_optimization_barrier(self): - model = self.SimpleLinear().to(xm.xla_device()) + model = self.SimpleLinear().to(torch_xla.device()) # The first layer won't have gradients in the hook. Not sure why. xs.xla_sharding.apply_backward_optimization_barrier(model.fc2) - x = torch.randn(2, 128).to(xm.xla_device()) + x = torch.randn(2, 128).to(torch_xla.device()) y = model(x) loss = y.sum() loss.backward() @@ -1145,7 +1145,7 @@ def test_backward_optimization_barrier(self): hlo) def test_mark_shard_scalar(self): - x = torch.tensor(1.0).to(xm.xla_device()) + x = torch.tensor(1.0).to(torch_xla.device()) self.assertEqual(len(x.shape), 0) xt = xs.mark_sharding(x, self._get_mesh((1, self.n_devices)), ()) @@ -1174,7 +1174,7 @@ def test_global_mesh(self): self.assertEqual(id(mesh), id(expected_mesh)) def test_mark_manual_sharding(self): - x = torch.zeros(3, 2).to(xm.xla_device()) + x = torch.zeros(3, 2).to(torch_xla.device()) with self.assertRaises(RuntimeError): xt = xs._mark_manual_sharding(x) @@ -1192,7 +1192,7 @@ def test_mark_manual_sharding(self): # xt.global_tensor.cpu() def test_spmd_full_to_shard_shape(self): - x = torch.zeros(8, 8).to(xm.xla_device()) + x = torch.zeros(8, 8).to(torch_xla.device()) with self.assertRaises(RuntimeError): x = torch_xla._XLAC._spmd_full_to_shard_shape(x) @@ -1213,7 +1213,7 @@ def test_spmd_full_to_shard_shape(self): # xx.cpu() # Replicated shape - x = torch.zeros(8, 4).to(xm.xla_device()) + x = torch.zeros(8, 4).to(torch_xla.device()) xt = xs.mark_sharding(x, self._get_mesh((self.n_devices, 1)), (None, None)) xx = torch_xla._XLAC._spmd_full_to_shard_shape(xt.global_tensor) @@ -1225,7 +1225,7 @@ def test_spmd_full_to_shard_shape(self): self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}") def test_spmd_shard_to_full_shape(self): - x = torch.zeros(8, 8).to(xm.xla_device()) + x = torch.zeros(8, 8).to(torch_xla.device()) x += 1 # No sharding spec attached. with self.assertRaises(RuntimeError): @@ -1256,7 +1256,7 @@ def test_spmd_shard_to_full_shape(self): self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{replicated}") def test_manual_sharding_e2e(self): - x = torch.zeros(8, 8).to(xm.xla_device()) + x = torch.zeros(8, 8).to(torch_xla.device()) mesh = self._get_mesh((1, self.n_devices)) partition_spec = (0, 1) xt = xs.mark_sharding(x, mesh, partition_spec) @@ -1275,7 +1275,7 @@ def test_manual_sharding_e2e(self): def test_manual_sharding_api_e2e(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.zeros(8, 8).to(xm.xla_device()) + x = torch.zeros(8, 8).to(torch_xla.device()) partition_spec = (0, 1) xx = xs.enable_manual_sharding(x, partition_spec) @@ -1290,7 +1290,7 @@ def test_manual_sharding_api_e2e(self): "Only runs on TPUv4") def test_spmd_reduce_scatter(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.ones(8, 8).to(xm.xla_device()) + x = torch.ones(8, 8).to(torch_xla.device()) # Reduce scatter x = xs.enable_manual_sharding(x, (None, None)).global_tensor @@ -1311,7 +1311,7 @@ def test_spmd_reduce_scatter(self): "Only runs on TPUv4") def test_spmd_reduce_scatter_canonical_index(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.ones(8, 8).to(xm.xla_device()) + x = torch.ones(8, 8).to(torch_xla.device()) # Reduce scatter x = xs.enable_manual_sharding(x, (None, None)).global_tensor @@ -1332,7 +1332,7 @@ def test_spmd_reduce_scatter_canonical_index(self): "Only runs on TPUv4") def test_spmd_all_reduce(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.ones(8, 8).to(xm.xla_device()) + x = torch.ones(8, 8).to(torch_xla.device()) # all reduce x = xs.enable_manual_sharding(x, (None, None)).global_tensor @@ -1352,7 +1352,7 @@ def test_spmd_all_reduce(self): "Only runs on TPUv4") def test_spmd_all_reduce_scale(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) - x = torch.ones(8, 8).to(xm.xla_device()) + x = torch.ones(8, 8).to(torch_xla.device()) scale = 0.25 # all reduce @@ -1487,7 +1487,7 @@ def test_xla_patched_linear(self): import torch_xla.core.xla_model as xm import torch.nn.functional as F - with xm.xla_device(): + with torch_xla.device(): torch_xla.manual_seed(42) x0 = torch.randn(2, 3, requires_grad=True) w0 = torch.randn(4, 3, requires_grad=True) @@ -1536,7 +1536,7 @@ def test_mark_sharding_with_gradients_basic(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device(), + device=torch_xla.device(), requires_grad=True) mesh = self._get_mesh((1, self.n_devices)) xst1 = xs.mark_sharding_with_gradients(xt1, mesh, partition_spec) @@ -1550,7 +1550,7 @@ def test_mark_sharding_with_gradients_annotation(self): partition_spec = (0,) x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, - device=xm.xla_device(), + device=torch_xla.device(), requires_grad=True) # Notice that the function does not modify in-place. y = xs.mark_sharding_with_gradients(x, mesh, partition_spec) @@ -1671,11 +1671,11 @@ def test_shard_as(self): partition_spec = (0,) x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) x = xs.mark_sharding_with_gradients(x, mesh, partition_spec) y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) x, y = xs.shard_as(x, y) torch_xla.sync() diff --git a/test/spmd/test_xla_sharding_hlo.py b/test/spmd/test_xla_sharding_hlo.py index a5a1159aa9e4..9a1653a7ef93 100644 --- a/test/spmd/test_xla_sharding_hlo.py +++ b/test/spmd/test_xla_sharding_hlo.py @@ -22,8 +22,8 @@ def setUpClass(cls): @patch.dict(os.environ, {"XLA_DUMP_POST_OPTIMIZATIONS": "1"}) def test_xla_sharded_hlo_dump_post_optimizations(self): - t1 = torch.randn(1, 128).to(xm.xla_device()) - t2 = torch.randn(128, 1).to(xm.xla_device()) + t1 = torch.randn(1, 128).to(torch_xla.device()) + t2 = torch.randn(128, 1).to(torch_xla.device()) xs.mark_sharding(t1, self._get_mesh((1, self.n_devices)), (0, 1)) t3 = t1 @ t2 diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index 8530ec3e7e4e..bc89c535608e 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -39,22 +39,22 @@ def test_is_master_ordinal(self): self.assertTrue(xm.is_master_ordinal()) def test_xla_device(self): - device = xm.xla_device() + device = torch_xla.device() self.assertEqual(device, torch.device('xla:0')) def test_xla_real_devices(self): - device = xm.xla_device() + device = torch_xla.device() device_type = os.environ['PJRT_DEVICE'] self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0']) def test_xla_device_hw(self): - device = xm.xla_device() + device = torch_xla.device() device_type = os.environ['PJRT_DEVICE'] replication_devices = xm.xla_replication_devices([device]) self.assertEqual(xm.xla_device_hw(device), device_type) def test_xla_replication_devices(self): - device = xm.xla_device() + device = torch_xla.device() device_type = os.environ['PJRT_DEVICE'] replication_devices = xm.xla_replication_devices([device]) self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0']) @@ -127,7 +127,7 @@ def test_runtime_spmd_api(self): # unittest process can persist XLA_USE_SPMD from other test suites, # so t may be on a SPMD or non-SPMD device. If this test is run independently # outside unittest, then it lives on a non-SPMD device. - t = torch.ones(2, 2).to(xm.xla_device()) + t = torch.ones(2, 2).to(torch_xla.device()) # Should enable SPMD without crashing. xr.use_spmd() @@ -149,7 +149,7 @@ def setUpClass(cls): @unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'], f"TPU/GPU autocast test.") def test_xla_autocast_api(self): - device = xm.xla_device() + device = torch_xla.device() t1 = torch.ones([2, 3], device=device, dtype=torch.float32) t2 = torch.ones([3, 2], device=device, dtype=torch.float32) with autocast(device, dtype=torch.bfloat16): diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 38d04ca7a95d..5be7f95daf98 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -23,21 +23,21 @@ def test_mark_sharding(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) self.assertTrue( torch.allclose( xt1 + 0, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, - device=xm.xla_device()))) + device=torch_xla.device()))) def test_metrics_recorded(self): met.clear_counters() partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) self.assertIn("VirtualDeviceUsage", met.counter_names()) self.assertNotEqual(met.counter_value("VirtualDeviceUsage"), 0) @@ -45,7 +45,7 @@ def test_metrics_recorded(self): def test_model_weight_metrics(self): met.clear_counters() partition_spec = (0, 1) - model = nn.Linear(128, 64).to(xm.xla_device()) + model = nn.Linear(128, 64).to(torch_xla.device()) xs.mark_sharding(model.weight, self._get_mesh((1, self.n_devices)), partition_spec) self.assertIn("VirtualDeviceUsage", met.counter_names()) @@ -54,17 +54,17 @@ def test_model_weight_metrics(self): def test_no_sharding(self): t1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) t2 = torch.tensor([[8, 7, 6, 5, 4, 3, 2, 1]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) t3 = t1 + t2 t3_expected = [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0] self.assertEqual(t3.tolist()[0], t3_expected) def test_no_sharding_1d(self): - t1 = torch.arange(9, dtype=torch.float, device=xm.xla_device()) - t2 = torch.arange(9, dtype=torch.float, device=xm.xla_device()) + t1 = torch.arange(9, dtype=torch.float, device=torch_xla.device()) + t2 = torch.arange(9, dtype=torch.float, device=torch_xla.device()) t3 = t1 + t2 t3_expected = list(range(0, 18, 2)) self.assertEqual(t3.tolist(), t3_expected) @@ -75,7 +75,7 @@ def test_outbound_data_metrics(self): met.clear_all() xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=xm.xla_device()) + device=torch_xla.device()) xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) outbound_with_virtual_device = met.metric_data("OutboundData")[1] @@ -88,7 +88,7 @@ def test_non_tensor_scalar(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], - xm.xla_device(), + torch_xla.device(), input_sharding=sharding_spec)[0] # we will transfer 0.5 as a device_data to the 'SPMD:0' device, need to make sure # that virtual device can handle this case. @@ -101,7 +101,7 @@ def test_sync_on_virtual_device(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], - xm.xla_device(), + torch_xla.device(), input_sharding=sharding_spec)[0] xt2 = xt1 / 0.5 torch_xla.sync(wait=True) @@ -111,7 +111,7 @@ def test_sync_on_virtual_device(self): def test_virtual_device_no_upload(self): met.clear_all() - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(5, 5).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) # t1's upload to device should be deferred @@ -125,7 +125,7 @@ def test_virtual_device_no_upload(self): def test_virtual_device_upload_after_mark_sharding(self): met.clear_all() partition_spec = (0, 1) - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(8, 8).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) @@ -139,7 +139,7 @@ def test_virtual_device_upload_after_mark_sharding(self): def test_virtual_device_upload_after_tracing(self): met.clear_all() - device = xm.xla_device() + device = torch_xla.device() t1 = torch.randn(8, 8).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) @@ -152,7 +152,7 @@ def test_virtual_device_upload_after_tracing(self): def test_virtual_device_upload_for_sharded_dataloader(self): met.clear_counters() - device = xm.xla_device() + device = torch_xla.device() sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ t1 = xm.send_cpu_data_to_device([torch.randn(8, 8)], diff --git a/test/stablehlo/test_composite.py b/test/stablehlo/test_composite.py index 64d08e4879fb..6e9521c79d5c 100644 --- a/test/stablehlo/test_composite.py +++ b/test/stablehlo/test_composite.py @@ -70,7 +70,7 @@ class XlaMarkPatternTest(unittest.TestCase): def run_func_get_stablehlo(self, f, input_args): - device = xm.xla_device() + device = torch_xla.device() input_args = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), input_args) exported = torch.export.export(AsModule(f), input_args) diff --git a/test/stablehlo/test_implicit_broadcasting.py b/test/stablehlo/test_implicit_broadcasting.py index 24c8a80c77d7..10fbe5789981 100644 --- a/test/stablehlo/test_implicit_broadcasting.py +++ b/test/stablehlo/test_implicit_broadcasting.py @@ -10,7 +10,7 @@ # The following tests cover the implcit-broadcasting for static and bounded # dynamic shapes. -device = xm.xla_device() +device = torch_xla.device() class ImplicitBroadcasting(unittest.TestCase): diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index 1f6ad974f203..3fc1276ec612 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -54,7 +54,7 @@ def count_qdq_ops(g: torch.fx.Graph): class PT2EExportTest(unittest.TestCase): def test_per_tensor_qdq(self): - device = xm.xla_device() + device = torch_xla.device() x = torch.randn(2, 3, 4, 5).to(device) x = torch.ops.quantized_decomposed.quantize_per_tensor( x, 0.4, 2, -128, 127, torch.int8) @@ -68,7 +68,7 @@ def test_per_tensor_qdq(self): self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) def test_per_channel_qdq(self): - device = xm.xla_device() + device = torch_xla.device() x = torch.randn(2, 3, 4, 5).to(device) scale = torch.tensor([3.2, 5.3, 0.1, 10]).to(device) zero_point = torch.tensor([1, 2, -1, -2], dtype=torch.int64).to(device) diff --git a/test/stablehlo/test_stablehlo_compile.py b/test/stablehlo/test_stablehlo_compile.py index 243f48fbff7c..a57faf7ff5f2 100644 --- a/test/stablehlo/test_stablehlo_compile.py +++ b/test/stablehlo/test_stablehlo_compile.py @@ -21,7 +21,7 @@ def test_resnet18_stablehlo_compile(self): torch_input = torch.tensor(np_input).float() cpu_output = resnet18(torch_input) # Run ResNet on XLA device. - device = xm.xla_device() + device = torch_xla.device() # materalize the fake data for test purpose torch_xla.sync() xm.wait_device_ops() diff --git a/test/stablehlo/test_stablehlo_custom_call.py b/test/stablehlo/test_stablehlo_custom_call.py index 112abb2135aa..a315bbc230db 100644 --- a/test/stablehlo/test_stablehlo_custom_call.py +++ b/test/stablehlo/test_stablehlo_custom_call.py @@ -118,7 +118,7 @@ def forward(self, x): # self.assertTrue("api_version = 1" in shlo_text) def test_place_to_host_device(self): - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones(10, device=dev) b = place_to_host(a) shlo_text = xm.get_stablehlo([b]) @@ -137,7 +137,7 @@ def test_place_to_host_device(self): def test_place_to_host_device_autograd(self): # Test that gradient can flow through place_to_host and place_to_device ops. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones(10, device=dev, requires_grad=True) b = place_to_host(a) c = b.sum() @@ -155,7 +155,7 @@ def test_place_to_host_device_aot_autograd(self): # specifically `aot_function`. from functorch.compile import aot_function, make_boxed_func # type: ignore - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones(10, device=dev, requires_grad=True) def my_fn(x): diff --git a/test/stablehlo/test_stablehlo_inference.py b/test/stablehlo/test_stablehlo_inference.py index 311aed0c4b54..a29b66ebceaa 100644 --- a/test/stablehlo/test_stablehlo_inference.py +++ b/test/stablehlo/test_stablehlo_inference.py @@ -67,7 +67,7 @@ def forward(self, x, y): output = m(*data) exported = export_torch_model(m, data) - device = xm.xla_device() + device = torch_xla.device() data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) output2 = exported(*data).cpu() @@ -91,7 +91,7 @@ def forward(self, inputs): output = m(*data) exported = export_torch_model(m, data) - device = xm.xla_device() + device = torch_xla.device() data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) output2 = exported(*data) self.assertEqual(len(output2), 2) diff --git a/test/stablehlo/test_stablehlo_save_load.py b/test/stablehlo/test_stablehlo_save_load.py index 1e1e41c3513a..71ff463578cb 100644 --- a/test/stablehlo/test_stablehlo_save_load.py +++ b/test/stablehlo/test_stablehlo_save_load.py @@ -17,7 +17,7 @@ class StableHloDumpTest(unittest.TestCase): def test_simple(self): - device = xm.xla_device() + device = torch_xla.device() x = torch.tensor([3], device=device) y = torch.tensor([3], device=device) z = x + y @@ -26,7 +26,7 @@ def test_simple(self): self.assertEqual(stablehlo.count("stablehlo.add"), 1) def test_resnet18(self): - device = xm.xla_device() + device = torch_xla.device() xla_resnet18 = torchvision.models.resnet18() xla_resnet18.eval() xla_resnet18 = xla_resnet18.to(device) @@ -66,7 +66,7 @@ class SimpleExportTest(unittest.TestCase): def export_stable_hlo(self, model, args, kwargs=None): if kwargs is None: kwargs = {} - device = xm.xla_device() + device = torch_xla.device() model.eval() model = model.to(device) args = tuple(i.to(device) for i in args if hasattr(i, 'to')) diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index 3cd17a7fe340..aa33a6533437 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -19,7 +19,7 @@ compare_exported_program_and_saved_model_result, has_tf_package, load_save_model_and_inference, wrap_func_as_nn_module) -device = xm.xla_device() +device = torch_xla.device() os.environ['EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM'] = '1' diff --git a/test/stablehlo/test_xla_export_interpreter.py b/test/stablehlo/test_xla_export_interpreter.py index 85a4607d3a85..51a73a402703 100644 --- a/test/stablehlo/test_xla_export_interpreter.py +++ b/test/stablehlo/test_xla_export_interpreter.py @@ -7,7 +7,7 @@ import torch_xla.core.xla_model as xm from torch_xla.stablehlo import exported_program_to_stablehlo -device = xm.xla_device() +device = torch_xla.device() class XLAExportInterpreterTest(unittest.TestCase): diff --git a/test/test_as_stride_use_slice.py b/test/test_as_stride_use_slice.py index 48c65bb80f69..454ff78caeb2 100644 --- a/test/test_as_stride_use_slice.py +++ b/test/test_as_stride_use_slice.py @@ -100,8 +100,8 @@ def pure_strided_wrapper(self, use_xla, use_aten_slice): ss = StridedAndSlice().to("cpu") input = torch.randn((2, 4, 256, 256), device="cpu").requires_grad_() if use_xla: - ss.to(xm.xla_device()) - input = input.to(xm.xla_device()) + ss.to(torch_xla.device()) + input = input.to(torch_xla.device()) return ss(input, use_aten_slice) @parameterized.named_parameters( @@ -137,7 +137,7 @@ def compiler(gm, _): cpu_output = compiler_func(input_cpu, use_aten_slice=use_aten_slice) torch_xla.sync() - input_xla = input_xla.to(xm.xla_device()) + input_xla = input_xla.to(torch_xla.device()) xla_output = compiler_func(input_xla, use_aten_slice=use_aten_slice) torch_xla.sync() torch.testing.assert_close(cpu_output, xla_output.cpu()) diff --git a/test/test_autocast.py b/test/test_autocast.py index 1e0a82a2870c..7c743aab7cc1 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -282,7 +282,7 @@ def cast(val, to_type): add_kwargs = {} self.assertFalse(self.is_autocast_enabled()) - with autocast(xm.xla_device(), dtype=autocast_dtype): + with autocast(torch_xla.device(), dtype=autocast_dtype): self.assertTrue(self.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type @@ -332,7 +332,7 @@ def compare(first, second): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with autocast(xm.xla_device(), enabled=False): + with autocast(torch_xla.device(), enabled=False): self.assertFalse(self.is_autocast_enabled()) if module is not None and hasattr(module, op): @@ -355,9 +355,9 @@ class TestAutocastCuda(TestAutocastBase): @classmethod def setUpClass(cls): super().setUpClass() - cls.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) + cls.autocast_lists = AutocastTestLists(torch.device(torch_xla.device())) cls.autocast_lists_extra = AutocastCudaTestExtraLists( - torch.device(xm.xla_device())) + torch.device(torch_xla.device())) cls.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() def setUp(self): @@ -439,7 +439,7 @@ class TestAutocastTPU(TestAutocastBase): @classmethod def setUpClass(cls): super().setUpClass() - cls.autocast_lists = AutocastTPUTestLists(torch.device(xm.xla_device())) + cls.autocast_lists = AutocastTPUTestLists(torch.device(torch_xla.device())) def setUp(self): super(TestAutocastTPU, self).setUp() @@ -481,7 +481,7 @@ def test_autocast_methods_expect_builtin_promote(self): op, args, torch.float32, module=None, out_type=out_type) def test_autocast_tpu_check_dtype(self): - with autocast(xm.xla_device(), dtype=torch.float16): + with autocast(torch_xla.device(), dtype=torch.float16): assert not torch.is_autocast_xla_enabled() @@ -491,7 +491,7 @@ class TestOtherOps(unittest.TestCase): not xm.get_xla_supported_devices("GPU"), "the behavior of batch_norm autocast on GPU is different from others") def test_batch_norm_gpu(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16) batch_norm = torch.nn.BatchNorm2d(16) with autocast(device, dtype=torch.bfloat16): @@ -504,7 +504,7 @@ def test_batch_norm_gpu(self): not xm.get_xla_supported_devices("TPU"), "the behavior of batch_norm autocast on TPU is different from others") def test_batch_norm_tpu(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16) batch_norm = torch.nn.BatchNorm2d(16) with autocast(device, dtype=torch.bfloat16): diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index 529219a98c65..e287cb1bae55 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -6,7 +6,7 @@ import torch_xla.distributed.spmd.xla_sharding as xs -device = xm.xla_device() +device = torch_xla.device() class TestAutocastXla(unittest.TestCase): diff --git a/test/test_compilation_cache_utils.py b/test/test_compilation_cache_utils.py index 0fb12c32ae5b..0ac8a013d814 100644 --- a/test/test_compilation_cache_utils.py +++ b/test/test_compilation_cache_utils.py @@ -31,7 +31,7 @@ def _test_spawn(fn, args): class TestGraphHash(parameterized.TestCase): def _test_num_graph_hash(self, use_dynamo, use_persistent): - xla_dev = xm.xla_device() + xla_dev = torch_xla.device() model = M().to(device=xla_dev) input_shape = (10, 5) if use_dynamo: diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 6e2ac67e4f52..036ecf2d6c53 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -46,7 +46,7 @@ def run_export_and_compare(testcase, atol=1e-3, rtol=1e-5, equal_nan=True): - device = xm.xla_device() + device = torch_xla.device() with testcase.subTest('torch_eval'): res = func(*args, **kwargs) with testcase.subTest('torch_xla_eval'): @@ -2932,7 +2932,7 @@ def test_aten_randperm_0(self): kwargs = dict() pytorch = torch.randperm(20) - xla = torch.randperm(20, device=xm.xla_device()) + xla = torch.randperm(20, device=torch_xla.device()) xla_detached = xla.detach().cpu() # Check equal lengths and that the sorted sets are equal. Since these numbers are randomly diff --git a/test/test_data_type.py b/test/test_data_type.py index ecd554c187ec..46beea6d5115 100644 --- a/test/test_data_type.py +++ b/test/test_data_type.py @@ -29,8 +29,8 @@ def _set_env(self, **kwargs): os.environ[key] = value def _test_datatype(self, dtype, expected_type, op): - t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) - t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) + t1 = torch.tensor([2, 3], dtype=dtype, device=torch_xla.device()) + t2 = torch.tensor([2, 3], dtype=dtype, device=torch_xla.device()) t3 = op(t1, t2) self.assertEqual(t3.dtype, dtype) diff --git a/test/test_env_var_mapper.py b/test/test_env_var_mapper.py index 26cb2f0870eb..e4dcef2ba8cb 100644 --- a/test/test_env_var_mapper.py +++ b/test/test_env_var_mapper.py @@ -15,7 +15,7 @@ def check_env_flag(name, default=''): class EnvVarMapperTest(unittest.TestCase): def test_xla_ir_debug_(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() with xp.Trace('test_xla_ir_debug'): t = torch.tensor([2.0, 3.0], dtype=torch.float, device=xla_device) diff --git a/test/test_fsdp_auto_wrap.py b/test/test_fsdp_auto_wrap.py index 3ed373f98dcc..019612899697 100644 --- a/test/test_fsdp_auto_wrap.py +++ b/test/test_fsdp_auto_wrap.py @@ -35,7 +35,7 @@ def forward(self, x): "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" ) def test(self): - dev = xm.xla_device() + dev = torch_xla.device() input = torch.zeros([16, 16], device=dev) model = self.MyModel(input_size=16, hidden_size=4) model = XlaFullyShardedDataParallel( @@ -48,7 +48,7 @@ def test(self): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_grad_checkpoint.py b/test/test_grad_checkpoint.py index 57826bbc6748..e4e318ba8310 100644 --- a/test/test_grad_checkpoint.py +++ b/test/test_grad_checkpoint.py @@ -11,7 +11,7 @@ def run(): - device = xm.xla_device() + device = torch_xla.device() model = torch.nn.ModuleList([ torch.nn.Sequential( torch.nn.Conv2d(1024, 1024, 1), diff --git a/test/test_gradient_accumulation.py b/test/test_gradient_accumulation.py index 6e431a4237d9..62ecfc431132 100644 --- a/test/test_gradient_accumulation.py +++ b/test/test_gradient_accumulation.py @@ -23,7 +23,7 @@ def forward(self, x): class GradAccumulationTest(XlaTestCase): def setUp(self): - self.device = xm.xla_device() + self.device = torch_xla.device() torch.manual_seed(123) def test_basic(self): diff --git a/test/test_hlo_metadata.py b/test/test_hlo_metadata.py index 82eebd9f3ada..da9d24bb2342 100644 --- a/test/test_hlo_metadata.py +++ b/test/test_hlo_metadata.py @@ -78,10 +78,10 @@ def test_metadata(self): model = torch.nn.Sequential(layer1, nl1, layer2, nl2) with CustomOpNameLowering() as c: - model = model.to(device=xm.xla_device()) - inp = torch.rand(4, 4, device=xm.xla_device()) + model = model.to(device=torch_xla.device()) + inp = torch.rand(4, 4, device=torch_xla.device()) #inp = torch.rand(4, 4) - #inp = inp.to(device=xm.xla_device()) + #inp = inp.to(device=torch_xla.device()) out = model(inp) # Get outer frames diff --git a/test/test_inplace_update.py b/test/test_inplace_update.py index f68811ecd0de..704888d4f6e7 100644 --- a/test/test_inplace_update.py +++ b/test/test_inplace_update.py @@ -11,7 +11,7 @@ class InplaceUpdateTest(unittest.TestCase): def test_aten_op_after_full_update(self): - device = xm.xla_device() + device = torch_xla.device() t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t.zero_() @@ -21,7 +21,7 @@ def test_aten_op_after_full_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_aten_op_after_partial_update(self): - device = xm.xla_device() + device = torch_xla.device() t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t[0][0] = 0 @@ -31,7 +31,7 @@ def test_aten_op_after_partial_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_non_aten_op_after_full_update(self): - device = xm.xla_device() + device = torch_xla.device() t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t.zero_() @@ -41,7 +41,7 @@ def test_non_aten_op_after_full_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_non_aten_op_after_partial_update(self): - device = xm.xla_device() + device = torch_xla.device() t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t[0][0] = 0 @@ -53,7 +53,7 @@ def test_non_aten_op_after_partial_update(self): def test_xm_save(self): with temporary_env( XLA_DISABLE_FUNCTIONALIZATION="0", XLA_ENABLE_PARAM_ALIASING="0"): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor([1], device=xla_device) t2 = t1.detach() torch_xla.sync() diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 906ffc326834..3f20f9d25c97 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -38,7 +38,7 @@ def config_context(value): class InputOutputAliasesTest(parameterized.TestCase): def test_non_view(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.randn(4, 2, 2).to(xla_device) t2 = torch.randn(4, 2, 2).to(xla_device) torch_xla.sync() @@ -53,7 +53,7 @@ def test_non_view(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) def test_aliasing_with_cloned(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.randn(4, 2, 2).to(xla_device) # t1_cloned share the same storage as t1 @@ -66,7 +66,7 @@ def test_aliasing_with_cloned(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) def test_aliasing_across_custom_inplace(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.randn(4, 5).to(xla_device) t1 *= t1 @@ -78,7 +78,7 @@ def test_aliasing_across_custom_inplace(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) def test_aliasing_across_sync(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.randn(4, 5).to(xla_device) t1 += 1 @@ -96,7 +96,7 @@ def test_aliasing_with_multiple_inplace_update(self): BLOCK_SIZE = 16 DTYPE = torch.bfloat16 num_blocks = 1024 - device = xm.xla_device() + device = torch_xla.device() key = torch.randn( BATCH_SIZE * SEQ_LEN, NUM_KV_HEADS, @@ -145,7 +145,7 @@ def try_grad_accum(model, device, train_x, train_label, accum_steps): torch_xla.sync() return [p.grad.to('cpu').numpy() for p in model.parameters()] - dev = xm.xla_device() + dev = torch_xla.device() train_x_sample = torch.rand((1, 28 * 28)) train_label_sample = torch.tensor([5]) c_model = MLP().to('cpu') @@ -171,7 +171,7 @@ def test_separate_graphs(self): """ Test that paramater aliasing differences should produce different graphs. """ - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.tensor([1], device=xla_device) t1 = torch.tensor([2], device=xla_device) torch_xla.sync() @@ -190,7 +190,7 @@ def test_xm_save_no_aliasing(self): """ Test that xm.save() does not perform aliasing. """ - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.tensor([1], device=xla_device) t1 = torch.tensor([2], device=xla_device) torch_xla.sync() @@ -212,7 +212,7 @@ def test_device_data_cache_no_aliasing(self): """ Test that device data in DataCache are not aliased. """ - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.tensor(42, device=xla_device) # drops the read-only bit on t0's device_data @@ -235,7 +235,7 @@ def test_device_data_cache_no_aliasing(self): def test_user_config_donation_with_ltc_donation(self): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) @@ -255,7 +255,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync( self, enable_buffer_donor_config): with alias_with_buffer_donor_config_context(enable_buffer_donor_config): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) @@ -279,7 +279,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync( def test_user_config_donation_with_ltc_donation_overlap(self): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -291,7 +291,7 @@ def test_user_config_donation_with_ltc_donation_overlap(self): def test_user_config_donation(self): with alias_with_buffer_donor_config_context(True): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -308,7 +308,7 @@ def test_user_config_donation(self): def test_user_config_donation_inplace_aliasing(self): with alias_with_buffer_donor_config_context(True): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -322,7 +322,7 @@ def test_user_config_donation_inplace_aliasing(self): def test_user_config_donation_no_op_sync(self): with alias_with_buffer_donor_config_context(True): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) torch_xla.sync() @@ -331,7 +331,7 @@ def test_user_config_donation_no_op_sync(self): self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) def test_no_op_sync_keep_buffer_donation(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() input = torch.randn(5, 5).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) torch_xla.sync() @@ -346,7 +346,7 @@ def test_device_data_node_tracing_aliasing(self): for a given set of unmutated input tensor during its tracing. This helps ensure that aliasings can be retained if using the binding for tracing purposes. """ - xla_device = xm.xla_device() + xla_device = torch_xla.device() t0 = torch.tensor(10).to(xla_device) t1 = t0 + 5 diff --git a/test/test_jax_interop.py b/test/test_jax_interop.py index e69821cfe219..5016462b982e 100644 --- a/test/test_jax_interop.py +++ b/test/test_jax_interop.py @@ -14,7 +14,7 @@ def setUp(self): def test_call_jax(self): """Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, b): @@ -29,7 +29,7 @@ def f(a, b): def test_call_jax_input_pytree(self): """Test that call_jax works with PyTree inputs.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((2, 2), device=dev) b = torch.ones((2, 2), device=dev) * 2 @@ -55,7 +55,7 @@ def f(inputs): def test_call_jax_output_pytree(self): """Test that call_jax works with PyTree outputs.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((2, 2), device=dev) def f(a): @@ -89,7 +89,7 @@ def f(a): def test_call_jax_some_arg_unused(self): """Test when the jax function doesn't use some input arguments.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.randn((3, 3), device=dev) b = torch.randn((3, 3), device=dev) c = torch.randn((3, 3), device=dev) @@ -106,7 +106,7 @@ def f(a, b, c, d): def test_call_jax_grad(self): """Test calling a simple jax.grad transformed function.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.randn((3, 3), device=dev, requires_grad=True) b = torch.randn((3, 3), device=dev, requires_grad=True) torch_xla.sync() @@ -143,7 +143,7 @@ def f_jax(a, b): def test_call_jax_non_tensor_args(self): """Test that call_jax works with non-tensor arguments.""" - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, num: float, string: str, dictionary: dict, none): @@ -173,7 +173,7 @@ def test_call_jax_cache_hlo(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace two different jax functions a couple of times. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, b): @@ -198,7 +198,7 @@ def test_call_jax_cache_by_shape(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different shapes. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) b = torch.ones((2, 2), device=dev) @@ -217,7 +217,7 @@ def test_call_jax_cache_by_tree_spec(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different tree specs. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) b = torch.ones((3, 2), device=dev) @@ -237,7 +237,7 @@ def test_call_jax_cache_by_static_args(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different static args. - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, num: float): @@ -255,7 +255,7 @@ def test_call_jax_with_different_jax_config(self): import jax starting_cache_misses = xb._jax_to_xla_computation_cache_elements() - dev = xm.xla_device() + dev = torch_xla.device() a = torch.ones((3, 3), device=dev) def f(a, b): diff --git a/test/test_metrics.py b/test/test_metrics.py index f124784fbfee..098b516079d6 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -24,7 +24,7 @@ def check_metrics_file(): class MetricsTest(unittest.TestCase): def test_clear_counters(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(100, device=xla_device) t1 += 2 self.assertIn("xla::add", met.metrics_report()) @@ -39,7 +39,7 @@ def test_clear_counters(self): assert (len(met.counter_names()) > 0) def test_clear_metrics(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(156, device=xla_device) self.assertIn("TensorToData", met.metrics_report()) assert (len(met.metric_names()) > 0) @@ -52,7 +52,7 @@ def test_clear_metrics(self): assert (len(met.metric_names()) > 0) def test_tracing_time_metrics(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.tensor(156, device=xla_device) t2 = t1 + 100 @@ -61,7 +61,7 @@ def test_tracing_time_metrics(self): def test_eager_metrics(self): with torch_xla.experimental.eager_mode_context(True): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.tensor(156, device=xla_device) t2 = t1 + 100 @@ -78,7 +78,7 @@ def test_eager_metrics(self): self.assertNotIn('ExecuteTime', met.metric_names()) def test_short_metrics_report_default_list(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(1456, device=xla_device) t2 = t1 * 2 torch_xla.sync() @@ -100,7 +100,7 @@ def test_short_metrics_report_default_list(self): assert check_metrics_file() def test_short_metrics_report_custom_list(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(100, device=xla_device) t2 = t1 * 2 t1 += 2 @@ -120,7 +120,7 @@ def test_short_metrics_report_custom_list(self): self.assertIn('InputOutputAliasCount', short_report) def test_short_metrics_fallback_counter(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(100, device=xla_device) t2 = t1 * 2 # this will trigger a aten::_local_scalar_dense which is the same as fallback counter @@ -135,7 +135,7 @@ def test_short_metrics_fallback_counter(self): def test_metrics_report(self): # TODO(jwtan): Add test to cover TrimIrGraph, SyncTensorsToData, TransferToDeviceAsync, IrValueTensorToXlaData - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(2077, device=xla_device) t2 = t1 * 2 torch_xla.sync() @@ -207,12 +207,12 @@ def test_metrics_report(self): @unittest.skipIf(xr.device_type() != "CPU", f"This test only works on CPU.") def test_execute_time_metric(self): # Initialize the client before starting the timer. - xm.xla_device() + torch_xla.device() begin = time.perf_counter_ns() value = torch.randn( - 10000, 10000, device=xm.xla_device()) * torch.randn( - 10000, 10000, device=xm.xla_device()) + 10000, 10000, device=torch_xla.device()) * torch.randn( + 10000, 10000, device=torch_xla.device()) value_mean = value.mean() torch_xla.sync() cpu_value = value_mean.cpu() @@ -226,7 +226,7 @@ def test_execute_time_metric(self): def test_pybind_increment_counter(self): met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.tensor(2077, device=xla_device) self.assertEqual(met.counter_value('CreateXlaTensor'), 1) torch_xla._XLAC._xla_increment_counter('CreateXlaTensor', 3) @@ -254,10 +254,10 @@ def getAndAssertFallbackOpsLenEquals(count): # Create N boxes in the format XYXY. # This should not run any fallback ops. N = 10 - x = torch.rand(N, 1).to(xm.xla_device()) - y = torch.rand(N, 1).to(xm.xla_device()) - width = torch.rand(N, 1).to(xm.xla_device()) - height = torch.rand(N, 1).to(xm.xla_device()) + x = torch.rand(N, 1).to(torch_xla.device()) + y = torch.rand(N, 1).to(torch_xla.device()) + width = torch.rand(N, 1).to(torch_xla.device()) + height = torch.rand(N, 1).to(torch_xla.device()) xys = torch.cat((x, x + width, y, y - height), dim=1) getAndAssertFallbackOpsLenEquals(0) @@ -274,7 +274,7 @@ def getAndAssertFallbackOpsLenEquals(count): if not XLAExperimentalContains("nms"): # Run torchvision operations as fallback. import torchvision - scores = torch.rand(N).to(xm.xla_device()) + scores = torch.rand(N).to(torch_xla.device()) # NMS doesn't have a PyTorch/XLA implementation without dynamic shapes. torchvision.ops.nms(xys, scores, 0.5) # remove_small_boxes is not implemented in C++. It calls other PyTorch diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 8cf8a7a92170..93d64f47ef3e 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -11,7 +11,7 @@ def all_gather(tensor, dim): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() input_list_size = 5 if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): diff --git a/test/test_mp_all_to_all.py b/test/test_mp_all_to_all.py index f7e4a2f0c084..9761507dea13 100644 --- a/test/test_mp_all_to_all.py +++ b/test/test_mp_all_to_all.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'NEURON'): slots_per_device = 4 size = slots_per_device * xr.world_size() diff --git a/test/test_mp_collective_matmul.py b/test/test_mp_collective_matmul.py index 7ebfd7d80f89..29f115c986cd 100644 --- a/test/test_mp_collective_matmul.py +++ b/test/test_mp_collective_matmul.py @@ -8,7 +8,7 @@ def _mp_fn(index): os.environ["ENABLE_COLLECTIVE_MATMUL_IN_MP"] = "1" - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() groups = [[i for i in range(world_size)]] scale = 1 / world_size diff --git a/test/test_mp_collective_permute.py b/test/test_mp_collective_permute.py index 79c7196ac5ab..81a1eb771bcd 100644 --- a/test/test_mp_collective_permute.py +++ b/test/test_mp_collective_permute.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ['TPU', 'NEURON']: world_size = xr.world_size() ordinal = xr.global_ordinal() diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index fd90398a7158..7d6c7982cb2f 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -7,7 +7,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA'): world_size = xr.world_size() diff --git a/test/test_mp_early_exit.py b/test/test_mp_early_exit.py index e8f411b9abab..89e46722e232 100644 --- a/test/test_mp_early_exit.py +++ b/test/test_mp_early_exit.py @@ -12,7 +12,7 @@ def _mp_fn(): dist.init_process_group('xla', init_method='xla://') - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ['TPU', 'CUDA']: train_loader = xu.SampleGenerator( data=torch.zeros(1, 12), sample_count=1024) diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 2b6d55bab596..bba65cde1ee8 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() scale = 1 / world_size scatter_dim = 1 diff --git a/test/test_mp_replication.py b/test/test_mp_replication.py index 5b3392f3c487..61a302a65784 100644 --- a/test/test_mp_replication.py +++ b/test/test_mp_replication.py @@ -10,7 +10,7 @@ def all_reduce(tensor): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() world_size = xr.world_size() if world_size > 1: ones = torch.ones((2, 3)) diff --git a/test/test_mp_save.py b/test/test_mp_save.py index 1a3696f9e76b..ae9f46df120a 100644 --- a/test/test_mp_save.py +++ b/test/test_mp_save.py @@ -35,7 +35,7 @@ def _get_data_str(data): def _mp_fn(index, temp_file): - device = xm.xla_device() + device = torch_xla.device() dd = _create_state_dict(device) xm.save(dd, temp_file) # User needs to manually rendezvous since only master process diff --git a/test/test_mp_sync_batch_norm.py b/test/test_mp_sync_batch_norm.py index 561b7976a83b..fa4f18ad00d2 100644 --- a/test/test_mp_sync_batch_norm.py +++ b/test/test_mp_sync_batch_norm.py @@ -47,7 +47,7 @@ def _sync_bn1d_no_channel(rank): t_global = torch.rand((xr.world_size() * bsz, length)) # XLA SyncBatchNorm - device = xm.xla_device() + device = torch_xla.device() t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(length).to(device) result = run_step(sbn_xla, t_xla) @@ -72,7 +72,7 @@ def _sync_bn1d_multi_channel(rank): t_global = torch.rand((xr.world_size() * bsz, features, length)) # XLA SyncBatchNorm - device = xm.xla_device() + device = torch_xla.device() t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) @@ -97,7 +97,7 @@ def _sync_bn2d(rank): t_global = torch.rand((xr.world_size() * bsz, features, h, w)) # XLA SyncBatchNorm - device = xm.xla_device() + device = torch_xla.device() t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) @@ -122,7 +122,7 @@ def _sync_bn3d(rank): t_global = torch.rand((xr.world_size() * bsz, features, d, h, w)) # XLA SyncBatchNorm - device = xm.xla_device() + device = torch_xla.device() t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) diff --git a/test/test_operations.py b/test/test_operations.py index b3f31e8a0f3b..1fc90898f13e 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -179,7 +179,7 @@ def onlyIfPJRTDeviceIsCUDA(fn): class TestToXlaTensorArena(test_utils.XlaTestCase): def test(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)] kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)]) @@ -307,7 +307,7 @@ def loop_fn(model, loader, device, context): class TestLongGraphChain(test_utils.XlaTestCase): def test(self): - device = xm.xla_device() + device = torch_xla.device() orig_x = torch.Tensor([[1, 2], [3, 4]]) orig_y = torch.Tensor([[0.1, 0.2], [0.3, 0.4]]) x = orig_x @@ -328,7 +328,7 @@ def test(self): class TestSelect(test_utils.XlaTestCase): def test_get_xla_tensor(self): - x = _gen_tensor(14, 24, 8, device=xm.xla_device()) + x = _gen_tensor(14, 24, 8, device=torch_xla.device()) t = x.data.cpu() sx = x.select(1, 12) tx = t.select(1, 12) @@ -343,7 +343,7 @@ def fn(tensor): # Call masked_fill. return tensor.masked_fill(mask, 10) - x = _gen_tensor(2, 2, device=xm.xla_device()) + x = _gen_tensor(2, 2, device=torch_xla.device()) x_cpu = x.cpu() self.assertEqual(fn(x_cpu), fn(x)) @@ -352,7 +352,7 @@ class TestRandom(test_utils.XlaTestCase): def test_random_from_to_bool(self): for from_val, to_val in [[0, 1], [0, 2], [1, 2]]: - x = _gen_tensor(10, device=xm.xla_device()) + x = _gen_tensor(10, device=torch_xla.device()) x.random_(from_val, to_val) delta = 1 self.assertTrue(from_val <= x.to(torch.int).min() < (from_val + delta)) @@ -416,20 +416,20 @@ def test_fn(x): class TestDynamicShape(test_utils.XlaTestCase): def test_nonzero_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device()) x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.nonzero(x, as_tuple=False), 0) self.assertEqual(x_dim0_shape.item(), 4) def test_masked_select_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device()) mask = x.ge(2) x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.masked_select(x, mask), 0) self.assertEqual(x_dim0_shape.item(), 3) def test_nonzero_cast(self): - t1 = torch.ones(5, 2, device=xm.xla_device()) + t1 = torch.ones(5, 2, device=torch_xla.device()) # Result of the nonzero should be the index type. Currently # index type is s64 on cpu and gpu, but s32 on TPU. We should be # able to cast it to any other type without error. @@ -440,7 +440,7 @@ def test_nonzero_cast(self): class TestOptimizationBarrier(test_utils.XlaTestCase): def test_optimization_barrier_correctness(self): - device = xm.xla_device() + device = torch_xla.device() # only test optimization_barrier on TPU if xm.xla_device_hw(device) != 'TPU': return @@ -459,7 +459,7 @@ def op_fn(a): return xb.Op.tuple((a, a.cast(xb.Type.BF16))) op = xor.register('test_mixed_dtype_tuple', op_fn) - xla_device = xm.xla_device() + xla_device = torch_xla.device() a_tensor = torch.randn([2, 3]).to(xla_device) a_result, a_cast = op(a_tensor) self.assertEqual(a_result.dtype, torch.float) @@ -476,14 +476,14 @@ def test_get_real_xla_devices(self): def test_negative_slice(self): t = _gen_tensor(32, 24, 32) - x = t.to(xm.xla_device()) + x = t.to(torch_xla.device()) t_slice = t[:, :, -1] x_slice = x[:, :, -1] self.assertEqual(t_slice.data, x_slice.data.cpu()) def test_negative_cat(self): t = _gen_tensor(2, 5, 3) - x = t.to(xm.xla_device()) + x = t.to(torch_xla.device()) t_cat = torch.cat([t, t], -1) x_cat = torch.cat([x, x], -1) self.assertEqual(t_cat.data, x_cat.data.cpu()) @@ -491,8 +491,8 @@ def test_negative_cat(self): def test_cat_empty_tensor(self): t = _gen_tensor(2, 5, 3) empty_tensor = torch.Tensor() - x = t.to(xm.xla_device()) - empty_tensor_xla = empty_tensor.to(xm.xla_device()) + x = t.to(torch_xla.device()) + empty_tensor_xla = empty_tensor.to(torch_xla.device()) t_cat = torch.cat([t, empty_tensor], 0) x_cat = torch.cat([x, empty_tensor_xla], 0) self.assertEqual(t_cat.data, x_cat.data.cpu()) @@ -530,7 +530,7 @@ def test_amp_foreach_non_finite_check_and_unscale_(self): found_inf_output0 = torch.tensor(0, dtype=torch.float32) found_inf_output1 = torch.tensor(1, dtype=torch.float32) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_grads0 = grads0.to(xla_device) xla_inv_scale = inv_scale.to(xla_device) xla_found_inf = found_inf.to(xla_device) @@ -550,9 +550,9 @@ def test_masked_fill_with_tensor(self): input = _gen_tensor(2, 5, 4, 3) mask = _gen_mask(input.size()) value = torch.tensor(42) - xla_input = input.to(xm.xla_device()) - xla_mask = mask.to(xm.xla_device()) - xla_value = value.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) + xla_mask = mask.to(torch_xla.device()) + xla_value = value.to(torch_xla.device()) result = torch.masked_fill(input, mask, value) xla_result = torch.masked_fill(xla_input, xla_mask, xla_value) self.assertEqual(input.data, xla_input.data.cpu()) @@ -571,63 +571,63 @@ def test_fn(a, b, m): def test_add_mixed_device(self): input = _gen_tensor(3, 800, 1066) - xla_input = input.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) output = input + 2 xla_output = xla_input + 2 self.assertEqual(output.data, xla_output.data.cpu()) def test_mul_mixed_device(self): input = _gen_tensor(3, 800, 1066) - xla_input = input.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) output = input * 2 xla_output = xla_input * 2 self.assertEqual(output.data, xla_output.data.cpu()) def test_sub_mixed_device(self): input = _gen_tensor(3, 800, 1066) - xla_input = input.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) output = input - 2 xla_output = xla_input - 2 self.assertEqual(output.data, xla_output.data.cpu()) def test_div_mixed_device(self): input = _gen_tensor(3, 800, 1066) - xla_input = input.to(xm.xla_device()) + xla_input = input.to(torch_xla.device()) output = input / 2 xla_output = xla_input / 2 self.assertEqual(output.data, xla_output.data.cpu()) def test_rand(self): - x = torch.rand(3, 5, device=xm.xla_device()) + x = torch.rand(3, 5, device=torch_xla.device()) self.assertEqual(x.device.type, 'xla') def test_randperm(self): - x = torch.randperm(3, device=xm.xla_device(), dtype=torch.int32) + x = torch.randperm(3, device=torch_xla.device(), dtype=torch.int32) self.assertEqual(x.device.type, 'xla') def test_randn_like(self): shape = (5, 1, 1) - x = torch.randn_like(torch.zeros(shape, device=xm.xla_device())) + x = torch.randn_like(torch.zeros(shape, device=torch_xla.device())) self.assertEqual(x.device.type, 'xla') def test_rand_like(self): shape = (5, 1, 1) - x = torch.rand_like(torch.zeros(shape, device=xm.xla_device())) + x = torch.rand_like(torch.zeros(shape, device=torch_xla.device())) self.assertEqual(x.device.type, 'xla') def test_randint_like(self): shape = (5, 1, 1) x = torch.randint_like( - torch.zeros(shape, device=xm.xla_device(), dtype=torch.uint8), 6, 10) + torch.zeros(shape, device=torch_xla.device(), dtype=torch.uint8), 6, 10) self.assertEqual(x.device.type, 'xla') def test_no_storage(self): - x = torch.randn(5, device=xm.xla_device()) + x = torch.randn(5, device=torch_xla.device()) self.assertRaises(Exception, x.device) def test_slice_copy(self): a = torch.rand(3, 3, 3) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() @@ -638,7 +638,7 @@ def test_slice_copy(self): def test_slice_assign(self): a = torch.rand(3, 3, 3) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() @@ -649,7 +649,7 @@ def test_slice_assign(self): def test_slice_stepped_assign(self): a = torch.ones((10, 4)) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_a = a.to(xla_device) a[:, 0::2] = 2 xla_a[:, 0::2] = 2 @@ -657,14 +657,14 @@ def test_slice_stepped_assign(self): def test_slice_stepped_other_assign(self): a = torch.ones((10, 4)) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_a = a.to(xla_device) a[:, 1::4] = 2 xla_a[:, 1::4] = 2 self.assertEqual(a.data, xla_a.data.cpu()) def test_ailing_slice(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.ones((1000, 324)).to(xla_device) xla_a = a.to(xla_device) w = a[:, 2::4] @@ -674,7 +674,7 @@ def test_ailing_slice(self): self.assertEqual(w.data, xla_w.data.cpu()) def test_slice_rnd_stepped_assign(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() size = 10 for s in range(0, size - 1): for e in range(1, size - s): @@ -686,12 +686,12 @@ def test_slice_rnd_stepped_assign(self): def test_arange_nan(self): with self.assertRaisesRegex(RuntimeError, r'unsupported range'): - a = torch.arange(-5, float('nan'), device=xm.xla_device()) + a = torch.arange(-5, float('nan'), device=torch_xla.device()) with self.assertRaisesRegex(RuntimeError, r'unsupported range'): - a = torch.arange(float('nan'), 5, device=xm.xla_device()) + a = torch.arange(float('nan'), 5, device=torch_xla.device()) def test_empty_advanced_indexing(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() base = torch.randn(2, 3, 4, 5) xla_base = base.to(device=xla_device) result = base[:, torch.empty(0, 6, dtype=torch.int64)] @@ -702,7 +702,7 @@ def test_empty_advanced_indexing(self): "grad_input produces wrong results after functionalization. pytorch/pytorch#91199" ) def test_empty_strided(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() m = nn.Conv1d(4, 6, kernel_size=3, groups=2) a = torch.rand(2, 4, 6, requires_grad=True) xla_m = copy.deepcopy(m).to(xla_device) @@ -730,13 +730,13 @@ def test_empty_strided(self): def test_clamp(self): a = torch.randn(3, 3) - xla_a = a.to(xm.xla_device()) + xla_a = a.to(torch_xla.device()) b = torch.clamp(a, max=3.4) xla_b = torch.clamp(xla_a, max=3.4) self.assertEqual(b.data, xla_b.data.cpu()) def test_rrelu_module(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(1, 2, 2, requires_grad=True) xla_a = a.to(xla_device).detach() xla_a.requires_grad = True @@ -753,7 +753,7 @@ def test_rrelu_module(self): self.assertEqual(a.grad, xla_a.grad.cpu()) def test_max_broadcast(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(3, 1, 2) b = torch.rand(4, 2) c = torch.max(a, b) @@ -763,7 +763,7 @@ def test_max_broadcast(self): self.assertEqual(c.data, xla_c.data.cpu()) def test_sgn(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t = torch.randn(2, 3, dtype=torch.cfloat) # Generate inf+infj t[0][0].real.div_(0) @@ -847,7 +847,7 @@ def test_view_as_complex_f64(self): torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3]) def test_index_put(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32) b = torch.rand(4) > 0.1 a[b] = 10 @@ -891,7 +891,7 @@ def test_baddmm_integer_types(self): def test_view_empty(self): # These used to throw floating point exception. - empty = torch.empty(0, device=xm.xla_device()) + empty = torch.empty(0, device=torch_xla.device()) with self.assertRaisesRegex( RuntimeError, r'unspecified dimension size -1 can be any value'): empty.view(-1, 0) @@ -912,12 +912,12 @@ def test_fn(device): return loss, linear.weight.grad cpu_loss, cpu_weight_grad = test_fn('cpu') - xla_loss, xla_weight_grad = test_fn(xm.xla_device()) + xla_loss, xla_weight_grad = test_fn(torch_xla.device()) self.assertEqual(cpu_loss, xla_loss) self.assertEqual(cpu_weight_grad, xla_weight_grad) def test_inplace_view_backprop_base(self): - root = torch.randn(2, 2, device=xm.xla_device(), requires_grad=True) + root = torch.randn(2, 2, device=torch_xla.device(), requires_grad=True) x = root.clone() v1 = x.narrow(0, 0, 1) v1.mul_(2) @@ -925,7 +925,7 @@ def test_inplace_view_backprop_base(self): self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]]) def test_inplace_view_backprop_view_of_view(self): - root = torch.randn(2, 2, device=xm.xla_device(), requires_grad=True) + root = torch.randn(2, 2, device=torch_xla.device(), requires_grad=True) x = root.clone() v1 = x.narrow(0, 0, 1) v2 = x.narrow(0, 0, 1) @@ -935,7 +935,7 @@ def test_inplace_view_backprop_view_of_view(self): def test_inplace_view_of_view(self): # modify view-of-view and backprop through base - root = torch.randn(2, 2, device=xm.xla_device(), requires_grad=True) + root = torch.randn(2, 2, device=torch_xla.device(), requires_grad=True) x = root.clone() v1 = x.narrow(0, 0, 1) v2 = v1.narrow(1, 1, 1) @@ -945,7 +945,7 @@ def test_inplace_view_of_view(self): def test_inplace_view_multiple_outputs(self): root = torch.arange( - 9., device=xm.xla_device()).reshape(3, 3).requires_grad_() + 9., device=torch_xla.device()).reshape(3, 3).requires_grad_() x = root.clone() v1 = x.unbind() with self.assertRaises(RuntimeError): @@ -986,7 +986,7 @@ def func(root, b): def test_inplace_view_backprop_view(self): # modify view and backprop through view - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.tensor([2., 5.], device=xla_device, requires_grad=False) b = torch.tensor([3.], device=xla_device, requires_grad=True) res = a.narrow(0, 1, 1).mul_(b) @@ -1040,7 +1040,8 @@ def func(root, b): def test_inplace_view_non_contig(self): root = torch.ones( - 2, 3, 2, device=xm.xla_device()).select(2, 1).t().requires_grad_(True) + 2, 3, 2, device=torch_xla.device()).select(2, + 1).t().requires_grad_(True) x = root.clone() v1 = x.narrow(0, 0, 1) v2 = v1.narrow(1, 1, 1) @@ -1079,12 +1080,12 @@ def func(x): def test_set(self): met.clear_all() - t1 = torch.zeros(50, device=xm.xla_device()) + t1 = torch.zeros(50, device=torch_xla.device()) t1 += 1 torch_xla.sync() self.assertEqual(met.counter_value('DestroyXlaTensor'), 3) - t2 = torch.zeros(10, device=xm.xla_device()) + t2 = torch.zeros(10, device=torch_xla.device()) self.assertEqual(met.counter_value('DestroyXlaTensor'), 4) t1.set_(t2) @@ -1097,12 +1098,12 @@ def test_set(self): def test_replace_xla_tensor(self): met.clear_all() - t1 = torch.zeros(50, device=xm.xla_device()) + t1 = torch.zeros(50, device=torch_xla.device()) t1 += 1 torch_xla.sync() self.assertEqual(met.counter_value('DestroyXlaTensor'), 3) - t2 = torch.zeros(10, device=xm.xla_device()) + t2 = torch.zeros(10, device=torch_xla.device()) self.assertEqual(met.counter_value('DestroyXlaTensor'), 4) torch_xla._XLAC._replace_xla_tensor(t1, t2) self.assertEqual(met.counter_value('DestroyXlaTensor'), 5) @@ -1111,7 +1112,7 @@ def test_replace_xla_tensor(self): self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10))) def test_pred_type(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(4) b = torch.rand(4) xla_a = a.to(xla_device) @@ -1133,7 +1134,7 @@ def test_pred_type(self): self.runAtenTest(c, lambda x: x ^ x.byte()) def test_bitwise_and_not(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.randint(255, (4,), dtype=torch.long) xla_a = a.to(xla_device) @@ -1143,27 +1144,27 @@ def test_fn(a): self.runAtenTest(a, test_fn) def test_s_copy_dtype(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(10).to(xla_device).to(dtype=torch.uint8) b = torch.tensor([0, 1, 2, 3]).to(xla_device) self.assertEqual(a[b].dtype, torch.uint8) def test_slice_zero_sized_dim(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() v = torch.randn(2, 3, 4, 5).to(xla_device) y = v[:, :, :, 1] z = y[:, 1:1, :] self.assertEqual(z.size()[1], 0) def test_byte_dtype(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.ByteTensor([0, 1]).to(xla_device) y = torch.ByteTensor([0, 1]).to(xla_device) z = x + y self.assertEqual(z.dtype, torch.uint8) def test_frac_negative(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.tensor(-3.2) b = a.frac() xla_a = a.to(xla_device) @@ -1171,7 +1172,7 @@ def test_frac_negative(self): self.assertEqual(b, xla_b) def test_flip(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) self.assertEqual( torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0)) @@ -1194,7 +1195,7 @@ def test_flip(self): torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0)) def test_flip_check_throws(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) # not allow flip on the same dim more than once self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) @@ -1206,7 +1207,7 @@ def test_flip_check_throws(self): self.assertRaises(RuntimeError, lambda: data.flip(3)) def test_flip_expand(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2) transposed_data = torch.arange( @@ -1218,7 +1219,7 @@ def test_flip_expand(self): transposed_data.flip(0, 1, 2)) def test_flip_shape(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.randn(2, 3, 4, device=device) size = [2, 3, 4] test_dims = [] @@ -1228,7 +1229,7 @@ def test_flip_shape(self): self.assertEqual(size, list(data.flip(ds).size())) def test_flip_rectangular(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3).to(device) flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]]).to(device) flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]]).to(device) @@ -1237,13 +1238,13 @@ def test_flip_rectangular(self): self.assertEqual(flip1_result, data.flip(1)) def test_flip_empty_tensor(self): - device = xm.xla_device() + device = torch_xla.device() data = torch.tensor([]) self.assertEqual(data, data.flip(0)) def test_norm_p0(self): # p = 0 is equivalent to nonzero - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.randn(3, 2) xla_a = a.to(xla_device) norm = a.norm(p=0) @@ -1289,7 +1290,7 @@ def test_fn(input, src): self.runAtenTest([torch.zeros(3, 3), torch.ones(3)], test_fn) def test_scatter_add_bool(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]]) b = torch.zeros(3, 5, dtype=torch.bool) @@ -1334,7 +1335,7 @@ def test_reduction_0dim(self): self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.mean(x)) self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.prod(x)) # min & max throws - xla_device = xm.xla_device() + xla_device = torch_xla.device() a = torch.rand(2, 0, 4) xla_a = a.to(xla_device) self.assertRaises(IndexError, lambda: torch.max(a, dim=1)) @@ -1470,11 +1471,11 @@ def check(device): d = a xm.check_view_sharing([a, d]) - check(xm.xla_device()) + check(torch_xla.device()) check(torch.device('cpu')) def test_save(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.randn(5, device=xla_device) with tempfile.NamedTemporaryFile() as tf: torch.save(x, tf) @@ -1482,7 +1483,7 @@ def test_save(self): self.assertEqual(x, x_loaded) def test_save_bf16(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.randn(5, dtype=torch.bfloat16, device=xla_device) with tempfile.NamedTemporaryFile() as tf: torch.save(x, tf) @@ -1490,7 +1491,7 @@ def test_save_bf16(self): self.assertEqual(x, x_loaded) def test_save_tuple(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.randn(5, device=xla_device) number = 3 with tempfile.NamedTemporaryFile() as tf: @@ -1500,7 +1501,7 @@ def test_save_tuple(self): self.assertEqual(number, number_loaded) def test_save_api(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() model = XlaMNIST().to(xla_device) with tempfile.NamedTemporaryFile() as tf: xm.save(model.state_dict(), tf) @@ -1513,7 +1514,7 @@ def test_save_api(self): def test_serialization_api(self): with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, 'data.pt') - xla_device = xm.xla_device() + xla_device = torch_xla.device() model = XlaMNIST().to(xla_device) xser.save(model.state_dict(), path) state_dict = xser.load(path) @@ -1523,7 +1524,7 @@ def test_serialization_api(self): self.assertEqual(model.state_dict(), loaded_model.state_dict()) def test_deepcopy(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.rand(5, device=xla_device) x0 = x[0] y = copy.deepcopy(x) @@ -1533,7 +1534,7 @@ def test_deepcopy(self): self.assertEqual(x[0], x0) def test_print(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.tensor([5], device=xla_device) expected_str = 'tensor([5], device=\'' + str(xla_device) + '\')' self.assertEqual(str(x), expected_str) @@ -1728,14 +1729,14 @@ def test_fn(t): self.runAtenTest([torch.tensor(20.0)], test_fn) def test_view_and_copy_(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device='cpu') y = torch.tensor([0, 0, 0, 0, 0, 0], device=xla_device) y[::2].copy_(x[::2]) self.assertEqual(y, [1, 0, 3, 0, 5, 0]) def test_view_and_multi_sync(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t1 = torch.zeros(100, device=xla_device) t1[10] = 113 torch_xla.sync() @@ -1745,7 +1746,7 @@ def test_view_and_multi_sync(self): torch_xla._XLAC._get_xla_tensors_text([t1])) def test_binaryop_order(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = torch.rand(5, device=xla_device) y = torch.rand(5) self.assertEqual(x + y, y + x) @@ -1753,14 +1754,14 @@ def test_binaryop_order(self): # Since in eager mode the tensor would be materialized and hence _get_xla_tensors_text would not show the prim::Constant node. @skipOnEagerDebug def test_pow_constant(self): - t1 = torch.pow(torch.tensor([2.0, 3.0], device=xm.xla_device()), 5) + t1 = torch.pow(torch.tensor([2.0, 3.0], device=torch_xla.device()), 5) hlo_text = torch_xla._XLAC._get_xla_tensors_text([t1]) const_hlo = hlo_text.split('\n')[1] assert 'prim::Constant' in const_hlo assert 'xla::device_data' not in const_hlo def test_emb_bf16(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() index = torch.ones(1, dtype=torch.long, device=xla_device) emb = torch.nn.Embedding(1024, 128, device=xla_device) emb = emb.to(torch.bfloat16) @@ -1780,7 +1781,7 @@ def test_on_device(device): return m(index) out = test_on_device("cpu") - out_x = test_on_device(xm.xla_device()) + out_x = test_on_device(torch_xla.device()) self.assertEqual(out, out_x.cpu()) def test_transpose_1d(self): @@ -1799,7 +1800,7 @@ def test_fn(t1): def test_sigmoid_bounds(self): torch.manual_seed(0) - xla_device = xm.xla_device() + xla_device = torch_xla.device() for _ in range(100): x = torch.rand(1000).to(xla_device) lower_bound = torch.sigmoid(x * (-100.0)) @@ -1816,7 +1817,7 @@ def test_manual_seed(self): self.assertTrue(torch.allclose(t1.cpu(), t2.cpu())) def test_cached_addcdiv(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() met.clear_all() t1 = torch.randn(1, 3).to(xla_device) @@ -1834,7 +1835,7 @@ def test_cached_addcdiv(self): @skipOnEagerDebug def test_print_execution(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() torch_xla.sync() xm.wait_device_ops() met.clear_all() @@ -1888,7 +1889,7 @@ def test_fn(input): return dropped[1].cpu(), input.grad.cpu() met.clear_all() - xla_device = xm.xla_device() + xla_device = torch_xla.device() input_cpu = torch.randn(7, 7, requires_grad=True) input_xla = torch.randn(7, 7, device=xla_device, requires_grad=True) mask_cpu, grad_cpu = test_fn(input_cpu) @@ -2046,7 +2047,7 @@ def foo(x): x = torch.arange(10).to(dtype) r = foo(x) - device = xm.xla_device() + device = torch_xla.device() Xx = x.to(device) Xr = foo(Xx) @@ -2089,8 +2090,8 @@ def foo(grad, inp): grad = torch.rand(10, 10, dtype=torch.bfloat16) inp = torch.rand(10, 10) - Xgrad = grad.to(xm.xla_device()) - Xinp = inp.to(xm.xla_device()) + Xgrad = grad.to(torch_xla.device()) + Xinp = inp.to(torch_xla.device()) r = foo(grad, inp) Xr = foo(Xgrad, Xinp) @@ -2105,8 +2106,8 @@ def foo(t): t = torch.rand(10, 10, requires_grad=True, dtype=torch.bfloat16) t.retain_grad() t.grad = torch.rand(10, 10, dtype=torch.bfloat16) - xt = t.to(xm.xla_device()) - xt.grad = t.grad.to(xm.xla_device(), dtype=torch.bfloat16) + xt = t.to(torch_xla.device()) + xt.grad = t.grad.to(torch_xla.device(), dtype=torch.bfloat16) foo(t) foo(xt) @@ -2116,7 +2117,7 @@ def foo(t): def test_clip_grad_norm_zero(self): t = torch.rand(10, 10, dtype=torch.bfloat16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) result = torch.nn.utils.clip_grad_norm_(xt, 1.0) self.assertEqual(result.device.type, 'xla') self.assertTrue(torch.allclose(result.cpu(), torch.tensor(0.))) @@ -2129,8 +2130,8 @@ def foo(t0, t1): t0 = torch.rand(10, 10, dtype=torch.bfloat16) t1 = torch.rand(10, 10) - Xt0 = t0.to(xm.xla_device()) - Xt1 = t1.to(xm.xla_device()) + Xt0 = t0.to(torch_xla.device()) + Xt1 = t1.to(torch_xla.device()) r = foo(t0, t1) Xr = foo(Xt0, Xt1) @@ -2171,8 +2172,8 @@ def test(f, xshape, ishapes): x = make_tensor(xshape) ilist = [make_index(s) for s in ishapes] - Xx = x.to(xm.xla_device()) - Xilist = [i.to(xm.xla_device()) for i in ilist] + Xx = x.to(torch_xla.device()) + Xilist = [i.to(torch_xla.device()) for i in ilist] out = f(x, *ilist) Xout = f(Xx, *Xilist) @@ -2212,8 +2213,8 @@ def fn(inp, s): inp = torch.rand(10, dtype=torch.half) s = torch.tensor(7, dtype=torch.double) - Xinp = inp.to(xm.xla_device()) - Xs = s.to(xm.xla_device()) + Xinp = inp.to(torch_xla.device()) + Xs = s.to(torch_xla.device()) out = fn(inp, s) Xout = fn(Xinp, Xs) @@ -2267,7 +2268,7 @@ def foo(x, is_xla=False): return r + 5 inp = torch.rand(1, 3, 10, 10, dtype=torch.double) - Xinp = inp.to(xm.xla_device()) + Xinp = inp.to(torch_xla.device()) out = foo(inp) Xout = foo(Xinp, is_xla=True) @@ -2332,7 +2333,7 @@ def clone_and_maybe_move(tensor, device=None): with self.subTest(sparse=sparse, mode=mode): kwargs_ = {k: clone_and_maybe_move(v) for k, v in kwargs.items()} xla_kwargs = { - k: clone_and_maybe_move(v, device=xm.xla_device()) + k: clone_and_maybe_move(v, device=torch_xla.device()) for k, v in kwargs.items() } @@ -2365,7 +2366,7 @@ def foo(x: torch.Tensor) -> torch.Tensor: input = torch.rand((10, 10), dtype=torch.float16) out = foo(input) - in_xla = input.to(xm.xla_device()) + in_xla = input.to(torch_xla.device()) out_xla = foo(in_xla) self.assertEqual(out.dtype, out_xla.dtype) @@ -2381,7 +2382,7 @@ def test_cummax_0_sized_dimension(self): a = torch.rand(5, 5, 0, 5) expected = torch.cummax(a, dim) - actual = torch.cummax(a.to(xm.xla_device()), dim) + actual = torch.cummax(a.to(torch_xla.device()), dim) self.assertEqual(actual, expected) @@ -2395,7 +2396,7 @@ def run(device): return runf(*args_) actual = run("cpu") - expected = run(xm.xla_device()) + expected = run(torch_xla.device()) self.assertFalse( met.executed_fallback_ops(), msg="expected no fallback operations.") @@ -2454,7 +2455,7 @@ class TestModelComparator(test_utils.XlaTestCase): def test(self): SEED = 42 - xla_device = xm.xla_device() + xla_device = torch_xla.device() x = _gen_tensor(8, 1, 28, 28) xla_x = x.to(xla_device) @@ -2479,13 +2480,13 @@ def test(self): class TestWaitDeviceOps(test_utils.XlaTestCase): def test_wait_device_ops(self): - xm.xla_device() - value = torch.randn(10000, 10000, device=xm.xla_device()) + torch_xla.device() + value = torch.randn(10000, 10000, device=torch_xla.device()) val_list = [] val_mean_list = [] met.clear_all() for _ in range(5): - new_val = value * torch.randn(10000, 10000, device=xm.xla_device()) + new_val = value * torch.randn(10000, 10000, device=torch_xla.device()) val_list.append(new_val) val_mean_list.append(new_val.mean()) torch_xla.sync() @@ -2498,7 +2499,7 @@ class TestDebuggingUtil(test_utils.XlaTestCase): @skipOnEagerDebug def test_get_xla_tensor_debug_info(self): - device = xm.xla_device() + device = torch_xla.device() # test non xla tensor cpu_t1 = torch.randn(5) cpu_t1_info = torch_xla._XLAC._get_xla_tensor_debug_info(cpu_t1) @@ -2533,7 +2534,7 @@ def runOpBuilderTest(self, kwargs=dict()): op = xor.register(name, opfn) if device is None: - device = xm.xla_device() + device = torch_xla.device() if aten_fn is None: aten_fn = opfn tensors = xu.as_list(tensors) @@ -2655,7 +2656,7 @@ class MpDecoratorTest(test_utils.XlaTestCase): @xtu.mp_test def test_mp_decorator(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() self.assertTrue(xla_device.type == 'xla') @@ -2694,7 +2695,7 @@ class TestLoweringContext(test_utils.XlaTestCase): def test_api(self): met.clear_all() - device = xm.xla_device() + device = torch_xla.device() a = torch.tensor([1.0, 2.0, 3.0], device=device) b = torch.tensor([4.0, 5.0, 6.0], device=device) @@ -2755,13 +2756,13 @@ def test_git_revisons(self): self.assertTrue('torch' in revs) def test_send_to_device_grad(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t = _gen_tensor(2, 2, requires_grad=True) dt = xm.send_cpu_data_to_device([t], xla_device) self.assertTrue(dt[0].requires_grad) def test_send_to_device_single(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() t = _gen_tensor(2, 2) dt = xm.send_cpu_data_to_device(t, xla_device) self.assertEqual(dt[0].device, xla_device) @@ -2861,7 +2862,7 @@ def from_tensors(self, tensors): wpack = PackWrapper(pack) - xla_device = xm.xla_device() + xla_device = torch_xla.device() xdata = xm.send_cpu_data_to_device(wpack, xla_device) self.assertTrue(isinstance(xdata, nn.utils.rnn.PackedSequence)) self.assertEqual(xdata.batch_sizes.device, torch.device('cpu')) @@ -2871,7 +2872,7 @@ def from_tensors(self, tensors): "https://github.com/pytorch/xla/pull/7864#issuecomment-2294034008") def test_as_strided_input_larger(self): size = (5, 5) - device = xm.xla_device() + device = torch_xla.device() a = torch.ones(size, device=device) small_a = a[:, ::2] @@ -2883,7 +2884,7 @@ def _test_move_tensor_cuda_to_xla(self, cpu_tensor): # Assumes CPU-XLA data movement works. cuda_tensor = cpu_tensor.to("cuda") # Move tensor CUDA -> XLA. - xla_tensor = cuda_tensor.to(xm.xla_device()) + xla_tensor = cuda_tensor.to(torch_xla.device()) # Move the XLA tensor back to CPU, and check that it is the same as # the original CPU tensor. self.assertTrue(torch.equal(cpu_tensor, xla_tensor.cpu())) @@ -2901,7 +2902,7 @@ def test_aten_move_scalar_cuda_to_xla(self): self._test_move_tensor_cuda_to_xla(torch.tensor(42)) def test_unsafe_buffer_pointer(self): - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_tensor_0 = torch.tensor(42).to(xla_device) # `torch_xla.sync()` ensures xtensor->CurrentDataHandle() != nullptr torch_xla.sync() @@ -2909,7 +2910,7 @@ def test_unsafe_buffer_pointer(self): self.assertGreaterEqual(buf_ptr_0, 0) # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr - xla_tensor_1 = torch.tensor(42, device=xm.xla_device()) + xla_tensor_1 = torch.tensor(42, device=torch_xla.device()) buf_ptr_1 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_1) self.assertGreaterEqual(buf_ptr_1, 0) @@ -2918,7 +2919,7 @@ def test_unsafe_buffer_pointer(self): buf_ptr_2 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_2) self.assertGreaterEqual(buf_ptr_2, 0) - xla_tensor_3 = torch.arange(5, device=xm.xla_device()) + xla_tensor_3 = torch.arange(5, device=torch_xla.device()) torch_xla.sync() # Without the `wait_device_ops()`, the pjrt buffer (pjrt_data->buffer) at https://github.com/pytorch/xla/blob/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd/torch_xla/csrc/runtime/pjrt_computation_client.cc#L467 will be nullptr. xm.wait_device_ops() @@ -2946,14 +2947,14 @@ def _test_dlpack_capsule_conversion_helper(self, xla_tensor): @onlyIfPJRTDeviceIsCUDA @parameterized.parameters(*all_types_and(torch.half, torch.bfloat16)) def test_dlpack_roundtrip_tensor(self, dtype): - xla_device = xm.xla_device() + xla_device = torch_xla.device() # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr # xla_tensor_2 uses XLANativeFunctions::_to_copy xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) self._test_dlpack_capsule_conversion_helper(xla_tensor_2) # xla_tensor_3 uses arange_out IR node. - xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device()) + xla_tensor_3 = torch.arange(5, dtype=dtype, device=torch_xla.device()) torch_xla.sync() self._test_dlpack_capsule_conversion_helper(xla_tensor_3) @@ -2963,7 +2964,7 @@ def test_dlpack_roundtrip_tensor(self, dtype): *all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64)) def test_dlpack_roundtrip_scalar(self, dtype): - xla_device = xm.xla_device() + xla_device = torch_xla.device() xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) # `torch_xla.sync()` ensures xtensor->CurrentDataHandle() != nullptr torch_xla.sync() @@ -2976,7 +2977,7 @@ def test_dlpack_roundtrip_scalar(self, dtype): @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_roundtrip_bool(self): - xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device()) + xla_tensor = torch.ones(1, dtype=torch.bool).to(torch_xla.device()) self._test_dlpack_capsule_conversion_helper(xla_tensor) @onlyIfTorchSupportsCUDA @@ -3044,7 +3045,7 @@ def test_dlpack_pytorch_cuda_to_xla_protocol_conversion(self): @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_pytorch_cuda(self): - xla_t1 = torch.arange(5).to(xm.xla_device()) + xla_t1 = torch.arange(5).to(torch_xla.device()) dlt1 = xdlpack.to_dlpack(xla_t1) cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) self.assertEqual(cuda_t1.device.type, 'cuda') @@ -3055,7 +3056,7 @@ def test_dlpack_xla_to_pytorch_cuda(self): @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_pytorch_cuda_protocol_conversion(self): - xla_t1 = torch.arange(5).to(xm.xla_device()) + xla_t1 = torch.arange(5).to(torch_xla.device()) cuda_t1 = torch.utils.dlpack.from_dlpack(xla_t1) self.assertEqual(cuda_t1.device.type, 'cuda') self.assertEqual(cuda_t1.device.index, xla_t1.device.index) @@ -3120,7 +3121,7 @@ def forward(self, inp): class TestActivationCheckpoint(test_utils.XlaTestCase): def test_dropout(self): - device = xm.xla_device() + device = torch_xla.device() model = SimpleModelWithDropout().to(device) model = checkpoint_module(model) _input = torch.randn(128, 128, requires_grad=True) @@ -3134,7 +3135,7 @@ def test_dropout(self): f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}") def test_opt_barrier(self): - device = xm.xla_device() + device = torch_xla.device() model = SimpleModelWithDropout().to(device) model = checkpoint_module(model) _input = torch.randn(128, 128, requires_grad=True) @@ -3169,7 +3170,7 @@ def _reference_nms(self, boxes, scores, iou_threshold): def _nms(self, boxes, scores, iou_threshold): import torchvision - device = xm.xla_device() + device = torch_xla.device() return torchvision.ops.nms( boxes.to(device), scores.to(device), iou_threshold).cpu() diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 25e17b7c265d..37708d199540 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -30,15 +30,15 @@ def tearDown(self): super(TestOperationsHlo, self).tearDown() def test_expand(self): - a = torch.rand(1, 5, device=xm.xla_device()) + a = torch.rand(1, 5, device=torch_xla.device()) b = a.expand(5, 5) hlo_text = torch_xla._XLAC._get_xla_tensors_text([b]) assert 'aten::expand' in hlo_text def test_special_scalars_addcdiv_addcmul(self): - a = torch.rand(5, 5).to(xm.xla_device()) - b = torch.rand(5, 5).to(xm.xla_device()) - c = torch.rand(5, 5).to(xm.xla_device()) + a = torch.rand(5, 5).to(torch_xla.device()) + b = torch.rand(5, 5).to(torch_xla.device()) + c = torch.rand(5, 5).to(torch_xla.device()) for op in [torch.addcdiv, torch.addcmul]: out = op(a, b, c, value=1.0) hlo_text = torch_xla._XLAC._get_xla_tensors_text([out]) @@ -52,8 +52,8 @@ def test_special_scalars_addcdiv_addcmul(self): def test_div_by_f64(self): mod = torch.nn.MultiheadAttention(768, 12, batch_first=True) - mod.to(xm.xla_device()) - a = torch.rand(1, 512, 768).to(xm.xla_device()) + mod.to(torch_xla.device()) + a = torch.rand(1, 512, 768).to(torch_xla.device()) b, _ = mod(a, a, a, need_weights=False) b.sum().backward() hlo_text = torch_xla._XLAC._get_xla_tensors_text( @@ -61,8 +61,8 @@ def test_div_by_f64(self): assert 'f64' not in hlo_text def test_dropout_by_u8_mask(self): - mod = torch.nn.Dropout().to(xm.xla_device()) - a = torch.rand(20, 16, dtype=torch.bfloat16).to(xm.xla_device()) + mod = torch.nn.Dropout().to(torch_xla.device()) + a = torch.rand(20, 16, dtype=torch.bfloat16).to(torch_xla.device()) b = mod(a) hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b]) assert 'u8' in hlo_text diff --git a/test/test_persistent_cache.py b/test/test_persistent_cache.py index 75a739e64638..ccc5c8a568c7 100644 --- a/test/test_persistent_cache.py +++ b/test/test_persistent_cache.py @@ -51,14 +51,14 @@ def _mp_test(rank, tmpdir, metrics): xr.initialize_cache(os.path.join(tmpdir, str(rank))) t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) _assert_correctness_and_metrics(t, xt, metrics) def _single_device_test(tmpdir, metrics): xr.initialize_cache(tmpdir) t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) _assert_correctness_and_metrics(t, xt, metrics) @@ -66,7 +66,7 @@ def _spmd_replicated_test(tmpdir, metrics): xr.initialize_cache(tmpdir) xr.use_spmd() t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) _assert_correctness_and_metrics(t, xt, metrics) @@ -74,7 +74,7 @@ def _spmd_explicitly_replicated_test(tmpdir, metrics): xr.initialize_cache(tmpdir) xr.use_spmd() t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) n_dev = xr.global_runtime_device_count() mesh = xs.Mesh(range(n_dev), (n_dev,)) @@ -87,7 +87,7 @@ def _spmd_sharded_test(tmpdir, metrics): xr.use_spmd() t = torch.randn(16) - xt = t.to(xm.xla_device()) + xt = t.to(torch_xla.device()) n_dev = xr.global_runtime_device_count() mesh = xs.Mesh(range(n_dev), (n_dev,)) xs.mark_sharding(xt, mesh, (0,)) diff --git a/test/test_profile_mp_mnist.py b/test/test_profile_mp_mnist.py index 266a7bc5634f..e23c2f59c223 100644 --- a/test/test_profile_mp_mnist.py +++ b/test/test_profile_mp_mnist.py @@ -144,7 +144,7 @@ def train_mnist(flags, # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() model = MNIST().to(device) writer = None if xm.is_master_ordinal(): diff --git a/test/test_python_ops.py b/test/test_python_ops.py index 24e48d9a1664..9dc145947f62 100644 --- a/test/test_python_ops.py +++ b/test/test_python_ops.py @@ -29,8 +29,8 @@ def test_put(self, dtype): raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( str(dtype))) - device = xm.xla_device() - real_device_type = xm.xla_device_hw(str(xm.xla_device())) + device = torch_xla.device() + real_device_type = xm.xla_device_hw(str(torch_xla.device())) if real_device_type == "TPU": raise unittest.SkipTest("TestPut is too slow on TPU. Skipped") @@ -108,7 +108,7 @@ def test_index_copy(self, dtype): raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( str(dtype))) - device = xm.xla_device() + device = torch_xla.device() # We just test for num_copy <= num_dest, as otherwise there are repeated indices # and the behavior is undefined diff --git a/test/test_syncfree_optimizers.py b/test/test_syncfree_optimizers.py index 991bf8e5c936..8807271440c6 100644 --- a/test/test_syncfree_optimizers.py +++ b/test/test_syncfree_optimizers.py @@ -53,7 +53,7 @@ def _test_optimizer(self, syncfree_optim_cls, ref_optim_cls, optim_kwargs={'lr': 1e-2}): - device = xm.xla_device() + device = torch_xla.device() loss_fn = nn.NLLLoss() # syncfree model torch.manual_seed(0) diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index f492fcef3334..f1d6499935c5 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -7,7 +7,7 @@ def _mp_fn(index): - dev = xm.xla_device() + dev = torch_xla.device() if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'): print( 'Default device {} is not a TPU or CUDA device'.format(dev), @@ -19,7 +19,7 @@ def _mp_fn(index): model = FSDP(model) # wrapping the linear module with FSDP - input = torch.rand((2, 1024), device=xm.xla_device()) + input = torch.rand((2, 1024), device=torch_xla.device()) output = model(input) loss = torch.sum(output) diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index bf4573713ec3..a3069a6637ec 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -63,7 +63,7 @@ def test_xla_backend_exists(self): self.assertIsNotNone(pg_xla_creator) def test_allreduce(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' dist.all_reduce(tensor) @@ -72,7 +72,7 @@ def test_allreduce(self): @patch_world(rank=3, size=6) def test_allreduce_with_mesh(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_options = {'xla_pg_options': {'spmd': True}} @@ -89,7 +89,7 @@ def test_allreduce_with_mesh(self): @patch_world(rank=3, size=8) def test_allgather(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)] all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' @@ -99,7 +99,7 @@ def test_allgather(self): @patch_world(rank=3, size=8) def test_all_scalar_allgather(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.zeros((), device=device) + 1 + 2 * dist.get_rank() output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)] all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' @@ -109,7 +109,7 @@ def test_all_scalar_allgather(self): @patch_world(rank=3, size=8) def test_allgather_coalesced(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() pg_xla = get_process_group_xla(rank=3, size=8) @@ -127,7 +127,7 @@ def test_allgather_coalesced(self): hlo_matches(hlo, all_gather_pattern) def test_broadcast(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' dist.broadcast(tensor, 0) @@ -136,7 +136,7 @@ def test_broadcast(self): # Needed for ZeRO stage 1 def test_reduce_scatter(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() input_list = [tensor] output = torch.zeros_like(tensor) @@ -148,7 +148,7 @@ def test_reduce_scatter(self): @skipIf(xr.device_type() == 'CPU', "UNIMPLEMENTED: ReduceScatter is not implemented on CPU.") def test_reduce_scatter_coalesced(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() input_tensors_list = [[tensor, tensor], [tensor2, tensor2]] @@ -168,7 +168,7 @@ def test_reduce_scatter_coalesced(self): @patch_world(0, 6) def test_send(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() input_list = [tensor] @@ -185,11 +185,11 @@ def test_send(self): hlo_matches(hlo, senddone_pattern) # Don't try to run Send on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) @patch_world(0, 6) def test_recv(self): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() with mock.patch.object( @@ -205,7 +205,7 @@ def test_recv(self): hlo_matches(hlo, recvdone_pattern) # Don't try to run Recv on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) @patch_world(rank=0, size=12) def test_new_group_no_ranks(self): @@ -365,7 +365,7 @@ def test_barrier(self): 'monitored_barrier', ) def test_unimplemented_op(self, op): - device = xm.xla_device() + device = torch_xla.device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_xla = dist.group.WORLD self.assertIsInstance(pg_xla, diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index bec580c3831e..efb34a2cc3af 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -250,7 +250,7 @@ def train_imagenet(): torch.manual_seed(42) - device = xm.xla_device() + device = torch_xla.device() model = get_model_property('model_fn')().to(device) # Initialization is nondeterministic with multiple threads in PjRt. diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index 0ab5e1fd8007..290857281fd7 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -194,7 +194,7 @@ def train_imagenet(): torch.manual_seed(42) - device = xm.xla_device() + device = torch_xla.device() device_hw = xm.xla_device_hw(device) model = get_model_property('model_fn')().to(device) writer = None @@ -229,7 +229,7 @@ def train_loop_fn(loader, epoch): for step, (data, target) in enumerate(loader): optimizer.zero_grad() if FLAGS.amp: - with autocast(xm.xla_device()): + with autocast(torch_xla.device()): output = model(data) loss = loss_fn(output, target) if scaler: diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index 8c9be15ac2e2..1d939d8385b3 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -241,7 +241,7 @@ def train_imagenet(): torch.manual_seed(42) - device = xm.xla_device() + device = torch_xla.device() model = get_model_property('model_fn')() # Automatic wrapping sub-modules with inner FSDP auto_wrap_policy = None diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 9e470719f27b..0a5e46fdcd1f 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -130,7 +130,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() model = MNIST().to(device) # Initialization is nondeterministic with multiple threads in PjRt. diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 3fa8770f1a89..0bd393b21f2e 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -130,7 +130,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() device_hw = xm.xla_device_hw(device) model = MNIST().to(device) diff --git a/test/test_train_mp_mnist_fsdp_with_ckpt.py b/test/test_train_mp_mnist_fsdp_with_ckpt.py index 169bfe264a3c..833612a2be49 100644 --- a/test/test_train_mp_mnist_fsdp_with_ckpt.py +++ b/test/test_train_mp_mnist_fsdp_with_ckpt.py @@ -164,7 +164,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() model = MNIST() # Automatic wrapping sub-modules with inner FSDP auto_wrap_policy = None diff --git a/test/test_train_mp_mnist_zero1.py b/test/test_train_mp_mnist_zero1.py index 77e284c98d76..523bf5fc0a19 100644 --- a/test/test_train_mp_mnist_zero1.py +++ b/test/test_train_mp_mnist_zero1.py @@ -114,7 +114,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = xm.xla_device() + device = torch_xla.device() model = MNIST().to(device) writer = None diff --git a/test/test_user_computation_debug_cache.py b/test/test_user_computation_debug_cache.py index 88467de3c51b..f83f856c2cfd 100644 --- a/test/test_user_computation_debug_cache.py +++ b/test/test_user_computation_debug_cache.py @@ -40,7 +40,7 @@ def input_scope_0(tensor): def input_scope_1(tensor): return [torch.sin(tensor), torch.cos(tensor)] - device = xm.xla_device() + device = torch_xla.device() init_tensor = torch.tensor(10).to(device) def create_user_computation(fn): diff --git a/test/test_utils.py b/test/test_utils.py index f238f4c82540..6a913f932e4d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -384,7 +384,7 @@ def compareResults(self, results, xla_results, rel_err=1e-2, abs_err=1e-5): def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5): if device is None: - device = xm.xla_device() + device = torch_xla.device() tensors = xu.as_list(tensors) xla_tensors = [ x.to(device).detach().requires_grad_(x.requires_grad) for x in tensors diff --git a/test/test_while_loop.py b/test/test_while_loop.py index e8ea617b0f96..4dc0a17a96ea 100644 --- a/test/test_while_loop.py +++ b/test/test_while_loop.py @@ -26,7 +26,7 @@ def _fake_while_loop(cond_fn, body_fn, operands): class WhileLoopTest(unittest.TestCase): def test_while_loop_addition(self): - device = xm.xla_device() + device = torch_xla.device() def cond_fn(iteri, x): return iteri > 0 @@ -41,7 +41,7 @@ def body_fn(iteri, x): self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) def test_while_loop_addition_nested(self): - device = xm.xla_device() + device = torch_xla.device() def cond_fn(iteri, x): return iteri > 0 @@ -56,7 +56,7 @@ def body_fn(iteri, x): self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) def test_while_loop_simple_linear_inside_loop(self): - device = xm.xla_device() + device = torch_xla.device() torch.set_grad_enabled(False) class SimpleLinear(torch.nn.Module): @@ -94,7 +94,7 @@ def forward_without_while_loop_op(self, iteri, x): # ====== fori_loop ====== @unittest.skip("Fori_loop is not supported now due to unstable result.") def test_fori_loop_addition(self): - device = xm.xla_device() + device = torch_xla.device() lower = torch.tensor(0, device=device) upper = torch.tensor(50, device=device) diff --git a/test/test_xla_graph_execution.py b/test/test_xla_graph_execution.py index d8aa33f39aee..bbf34321f5c6 100644 --- a/test/test_xla_graph_execution.py +++ b/test/test_xla_graph_execution.py @@ -21,7 +21,7 @@ class TestXlaGraphExecution(test_utils.XlaTestCase): def test_graph_execution_allowed(self): torch_xla._XLAC._set_allow_execution(True) - x = torch.ones(2, device=xm.xla_device()) + x = torch.ones(2, device=torch_xla.device()) self.assertEqual(x[0], 1.0) # This should trigger the checking del x @@ -30,7 +30,7 @@ def test_graph_execution_disallowed_with_error(self): # Trigger runtime error for unexpected graph execution torch_xla._XLAC._set_allow_execution( False) # this flag disallows graph execution - x = torch.ones(2, device=xm.xla_device()) + x = torch.ones(2, device=torch_xla.device()) with self.assertRaises(RuntimeError) as e: self.assertEqual(x[0], 1.0) # This should trigger the checking self.assertIn( diff --git a/test/test_zero1.py b/test/test_zero1.py index e3dc5738f846..8bb2fbc3d822 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -34,7 +34,7 @@ class XlaZeRO1Test(test_utils.XlaTestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") def test_zero1(self): - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(32, 32) x = torch.ones((32, 32)) @@ -89,7 +89,7 @@ def test_zero1(self): torch_xla.sync() def test_zero1_load(self): - device = xm.xla_device() + device = torch_xla.device() model = nn.Linear(32, 32) x = torch.ones((32, 32)) @@ -153,7 +153,7 @@ def test_zero1_load(self): def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/torch_distributed/test_ddp.py b/test/torch_distributed/test_ddp.py index d4e3fc77c7f2..1d91f520d5aa 100644 --- a/test/torch_distributed/test_ddp.py +++ b/test/torch_distributed/test_ddp.py @@ -24,7 +24,7 @@ def _ddp_correctness(rank, gradient_as_bucket_view: bool = False): # We cannot run this guard before XMP, # see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing. - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) not in ('TPU', 'CUDA'): print( 'Default device {} is not a TPU device'.format(device), diff --git a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py index 84b950becde6..7c30b211ad49 100644 --- a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py index b354ec3d57a0..2fd71d2ed84e 100644 --- a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() dist.init_process_group('xla', init_method='xla://') diff --git a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py index 3d5736b0ec43..c462f7552800 100644 --- a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_fsdp_meta.py b/test/torch_distributed/test_torch_distributed_fsdp_meta.py index 182a7818ecbb..9b7c77011242 100644 --- a/test/torch_distributed/test_torch_distributed_fsdp_meta.py +++ b/test/torch_distributed/test_torch_distributed_fsdp_meta.py @@ -60,7 +60,7 @@ def _init_with_reset_params(module): """ is_meta = any(t.is_meta for t in module.parameters()) if is_meta: - module.to_empty(device=xm.xla_device()) + module.to_empty(device=torch_xla.device()) with torch.no_grad(): module.reset_parameters() @@ -87,7 +87,7 @@ def _compare_fsdp(self, fsdp1, fsdp2): def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): # Create model on meta device and wrap with FSDP. model = meta_module_fn() - inp = torch.randn(10, 2, device=xm.xla_device()) + inp = torch.randn(10, 2, device=torch_xla.device()) fsdp_meta = XlaFullyShardedDataParallel( model, @@ -99,7 +99,7 @@ def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): meta_opt.step() torch_xla.sync() - regular = MyModel(device=xm.xla_device()) + regular = MyModel(device=torch_xla.device()) fsdp_regular = XlaFullyShardedDataParallel( regular, auto_wrap_policy=always_wrap) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) @@ -127,7 +127,7 @@ def meta_module_fn(): def test_simple_model_with_torchdistX_init_fn(self): def meta_module_fn(): - return deferred_init.deferred_init(MyModel, device=xm.xla_device()) + return deferred_init.deferred_init(MyModel, device=torch_xla.device()) self._test_simple_model_with_meta_device( meta_module_fn, init_fn=_init_with_torchdistX) @@ -135,13 +135,13 @@ def meta_module_fn(): def test_simple_model_with_default_torchdistX(self): def meta_module_fn(): - return deferred_init.deferred_init(MyModel, device=xm.xla_device()) + return deferred_init.deferred_init(MyModel, device=torch_xla.device()) self._test_simple_model_with_meta_device(meta_module_fn) def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() # This test fails on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) if xm.xla_device_hw(device) in ('TPU', 'NEURON'): dist.init_process_group('xla', init_method='xla://') diff --git a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py index 8ca45141350e..9089f9d799ff 100644 --- a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py index 36e6420dce10..006d3fd33a95 100644 --- a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = xm.xla_device() + device = torch_xla.device() if xm.xla_device_hw(device) in ('TPU', 'CUDA'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/utils/train_spmd_linear_model.py b/test/utils/train_spmd_linear_model.py index 53ca0c6cc6dd..e2bcb6124f87 100644 --- a/test/utils/train_spmd_linear_model.py +++ b/test/utils/train_spmd_linear_model.py @@ -69,7 +69,7 @@ def forward(self, x): def train(): - device = xm.xla_device() + device = torch_xla.device() torch.manual_seed(42) model = SimpleLinear().to(device) print('===> Preparing data..') @@ -148,5 +148,5 @@ def train_and_evaluate(): xr.use_spmd(auto=FLAGS.auto_spmd) print('Start training loop...') losses, m = train() - t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device()) + t = torch.randn(10, FLAGS.input_dim).to(torch_xla.device()) return [loss.cpu() for loss in losses], m(t).cpu() diff --git a/test/utils/train_spmd_linear_model_grad_acc.py b/test/utils/train_spmd_linear_model_grad_acc.py index b3c107770ae8..294309d62ed6 100644 --- a/test/utils/train_spmd_linear_model_grad_acc.py +++ b/test/utils/train_spmd_linear_model_grad_acc.py @@ -77,7 +77,7 @@ def forward(self, x): def train(): - device = xm.xla_device() + device = torch_xla.device() num_devices = xr.global_runtime_device_count() print(f'num_devices: {num_devices}') # Define a mesh with all devices along one axis @@ -182,6 +182,6 @@ def train_and_evaluate_grad_acc(): xr.use_spmd(auto=FLAGS.auto_spmd) print('Start training loop...') losses, m = train() - t = torch.randn(10, FLAGS.input_dim).to(xm.xla_device()) + t = torch.randn(10, FLAGS.input_dim).to(torch_xla.device()) m(t).cpu() return [loss.cpu() for loss in losses] diff --git a/torch_xla/_dynamo/dynamo_bridge.py b/torch_xla/_dynamo/dynamo_bridge.py index ac7d9d906ff9..ce11cc07b2cf 100644 --- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -495,7 +495,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule, # 2. All of the pending IRs are result of our warm up cache tracing and they # should be removed to avoid extra computation executed and in place updates op # mistakenlly update the input tensors. - torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, @@ -564,7 +564,7 @@ def optimized_mod(*args: tuple): is_cuda_args = original_device.type == "cuda" if is_cuda_args: - args = _maybe_move_tensors_to_device(args, xm.xla_device()) + args = _maybe_move_tensors_to_device(args, torch_xla.device()) if not config.skip_input_data_check: # `torch_xla.sync()` needs to be blocking since we want to access args's @@ -761,7 +761,7 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args, # UnsupportedNodesCollector might trigger in place ops, need to clear them here. _clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args) - torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): @@ -805,7 +805,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args): if _args_on_cuda(xla_args): - xla_args = tuple(_maybe_move_tensors_to_device(xla_args, xm.xla_device())) + xla_args = tuple( + _maybe_move_tensors_to_device(xla_args, torch_xla.device())) # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its # value reference before actually computing it. diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 6c1c4c26a392..25a0ee36c36e 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -104,7 +104,7 @@ def initialize_singleprocess(): plugins.default().configure_single_process() elif runtime.device_type() == 'TPU': tpu.configure_one_chip_topology() - xm.set_replication(xm.xla_device(), []) + xm.set_replication(torch_xla.device(), []) def initialize_multiprocess(local_rank: int, local_world_size: int): @@ -119,7 +119,7 @@ def initialize_multiprocess(local_rank: int, local_world_size: int): neuron.initialize_env(local_rank, local_world_size) devices = xm.get_xla_supported_devices() - xm.set_replication(xm.xla_device(), devices) + xm.set_replication(torch_xla.device(), devices) def run_multiprocess(fn: Callable[..., R], diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index f7118dd7b3f3..079b1d49f6b2 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -312,7 +312,7 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str: if xr.is_spmd(): return _spmd_find_master_ip(worker_ips[current_worker_id]) - t = torch.tensor([current_worker_id], device=xm.xla_device()) + t = torch.tensor([current_worker_id], device=torch_xla.device()) xm.collective_broadcast([t]) torch_xla.sync() diff --git a/torch_xla/amp/syncfree/adam.py b/torch_xla/amp/syncfree/adam.py index 4201933ca590..abb2bb55d2d4 100644 --- a/torch_xla/amp/syncfree/adam.py +++ b/torch_xla/amp/syncfree/adam.py @@ -94,7 +94,7 @@ def step(self, closure=None, found_inf: Tensor = None): p, memory_format=torch.preserve_format) else: state['max_exp_avg_sq'] = torch.empty( - 0, dtype=torch.float, device=xm.xla_device()) + 0, dtype=torch.float, device=torch_xla.device()) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/amp/syncfree/adamw.py b/torch_xla/amp/syncfree/adamw.py index 83e11d46fad9..b1abe5bbda8c 100644 --- a/torch_xla/amp/syncfree/adamw.py +++ b/torch_xla/amp/syncfree/adamw.py @@ -92,7 +92,7 @@ def step(self, closure=None, found_inf: Tensor = None): p, memory_format=torch.preserve_format) else: state['max_exp_avg_sq'] = torch.empty( - 0, dtype=torch.float, device=xm.xla_device()) + 0, dtype=torch.float, device=torch_xla.device()) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/core/xla_op_registry.py b/torch_xla/core/xla_op_registry.py index aba1c7076c39..62943f4c70c5 100644 --- a/torch_xla/core/xla_op_registry.py +++ b/torch_xla/core/xla_op_registry.py @@ -68,7 +68,7 @@ def slice_and_add(a, b, dimno=0): SLICE_AND_ADD = xor.register('slice_and_add', slice_and_add) def user_computation_test(): - device = xm.xla_device() + device = torch_xla.device() x = torch.randn(2, 2).to(device) y = torch.randn(2, 2).to(device) z = SLICE_AND_ADD(x, y, dimno=0) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 0bf6a7fcda56..f78531dd8eee 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1232,7 +1232,7 @@ void BuildLoweringContextSubmodule(py::module* m) { * import torch_xla * import torch_xla.core.xla_model as xm * - * device = xm.xla_device() + * device = torch_xla.device() * example = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device) * * def network(x): diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index af2b1246baf3..ce1b342cdf17 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -139,7 +139,7 @@ class XlaFullyShardedDataParallel(nn.Module): module (nn.Module): module to be wrapped with FSDP. If the input module's parameters and buffers are not already on XLA device, they will be cast to - ``xm.xla_device()`` (after sharding) during FSDP initialization. + ``torch_xla.device()`` (after sharding) during FSDP initialization. reshard_after_forward (bool, Optional): if ``True``, reshard parameters after the forward pass. This saves memory but slows training. This is only relevant when resharding @@ -527,7 +527,7 @@ def __init__( List[Parameter], self._fsdp_wrapped_module.flat_params) + non_flatten_params - self.xla_device = xm.xla_device() + self.xla_device = torch_xla.device() # Shard module parameters in place self._shard_parameters_(params_to_shard) # Cast the module buffers to the specified buffer_dtype @@ -1014,7 +1014,7 @@ def _dummy_forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: A dummy forward pass with minimal computation that sums all inputs and full parameters, e.g. to debug parameter memory consumption. """ - outputs = torch.zeros(1, device=xm.xla_device()) + outputs = torch.zeros(1, device=torch_xla.device()) for t in chain(args, kwargs.values(), self.full_params): if isinstance(t, torch.Tensor) and t.dtype == torch.float32: outputs = outputs + t.mean() @@ -1646,7 +1646,7 @@ def _print_r0(self, msg: str, restart: bool = False) -> None: if restart: self._tstart = time.time() if self.rank == 0: - memory_info = xm.get_memory_info(xm.xla_device()) + memory_info = xm.get_memory_info(torch_xla.device()) gb_free = memory_info["kb_free"] / 1024 / 1024 gb_total = memory_info["kb_total"] / 1024 / 1024 logging.info( diff --git a/torch_xla/distributed/spmd/api.py b/torch_xla/distributed/spmd/api.py index 3c6dcff14e05..77ff9e9ac6ee 100644 --- a/torch_xla/distributed/spmd/api.py +++ b/torch_xla/distributed/spmd/api.py @@ -215,10 +215,10 @@ def xla_distribute_module( if partition_fn: if getattr(partition_fn, '__name__', 'unknown') == "auto_policy": # TODO(yeounoh) allow pre-loading to xla device in the future. - assert next(module.parameters()).device != xm.xla_device(), \ + assert next(module.parameters()).device != torch_xla.device(), \ f"Currently requires module to be on cpu, before xla_distribute_module." xr.use_spmd(auto=True) - module = module.to(xm.xla_device()) + module = module.to(torch_xla.device()) else: # apply partition_fun to submodules for name, submod in module.named_modules(): diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index cb61158df903..b7089ffc498e 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -334,7 +334,7 @@ def _get_physical_tpu_mesh(self, devices: np.ndarray) -> np.ndarray: A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On v2 and v3, global_z is instead cores_per_chip (i.e., 2). """ - assert xm.xla_device_hw(xm.xla_device()) == 'TPU' + assert xm.xla_device_hw(torch_xla.device()) == 'TPU' # coords is a 3-dims tuple representing the device in physical mesh device_coords = [self.device_attributes[d]['coords'] for d in devices] dims = tuple(d + 1 for d in max(device_coords)) @@ -595,9 +595,9 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, >>> num_devices = xr.global_runtime_device_count() >>> device_ids = np.array(range(num_devices)) >>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) - >>> input = torch.randn(8, 32).to(xm.xla_device()) + >>> input = torch.randn(8, 32).to(torch_xla.device()) >>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel - >>> linear = nn.Linear(32, 10).to(xm.xla_device()) + >>> linear = nn.Linear(32, 10).to(torch_xla.device()) >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel """ # We only allow fully specified `partition_spec` to be applicable, as opposed @@ -793,7 +793,7 @@ def can_apply(self, t: torch.Tensor) -> bool: def apply(self, t: torch.Tensor): # TODO(yeounoh) use virtual device interface when available. - assert (t.device == xm.xla_device()) + assert (t.device == torch_xla.device()) mark_sharding(t, self.mesh, self.partition_spec) diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index d699abaebafb..e3b349a4b7fb 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -56,7 +56,7 @@ class MpModelWrapper(object): WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork()) def _mp_fn(index, ...): - device = xm.xla_device() + device = torch_xla.device() model = WRAPPED_MODEL.to(device) ... diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index ca54521c8bc0..ea4c8d54c1a2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -16,7 +16,7 @@ def fori_loop(lower, upper, body_fun, *input_value): - device = xm.xla_device() + device = torch_xla.device() if (upper < lower): print("ERROR: upper should be a larger number than lower") iteri = upper - lower diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index d11141db58db..ea2a097d5974 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -129,7 +129,7 @@ def scan(fn, init, xs): >>> y = new_carry >>> return new_carry, y >>> - >>> with xm.xla_device(): + >>> with torch_xla.device(): >>> init = torch.tensor([0.0, 0.0], requires_grad=True) >>> xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], >>> requires_grad=True) diff --git a/torch_xla/experimental/scan_layers.py b/torch_xla/experimental/scan_layers.py index 8ddafed82024..0be37b363909 100644 --- a/torch_xla/experimental/scan_layers.py +++ b/torch_xla/experimental/scan_layers.py @@ -51,7 +51,7 @@ def scan_layers(layers: Iterable[torch.nn.Module], >>> import torch >>> import torch.nn as nn >>> from torch_xla.experimental.scan_layers import scan_layers - >>> with xm.xla_device(): + >>> with torch_xla.device(): >>> layers = [nn.Linear(16, 16) for i in range(10)] >>> input = torch.randn(16) >>> output = scan_layers(layers, input) diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index 0b198785d76b..e0e7f5ba6027 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -109,7 +109,7 @@ def __init__( # Let's move the module to xla device in case it's not moved # by the caller already. - self._orig_module = module.to(xm.xla_device()) + self._orig_module = module.to(torch_xla.device()) self._mesh = mesh # Only handle params which are not already sharded. This enables diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index e9dbb9d48241..53aa7399c230 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -97,7 +97,7 @@ def is_bf16_supported(): """Returns whether torch.bfloat16 is supported on this environment. """ try: - torch.tensor([1.], dtype=torch.bfloat16, device=xm.xla_device()) + torch.tensor([1.], dtype=torch.bfloat16, device=torch_xla.device()) return True except Exception as e: return False @@ -156,7 +156,7 @@ def local_ordinal() -> int: Local ordinal is in range [0, local_device_count).""" local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0) devices_per_process = addressable_device_count() - return local_rank * devices_per_process + xm.xla_device().index + return local_rank * devices_per_process + torch_xla.device().index def process_index() -> int: diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 6b3e25584b4c..b88a8131b2d8 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -341,7 +341,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, assert len(kwargs) == 0, "Export to stablehlo doesnt support kwargs yet." - device = xm.xla_device() + device = torch_xla.device() _flat_input_args = exported_model._graph_module_flat_inputs(args, {}) _flat_input_args = pytree.tree_map_only(torch.Tensor, @@ -352,7 +352,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, torch_xla.sync() xm.wait_device_ops() metrics.clear_counters() - device = xm.xla_device() + device = torch_xla.device() # Run the fx graph tracing using lazy tensor if options.inline_all_constant: From 9d7d365dad62d8810765b3fab1d410ffc9681c94 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Fri, 23 May 2025 19:28:10 +0000 Subject: [PATCH 4/9] Replace `device=torch_xla.device()` with `device="xla"` --- API_GUIDE.md | 10 +-- benchmarks/matmul_bench.py | 7 +- ...ributed-pytorch-xla-basics-with-pjrt.ipynb | 18 ++--- docs/source/learn/_pjrt.md | 2 +- docs/source/learn/pytorch-on-xla-devices.md | 10 +-- docs/source/learn/troubleshoot.md | 4 +- docs/source/tutorials/precision_tutorial.py | 4 +- examples/scan/scan_examples.py | 10 +-- test/ds/test_dynamic_shapes.py | 6 +- test/dynamo/test_dynamo.py | 4 +- test/metrics_compare_utils_test.py | 2 +- test/neuron/test_neuron_data_types.py | 4 +- test/pjrt/test_metrics.py | 2 +- test/pjrt/test_runtime_multi_cpu.py | 6 +- test/pjrt/test_runtime_multi_gpu.py | 2 +- test/pjrt/test_runtime_tpu.py | 4 +- test/pjrt/test_torchrun.py | 12 ++-- test/scan/test_scan.py | 14 ++-- test/scan/test_scan_debug.py | 6 +- test/scan/test_scan_layers.py | 6 +- test/spmd/test_dtensor_integration.py | 5 +- test/spmd/test_xla_sharding.py | 26 ++++---- test/spmd/test_xla_virtual_device.py | 16 ++--- test/test_callback.py | 4 +- test/test_core_aten_ops.py | 2 +- test/test_data_type.py | 4 +- test/test_dynamic_shapes_detector.py | 16 ++--- test/test_hlo_metadata.py | 6 +- test/test_metrics.py | 4 +- test/test_operations.py | 65 +++++++++---------- test/test_operations_hlo.py | 2 +- test/test_pallas.py | 2 +- ...st_torch_distributed_fsdp_frozen_weight.py | 2 +- test/test_triton.py | 6 +- test/test_xla_graph_execution.py | 4 +- .../test_torch_distributed_fsdp_meta.py | 10 +-- torch_xla/_internal/tpu.py | 2 +- torch_xla/amp/syncfree/adam.py | 2 +- torch_xla/amp/syncfree/adamw.py | 2 +- .../fsdp/xla_fully_sharded_data_parallel.py | 2 +- torch_xla/experimental/scan.py | 2 +- torch_xla/runtime.py | 2 +- 42 files changed, 148 insertions(+), 171 deletions(-) diff --git a/API_GUIDE.md b/API_GUIDE.md index bb8895c1774b..30bb9fddcee1 100644 --- a/API_GUIDE.md +++ b/API_GUIDE.md @@ -15,7 +15,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -t = torch.randn(2, 2, device=torch_xla.device()) +t = torch.randn(2, 2, device='xla') print(t.device) print(t) ``` @@ -32,8 +32,8 @@ PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors For example, XLA tensors can be added together: ```python -t0 = torch.randn(2, 2, device=torch_xla.device()) -t1 = torch.randn(2, 2, device=torch_xla.device()) +t0 = torch.randn(2, 2, device='xla') +t1 = torch.randn(2, 2, device='xla') print(t0 + t1) ``` @@ -46,7 +46,7 @@ print(t0.mm(t1)) Or used with neural network modules: ```python -l_in = torch.randn(10, device=torch_xla.device()) +l_in = torch.randn(10, device='xla') linear = torch.nn.Linear(10, 20).to(torch_xla.device()) l_out = linear(l_in) print(l_out) @@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like ```python -l_in = torch.randn(10, device=torch_xla.device()) +l_in = torch.randn(10, device='xla') linear = torch.nn.Linear(10, 20) l_out = linear(l_in) print(l_out) diff --git a/benchmarks/matmul_bench.py b/benchmarks/matmul_bench.py index 661963bc358c..833dee1a41bc 100644 --- a/benchmarks/matmul_bench.py +++ b/benchmarks/matmul_bench.py @@ -39,10 +39,7 @@ def main(): """ xla_bench_fn = lambda fn: do_bench( - fn, - return_mode='min', - sync_fn=lambda: xm.wait_device_ops(), - device=torch_xla.device()) + fn, return_mode='min', sync_fn=lambda: xm.wait_device_ops(), device='xla') ind_bench_fn = lambda fn: do_bench( fn, return_mode='min', @@ -53,7 +50,7 @@ def main(): for dtype in dtypes: for inductor_matmul, xla_matmul in zip( get_matmuls(device='cuda', dtype=dtype, backend='inductor'), - get_matmuls(device=torch_xla.device(), dtype=dtype, backend='openxla')): + get_matmuls(device='xla', dtype=dtype, backend='openxla')): ind_lhs_shape, ind_rhs_shape, ind_fn = inductor_matmul xla_lhs_shape, xla_rhs_shape, xla_fn = xla_matmul assert ind_lhs_shape == xla_lhs_shape, f"Expect matmul shapes to match for benchmarking. Mismatch lhs: {ind_lhs_shape}, rhs: {xla_rhs_shape}" diff --git a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb index 6f0f06d7c146..8d4fbd95bff7 100644 --- a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb +++ b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb @@ -273,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-01-10T19:30:33.219878Z", @@ -318,12 +318,12 @@ ], "source": [ "def add_ones(i, lock):\n", - " x = torch.ones((3, 3), device=torch_xla.device())\n", + " x = torch.ones((3, 3), device='xla')\n", " y = x + x\n", - " \n", + "\n", " # Run graph to compute `y` before printing\n", " torch_xla.sync()\n", - " \n", + "\n", " with lock:\n", " print(i, y)\n", "\n", @@ -340,7 +340,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-01-10T19:30:35.656796Z", @@ -378,10 +378,10 @@ "source": [ "def gather_ids(i, lock):\n", " # Create a tensor on each device with the device ID\n", - " t = torch.tensor([i], device=torch_xla.device())\n", + " t = torch.tensor([i], device='xla')\n", " with lock:\n", " print(i, t)\n", - " \n", + "\n", " # Collect and concatenate the IDs\n", " ts = xm.all_gather(t)\n", " torch_xla.sync()\n", @@ -402,7 +402,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-01-10T19:30:38.315927Z", @@ -479,7 +479,7 @@ " loss.backward()\n", "\n", " optimizer.step()\n", - " \n", + "\n", " # Run the pending graph\n", " torch_xla.sync()\n", "\n", diff --git a/docs/source/learn/_pjrt.md b/docs/source/learn/_pjrt.md index 91917b115a5d..edaa56ecee72 100644 --- a/docs/source/learn/_pjrt.md +++ b/docs/source/learn/_pjrt.md @@ -377,7 +377,7 @@ def _all_gather(index: int): # No need to pass in `rank` or `world_size` dist.init_process_group('xla', init_method='xla://') - t = torch.tensor([index], dtype=torch.int32, device=torch_xla.device()) + t = torch.tensor([index], dtype=torch.int32, device='xla') output = [torch.zeros_like(t) for _ in range(dist.get_world_size())] dist.all_gather(output, t) diff --git a/docs/source/learn/pytorch-on-xla-devices.md b/docs/source/learn/pytorch-on-xla-devices.md index 25328f0ab0a0..2e430d9d492e 100644 --- a/docs/source/learn/pytorch-on-xla-devices.md +++ b/docs/source/learn/pytorch-on-xla-devices.md @@ -14,7 +14,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -t = torch.randn(2, 2, device=torch_xla.device()) +t = torch.randn(2, 2, device='xla') print(t.device) print(t) ``` @@ -32,8 +32,8 @@ tensors. For example, XLA tensors can be added together: ``` python -t0 = torch.randn(2, 2, device=torch_xla.device()) -t1 = torch.randn(2, 2, device=torch_xla.device()) +t0 = torch.randn(2, 2, device='xla') +t1 = torch.randn(2, 2, device='xla') print(t0 + t1) ``` @@ -46,7 +46,7 @@ print(t0.mm(t1)) Or used with neural network modules: ``` python -l_in = torch.randn(10, device=torch_xla.device()) +l_in = torch.randn(10, device='xla') linear = torch.nn.Linear(10, 20).to(torch_xla.device()) l_out = linear(l_in) print(l_out) @@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on the same device. So code like ``` python -l_in = torch.randn(10, device=torch_xla.device()) +l_in = torch.randn(10, device='xla') linear = torch.nn.Linear(10, 20) l_out = linear(l_in) print(l_out) diff --git a/docs/source/learn/troubleshoot.md b/docs/source/learn/troubleshoot.md index 67497bfa5f09..f911fd2579cf 100644 --- a/docs/source/learn/troubleshoot.md +++ b/docs/source/learn/troubleshoot.md @@ -32,8 +32,8 @@ vm:~$ export PJRT_DEVICE=TPU vm:~$ python3 >>> import torch >>> import torch_xla.core.xla_model as xm ->>> t1 = torch.tensor(100, device=torch_xla.device()) ->>> t2 = torch.tensor(200, device=torch_xla.device()) +>>> t1 = torch.tensor(100, device='xla') +>>> t2 = torch.tensor(200, device='xla') >>> print(t1 + t2) tensor(300, device='xla:0') ``` diff --git a/docs/source/tutorials/precision_tutorial.py b/docs/source/tutorials/precision_tutorial.py index 7d883835e96d..5126a74824b5 100644 --- a/docs/source/tutorials/precision_tutorial.py +++ b/docs/source/tutorials/precision_tutorial.py @@ -168,9 +168,9 @@ def fp32_to_binary_fraction(fp32_float: float) -> str: def get_rand_matrix(): """Returns a diagonal matrix of shape 1024, 1024, values between 0.999 and 1.111""" - eye = torch.eye(1024, dtype=torch.float32, device="xla") + eye = torch.eye(1024, dtype=torch.float32, device='xla') rand_ = torch.rand( - (1024, 1024), dtype=torch.float32, device="xla") * 0.2 + 0.9 + (1024, 1024), dtype=torch.float32, device='xla') * 0.2 + 0.9 result = eye * rand_ assert torch.nonzero(result).size(0) == 1024, torch.nonzero(result).size(0) return result diff --git a/examples/scan/scan_examples.py b/examples/scan/scan_examples.py index 5a4097d029ee..211f382a86e7 100644 --- a/examples/scan/scan_examples.py +++ b/examples/scan/scan_examples.py @@ -18,8 +18,8 @@ def cumsum(accumulated, element): return accumulated, accumulated # 2) Define an initial carry and the input tensor. - init_sum = torch.tensor([0.0], device=torch_xla.device()) - xs = torch.tensor([1.0, 2.0, 3.0], device=torch_xla.device()) + init_sum = torch.tensor([0.0], device='xla') + xs = torch.tensor([1.0, 2.0, 3.0], device='xla') torch_xla.sync() # 3) Call `scan` with our combine function, initial carry, and input tensor. @@ -40,15 +40,15 @@ def scan_example_pytree(): # - 'sum' to accumulate the sum of all seen values # - 'count' to count how many values have been seen carry = { - 'sum': torch.tensor([0.0], device=torch_xla.device()), - 'count': torch.tensor([0.0], device=torch_xla.device()) + 'sum': torch.tensor([0.0], device='xla'), + 'count': torch.tensor([0.0], device='xla') } # 2) Define our input PyTree, which in this case is just a dictionary with one leaf: # - 'values' is a 1D tensor representing data points we want to scan over. xs = { 'values': - torch.arange(1, 6, dtype=torch.float32, device=torch_xla.device()) + torch.arange(1, 6, dtype=torch.float32, device='xla') } # Here, xs['values'] has shape [5]. The `scan` function will automatically slice diff --git a/test/ds/test_dynamic_shapes.py b/test/ds/test_dynamic_shapes.py index 3d1f8bb28fbd..57119ac88d6f 100644 --- a/test/ds/test_dynamic_shapes.py +++ b/test/ds/test_dynamic_shapes.py @@ -163,7 +163,7 @@ def test_t_copy(self): self.assertEqual(t2_t.shape[1], 7) def test_nonzero_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device='xla') x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.nonzero(x, as_tuple=False), 0) self.assertEqual(x_dim0_shape.item(), 4) @@ -176,14 +176,14 @@ def test_nonzero_correctness(self): self.assertEqual(t2.cpu(), t2_aten) def test_masked_select_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device='xla') mask = x.ge(2) x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.masked_select(x, mask), 0) self.assertEqual(x_dim0_shape.item(), 3) def test_nonzero_cast(self): - t1 = torch.ones(5, 2, device=torch_xla.device()) + t1 = torch.ones(5, 2, device='xla') # Result of the nonzero should be the index type. Currently # index type is s64 on cpu and gpu, but s32 on TPU. We should be # able to cast it to any other type without error. diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 01b8085e02dd..3241dc9372b3 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -49,7 +49,7 @@ def inplace_update(self, a): def test_inplace_update_correctness(self, backend): dynamo_inplace = torch.compile( self.inplace_update, backend=backend, fullgraph=True) - t = torch.tensor([0, 1, 2], device=torch_xla.device()) + t = torch.tensor([0, 1, 2], device='xla') for i in range(10): t = dynamo_inplace(t) self.assertTrue(torch.all(torch.eq(t.cpu(), torch.tensor([10, 11, 12])))) @@ -131,7 +131,7 @@ def dummy_fn(self, a): def test_dynamo_with_trace(self): dynamo_dummy = torch.compile( self.dummy_fn, backend="openxla", fullgraph=True) - t = torch.randn(2, 3, 4, device=torch_xla.device()) + t = torch.randn(2, 3, 4, device='xla') for i in range(10): with xp.Trace('build_graph'): t = dynamo_dummy(t) diff --git a/test/metrics_compare_utils_test.py b/test/metrics_compare_utils_test.py index cf2ed9beae73..fa84850811f5 100644 --- a/test/metrics_compare_utils_test.py +++ b/test/metrics_compare_utils_test.py @@ -275,7 +275,7 @@ def test_compare_metrics_reports_new_counters(self): def test_parse_real_metrics(self): print( 'Testing against TPU. If this hangs, check that $XRT_TPU_CONFIG is set') - x = torch.rand(3, 5, device=torch_xla.device()) + x = torch.rand(3, 5, device='xla') x = torch.flatten(x, 1) x = torch.roll(x, 1, 0) x = torch.flip(x, [0, 1]) diff --git a/test/neuron/test_neuron_data_types.py b/test/neuron/test_neuron_data_types.py index 326b3857794e..caecc745f551 100644 --- a/test/neuron/test_neuron_data_types.py +++ b/test/neuron/test_neuron_data_types.py @@ -9,8 +9,8 @@ class NeuronXlaDataTypeTest(unittest.TestCase): def _test_datatypes(self, dtype, op_xla_dtype, op): - t1 = torch.tensor([2, 3], dtype=dtype, device=torch_xla.device()) - t2 = torch.tensor([2, 3], dtype=dtype, device=torch_xla.device()) + t1 = torch.tensor([2, 3], dtype=dtype, device='xla') + t2 = torch.tensor([2, 3], dtype=dtype, device='xla') t3 = op(t1, t2) diff --git a/test/pjrt/test_metrics.py b/test/pjrt/test_metrics.py index 3ff7563f7c15..bbbb3b4491cd 100644 --- a/test/pjrt/test_metrics.py +++ b/test/pjrt/test_metrics.py @@ -27,7 +27,7 @@ def test_metrics_report(self): self.assertEmpty(met.metrics_report()) # Move a tensor to the XLA device and back - torch.rand(3, 3, device=torch_xla.device()).cpu() + torch.rand(3, 3, device='xla').cpu() metrics = met.metrics_report() self.assertNotEmpty(metrics) diff --git a/test/pjrt/test_runtime_multi_cpu.py b/test/pjrt/test_runtime_multi_cpu.py index 1b61f57d47bd..25c3280ce4b5 100644 --- a/test/pjrt/test_runtime_multi_cpu.py +++ b/test/pjrt/test_runtime_multi_cpu.py @@ -68,7 +68,7 @@ def backward(ctx, grad_output): results['device'] = str(torch_xla.device()) return grad_output - x = torch.ones(1, requires_grad=True, device=torch_xla.device()) + x = torch.ones(1, requires_grad=True, device='xla') y = _CustomBackwards.apply(x) y.backward() torch_xla.sync() @@ -110,7 +110,7 @@ def _hlo_dump(tmpdir: str): os.environ['XLA_SAVE_TENSORS_FMT'] = 'hlo' os.environ['XLA_SAVE_TENSORS_FILE'] = os.path.join(tmpdir, 'save.hlo') - x = torch.randn((3, 3), device=torch_xla.device()) + x = torch.randn((3, 3), device='xla') torch_xla.sync() x.cpu() @@ -124,7 +124,7 @@ def test_hlo_dump(self): @staticmethod def _all_reduce_hlo(): - ones = torch.ones((3, 3), device=torch_xla.device()) + ones = torch.ones((3, 3), device='xla') torch_xla.sync() reduced = xm.all_reduce(xm.REDUCE_SUM, ones) diff --git a/test/pjrt/test_runtime_multi_gpu.py b/test/pjrt/test_runtime_multi_gpu.py index e48185af0c0d..4a9f1a2c7a0a 100644 --- a/test/pjrt/test_runtime_multi_gpu.py +++ b/test/pjrt/test_runtime_multi_gpu.py @@ -125,7 +125,7 @@ def backward(ctx, grad_output): results['device'] = str(torch_xla.device()) return grad_output - x = torch.ones(1, requires_grad=True, device=torch_xla.device()) + x = torch.ones(1, requires_grad=True, device='xla') y = _CustomBackwards.apply(x) y.backward() torch_xla.sync() diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index a19e2323c4de..70dec3c90906 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -230,8 +230,8 @@ def _execute_time_metric(): begin = time.perf_counter_ns() value = ( - torch.randn(10000, 10000, device=torch_xla.device()) * - torch.randn(10000, 10000, device=torch_xla.device())) + torch.randn(10000, 10000, device='xla') * + torch.randn(10000, 10000, device='xla')) value_mean = value.mean() torch_xla.sync() cpu_value = value_mean.cpu() diff --git a/test/pjrt/test_torchrun.py b/test/pjrt/test_torchrun.py index 0024c189aa75..4c26e1abe534 100644 --- a/test/pjrt/test_torchrun.py +++ b/test/pjrt/test_torchrun.py @@ -26,9 +26,7 @@ def test_all_gather(self): expected_world_size = dist_world_size * devices_per_thread - rank = torch.tensor([dist.get_rank()], - dtype=torch.float32, - device=torch_xla.device()) + rank = torch.tensor([dist.get_rank()], dtype=torch.float32, device='xla') output = [rank.clone() for _ in range(expected_world_size)] dist.all_gather(output, rank) result = torch.concat(output) @@ -52,8 +50,7 @@ def test_all_reduce(self): expected = sum(tensors) xla_tensor = torch.arange( - 2, dtype=torch.int64, - device=torch_xla.device()) + 1 + 2 * dist.get_rank() + 2, dtype=torch.int64, device='xla') + 1 + 2 * dist.get_rank() dist.all_reduce(xla_tensor, op=dist.ReduceOp.SUM) torch_xla.sync() @@ -70,10 +67,9 @@ def test_reduce_scatter(self): world_size * world_size, dtype=torch.int64) expected = torch.split(tensor, world_size)[dist.get_rank()] - tensor_out = torch.zeros( - world_size, dtype=torch.int64, device=torch_xla.device()) + tensor_out = torch.zeros(world_size, dtype=torch.int64, device='xla') tensor_in = torch.arange( - world_size * world_size, dtype=torch.int64, device=torch_xla.device()) + world_size * world_size, dtype=torch.int64, device='xla') dist.reduce_scatter(tensor_out, [tensor_in], op=dist.ReduceOp.SUM) torch_xla.sync() diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index b61d8648fa2d..cbc2778fc679 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -273,7 +273,7 @@ def test_scan_external_in_place_mutation(self): giving wrong results. """ # TODO(yifeit): Modify this test when external in-place mutation is eventually supported. - weird_global = torch.tensor([0.0, 0.0], device=torch_xla.device()) + weird_global = torch.tensor([0.0, 0.0], device='xla') def step_fn(carry, x): new_carry = carry + x @@ -281,9 +281,8 @@ def step_fn(carry, x): y = new_carry + weird_global return new_carry, y - init = torch.tensor([0.0, 0.0], device=torch_xla.device()) - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], device='xla') + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla') with self.assertRaisesRegex(AssertionError, "FakeTensor"): scan(step_fn, init, xs) @@ -351,12 +350,11 @@ def test_scan_rand_in_fn(self): def step_fn(carry, x): new_carry = carry + x - y = new_carry + torch.rand(2, device=torch_xla.device()) + y = new_carry + torch.rand(2, device='xla') return new_carry, y - init = torch.tensor([0.0, 0.0], device=torch_xla.device()) - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], device='xla') + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla') _, ys = scan(step_fn, init, xs) # ys should be a 2D tensor with this shape. self.assertEqual(ys.shape, (3, 2)) diff --git a/test/scan/test_scan_debug.py b/test/scan/test_scan_debug.py index d800a36998df..49247412a0e4 100644 --- a/test/scan/test_scan_debug.py +++ b/test/scan/test_scan_debug.py @@ -36,12 +36,10 @@ def fn2(carry, x): y = x + 42 return carry, y - init = torch.tensor([0.0, 0.0], - requires_grad=True, - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], requires_grad=True, device='xla') xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], requires_grad=True, - device=torch_xla.device()) + device='xla') # Run some graph involving a scan operation two times. for i in range(2): diff --git a/test/scan/test_scan_layers.py b/test/scan/test_scan_layers.py index f239e1a51d25..fd13f90f7e93 100644 --- a/test/scan/test_scan_layers.py +++ b/test/scan/test_scan_layers.py @@ -265,15 +265,13 @@ def test_heterogenous_layers(self): layer1 = nn.Linear(128, 128).to(torch_xla.device()) layer2 = nn.Sequential(nn.Linear(128, 128).to(torch_xla.device())) with self.assertRaisesRegex(ValueError, "mismatched keys"): - scan_layers([layer1, layer2], - torch.zeros((128,), device=torch_xla.device())) + scan_layers([layer1, layer2], torch.zeros((128,), device='xla')) def test_mismatched_shapes(self): layer1 = nn.Linear(128, 128).to(torch_xla.device()) layer2 = nn.Linear(128, 129).to(torch_xla.device()) with self.assertRaisesRegex(ValueError, "Shape mismatch"): - scan_layers([layer1, layer2], - torch.zeros((128,), device=torch_xla.device())) + scan_layers([layer1, layer2], torch.zeros((128,), device='xla')) if __name__ == '__main__': diff --git a/test/spmd/test_dtensor_integration.py b/test/spmd/test_dtensor_integration.py index 402da525c96f..650b5774dcf0 100644 --- a/test/spmd/test_dtensor_integration.py +++ b/test/spmd/test_dtensor_integration.py @@ -30,10 +30,7 @@ def test_xla_distribute_tensor(self): for requires_grad in [True, False]: tensor_to_shard = torch.randn( - 3 * device_count, - 3, - requires_grad=requires_grad, - device=torch_xla.device()) + 3 * device_count, 3, requires_grad=requires_grad, device='xla') dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) # TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor assert type(dist_tensor).__name__ == "XLAShardedTensor" diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 81525faabed8..0a131865d890 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -36,7 +36,7 @@ def test_xla_sharded_tensor(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device()) + device='xla') xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) self.assertTrue(isinstance(xst1, XLAShardedTensor)) @@ -59,7 +59,7 @@ def test_sharded_tensor_debug_info(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device()) + device='xla') xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) @@ -229,7 +229,7 @@ def test_xla_sharding_type(self): self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED) def test_custom_tile_assignment(self): - xt = torch.randn(10, 20).to(device=torch_xla.device()) + xt = torch.randn(10, 20).to(device='xla') mesh_shape = (1, self.n_devices) device_ids = np.flip(self.device_ids) mesh = self._get_mesh(mesh_shape, device_ids) @@ -696,10 +696,10 @@ def test_no_sharding(self): partition_spec = (0, 1) t1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device()) + device='xla') t2 = torch.tensor([[8, 7, 6, 5, 4, 3, 2, 1]], dtype=torch.float, - device=torch_xla.device()) + device='xla') t3 = t1 + t2 t3_expected = [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0] self.assertEqual(t3.tolist()[0], t3_expected) @@ -708,7 +708,7 @@ def test_xla_sharded_hlo_dump(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device()) + device='xla') xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) xst2 = xst1 + 5 @@ -912,7 +912,7 @@ def test_sharded_tensor_aliasing(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device()) + device='xla') xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) xst1 += 1 @@ -1536,7 +1536,7 @@ def test_mark_sharding_with_gradients_basic(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device(), + device='xla', requires_grad=True) mesh = self._get_mesh((1, self.n_devices)) xst1 = xs.mark_sharding_with_gradients(xt1, mesh, partition_spec) @@ -1550,7 +1550,7 @@ def test_mark_sharding_with_gradients_annotation(self): partition_spec = (0,) x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, - device=torch_xla.device(), + device='xla', requires_grad=True) # Notice that the function does not modify in-place. y = xs.mark_sharding_with_gradients(x, mesh, partition_spec) @@ -1669,13 +1669,9 @@ def test_get_logical_mesh(self): def test_shard_as(self): mesh = self._get_mesh((self.n_devices,)) partition_spec = (0,) - x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float, - device=torch_xla.device()) + x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, device='xla') x = xs.mark_sharding_with_gradients(x, mesh, partition_spec) - y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float, - device=torch_xla.device()) + y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, device='xla') x, y = xs.shard_as(x, y) torch_xla.sync() diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 5be7f95daf98..fb59653f921b 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -23,21 +23,21 @@ def test_mark_sharding(self): partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device()) + device='xla') xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) self.assertTrue( torch.allclose( xt1 + 0, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, - device=torch_xla.device()))) + device='xla'))) def test_metrics_recorded(self): met.clear_counters() partition_spec = (0, 1) xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device()) + device='xla') xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) self.assertIn("VirtualDeviceUsage", met.counter_names()) self.assertNotEqual(met.counter_value("VirtualDeviceUsage"), 0) @@ -54,17 +54,17 @@ def test_model_weight_metrics(self): def test_no_sharding(self): t1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device()) + device='xla') t2 = torch.tensor([[8, 7, 6, 5, 4, 3, 2, 1]], dtype=torch.float, - device=torch_xla.device()) + device='xla') t3 = t1 + t2 t3_expected = [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0] self.assertEqual(t3.tolist()[0], t3_expected) def test_no_sharding_1d(self): - t1 = torch.arange(9, dtype=torch.float, device=torch_xla.device()) - t2 = torch.arange(9, dtype=torch.float, device=torch_xla.device()) + t1 = torch.arange(9, dtype=torch.float, device='xla') + t2 = torch.arange(9, dtype=torch.float, device='xla') t3 = t1 + t2 t3_expected = list(range(0, 18, 2)) self.assertEqual(t3.tolist(), t3_expected) @@ -75,7 +75,7 @@ def test_outbound_data_metrics(self): met.clear_all() xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float, - device=torch_xla.device()) + device='xla') xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec) outbound_with_virtual_device = met.metric_data("OutboundData")[1] diff --git a/test/test_callback.py b/test/test_callback.py index 242fef6443ea..296922eba260 100644 --- a/test/test_callback.py +++ b/test/test_callback.py @@ -11,8 +11,8 @@ class TestExperimentalCallback(absltest.TestCase): @staticmethod @torch_xla.compile def executable(): - a, b = torch.randn((100, 100), device=torch_xla.device()), torch.randn( - (100, 100), device=torch_xla.device()) + a, b = torch.randn((100, 100), device='xla'), torch.randn((100, 100), + device='xla') return a @ b def test_callback(self): diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 036ecf2d6c53..361c09de7faf 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -2932,7 +2932,7 @@ def test_aten_randperm_0(self): kwargs = dict() pytorch = torch.randperm(20) - xla = torch.randperm(20, device=torch_xla.device()) + xla = torch.randperm(20, device='xla') xla_detached = xla.detach().cpu() # Check equal lengths and that the sorted sets are equal. Since these numbers are randomly diff --git a/test/test_data_type.py b/test/test_data_type.py index 46beea6d5115..da4b7d00681f 100644 --- a/test/test_data_type.py +++ b/test/test_data_type.py @@ -29,8 +29,8 @@ def _set_env(self, **kwargs): os.environ[key] = value def _test_datatype(self, dtype, expected_type, op): - t1 = torch.tensor([2, 3], dtype=dtype, device=torch_xla.device()) - t2 = torch.tensor([2, 3], dtype=dtype, device=torch_xla.device()) + t1 = torch.tensor([2, 3], dtype=dtype, device='xla') + t2 = torch.tensor([2, 3], dtype=dtype, device='xla') t3 = op(t1, t2) self.assertEqual(t3.dtype, dtype) diff --git a/test/test_dynamic_shapes_detector.py b/test/test_dynamic_shapes_detector.py index 99a8423e45f5..9e7eb0f6a761 100644 --- a/test/test_dynamic_shapes_detector.py +++ b/test/test_dynamic_shapes_detector.py @@ -50,7 +50,7 @@ def test_single(self): def foo(x): return x + x - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device='xla') self._run_and_compare(foo, args=(inp,), max_different_graphs=1) def test_many_graphs(self): @@ -70,7 +70,7 @@ def foo(x, step): return r * 4 return r0 - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device='xla') for i in range(6): self._run_and_compare(foo, args=(inp, i), max_different_graphs=4) @@ -84,7 +84,7 @@ def test_graph_limit_exceeded_different_input_shape(self): def foo(x): return x + x - inp1 = torch.rand(10, device=torch_xla.device()) + inp1 = torch.rand(10, device='xla') self._run_and_compare( foo, args=(inp1,), max_different_graphs=max_different_graphs) @@ -95,7 +95,7 @@ def foo(x): """) with self.assertRaisesRegex(RuntimeError, expected_error_msg): - inp2 = torch.rand(5, device=torch_xla.device()) + inp2 = torch.rand(5, device='xla') self._run_and_compare( foo, args=(inp2,), max_different_graphs=max_different_graphs) @@ -118,7 +118,7 @@ def foo(x, step): else: return x * 5 - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device='xla') self._run_and_compare( foo, args=(inp, 0), max_different_graphs=max_different_graphs) @@ -157,7 +157,7 @@ def foo(x, step): return r + x return r / 3 - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device='xla') self._run_and_compare( foo, args=(inp, 0), max_different_graphs=max_different_graphs) self._run_and_compare( @@ -194,7 +194,7 @@ def foo(x, mul=False): else: return r - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device='xla') self._run_and_compare( foo, args=(inp, True), max_different_graphs=max_different_graphs) @@ -231,7 +231,7 @@ def foo(x, step): return r + x return r - inp = torch.rand(10, device=torch_xla.device()) + inp = torch.rand(10, device='xla') self._run_and_compare( foo, args=(inp, 0), max_different_graphs=max_different_graphs) self._run_and_compare( diff --git a/test/test_hlo_metadata.py b/test/test_hlo_metadata.py index da9d24bb2342..f6ce7f2fee40 100644 --- a/test/test_hlo_metadata.py +++ b/test/test_hlo_metadata.py @@ -78,10 +78,10 @@ def test_metadata(self): model = torch.nn.Sequential(layer1, nl1, layer2, nl2) with CustomOpNameLowering() as c: - model = model.to(device=torch_xla.device()) - inp = torch.rand(4, 4, device=torch_xla.device()) + model = model.to(device='xla') + inp = torch.rand(4, 4, device='xla') #inp = torch.rand(4, 4) - #inp = inp.to(device=torch_xla.device()) + #inp = inp.to(device='xla') out = model(inp) # Get outer frames diff --git a/test/test_metrics.py b/test/test_metrics.py index 098b516079d6..8749ebd0e2b3 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -211,8 +211,8 @@ def test_execute_time_metric(self): begin = time.perf_counter_ns() value = torch.randn( - 10000, 10000, device=torch_xla.device()) * torch.randn( - 10000, 10000, device=torch_xla.device()) + 10000, 10000, device='xla') * torch.randn( + 10000, 10000, device='xla') value_mean = value.mean() torch_xla.sync() cpu_value = value_mean.cpu() diff --git a/test/test_operations.py b/test/test_operations.py index 1fc90898f13e..2dc4f7c2da7a 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -328,7 +328,7 @@ def test(self): class TestSelect(test_utils.XlaTestCase): def test_get_xla_tensor(self): - x = _gen_tensor(14, 24, 8, device=torch_xla.device()) + x = _gen_tensor(14, 24, 8, device='xla') t = x.data.cpu() sx = x.select(1, 12) tx = t.select(1, 12) @@ -343,7 +343,7 @@ def fn(tensor): # Call masked_fill. return tensor.masked_fill(mask, 10) - x = _gen_tensor(2, 2, device=torch_xla.device()) + x = _gen_tensor(2, 2, device='xla') x_cpu = x.cpu() self.assertEqual(fn(x_cpu), fn(x)) @@ -352,7 +352,7 @@ class TestRandom(test_utils.XlaTestCase): def test_random_from_to_bool(self): for from_val, to_val in [[0, 1], [0, 2], [1, 2]]: - x = _gen_tensor(10, device=torch_xla.device()) + x = _gen_tensor(10, device='xla') x.random_(from_val, to_val) delta = 1 self.assertTrue(from_val <= x.to(torch.int).min() < (from_val + delta)) @@ -416,20 +416,20 @@ def test_fn(x): class TestDynamicShape(test_utils.XlaTestCase): def test_nonzero_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device='xla') x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.nonzero(x, as_tuple=False), 0) self.assertEqual(x_dim0_shape.item(), 4) def test_masked_select_shape(self): - x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device()) + x = torch.tensor((0, 1, 2, 0, 3, 4), device='xla') mask = x.ge(2) x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size( torch.masked_select(x, mask), 0) self.assertEqual(x_dim0_shape.item(), 3) def test_nonzero_cast(self): - t1 = torch.ones(5, 2, device=torch_xla.device()) + t1 = torch.ones(5, 2, device='xla') # Result of the nonzero should be the index type. Currently # index type is s64 on cpu and gpu, but s32 on TPU. We should be # able to cast it to any other type without error. @@ -598,31 +598,31 @@ def test_div_mixed_device(self): self.assertEqual(output.data, xla_output.data.cpu()) def test_rand(self): - x = torch.rand(3, 5, device=torch_xla.device()) + x = torch.rand(3, 5, device='xla') self.assertEqual(x.device.type, 'xla') def test_randperm(self): - x = torch.randperm(3, device=torch_xla.device(), dtype=torch.int32) + x = torch.randperm(3, device='xla', dtype=torch.int32) self.assertEqual(x.device.type, 'xla') def test_randn_like(self): shape = (5, 1, 1) - x = torch.randn_like(torch.zeros(shape, device=torch_xla.device())) + x = torch.randn_like(torch.zeros(shape, device='xla')) self.assertEqual(x.device.type, 'xla') def test_rand_like(self): shape = (5, 1, 1) - x = torch.rand_like(torch.zeros(shape, device=torch_xla.device())) + x = torch.rand_like(torch.zeros(shape, device='xla')) self.assertEqual(x.device.type, 'xla') def test_randint_like(self): shape = (5, 1, 1) x = torch.randint_like( - torch.zeros(shape, device=torch_xla.device(), dtype=torch.uint8), 6, 10) + torch.zeros(shape, device='xla', dtype=torch.uint8), 6, 10) self.assertEqual(x.device.type, 'xla') def test_no_storage(self): - x = torch.randn(5, device=torch_xla.device()) + x = torch.randn(5, device='xla') self.assertRaises(Exception, x.device) def test_slice_copy(self): @@ -686,9 +686,9 @@ def test_slice_rnd_stepped_assign(self): def test_arange_nan(self): with self.assertRaisesRegex(RuntimeError, r'unsupported range'): - a = torch.arange(-5, float('nan'), device=torch_xla.device()) + a = torch.arange(-5, float('nan'), device='xla') with self.assertRaisesRegex(RuntimeError, r'unsupported range'): - a = torch.arange(float('nan'), 5, device=torch_xla.device()) + a = torch.arange(float('nan'), 5, device='xla') def test_empty_advanced_indexing(self): xla_device = torch_xla.device() @@ -891,7 +891,7 @@ def test_baddmm_integer_types(self): def test_view_empty(self): # These used to throw floating point exception. - empty = torch.empty(0, device=torch_xla.device()) + empty = torch.empty(0, device='xla') with self.assertRaisesRegex( RuntimeError, r'unspecified dimension size -1 can be any value'): empty.view(-1, 0) @@ -917,7 +917,7 @@ def test_fn(device): self.assertEqual(cpu_weight_grad, xla_weight_grad) def test_inplace_view_backprop_base(self): - root = torch.randn(2, 2, device=torch_xla.device(), requires_grad=True) + root = torch.randn(2, 2, device='xla', requires_grad=True) x = root.clone() v1 = x.narrow(0, 0, 1) v1.mul_(2) @@ -925,7 +925,7 @@ def test_inplace_view_backprop_base(self): self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]]) def test_inplace_view_backprop_view_of_view(self): - root = torch.randn(2, 2, device=torch_xla.device(), requires_grad=True) + root = torch.randn(2, 2, device='xla', requires_grad=True) x = root.clone() v1 = x.narrow(0, 0, 1) v2 = x.narrow(0, 0, 1) @@ -935,7 +935,7 @@ def test_inplace_view_backprop_view_of_view(self): def test_inplace_view_of_view(self): # modify view-of-view and backprop through base - root = torch.randn(2, 2, device=torch_xla.device(), requires_grad=True) + root = torch.randn(2, 2, device='xla', requires_grad=True) x = root.clone() v1 = x.narrow(0, 0, 1) v2 = v1.narrow(1, 1, 1) @@ -944,8 +944,7 @@ def test_inplace_view_of_view(self): self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]]) def test_inplace_view_multiple_outputs(self): - root = torch.arange( - 9., device=torch_xla.device()).reshape(3, 3).requires_grad_() + root = torch.arange(9., device='xla').reshape(3, 3).requires_grad_() x = root.clone() v1 = x.unbind() with self.assertRaises(RuntimeError): @@ -1040,8 +1039,7 @@ def func(root, b): def test_inplace_view_non_contig(self): root = torch.ones( - 2, 3, 2, device=torch_xla.device()).select(2, - 1).t().requires_grad_(True) + 2, 3, 2, device='xla').select(2, 1).t().requires_grad_(True) x = root.clone() v1 = x.narrow(0, 0, 1) v2 = v1.narrow(1, 1, 1) @@ -1080,12 +1078,12 @@ def func(x): def test_set(self): met.clear_all() - t1 = torch.zeros(50, device=torch_xla.device()) + t1 = torch.zeros(50, device='xla') t1 += 1 torch_xla.sync() self.assertEqual(met.counter_value('DestroyXlaTensor'), 3) - t2 = torch.zeros(10, device=torch_xla.device()) + t2 = torch.zeros(10, device='xla') self.assertEqual(met.counter_value('DestroyXlaTensor'), 4) t1.set_(t2) @@ -1098,12 +1096,12 @@ def test_set(self): def test_replace_xla_tensor(self): met.clear_all() - t1 = torch.zeros(50, device=torch_xla.device()) + t1 = torch.zeros(50, device='xla') t1 += 1 torch_xla.sync() self.assertEqual(met.counter_value('DestroyXlaTensor'), 3) - t2 = torch.zeros(10, device=torch_xla.device()) + t2 = torch.zeros(10, device='xla') self.assertEqual(met.counter_value('DestroyXlaTensor'), 4) torch_xla._XLAC._replace_xla_tensor(t1, t2) self.assertEqual(met.counter_value('DestroyXlaTensor'), 5) @@ -1754,7 +1752,7 @@ def test_binaryop_order(self): # Since in eager mode the tensor would be materialized and hence _get_xla_tensors_text would not show the prim::Constant node. @skipOnEagerDebug def test_pow_constant(self): - t1 = torch.pow(torch.tensor([2.0, 3.0], device=torch_xla.device()), 5) + t1 = torch.pow(torch.tensor([2.0, 3.0], device='xla'), 5) hlo_text = torch_xla._XLAC._get_xla_tensors_text([t1]) const_hlo = hlo_text.split('\n')[1] assert 'prim::Constant' in const_hlo @@ -2333,8 +2331,7 @@ def clone_and_maybe_move(tensor, device=None): with self.subTest(sparse=sparse, mode=mode): kwargs_ = {k: clone_and_maybe_move(v) for k, v in kwargs.items()} xla_kwargs = { - k: clone_and_maybe_move(v, device=torch_xla.device()) - for k, v in kwargs.items() + k: clone_and_maybe_move(v, device='xla') for k, v in kwargs.items() } expected_out, expected_grad = fn(**kwargs_, **extra_kwargs) @@ -2481,12 +2478,12 @@ class TestWaitDeviceOps(test_utils.XlaTestCase): def test_wait_device_ops(self): torch_xla.device() - value = torch.randn(10000, 10000, device=torch_xla.device()) + value = torch.randn(10000, 10000, device='xla') val_list = [] val_mean_list = [] met.clear_all() for _ in range(5): - new_val = value * torch.randn(10000, 10000, device=torch_xla.device()) + new_val = value * torch.randn(10000, 10000, device='xla') val_list.append(new_val) val_mean_list.append(new_val.mean()) torch_xla.sync() @@ -2910,7 +2907,7 @@ def test_unsafe_buffer_pointer(self): self.assertGreaterEqual(buf_ptr_0, 0) # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr - xla_tensor_1 = torch.tensor(42, device=torch_xla.device()) + xla_tensor_1 = torch.tensor(42, device='xla') buf_ptr_1 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_1) self.assertGreaterEqual(buf_ptr_1, 0) @@ -2919,7 +2916,7 @@ def test_unsafe_buffer_pointer(self): buf_ptr_2 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_2) self.assertGreaterEqual(buf_ptr_2, 0) - xla_tensor_3 = torch.arange(5, device=torch_xla.device()) + xla_tensor_3 = torch.arange(5, device='xla') torch_xla.sync() # Without the `wait_device_ops()`, the pjrt buffer (pjrt_data->buffer) at https://github.com/pytorch/xla/blob/e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd/torch_xla/csrc/runtime/pjrt_computation_client.cc#L467 will be nullptr. xm.wait_device_ops() @@ -2954,7 +2951,7 @@ def test_dlpack_roundtrip_tensor(self, dtype): self._test_dlpack_capsule_conversion_helper(xla_tensor_2) # xla_tensor_3 uses arange_out IR node. - xla_tensor_3 = torch.arange(5, dtype=dtype, device=torch_xla.device()) + xla_tensor_3 = torch.arange(5, dtype=dtype, device='xla') torch_xla.sync() self._test_dlpack_capsule_conversion_helper(xla_tensor_3) diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 37708d199540..c151bdb49be9 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -30,7 +30,7 @@ def tearDown(self): super(TestOperationsHlo, self).tearDown() def test_expand(self): - a = torch.rand(1, 5, device=torch_xla.device()) + a = torch.rand(1, 5, device='xla') b = a.expand(5, 5) hlo_text = torch_xla._XLAC._get_xla_tensors_text([b]) assert 'aten::expand' in hlo_text diff --git a/test/test_pallas.py b/test/test_pallas.py index 8e8ab07bd87d..54fcbbc51dd1 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -369,7 +369,7 @@ def add_minus_vectors(x: jax.Array, y: jax.Array) -> jax.Array: pt_kernel = make_kernel_from_pallas( add_minus_vectors, lambda x, y: [(x.shape, x.dtype), (x.shape, x.dtype)]) - x = torch.arange(8, device="xla", dtype=torch.float) + x = torch.arange(8, device='xla', dtype=torch.float) o = pt_kernel(x, x) self.assertEqual(len(o), 2) diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index f1d6499935c5..98730dbf7009 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -19,7 +19,7 @@ def _mp_fn(index): model = FSDP(model) # wrapping the linear module with FSDP - input = torch.rand((2, 1024), device=torch_xla.device()) + input = torch.rand((2, 1024), device='xla') output = model(input) loss = torch.sum(output) diff --git a/test/test_triton.py b/test/test_triton.py index 21f9f511e41d..c2600526363a 100644 --- a/test/test_triton.py +++ b/test/test_triton.py @@ -267,9 +267,9 @@ def test_gpu_custom_call_triton_flash_attention(self): causal = False stage = 3 if causal else 1 dtype = torch.float16 - q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xla") - k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xla") - v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="xla") + q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device='xla') + k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device='xla') + v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device='xla') sm_scale = 0.5 # reference implementation triangle = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) diff --git a/test/test_xla_graph_execution.py b/test/test_xla_graph_execution.py index bbf34321f5c6..acf440bf0071 100644 --- a/test/test_xla_graph_execution.py +++ b/test/test_xla_graph_execution.py @@ -21,7 +21,7 @@ class TestXlaGraphExecution(test_utils.XlaTestCase): def test_graph_execution_allowed(self): torch_xla._XLAC._set_allow_execution(True) - x = torch.ones(2, device=torch_xla.device()) + x = torch.ones(2, device='xla') self.assertEqual(x[0], 1.0) # This should trigger the checking del x @@ -30,7 +30,7 @@ def test_graph_execution_disallowed_with_error(self): # Trigger runtime error for unexpected graph execution torch_xla._XLAC._set_allow_execution( False) # this flag disallows graph execution - x = torch.ones(2, device=torch_xla.device()) + x = torch.ones(2, device='xla') with self.assertRaises(RuntimeError) as e: self.assertEqual(x[0], 1.0) # This should trigger the checking self.assertIn( diff --git a/test/torch_distributed/test_torch_distributed_fsdp_meta.py b/test/torch_distributed/test_torch_distributed_fsdp_meta.py index 9b7c77011242..444c47890330 100644 --- a/test/torch_distributed/test_torch_distributed_fsdp_meta.py +++ b/test/torch_distributed/test_torch_distributed_fsdp_meta.py @@ -60,7 +60,7 @@ def _init_with_reset_params(module): """ is_meta = any(t.is_meta for t in module.parameters()) if is_meta: - module.to_empty(device=torch_xla.device()) + module.to_empty(device='xla') with torch.no_grad(): module.reset_parameters() @@ -87,7 +87,7 @@ def _compare_fsdp(self, fsdp1, fsdp2): def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): # Create model on meta device and wrap with FSDP. model = meta_module_fn() - inp = torch.randn(10, 2, device=torch_xla.device()) + inp = torch.randn(10, 2, device='xla') fsdp_meta = XlaFullyShardedDataParallel( model, @@ -99,7 +99,7 @@ def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): meta_opt.step() torch_xla.sync() - regular = MyModel(device=torch_xla.device()) + regular = MyModel(device='xla') fsdp_regular = XlaFullyShardedDataParallel( regular, auto_wrap_policy=always_wrap) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) @@ -127,7 +127,7 @@ def meta_module_fn(): def test_simple_model_with_torchdistX_init_fn(self): def meta_module_fn(): - return deferred_init.deferred_init(MyModel, device=torch_xla.device()) + return deferred_init.deferred_init(MyModel, device='xla') self._test_simple_model_with_meta_device( meta_module_fn, init_fn=_init_with_torchdistX) @@ -135,7 +135,7 @@ def meta_module_fn(): def test_simple_model_with_default_torchdistX(self): def meta_module_fn(): - return deferred_init.deferred_init(MyModel, device=torch_xla.device()) + return deferred_init.deferred_init(MyModel, device='xla') self._test_simple_model_with_meta_device(meta_module_fn) diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 079b1d49f6b2..182c8675fbdd 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -312,7 +312,7 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str: if xr.is_spmd(): return _spmd_find_master_ip(worker_ips[current_worker_id]) - t = torch.tensor([current_worker_id], device=torch_xla.device()) + t = torch.tensor([current_worker_id], device='xla') xm.collective_broadcast([t]) torch_xla.sync() diff --git a/torch_xla/amp/syncfree/adam.py b/torch_xla/amp/syncfree/adam.py index abb2bb55d2d4..63dab0a6b3ea 100644 --- a/torch_xla/amp/syncfree/adam.py +++ b/torch_xla/amp/syncfree/adam.py @@ -94,7 +94,7 @@ def step(self, closure=None, found_inf: Tensor = None): p, memory_format=torch.preserve_format) else: state['max_exp_avg_sq'] = torch.empty( - 0, dtype=torch.float, device=torch_xla.device()) + 0, dtype=torch.float, device='xla') exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/amp/syncfree/adamw.py b/torch_xla/amp/syncfree/adamw.py index b1abe5bbda8c..36d6941dc36f 100644 --- a/torch_xla/amp/syncfree/adamw.py +++ b/torch_xla/amp/syncfree/adamw.py @@ -92,7 +92,7 @@ def step(self, closure=None, found_inf: Tensor = None): p, memory_format=torch.preserve_format) else: state['max_exp_avg_sq'] = torch.empty( - 0, dtype=torch.float, device=torch_xla.device()) + 0, dtype=torch.float, device='xla') exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index ce1b342cdf17..c5605d2b3ed2 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -1014,7 +1014,7 @@ def _dummy_forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: A dummy forward pass with minimal computation that sums all inputs and full parameters, e.g. to debug parameter memory consumption. """ - outputs = torch.zeros(1, device=torch_xla.device()) + outputs = torch.zeros(1, device='xla') for t in chain(args, kwargs.values(), self.full_params): if isinstance(t, torch.Tensor) and t.dtype == torch.float32: outputs = outputs + t.mean() diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index ea2a097d5974..ba86af2c87d0 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -727,7 +727,7 @@ def defeat_device_data(v: torch.Tensor) -> torch.Tensor: seed_tensor = hoisted_vars[seed_parameter_id] assert seed_tensor.dtype == torch.int64 hoisted_vars[seed_parameter_id] = torch.randint( - 0, 2**62, (num_iters,), dtype=torch.int64, device=torch_xla.device()) + 0, 2**62, (num_iters,), dtype=torch.int64, device='xla') # Add hoisted variables as While computation params as well, # including the potentially updated seed tensor. diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 53aa7399c230..2e274190db75 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -97,7 +97,7 @@ def is_bf16_supported(): """Returns whether torch.bfloat16 is supported on this environment. """ try: - torch.tensor([1.], dtype=torch.bfloat16, device=torch_xla.device()) + torch.tensor([1.], dtype=torch.bfloat16, device='xla') return True except Exception as e: return False From 7971bbf2874f2e0abd77e02c5e6873364c58e2e2 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Fri, 23 May 2025 19:40:37 +0000 Subject: [PATCH 5/9] Replace "xla" with 'xla' --- docs/source/features/pallas.md | 6 +- docs/source/features/triton.md | 4 +- .../test_ragged_paged_attention_benchmark.py | 12 +- test/test_gmm.py | 42 +-- test/test_pallas.py | 286 +++++++++--------- test/test_pallas_spmd.py | 106 +++---- test/test_splash_attention.py | 26 +- test/test_triton.py | 4 +- torch_xla/_patched_functions.py | 2 +- torch_xla/experimental/custom_kernel.py | 10 +- 10 files changed, 249 insertions(+), 249 deletions(-) diff --git a/docs/source/features/pallas.md b/docs/source/features/pallas.md index 078ed5bb5d9e..89714ab9623a 100644 --- a/docs/source/features/pallas.md +++ b/docs/source/features/pallas.md @@ -40,9 +40,9 @@ jax will lock the TPU and torch-xla cannot access it. Example usage: ``` python3 -q = torch.randn(3, 2, 128, 4).to("xla") -k = torch.randn(3, 2, 128, 4).to("xla") -v = torch.randn(3, 2, 128, 4).to("xla") +q = torch.randn(3, 2, 128, 4).to('xla') +k = torch.randn(3, 2, 128, 4).to('xla') +v = torch.randn(3, 2, 128, 4).to('xla') # Adopts any Pallas kernel from torch_xla.experimental.custom_kernel import make_kernel_from_pallas diff --git a/docs/source/features/triton.md b/docs/source/features/triton.md index 25c5642a8836..991583aab221 100644 --- a/docs/source/features/triton.md +++ b/docs/source/features/triton.md @@ -42,8 +42,8 @@ import triton import triton.language as tl size = 16 -x = torch.arange(size, dtype=torch.int64).to("xla") -y = torch.arange(size, dtype=torch.int64).to("xla") +x = torch.arange(size, dtype=torch.int64).to('xla') +y = torch.arange(size, dtype=torch.int64).to('xla') output = torch.empty_like(x) block_size = 8 grid = (triton.cdiv(size, block_size),) diff --git a/test/benchmarks/test_ragged_paged_attention_benchmark.py b/test/benchmarks/test_ragged_paged_attention_benchmark.py index bc63bb0f66f5..dd49c8e3544c 100644 --- a/test/benchmarks/test_ragged_paged_attention_benchmark.py +++ b/test/benchmarks/test_ragged_paged_attention_benchmark.py @@ -161,14 +161,14 @@ def benchmark(args): if _run_with_torch_xla(args.kernel): queries_xla = torch.from_numpy(np.array(queries)).to( - torch.bfloat16).to("xla") + torch.bfloat16).to('xla') k_pages_xla = torch.from_numpy(np.array(k_pages)).to( - torch.bfloat16).to("xla") + torch.bfloat16).to('xla') v_pages_xla = torch.from_numpy(np.array(v_pages)).to( - torch.bfloat16).to("xla") - kv_lens_xla = torch.from_numpy(np.array(kv_lens_np)).to("xla") - page_indices_xla = torch.from_numpy(np.array(page_indices)).to("xla") - cu_q_lens_xla = torch.from_numpy(np.array(cu_q_lens)).to("xla") + torch.bfloat16).to('xla') + kv_lens_xla = torch.from_numpy(np.array(kv_lens_np)).to('xla') + page_indices_xla = torch.from_numpy(np.array(page_indices)).to('xla') + cu_q_lens_xla = torch.from_numpy(np.array(cu_q_lens)).to('xla') def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs, diff --git a/test/test_gmm.py b/test/test_gmm.py index 35b41f6fa4b1..3ce2cc1656bc 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -127,9 +127,9 @@ def test_gmm(self): ref_out = self._reference_gmm(lhs, rhs, group_sizes, transpose_rhs) out = gmm_func( - lhs.to("xla"), - rhs.to("xla"), - group_sizes.to("xla"), + lhs.to('xla'), + rhs.to('xla'), + group_sizes.to('xla'), transpose_rhs=transpose_rhs) # torch.compiled version of the gmm will cache the payload in dynamo layer # hence won't trigger the trace_pallas cache @@ -137,9 +137,9 @@ def test_gmm(self): old_cnt = xr.get_num_cached_compilation_graph() # execute the same gmm func, expected to hit the cache out = gmm_func( - lhs.to("xla"), - rhs.to("xla"), - group_sizes.to("xla"), + lhs.to('xla'), + rhs.to('xla'), + group_sizes.to('xla'), transpose_rhs=transpose_rhs) new_cnt = xr.get_num_cached_compilation_graph() self.assertEqual(old_cnt, new_cnt) @@ -170,13 +170,13 @@ def test_gmm_bf16(self): group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) ref_out = self._reference_gmm(lhs, rhs, group_sizes) - out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + out = gmm_func(lhs.to('xla'), rhs.to('xla'), group_sizes.to('xla')) # torch.compiled version of the gmm will cache the payload in dynamo layer # hence won't trigger the trace_pallas cache if test_cache and gmm_func != compiled_gmm: old_cnt = xr.get_num_cached_compilation_graph() # execute the same gmm func, expected to hit the cache - out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + out = gmm_func(lhs.to('xla'), rhs.to('xla'), group_sizes.to('xla')) new_cnt = xr.get_num_cached_compilation_graph() self.assertEqual(old_cnt, new_cnt) self.assertTrue(torch.allclose(ref_out, out.cpu())) @@ -203,11 +203,11 @@ def test_tgmm(self): group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) ref_out = self._reference_tgmm(lhs, rhs, group_sizes) - out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + out = tgmm(lhs.to('xla'), rhs.to('xla'), group_sizes.to('xla')) if test_cache: old_cnt = xr.get_num_cached_compilation_graph() # execute the same gmm func, expected to hit the cache - out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + out = tgmm(lhs.to('xla'), rhs.to('xla'), group_sizes.to('xla')) new_cnt = xr.get_num_cached_compilation_graph() self.assertEqual(new_cnt, old_cnt) self.assertTrue(torch.allclose(ref_out, out.cpu())) @@ -234,11 +234,11 @@ def test_tgmm_bf16(self): group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) ref_out = self._reference_tgmm(lhs, rhs, group_sizes) - out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + out = tgmm(lhs.to('xla'), rhs.to('xla'), group_sizes.to('xla')) if test_cache: old_cnt = xr.get_num_cached_compilation_graph() # execute the same gmm func, expected to hit the cache - out = tgmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + out = tgmm(lhs.to('xla'), rhs.to('xla'), group_sizes.to('xla')) new_cnt = xr.get_num_cached_compilation_graph() self.assertEqual(new_cnt, old_cnt) self.assertTrue(torch.allclose(ref_out, out.cpu())) @@ -269,8 +269,8 @@ def test_gmm_backward(self): ref_out_backward = torch.ones_like(ref_out) grad_lhs, grad_rhs = gmm_backward( - ref_out_backward.to("xla"), lhs.to("xla"), rhs.to("xla"), - group_sizes.to("xla")) + ref_out_backward.to('xla'), lhs.to('xla'), rhs.to('xla'), + group_sizes.to('xla')) # same gmm/tgmm was run for the `test_cache=False` case so the # cache should be populated now new_cnt = xr.get_num_cached_compilation_graph() @@ -304,13 +304,13 @@ def test_gmm_backward_2(self): ref_out.sum().backward() torch.manual_seed(42) - lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla") + lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to('xla') rhs_xla = torch.rand( - num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") + num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to('xla') lhs_xla.retain_grad() rhs_xla.retain_grad() - out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla")) + out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to('xla')) out.sum().backward() self.assertTrue(torch.allclose(ref_out, out.cpu())) @@ -341,13 +341,13 @@ def test_gmm_backward_3(self): ref_out.sum().backward() torch.manual_seed(42) - lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to("xla") + lhs_xla = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True).to('xla') rhs_xla = torch.rand( - num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to("xla") + num_groups, k, n, dtype=rhs_dtype, requires_grad=True).to('xla') lhs_xla.retain_grad() rhs_xla.retain_grad() - out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to("xla")) + out = GMM.apply(lhs_xla, rhs_xla, group_sizes.to('xla')) grad_out = torch.ones_like(out) torch.autograd.backward([out], [grad_out, lhs_xla, rhs_xla]) @@ -380,7 +380,7 @@ def test_gmm_cache_miss(self): rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype) group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) - out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"), tiling) + out = gmm(lhs.to('xla'), rhs.to('xla'), group_sizes.to('xla'), tiling) self.assertEqual(met.counter_value('trace_pallas_cache_hit'), None) diff --git a/test/test_pallas.py b/test/test_pallas.py index 54fcbbc51dd1..080cbb5eed87 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -147,8 +147,8 @@ def test_tpu_custom_call_pallas_add(self): # o_ref[...] = x + y payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA2lNDQFLBw8LEw8PDwsPMwsLCwtlCwsLCwsPCw8PEwsTDwsTDwsPDxMLDwUDYQENGwcTDxsPAsICHx0rLQUXAwMnKRURNx1HSRELAQUZHTM1AwsVFxkbHw0hDSMlBRsBAQUdDQlhZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABR8FIQUjBSUFJxEDAQUpFS8JHQ8xFwUTAQUrFwUdAR05OwUtFwUlAR0/QQUvFUMJHQ9FFwUVAQUxFREJI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF0sDIQcdAycDIQcBAgIFBwEBAQEBAgQEpwUBEAEHAwEFAxEBEwcDFScHAQEBAQEBBwMDBwMDCwYDAwUFAQcHAwMHAwMLBgMDBQUDCwkGPQMFBQkNBwMLBwMDCwYLAwUFBRENBAsHDwURBQABBgMBBQEAdgcz2wsTGdkNCxMjIR0pJ0MNCwsTDw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAdmVjdG9yAG1vZHVsZQByZXR1cm4AY29uc3RhbnQAYWRkaQBsb2FkAHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AGFkZF92ZWN0b3JzX2tlcm5lbABkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgB2YWx1ZQAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQBhZGRfdmVjdG9ycwA8bW9kdWxlPgAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\", \"needs_layout_passes\": true}}" - x = torch.arange(8, dtype=torch.int).to("xla") - y = torch.arange(8, dtype=torch.int).to("xla") + x = torch.arange(8, dtype=torch.int).to('xla') + y = torch.arange(8, dtype=torch.int).to('xla') expected_output = x + y output = torch_xla._XLAC._xla_tpu_custom_call([x, y], payload, [x.shape], @@ -162,7 +162,7 @@ def test_tpu_custom_call_pallas_add_one(self): # o_ref[...] = x_ref[...] + 1 payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==\", \"needs_layout_passes\": true}}" - x = torch.arange(8, dtype=torch.int).to("xla") + x = torch.arange(8, dtype=torch.int).to('xla') expected_output = x + 1 output = torch_xla._XLAC._xla_tpu_custom_call([x], payload, [x.shape], @@ -191,9 +191,9 @@ def test_tpu_custom_call_pallas_flash_attention(self): q_mini = torch.arange(128 * 4, dtype=torch.float32).reshape(128, 4) / 13 k_mini = torch.arange( 1000, 1000 + 128 * 4, dtype=torch.float32).reshape(128, 4) / 13 - q = q_mini.broadcast_to(3, 2, 128, 4).to("xla") - k = k_mini.broadcast_to(3, 2, 128, 4).to("xla") - v = torch.ones(3, 2, 128, 4).to("xla") + q = q_mini.broadcast_to(3, 2, 128, 4).to('xla') + k = k_mini.broadcast_to(3, 2, 128, 4).to('xla') + v = torch.ones(3, 2, 128, 4).to('xla') expected_o = self._attention(q, k, v) @@ -244,8 +244,8 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: torch.float32, torch.float ] # Add doesn't support torch.float64, torch.bfloat16, torch.float16. for i in range(len(dtypes)): - x = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla") - y = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla") + x = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to('xla') + y = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to('xla') expected_output = x + y output = pt_kernel(x, y) self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu())) @@ -254,8 +254,8 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: torch.int32, torch.int ] # Add doesn't support torch.int64, torch.int16, torch.int8, torch.uint8. for i in range(len(dtypes)): - x = torch.arange(i + 1, dtype=dtypes[i]).to("xla") - y = torch.arange(i + 1, dtype=dtypes[i]).to("xla") + x = torch.arange(i + 1, dtype=dtypes[i]).to('xla') + y = torch.arange(i + 1, dtype=dtypes[i]).to('xla') expected_output = x + y output = pt_kernel(x, y) self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu())) @@ -271,9 +271,9 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self): q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13 k_mini = torch.arange( 1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13 - q = q_mini.broadcast_to(3, 2, 128, 4).to("xla") - k = k_mini.broadcast_to(3, 2, 128, 4).to("xla") - v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla") + q = q_mini.broadcast_to(3, 2, 128, 4).to('xla') + k = k_mini.broadcast_to(3, 2, 128, 4).to('xla') + v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to('xla') o = flash_attention_kernel(q, k, v) expected_o = self._attention(q, k, v) @@ -286,9 +286,9 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self): def test_flash_attention_wrapper(self): from torch_xla.experimental.custom_kernel import flash_attention - q = torch.randn(3, 2, 128, 4).to("xla") - k = torch.randn(3, 2, 128, 4).to("xla") - v = torch.randn(3, 2, 128, 4).to("xla") + q = torch.randn(3, 2, 128, 4).to('xla') + k = torch.randn(3, 2, 128, 4).to('xla') + v = torch.randn(3, 2, 128, 4).to('xla') o = flash_attention(q, k, v) expected_o = self._attention(q, k, v) @@ -300,10 +300,10 @@ def test_flash_attention_wrapper(self): def test_flash_attention_wrapper_kv_and_ab_padding(self): from torch_xla.experimental.custom_kernel import flash_attention - q = torch.randn(1, 2, 513, 4).to("xla") - k = torch.randn(1, 2, 513, 4).to("xla") - v = torch.randn(1, 2, 513, 4).to("xla") - ab = torch.randn(1, 2, 513, 513).to("xla") + q = torch.randn(1, 2, 513, 4).to('xla') + k = torch.randn(1, 2, 513, 4).to('xla') + v = torch.randn(1, 2, 513, 4).to('xla') + ab = torch.randn(1, 2, 513, 513).to('xla') o = flash_attention(q, k, v, ab=ab) expected_o = self._attention(q, k, v, ab=ab) @@ -318,9 +318,9 @@ def test_flash_attention_wrapper_with_dynamo(self): def flash_attention_wrapper(q, k, v, causal=False): return torch.ops.xla.flash_attention(q, k, v, causal) - q = torch.randn(3, 2, 128, 4).to("xla") - k = torch.randn(3, 2, 128, 4).to("xla") - v = torch.randn(3, 2, 128, 4).to("xla") + q = torch.randn(3, 2, 128, 4).to('xla') + k = torch.randn(3, 2, 128, 4).to('xla') + v = torch.randn(3, 2, 128, 4).to('xla') compiled_flash_attention = torch.compile( flash_attention_wrapper, backend="openxla") @@ -340,9 +340,9 @@ def flash_attention_wrapper(q, k, v, causal=False): def test_flash_attention_wrapper_causal(self): from torch_xla.experimental.custom_kernel import flash_attention - q = torch.randn(3, 2, 128, 4).to("xla") - k = torch.randn(3, 2, 128, 4).to("xla") - v = torch.randn(3, 2, 128, 4).to("xla") + q = torch.randn(3, 2, 128, 4).to('xla') + k = torch.randn(3, 2, 128, 4).to('xla') + v = torch.randn(3, 2, 128, 4).to('xla') # The causal mask is turned on by default in the wrapper. # It masks out the top right triangle of the attention matrix, therefore it speeds up the compute but also changes the output. @@ -394,9 +394,9 @@ def shape_dtype(q, *arg): flash_attention_kernel = make_kernel_from_pallas(_flash_attention_impl, shape_dtype) - q = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to("xla") - k = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to("xla") - v = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to("xla") + q = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to('xla') + k = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to('xla') + v = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to('xla') o, l, m = flash_attention_kernel( q, @@ -429,13 +429,13 @@ def test__flash_attention_bwd_dkv(self): MIN_BLOCK_SIZE = 128 DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max) - q = torch.randn(3, 2, 128, 4).to("xla") - k = torch.randn(3, 2, 128, 4).to("xla") - v = torch.randn(3, 2, 128, 4).to("xla") - l = torch.randn(3, 2, 128).to("xla") - m = torch.randn(3, 2, 128).to("xla") - grad_i = torch.randn(3, 2, 128, dtype=torch.float32).to("xla") - grad_o = torch.randn(3, 2, 128, 4).to("xla") + q = torch.randn(3, 2, 128, 4).to('xla') + k = torch.randn(3, 2, 128, 4).to('xla') + v = torch.randn(3, 2, 128, 4).to('xla') + l = torch.randn(3, 2, 128).to('xla') + m = torch.randn(3, 2, 128).to('xla') + grad_i = torch.randn(3, 2, 128, dtype=torch.float32).to('xla') + grad_o = torch.randn(3, 2, 128, 4).to('xla') payload, _ = trace_pallas( _flash_attention_bwd_dkv, @@ -483,13 +483,13 @@ def test__flash_attention_bwd_dkv(self): MIN_BLOCK_SIZE = 128 DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max) - q = torch.randn(3, 2, 128, 4).to("xla") - k = torch.randn(3, 2, 128, 4).to("xla") - v = torch.randn(3, 2, 128, 4).to("xla") - l = torch.randn(3, 2, 128).to("xla") - m = torch.randn(3, 2, 128).to("xla") - grad_i = torch.randn(3, 2, 128, dtype=torch.float32).to("xla") - grad_o = torch.randn(3, 2, 128, 4).to("xla") + q = torch.randn(3, 2, 128, 4).to('xla') + k = torch.randn(3, 2, 128, 4).to('xla') + v = torch.randn(3, 2, 128, 4).to('xla') + l = torch.randn(3, 2, 128).to('xla') + m = torch.randn(3, 2, 128).to('xla') + grad_i = torch.randn(3, 2, 128, dtype=torch.float32).to('xla') + grad_o = torch.randn(3, 2, 128, 4).to('xla') payload, _ = trace_pallas( _flash_attention_bwd_dq, @@ -533,9 +533,9 @@ def test_flash_attention_backward(self): from torch_xla.experimental.custom_kernel import flash_attention torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() @@ -550,9 +550,9 @@ def test_flash_attention_backward(self): v_grad = v.grad torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() @@ -589,11 +589,11 @@ def test_paged_attention_wrapper(self): head_dim, ) - q_xla = q.to("xla") - k_pages_xla = k_pages.to("xla") - v_pages_xla = v_pages.to("xla") - seq_lens_xla = seq_lens.to("xla") - page_indices_xla = page_indices.to("xla") + q_xla = q.to('xla') + k_pages_xla = k_pages.to('xla') + v_pages_xla = v_pages.to('xla') + seq_lens_xla = seq_lens.to('xla') + page_indices_xla = page_indices.to('xla') output = paged_attention( q_xla, @@ -670,12 +670,12 @@ def _test_ragged_paged_attention( if kv_dtype is torch.float8_e5m2 and tpu.version() <= 4: self.skipTest("TPU v4 or older doesn't support fp8") - q_xla = q.to("xla") - kv_pages_xla = kv_pages.to("xla") - kv_lens_xla = kv_lens.to("xla") - page_indices_xla = page_indices.to("xla") - cu_q_lens_xla = cu_q_lens.to("xla") - num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to("xla") + q_xla = q.to('xla') + kv_pages_xla = kv_pages.to('xla') + kv_lens_xla = kv_lens.to('xla') + page_indices_xla = page_indices.to('xla') + cu_q_lens_xla = cu_q_lens.to('xla') + num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to('xla') if use_dynamo: @@ -914,12 +914,12 @@ def test_paged_attention_multi_queries_wrapper(self): query_len=query_len, ) - q_xla = q.to("xla") - k_pages_xla = k_pages.to("xla") - v_pages_xla = v_pages.to("xla") - kv_seq_lens_xla = kv_seq_lens.to("xla") - page_indices_xla = page_indices.to("xla") - effective_q_lens_xla = effective_q_lens.to("xla") + q_xla = q.to('xla') + k_pages_xla = k_pages.to('xla') + v_pages_xla = v_pages.to('xla') + kv_seq_lens_xla = kv_seq_lens.to('xla') + page_indices_xla = page_indices.to('xla') + effective_q_lens_xla = effective_q_lens.to('xla') output_no_cap = multi_queries_paged_attention( q_xla, @@ -1040,12 +1040,12 @@ def test_paged_attention_multi_queries_wrapper_with_dynamo(self): query_len=query_len, ) - q_xla = q.to("xla") - k_pages_xla = k_pages.to("xla") - v_pages_xla = v_pages.to("xla") - kv_seq_lens_xla = kv_seq_lens.to("xla") - page_indices_xla = page_indices.to("xla") - effective_q_lens_xla = effective_q_lens.to("xla") + q_xla = q.to('xla') + k_pages_xla = k_pages.to('xla') + v_pages_xla = v_pages.to('xla') + kv_seq_lens_xla = kv_seq_lens.to('xla') + page_indices_xla = page_indices.to('xla') + effective_q_lens_xla = effective_q_lens.to('xla') def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, effective_q_lens, @@ -1124,11 +1124,11 @@ def test_paged_attention_wrapper_with_megacore_modes(self): head_dim, ) - q_xla = q.to("xla") - k_pages_xla = k_pages.to("xla") - v_pages_xla = v_pages.to("xla") - seq_lens_xla = seq_lens.to("xla") - page_indices_xla = page_indices.to("xla") + q_xla = q.to('xla') + k_pages_xla = k_pages.to('xla') + v_pages_xla = v_pages.to('xla') + seq_lens_xla = seq_lens.to('xla') + page_indices_xla = page_indices.to('xla') outputs = [] for megacore_mode in ['kv_head', 'batch', None]: @@ -1191,11 +1191,11 @@ def test_paged_attention_wrapper_with_dynamo(self): head_dim, ) - q_xla = q.to("xla") - k_pages_xla = k_pages.to("xla") - v_pages_xla = v_pages.to("xla") - seq_lens_xla = seq_lens.to("xla") - page_indices_xla = page_indices.to("xla") + q_xla = q.to('xla') + k_pages_xla = k_pages.to('xla') + v_pages_xla = v_pages.to('xla') + seq_lens_xla = seq_lens.to('xla') + page_indices_xla = page_indices.to('xla') def paged_attention_wrapper(q, k, v, seq_lens, page_indices, pages_per_compute_block, attn_logits_soft_cap): @@ -1271,11 +1271,11 @@ def test_paged_attention_wrapper_with_attn_logits_soft_cap(self): head_dim, ) - q_xla = q.to("xla") - k_pages_xla = k_pages.to("xla") - v_pages_xla = v_pages.to("xla") - seq_lens_xla = seq_lens.to("xla") - page_indices_xla = page_indices.to("xla") + q_xla = q.to('xla') + k_pages_xla = k_pages.to('xla') + v_pages_xla = v_pages.to('xla') + seq_lens_xla = seq_lens.to('xla') + page_indices_xla = page_indices.to('xla') outputs = [] for attn_logits_soft_cap in [1.0, None]: @@ -1328,8 +1328,8 @@ def test_flash_attention_wrapper_segment_ids_1(self): zeros = torch.zeros(3, 32) segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) o = flash_attention( - q.to("xla"), k.to("xla"), v.to("xla"), False, segment_ids.to("xla"), - segment_ids.to("xla")) + q.to('xla'), k.to('xla'), v.to('xla'), False, segment_ids.to('xla'), + segment_ids.to('xla')) jax_q = jnp.array(q.numpy(), dtype=jnp.float32) jax_k = jnp.array(k.numpy(), dtype=jnp.float32) @@ -1352,10 +1352,10 @@ def test_flash_attention_wrapper_segment_ids_1(self): def test_flash_attention_wrapper_segment_ids_2(self): from torch_xla.experimental.custom_kernel import flash_attention - q = torch.randn(3, 2, 128, 4).to("xla") - k = torch.randn(3, 2, 128, 4).to("xla") - v = torch.randn(3, 2, 128, 4).to("xla") - zeros = torch.zeros(3, 32).to("xla") + q = torch.randn(3, 2, 128, 4).to('xla') + k = torch.randn(3, 2, 128, 4).to('xla') + v = torch.randn(3, 2, 128, 4).to('xla') + zeros = torch.zeros(3, 32).to('xla') segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) o = flash_attention(q, k, v, False, segment_ids, segment_ids) @@ -1375,10 +1375,10 @@ def test_flash_attention_backward_segment_ids(self): from torch_xla.experimental.custom_kernel import flash_attention torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - zeros = torch.zeros(4, 32).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + zeros = torch.zeros(4, 32).to('xla') segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) q.retain_grad() k.retain_grad() @@ -1394,10 +1394,10 @@ def test_flash_attention_backward_segment_ids(self): v_grad = v.grad torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - zeros = torch.zeros(4, 32).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + zeros = torch.zeros(4, 32).to('xla') segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) q.retain_grad() k.retain_grad() @@ -1422,9 +1422,9 @@ def test_flash_attention_backward_segment_ids(self): def test_flash_attention_wrapper_sm_scale(self): from torch_xla.experimental.custom_kernel import flash_attention - q = torch.randn(3, 2, 128, 4).to("xla") - k = torch.randn(3, 2, 128, 4).to("xla") - v = torch.randn(3, 2, 128, 4).to("xla") + q = torch.randn(3, 2, 128, 4).to('xla') + k = torch.randn(3, 2, 128, 4).to('xla') + v = torch.randn(3, 2, 128, 4).to('xla') sm_scale = 0.7 o = flash_attention(q, k, v, False, None, None, sm_scale) @@ -1438,9 +1438,9 @@ def test_flash_attention_sm_scale_backward(self): from torch_xla.experimental.custom_kernel import flash_attention torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') sm_scale = 0.7 q.retain_grad() k.retain_grad() @@ -1456,9 +1456,9 @@ def test_flash_attention_sm_scale_backward(self): v_grad = v.grad torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() @@ -1478,11 +1478,11 @@ def test_flash_attention_sm_scale_backward(self): def test_flash_attention_ab(self): from torch_xla.experimental.custom_kernel import flash_attention - q = torch.randn(3, 2, 128, 4).to("xla") - k = torch.randn(3, 2, 128, 4).to("xla") - v = torch.randn(3, 2, 128, 4).to("xla") - mask = (torch.rand(3, 2, 128, 128) > 0.5).to("xla") - ab = torch.ones(3, 2, 128, 128).to("xla") + q = torch.randn(3, 2, 128, 4).to('xla') + k = torch.randn(3, 2, 128, 4).to('xla') + v = torch.randn(3, 2, 128, 4).to('xla') + mask = (torch.rand(3, 2, 128, 128) > 0.5).to('xla') + ab = torch.ones(3, 2, 128, 128).to('xla') ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min) o = flash_attention(q, k, v, ab=ab) @@ -1497,11 +1497,11 @@ def test_flash_attention_ab_backward_1(self): from torch_xla.experimental.custom_kernel import flash_attention torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla") - ab = torch.ones(4, 2, 128, 128).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + mask = (torch.rand(4, 2, 128, 128) > 0.5).to('xla') + ab = torch.ones(4, 2, 128, 128).to('xla') ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min) q.retain_grad() k.retain_grad() @@ -1535,11 +1535,11 @@ def test_flash_attention_ab_backward_2(self): from torch_xla.experimental.custom_kernel import flash_attention torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla") - ab = torch.ones(4, 2, 128, 128).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + mask = (torch.rand(4, 2, 128, 128) > 0.5).to('xla') + ab = torch.ones(4, 2, 128, 128).to('xla') ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min) ab.requires_grad = True q.retain_grad() @@ -1583,9 +1583,9 @@ def compiler(gm, _): return make_boxed_func(gm) torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() @@ -1600,7 +1600,7 @@ def compiler(gm, _): kv_segment_ids, sm_scale) torch_xla.sync() if causal: - attention_mask = torch.triu(torch.ones(SEQ, SEQ), diagonal=1).to("xla") + attention_mask = torch.triu(torch.ones(SEQ, SEQ), diagonal=1).to('xla') else: attention_mask = None @@ -1621,16 +1621,16 @@ def compiler(gm, _): return make_boxed_func(gm) torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8).to("xla") - k = torch.randn(4, 2, 128, 8).to("xla") - v = torch.randn(4, 2, 128, 8).to("xla") + q = torch.randn(4, 2, 128, 8).to('xla') + k = torch.randn(4, 2, 128, 8).to('xla') + v = torch.randn(4, 2, 128, 8).to('xla') B, N, SEQ, H = q.size() causal = False q_segment_ids = None kv_segment_ids = None sm_scale = 1.0 - mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla") - ab = torch.ones(4, 2, 128, 128).to("xla") + mask = (torch.rand(4, 2, 128, 128) > 0.5).to('xla') + ab = torch.ones(4, 2, 128, 128).to('xla') ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min) compiled_flash_attention = aot_function( @@ -1656,15 +1656,15 @@ def compiler(gm, _): return make_boxed_func(gm) torch.manual_seed(42) - q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() B, N, SEQ, H = q.size() - mask = (torch.rand(4, 2, 128, 128) > 0.5).to("xla") - ab = torch.ones(4, 2, 128, 128).to("xla") + mask = (torch.rand(4, 2, 128, 128) > 0.5).to('xla') + ab = torch.ones(4, 2, 128, 128).to('xla') ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_() ab.retain_grad() @@ -1685,13 +1685,13 @@ def compiler(gm, _): ab_grad = ab.grad torch.manual_seed(42) - expected_q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - expected_k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") - expected_v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + expected_q = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + expected_k = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') + expected_v = torch.randn(4, 2, 128, 8, requires_grad=True).to('xla') expected_q.retain_grad() expected_k.retain_grad() expected_v.retain_grad() - expected_ab = torch.ones(4, 2, 128, 128).to("xla") + expected_ab = torch.ones(4, 2, 128, 128).to('xla') expected_ab = expected_ab.masked_fill(mask, torch.finfo( ab.dtype).min).requires_grad_() diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index f44c1baeb919..e611fd75784f 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -65,9 +65,9 @@ def test_flash_attention_spmd_data_parallel(self): n_devices = xr.global_runtime_device_count() xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) - q = torch.randn(8, 2, 128, 8).to("xla") - k = torch.randn(8, 2, 128, 8).to("xla") - v = torch.randn(8, 2, 128, 8).to("xla") + q = torch.randn(8, 2, 128, 8).to('xla') + k = torch.randn(8, 2, 128, 8).to('xla') + v = torch.randn(8, 2, 128, 8).to('xla') o = flash_attention(q, k, v, partition_spec=(0, None, None, None)) dev_ids = ','.join(map(str, range(n_devices))) @@ -89,9 +89,9 @@ def test_flash_attention_spmd_data_parallel_5d(self): range(n_devices), (n_devices // 2, 2, 1, 1, 1), ('fsdp', 'dp', 'a', 'b', 'c'))) - q = torch.randn(4, 2, 2, 128, 4).to("xla") - k = torch.randn(4, 2, 2, 128, 4).to("xla") - v = torch.randn(4, 2, 2, 128, 4).to("xla") + q = torch.randn(4, 2, 2, 128, 4).to('xla') + k = torch.randn(4, 2, 2, 128, 4).to('xla') + v = torch.randn(4, 2, 2, 128, 4).to('xla') o = flash_attention( q, k, v, partition_spec=('fsdp', 'dp', None, None, None)) @@ -111,10 +111,10 @@ def test_flash_attention_spmd_data_parallel_kv_and_ab_padding(self): n_devices = xr.global_runtime_device_count() xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) - q = torch.randn(8, 2, 513, 4).to("xla") - k = torch.randn(8, 2, 513, 4).to("xla") - v = torch.randn(8, 2, 513, 4).to("xla") - ab = torch.randn(8, 2, 513, 513).to("xla") + q = torch.randn(8, 2, 513, 4).to('xla') + k = torch.randn(8, 2, 513, 4).to('xla') + v = torch.randn(8, 2, 513, 4).to('xla') + ab = torch.randn(8, 2, 513, 513).to('xla') o = flash_attention(q, k, v, ab=ab, partition_spec=(0, None, None, None)) dev_ids = ','.join(map(str, range(n_devices))) @@ -134,9 +134,9 @@ def test_flash_attention_backward_spmd_data_parallel(self): xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) torch.manual_seed(42) - q = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() @@ -162,9 +162,9 @@ def test_flash_attention_backward_spmd_data_parallel(self): f"{{devices=[{n_devices},1,1,1]{dev_ids}}}") torch.manual_seed(42) - q = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() @@ -191,15 +191,15 @@ def test_flash_attention_wrapper_segment_ids_spmd(self): v = torch.randn(8, 2, 128, 4) zeros = torch.zeros(8, 32) segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) - segment_ids_xla = segment_ids.to("xla") + segment_ids_xla = segment_ids.to('xla') # only shard data dimension o = flash_attention( - q.to("xla"), - k.to("xla"), - v.to("xla"), + q.to('xla'), + k.to('xla'), + v.to('xla'), False, segment_ids_xla, - segment_ids.to("xla"), + segment_ids.to('xla'), partition_spec=("data", None, None, None)) n_devices = xr.global_runtime_device_count() dev_ids = ','.join(map(str, range(n_devices))) @@ -232,10 +232,10 @@ def test_flash_attention_backward_segment_ids_spmd(self): xs.set_global_mesh(xs.get_1d_mesh("data")) torch.manual_seed(42) - q = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - zeros = torch.zeros(8, 32).to("xla") + q = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + zeros = torch.zeros(8, 32).to('xla') segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) q.retain_grad() k.retain_grad() @@ -271,10 +271,10 @@ def test_flash_attention_backward_segment_ids_spmd(self): torch_xla.sync() torch.manual_seed(42) - q = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - zeros = torch.zeros(8, 32).to("xla") + q = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + zeros = torch.zeros(8, 32).to('xla') segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) q.retain_grad() k.retain_grad() @@ -310,12 +310,12 @@ def test_cross_flash_attention_wrapper_segment_ids_spmd(self): q_segment_ids = torch.ones(8, q.shape[2], dtype=torch.float32) # only shard data dimension o = flash_attention( - q.to("xla"), - k.to("xla"), - v.to("xla"), + q.to('xla'), + k.to('xla'), + v.to('xla'), False, - q_segment_ids.to("xla"), - kv_segment_ids.to("xla"), + q_segment_ids.to('xla'), + kv_segment_ids.to('xla'), partition_spec=("data", None, None, None)) n_devices = xr.global_runtime_device_count() dev_ids = ','.join(map(str, range(n_devices))) @@ -349,12 +349,12 @@ def test_cross_flash_attention_backward_segment_ids_spmd(self): xs.set_global_mesh(xs.get_1d_mesh("data")) torch.manual_seed(42) - q = torch.randn(8, 2, 1024, 4, requires_grad=True).to("xla") - k = torch.randn(8, 2, 128, 4, requires_grad=True).to("xla") - v = torch.randn(8, 2, 128, 4, requires_grad=True).to("xla") - zeros = torch.zeros(8, 32).to("xla") + q = torch.randn(8, 2, 1024, 4, requires_grad=True).to('xla') + k = torch.randn(8, 2, 128, 4, requires_grad=True).to('xla') + v = torch.randn(8, 2, 128, 4, requires_grad=True).to('xla') + zeros = torch.zeros(8, 32).to('xla') kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) - q_segment_ids = torch.ones(8, q.shape[2], dtype=torch.float32).to("xla") + q_segment_ids = torch.ones(8, q.shape[2], dtype=torch.float32).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() @@ -388,12 +388,12 @@ def test_cross_flash_attention_backward_segment_ids_spmd(self): torch_xla.sync() torch.manual_seed(42) - q = torch.randn(8, 2, 1024, 4, requires_grad=True).to("xla") - k = torch.randn(8, 2, 128, 4, requires_grad=True).to("xla") - v = torch.randn(8, 2, 128, 4, requires_grad=True).to("xla") - zeros = torch.zeros(8, 32).to("xla") + q = torch.randn(8, 2, 1024, 4, requires_grad=True).to('xla') + k = torch.randn(8, 2, 128, 4, requires_grad=True).to('xla') + v = torch.randn(8, 2, 128, 4, requires_grad=True).to('xla') + zeros = torch.zeros(8, 32).to('xla') kv_segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) - q_segment_ids = torch.ones(8, q.shape[2], dtype=torch.float32).to("xla") + q_segment_ids = torch.ones(8, q.shape[2], dtype=torch.float32).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() @@ -449,15 +449,15 @@ def flash_attention_wrapper(q, k, v, casual, q_segment_ids, kv_segment_ids, flash_attention_wrapper, fw_compiler=compiler) torch.manual_seed(42) - q = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() B, N, SEQ, H = q.size() - mask = (torch.rand(8, 2, 128, 128) > 0.5).to("xla") - ab = torch.ones(8, 2, 128, 128).to("xla") + mask = (torch.rand(8, 2, 128, 128) > 0.5).to('xla') + ab = torch.ones(8, 2, 128, 128).to('xla') ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_() ab.retain_grad() @@ -476,13 +476,13 @@ def flash_attention_wrapper(q, k, v, casual, q_segment_ids, kv_segment_ids, ab_grad = ab.grad torch.manual_seed(42) - q = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - k = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") - v = torch.randn(8, 2, 128, 8, requires_grad=True).to("xla") + q = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + k = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') + v = torch.randn(8, 2, 128, 8, requires_grad=True).to('xla') q.retain_grad() k.retain_grad() v.retain_grad() - ab = torch.ones(8, 2, 128, 128).to("xla") + ab = torch.ones(8, 2, 128, 128).to('xla') ab = ab.masked_fill(mask, torch.finfo(ab.dtype).min).requires_grad_() ab.retain_grad() diff --git a/test/test_splash_attention.py b/test/test_splash_attention.py index 6e8bb56fe2c3..3b0ccc2122c1 100644 --- a/test/test_splash_attention.py +++ b/test/test_splash_attention.py @@ -79,19 +79,19 @@ def ab_comparsion_input_generation(self): self.NUM_Q_HEADS, self.SEQ_LEN, self.HEAD_DIM, - ).to("xla").requires_grad_() + ).to('xla').requires_grad_() k = torch.randn( self.BATCH_SIZE, self.NUM_KV_HEADS, self.SEQ_LEN, self.HEAD_DIM, - ).to("xla").requires_grad_() + ).to('xla').requires_grad_() v = torch.randn( self.BATCH_SIZE, self.NUM_KV_HEADS, self.SEQ_LEN, self.HEAD_DIM, - ).to("xla").requires_grad_() + ).to('xla').requires_grad_() q_sa = q.clone().detach().requires_grad_() k_sa = k.clone().detach().requires_grad_() v_sa = v.clone().detach().requires_grad_() @@ -118,7 +118,7 @@ def _attention(self, q, k, v, *, attn_mask=None, ab=None): def test_splash_attention_base(self): q, k, v, q_sa, k_sa, v_sa = self.ab_comparsion_input_generation() attention_mask = torch.triu( - torch.ones(self.SEQ_LEN, self.SEQ_LEN), diagonal=1).to("xla") + torch.ones(self.SEQ_LEN, self.SEQ_LEN), diagonal=1).to('xla') o = self._attention(q, k, v, attn_mask=attention_mask) torch_xla.sync() loss = torch.sum(o) @@ -149,13 +149,13 @@ def test_splash_attention_sharding(self): n_devices = xr.global_runtime_device_count() q = ( torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN, - self.HEAD_DIM).requires_grad_().to("xla")) + self.HEAD_DIM).requires_grad_().to('xla')) k = ( torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN, - self.HEAD_DIM).requires_grad_().to("xla")) + self.HEAD_DIM).requires_grad_().to('xla')) v = ( torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN, - self.HEAD_DIM).requires_grad_().to("xla")) + self.HEAD_DIM).requires_grad_().to('xla')) o = splash_attention(q, k, v, self.config.to_json()) torch_xla.sync() self.assertEqual( @@ -168,7 +168,7 @@ def test_splash_attention_sharding(self): @with_jax_high_precision def test_splash_attention_segment_id(self): q, k, v, q_sa, k_sa, v_sa = self.ab_comparsion_input_generation() - zeros = torch.zeros(self.BATCH_SIZE, self.SEQ_LEN // 4).to("xla") + zeros = torch.zeros(self.BATCH_SIZE, self.SEQ_LEN // 4).to('xla') segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) segment_ids_sa = segment_ids.clone().detach() @@ -225,7 +225,7 @@ def compiler(gm, _): splash_attention, fw_compiler=compiler) attention_mask = torch.triu( - torch.ones(self.SEQ_LEN, self.SEQ_LEN), diagonal=1).to("xla") + torch.ones(self.SEQ_LEN, self.SEQ_LEN), diagonal=1).to('xla') o = self._attention(q, kk, vv, attn_mask=attention_mask) torch_xla.sync() loss = torch.sum(o) @@ -255,14 +255,14 @@ def test_splash_attention_cache_hit(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() q = ( torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN, - self.HEAD_DIM).requires_grad_().to("xla")) + self.HEAD_DIM).requires_grad_().to('xla')) k = ( torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN, - self.HEAD_DIM).requires_grad_().to("xla")) + self.HEAD_DIM).requires_grad_().to('xla')) v = ( torch.randn(self.BATCH_SIZE, self.NUM_HEADS, self.SEQ_LEN, - self.HEAD_DIM).requires_grad_().to("xla")) - segment_ids = torch.zeros(self.BATCH_SIZE, self.SEQ_LEN).to("xla") + self.HEAD_DIM).requires_grad_().to('xla')) + segment_ids = torch.zeros(self.BATCH_SIZE, self.SEQ_LEN).to('xla') for i in range(self.BATCH_SIZE): segment_ids[i, :] = i diff --git a/test/test_triton.py b/test/test_triton.py index c2600526363a..f69def68c86d 100644 --- a/test/test_triton.py +++ b/test/test_triton.py @@ -248,8 +248,8 @@ class TritonTest(unittest.TestCase): def test_gpu_custom_call_triton_add(self): size = 16 - x = torch.arange(size, dtype=torch.int64).to("xla") - y = torch.arange(size, dtype=torch.int64).to("xla") + x = torch.arange(size, dtype=torch.int64).to('xla') + y = torch.arange(size, dtype=torch.int64).to('xla') output = torch.empty_like(x) block_size = 8 grid = (triton.cdiv(size, block_size),) diff --git a/torch_xla/_patched_functions.py b/torch_xla/_patched_functions.py index 0cf3fcb97c4a..bb3102bf2e2b 100644 --- a/torch_xla/_patched_functions.py +++ b/torch_xla/_patched_functions.py @@ -38,7 +38,7 @@ def clip_grad_norm_(parameters: _tensor_or_tensors, max_norm = float(max_norm) norm_type = float(norm_type) if len(parameters) == 0: - return torch.tensor(0.).to("xla") + return torch.tensor(0.).to('xla') dtype = parameters[0].grad.dtype device = parameters[0].grad.device if norm_type == inf: diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 6a12b5b1944e..0733b990c787 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1041,7 +1041,7 @@ def ragged_paged_attention( ], ) - seq_buf_idx = torch.tensor([0, 0], dtype=torch.int32).to("xla") + seq_buf_idx = torch.tensor([0, 0], dtype=torch.int32).to('xla') output = torch_xla._XLAC._xla_tpu_custom_call( [ kv_lens, @@ -1174,8 +1174,8 @@ def multi_queries_paged_attention( q_dtype_for_kernel_launch = q.dtype page_indices_reshaped = page_indices.reshape(-1) - buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") - step = torch.zeros((1,), dtype=torch.int32).to("xla") + buffer_index = torch.zeros((1,), dtype=torch.int32).to('xla') + step = torch.zeros((1,), dtype=torch.int32).to('xla') q = q.permute(0, 2, 1, 3) MIN_BLOCK_SIZE = 128 output_shape = torch.Size(list(q.shape[:-1]) + [MIN_BLOCK_SIZE]) @@ -1236,8 +1236,8 @@ def paged_attention(q, q_dtype_for_kernel_launch = torch.float32 page_indices_reshaped = page_indices.reshape(-1) - buffer_index = torch.zeros((1,), dtype=torch.int32).to("xla") - step = torch.ones((1,), dtype=torch.int32).to("xla") + buffer_index = torch.zeros((1,), dtype=torch.int32).to('xla') + step = torch.ones((1,), dtype=torch.int32).to('xla') output_shape = torch.Size(list(q.shape[:-1]) + [1]) output, _, _ = torch_xla._XLAC._xla_tpu_custom_call( From e4a181777b69e77cc3f2d216222a9ee8c56c18bb Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 28 May 2025 02:50:19 +0000 Subject: [PATCH 6/9] Fix error --- torch_xla/torch_xla.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 3b2b327ff5c9..9062d6a9ef21 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -36,14 +36,15 @@ def device(index: int = None) -> torch.device: torch_xla._XLAC._xla_set_default_device(device) return torch.device(device) - if n is None: + if index is None: return torch.device(torch_xla._XLAC._xla_get_default_device()) devices = xm.get_xla_supported_devices() - if n > len(devices): - raise IndexError('Device index {} out of range in {}'.format(n, devices)) + if index > len(devices): + raise IndexError('Device index {} out of range in {}'.format( + index, devices)) - device = devices[n] + device = devices[index] torch_xla._XLAC._xla_set_default_device(device) return torch.device(device) From 637ead608db99d1d3a3f87229f24f0007608ff63 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 28 May 2025 04:24:12 +0000 Subject: [PATCH 7/9] Import missing packages --- examples/train_resnet_amp.py | 1 + test/ds/test_dynamic_shapes.py | 1 + test/pjrt/test_ddp.py | 1 + test/pjrt/test_dtypes.py | 1 + test/pjrt/test_profiler.py | 1 + test/stablehlo/test_composite.py | 1 + test/stablehlo/test_pt2e_qdq.py | 1 + test/stablehlo/test_unbounded_dynamism.py | 1 + torch_xla/distributed/spmd/api.py | 1 + 9 files changed, 9 insertions(+) diff --git a/examples/train_resnet_amp.py b/examples/train_resnet_amp.py index 8082d01524e9..65ac9df208f8 100644 --- a/examples/train_resnet_amp.py +++ b/examples/train_resnet_amp.py @@ -2,6 +2,7 @@ import itertools +import torch_xla import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.core.xla_model as xm from torch_xla.amp import autocast diff --git a/test/ds/test_dynamic_shapes.py b/test/ds/test_dynamic_shapes.py index 57119ac88d6f..adfdad76b21a 100644 --- a/test/ds/test_dynamic_shapes.py +++ b/test/ds/test_dynamic_shapes.py @@ -2,6 +2,7 @@ import sys import unittest import torch, torch_xla +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met diff --git a/test/pjrt/test_ddp.py b/test/pjrt/test_ddp.py index d236b8e11ea1..d93bbe45c4d9 100644 --- a/test/pjrt/test_ddp.py +++ b/test/pjrt/test_ddp.py @@ -5,6 +5,7 @@ import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_backend from torch_xla import runtime as xr diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py index dd6a4344c94b..5873c189c6c5 100644 --- a/test/pjrt/test_dtypes.py +++ b/test/pjrt/test_dtypes.py @@ -1,5 +1,6 @@ from absl.testing import absltest, parameterized import torch +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr diff --git a/test/pjrt/test_profiler.py b/test/pjrt/test_profiler.py index 15e799473b3d..17892261119a 100644 --- a/test/pjrt/test_profiler.py +++ b/test/pjrt/test_profiler.py @@ -7,6 +7,7 @@ from absl.testing import absltest import torch +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp import torch_xla.runtime as xr diff --git a/test/stablehlo/test_composite.py b/test/stablehlo/test_composite.py index 6e9521c79d5c..8fe211475ba1 100644 --- a/test/stablehlo/test_composite.py +++ b/test/stablehlo/test_composite.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.experimental.xla_marker from torch.utils import _pytree as pytree diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index 3fc1276ec612..34426f978029 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, List import torch +import torch_xla import torch_xla.core.xla_model as xm import torchvision from torch.export import export_for_training diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index aa33a6533437..dc61fc9849be 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -6,6 +6,7 @@ import numpy as np import torch +import torch_xla import torch_xla.core.xla_model as xm from torch.export import Dim, export from torch_xla.stablehlo import exported_program_to_stablehlo diff --git a/torch_xla/distributed/spmd/api.py b/torch_xla/distributed/spmd/api.py index 77ff9e9ac6ee..c1bb268cfed4 100644 --- a/torch_xla/distributed/spmd/api.py +++ b/torch_xla/distributed/spmd/api.py @@ -11,6 +11,7 @@ from torch.distributed import DeviceMesh from torch.distributed.tensor.placement_types import Placement, Replicate +import torch_xla import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401 import torch_xla.runtime as xr # type:ignore[import] from torch_xla.distributed.spmd import ( # type:ignore[import] From 9a0b1edba9c8bc9d3159dea01ed6ed3a14e543dc Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 28 May 2025 15:26:10 +0000 Subject: [PATCH 8/9] Fix test errors --- test/scan/test_scan.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index cbc2778fc679..da79e9c76dce 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -273,7 +273,7 @@ def test_scan_external_in_place_mutation(self): giving wrong results. """ # TODO(yifeit): Modify this test when external in-place mutation is eventually supported. - weird_global = torch.tensor([0.0, 0.0], device='xla') + weird_global = torch.tensor([0.0, 0.0], device=torch_xla.device()) def step_fn(carry, x): new_carry = carry + x @@ -281,8 +281,9 @@ def step_fn(carry, x): y = new_carry + weird_global return new_carry, y - init = torch.tensor([0.0, 0.0], device='xla') - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla') + init = torch.tensor([0.0, 0.0], device=torch_xla.device()) + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + device=torch_xla.device()) with self.assertRaisesRegex(AssertionError, "FakeTensor"): scan(step_fn, init, xs) @@ -350,11 +351,13 @@ def test_scan_rand_in_fn(self): def step_fn(carry, x): new_carry = carry + x - y = new_carry + torch.rand(2, device='xla') + # TODO: figure out why device='xla' doesn't work + y = new_carry + torch.rand(2, device=torch_xla.device()) return new_carry, y - init = torch.tensor([0.0, 0.0], device='xla') - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla') + init = torch.tensor([0.0, 0.0], device=torch_xla.device()) + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + device=torch_xla.device()) _, ys = scan(step_fn, init, xs) # ys should be a 2D tensor with this shape. self.assertEqual(ys.shape, (3, 2)) From df6d66435763708440cbfd25d903f8f76952afa3 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 2 Jun 2025 19:23:35 +0000 Subject: [PATCH 9/9] Address comments --- test/stablehlo/test_unbounded_dynamism.py | 1 - torch_xla/experimental/scan_layers.py | 1 - 2 files changed, 2 deletions(-) diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index dc61fc9849be..88fce368b668 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -7,7 +7,6 @@ import numpy as np import torch import torch_xla -import torch_xla.core.xla_model as xm from torch.export import Dim, export from torch_xla.stablehlo import exported_program_to_stablehlo diff --git a/torch_xla/experimental/scan_layers.py b/torch_xla/experimental/scan_layers.py index 0be37b363909..cfc5de5d1ed1 100644 --- a/torch_xla/experimental/scan_layers.py +++ b/torch_xla/experimental/scan_layers.py @@ -47,7 +47,6 @@ def scan_layers(layers: Iterable[torch.nn.Module], Example: - >>> import torch_xla.core.xla_model as xm >>> import torch >>> import torch.nn as nn >>> from torch_xla.experimental.scan_layers import scan_layers