From 650f114381ffd6c4e6d678e9f0c54f5c49ef3de6 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Fri, 23 May 2025 19:28:10 +0000 Subject: [PATCH 1/3] Migrate `torch_xla.device()` to `torch.device('xla')` --- API_GUIDE.md | 10 +- benchmarks/experiment_runner.py | 2 +- ...ributed-pytorch-xla-basics-with-pjrt.ipynb | 6 +- docs/source/learn/_pjrt.md | 2 +- docs/source/learn/eager.md | 4 +- docs/source/learn/pytorch-on-xla-devices.md | 10 +- docs/source/learn/xla-overview.md | 8 +- docs/source/perf/amp.md | 10 +- docs/source/perf/ddp.md | 2 +- docs/source/perf/dynamo.md | 4 +- docs/source/perf/fori_loop.md | 4 +- docs/source/perf/quantized_ops.md | 2 +- examples/train_decoder_only_base.py | 2 +- examples/train_resnet_amp.py | 2 +- examples/train_resnet_base.py | 2 +- plugins/cpu/README.md | 2 +- plugins/cuda/README.md | 2 +- test/bench.py | 2 +- test/debug_tool/test_mp_pt_xla_debug.py | 2 +- test/debug_tool/test_pt_xla_debug.py | 18 +- test/distributed_util.py | 2 +- test/ds/test_dynamic_shape_models.py | 2 +- test/ds/test_dynamic_shapes.py | 4 +- test/dynamo/test_bridge.py | 14 +- test/dynamo/test_dynamo.py | 24 +- test/dynamo/test_dynamo_aliasing.py | 16 +- test/dynamo/test_dynamo_config.py | 2 +- test/dynamo/test_dynamo_dynamic_shape.py | 16 +- test/dynamo/test_dynamo_graph_dump.py | 2 +- test/dynamo/test_dynamo_integrations_util.py | 12 +- test/dynamo/test_graph_input_matcher.py | 2 +- test/dynamo/test_num_output.py | 2 +- test/dynamo/test_traceable_collectives.py | 2 +- test/eager/test_eager.py | 14 +- test/eager/test_eager_all_reduce_in_place.py | 2 +- test/eager/test_eager_spmd.py | 4 +- test/eager/test_eager_with_torch_compile.py | 4 +- test/eager/test_eager_with_xla_compile.py | 6 +- test/pjrt/test_collective_ops_tpu.py | 20 +- test/pjrt/test_ddp.py | 2 +- test/pjrt/test_profiler.py | 4 +- test/pjrt/test_runtime_multi_cpu.py | 2 +- test/pjrt/test_runtime_multi_gpu.py | 266 ++++++++++++++++++ test/pjrt/test_runtime_tpu.py | 8 +- test/pjrt/test_train_hf_transformer.py | 2 +- test/pytorch_test_base.py | 4 +- test/quantized_ops/test_dot_general.py | 2 +- test/quantized_ops/test_quantized_matmul.py | 2 +- test/scan/test_scan.py | 14 +- test/scan/test_scan_layers.py | 2 +- test/scan/test_scan_pallas.py | 2 +- test/scan/test_scan_spmd.py | 2 +- test/spmd/test_dynamo_spmd.py | 16 +- test/spmd/test_mp_input_sharding.py | 12 +- test/spmd/test_sharding_strategies.py | 2 +- test/spmd/test_spmd_debugging.py | 6 +- test/spmd/test_spmd_graph_dump.py | 2 +- test/spmd/test_spmd_lowering_context.py | 4 +- test/spmd/test_spmd_parameter_wrapping.py | 4 +- test/spmd/test_train_spmd_imagenet.py | 2 +- test/spmd/test_xla_distributed_checkpoint.py | 2 +- test/spmd/test_xla_sharding.py | 32 +-- .../test_xla_spmd_python_api_interaction.py | 10 +- test/spmd/test_xla_virtual_device.py | 12 +- test/stablehlo/test_composite.py | 2 +- test/stablehlo/test_implicit_broadcasting.py | 2 +- test/stablehlo/test_pt2e_qdq.py | 4 +- test/stablehlo/test_stablehlo_compile.py | 2 +- test/stablehlo/test_stablehlo_custom_call.py | 6 +- test/stablehlo/test_stablehlo_inference.py | 4 +- test/stablehlo/test_stablehlo_save_load.py | 6 +- test/stablehlo/test_unbounded_dynamism.py | 2 +- test/stablehlo/test_xla_export_interpreter.py | 2 +- test/test_autocast.py | 10 +- test/test_autocast_xla.py | 2 +- test/test_compilation_cache_utils.py | 2 +- test/test_core_aten_ops.py | 2 +- test/test_data_type.py | 2 +- test/test_env_var_mapper.py | 2 +- test/test_fp8.py | 2 +- test/test_fsdp_auto_wrap.py | 4 +- test/test_grad_checkpoint.py | 2 +- test/test_gradient_accumulation.py | 2 +- test/test_inplace_update.py | 10 +- test/test_input_output_aliases.py | 34 +-- test/test_jax_interop.py | 22 +- test/test_metrics.py | 20 +- test/test_mp_all_gather.py | 2 +- test/test_mp_all_to_all.py | 2 +- test/test_mp_collective_matmul.py | 2 +- test/test_mp_collective_permute.py | 2 +- test/test_mp_distributed_mm.py | 2 +- test/test_mp_early_exit.py | 2 +- test/test_mp_reduce_scatter.py | 2 +- test/test_mp_replication.py | 2 +- test/test_mp_save.py | 2 +- test/test_mp_sync_batch_norm.py | 8 +- test/test_operations.py | 160 +++++------ test/test_placeholder.py | 4 +- test/test_profile_mp_mnist.py | 2 +- test/test_python_ops.py | 6 +- test/test_syncfree_optimizers.py | 2 +- ...st_torch_distributed_fsdp_frozen_weight.py | 2 +- test/test_torch_distributed_xla_backend.py | 26 +- test/test_train_mp_imagenet.py | 2 +- test/test_train_mp_imagenet_amp.py | 4 +- test/test_train_mp_imagenet_fsdp.py | 2 +- test/test_train_mp_mnist.py | 2 +- test/test_train_mp_mnist_amp.py | 2 +- test/test_train_mp_mnist_fsdp_with_ckpt.py | 2 +- test/test_train_mp_mnist_zero1.py | 2 +- test/test_user_computation_debug_cache.py | 2 +- test/test_utils.py | 2 +- test/test_while_loop.py | 8 +- test/test_zero1.py | 6 +- test/torch_distributed/test_ddp.py | 2 +- ...orch_distributed_all_gather_xla_backend.py | 2 +- ...orch_distributed_all_reduce_xla_backend.py | 2 +- ...ributed_bucketed_all_reduce_xla_backend.py | 2 +- .../test_torch_distributed_fsdp_meta.py | 2 +- ...istributed_multi_all_reduce_xla_backend.py | 2 +- ..._distributed_reduce_scatter_xla_backend.py | 2 +- test/utils/train_spmd_linear_model.py | 2 +- .../utils/train_spmd_linear_model_grad_acc.py | 2 +- torch_xla/_dynamo/dynamo_backend2.py | 2 +- torch_xla/_dynamo/dynamo_bridge.py | 8 +- torch_xla/_internal/pjrt.py | 4 +- torch_xla/core/xla_op_registry.py | 2 +- torch_xla/csrc/init_python_bindings.cpp | 2 +- .../fsdp/xla_fully_sharded_data_parallel.py | 6 +- torch_xla/distributed/parallel_loader.py | 2 +- torch_xla/distributed/spmd/api.py | 2 +- torch_xla/distributed/spmd/xla_sharding.py | 4 +- torch_xla/distributed/xla_multiprocessing.py | 2 +- torch_xla/experimental/fori_loop.py | 2 +- .../experimental/gradient_accumulation.py | 2 +- torch_xla/experimental/scan.py | 4 +- torch_xla/experimental/scan_layers.py | 2 +- torch_xla/runtime.py | 2 +- torch_xla/stablehlo.py | 4 +- 140 files changed, 701 insertions(+), 437 deletions(-) create mode 100644 test/pjrt/test_runtime_multi_gpu.py diff --git a/API_GUIDE.md b/API_GUIDE.md index f2e9fc1cc2dd..cd2e1f2fd5f3 100644 --- a/API_GUIDE.md +++ b/API_GUIDE.md @@ -22,7 +22,7 @@ print(t) This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and -`torch_xla.device()` returns the current XLA device. This may be a CPU or TPU +`torch.device('xla')` returns the current XLA device. This may be a CPU or TPU depending on your environment. ## XLA Tensors are PyTorch Tensors @@ -112,7 +112,7 @@ train_loader = xu.SampleGenerator( torch.zeros(batch_size, dtype=torch.int64)), sample_count=60000 // batch_size // xr.world_size()) -device = torch_xla.device() # Get the XLA device (TPU). +device = torch.device('xla') # Get the XLA device (TPU). model = MNIST().train().to(device) # Create a model and move it to the device. loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) @@ -169,7 +169,7 @@ def _mp_fn(index): index: Index of the process. """ - device = torch_xla.device() # Get the device assigned to this process. + device = torch.device('xla') # Get the device assigned to this process. # Wrap the loader for multi-device. mp_device_loader = pl.MpDeviceLoader(train_loader, device) @@ -197,7 +197,7 @@ single device snippet. Let's go over then one by one. - `torch_xla.launch()` - Creates the processes that each run an XLA device. - This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device. - - Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details). + - Note that if you print the `torch.device('xla')` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details). - `MpDeviceLoader` - Loads the training data onto each device. - `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance. @@ -290,7 +290,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -device = torch_xla.device() +device = torch.device('xla') t0 = torch.randn(2, 2, device=device) t1 = torch.randn(2, 2, device=device) diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index b784af68e47b..9c5867cbfc9a 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -255,7 +255,7 @@ def _default_iter_fn(self, benchmark_experiment: BenchmarkExperiment, def _pure_wall_time_iter_fn(self, benchmark_experiment: BenchmarkExperiment, benchmark_model: BenchmarkModel, input_tensor): - device = torch_xla.device() if benchmark_experiment.xla else 'cuda' + device = torch.device('xla') if benchmark_experiment.xla else 'cuda' sync_fn = xm.wait_device_ops if benchmark_experiment.xla else torch.cuda.synchronize timing, output = bench.do_bench( lambda: benchmark_model.model_iter_fn( diff --git a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb index 8d4fbd95bff7..c829c4b9a36f 100644 --- a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb +++ b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb @@ -193,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2024-01-10T19:30:28.607393Z", @@ -210,7 +210,7 @@ "lock = mp.Manager().Lock()\n", "\n", "def print_device(i, lock):\n", - " device = torch_xla.device()\n", + " device = torch.device('xla')\n", " with lock:\n", " print('process', i, device)" ] @@ -454,7 +454,7 @@ "import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n", "\n", "def toy_model(index, lock):\n", - " device = torch_xla.device()\n", + " device = torch.device('xla')\n", " dist.init_process_group('xla', init_method='xla://')\n", "\n", " # Initialize a basic toy model\n", diff --git a/docs/source/learn/_pjrt.md b/docs/source/learn/_pjrt.md index edaa56ecee72..38cd322e7940 100644 --- a/docs/source/learn/_pjrt.md +++ b/docs/source/learn/_pjrt.md @@ -73,7 +73,7 @@ import torch_xla.distributed.xla_backend def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') - dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size()) + dist.init_process_group('xla', init_method='xla://') diff --git a/docs/source/learn/eager.md b/docs/source/learn/eager.md index 0d82ae3c581c..cbf54d3a6c32 100644 --- a/docs/source/learn/eager.md +++ b/docs/source/learn/eager.md @@ -13,7 +13,7 @@ import torch import torch_xla import torchvision -device = torch_xla.device() +device = torch.device('xla') model = torchvision.models.resnet18().to(device) input = torch.randn(64, 3, 224, 224).to(device) @@ -71,7 +71,7 @@ import torchvision # Run ops eagerly by default torch_xla.experimental.eager_mode(True) -device = torch_xla.device() +device = torch.device('xla') model = torchvision.models.resnet18().to(device) # Mark the function to be compiled diff --git a/docs/source/learn/pytorch-on-xla-devices.md b/docs/source/learn/pytorch-on-xla-devices.md index c0b48bec1813..0be3ce038e5f 100644 --- a/docs/source/learn/pytorch-on-xla-devices.md +++ b/docs/source/learn/pytorch-on-xla-devices.md @@ -21,7 +21,7 @@ print(t) This code should look familiar. PyTorch/XLA uses the same interface as regular PyTorch with a few additions. Importing `torch_xla` initializes -PyTorch/XLA, and `torch_xla.device()` returns the current XLA device. This +PyTorch/XLA, and `torch.device('xla')` returns the current XLA device. This may be a CPU or TPU depending on your environment. ## XLA Tensors are PyTorch Tensors @@ -81,7 +81,7 @@ The following snippet shows a network training on a single XLA device: ``` python import torch_xla.core.xla_model as xm -device = torch_xla.device() +device = torch.device('xla') model = MNIST().train().to(device) loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) @@ -120,7 +120,7 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') mp_device_loader = pl.MpDeviceLoader(train_loader, device) model = MNIST().train().to(device) @@ -148,7 +148,7 @@ previous single device snippet. Let's go over then one by one. will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device. - - Note that if you print the `torch_xla.device()` on each process you + - Note that if you print the `torch.device('xla')` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 @@ -283,7 +283,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm -device = torch_xla.device() +device = torch.device('xla') t0 = torch.randn(2, 2, device=device) t1 = torch.randn(2, 2, device=device) diff --git a/docs/source/learn/xla-overview.md b/docs/source/learn/xla-overview.md index f6b0761fd69a..e74247c2fb88 100644 --- a/docs/source/learn/xla-overview.md +++ b/docs/source/learn/xla-overview.md @@ -184,7 +184,7 @@ repo. contains examples for training and serving many LLM and diffusion models. General guidelines to modify your code: -- Replace `cuda` with `torch_xla.device()` +- Replace `cuda` with `torch.device('xla')` - Remove progress bar, printing that would access the XLA tensor values - Reduce logging and callbacks that would access the XLA tensor values @@ -227,7 +227,7 @@ tutorial, but you can pass the `device` value to the function as well. ``` python import torch_xla.core.xla_model as xm - self.device = torch_xla.device() + self.device = torch.device('xla') ``` Another place in the code that has cuda specific code is DDIM scheduler. @@ -244,7 +244,7 @@ if attr.device != torch.device("cuda"): with ``` python -device = torch_xla.device() +device = torch.device('xla') attr = attr.to(torch.device(device)) ``` @@ -339,7 +339,7 @@ with the following lines: ``` python import torch_xla.core.xla_model as xm -device = torch_xla.device() +device = torch.device('xla') pipe.to(device) ``` diff --git a/docs/source/perf/amp.md b/docs/source/perf/amp.md index 4ad48753d45c..0d0db54f1682 100644 --- a/docs/source/perf/amp.md +++ b/docs/source/perf/amp.md @@ -27,7 +27,7 @@ for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(torch_xla.device()): + with autocast(torch.device('xla')): output = model(input) loss = loss_fn(output, target) @@ -36,7 +36,7 @@ for input, target in data: xm.optimizer_step.(optimizer) ``` -`autocast(torch_xla.device())` aliases `torch.autocast('xla')` when the XLA +`autocast(torch.device('xla'))` aliases `torch.autocast('xla')` when the XLA Device is a TPU. Alternatively, if a script is only used with TPUs, then `torch.autocast('xla', dtype=torch.bfloat16)` can be directly used. @@ -115,7 +115,7 @@ for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(torch_xla.device()): + with autocast(torch.device('xla')): output = model(input) loss = loss_fn(output, target) @@ -127,12 +127,12 @@ for input, target in data: scaler.update() ``` -`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the +`autocast(torch.device('xla'))` aliases `torch.cuda.amp.autocast()` when the XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is only used with CUDA devices, then `torch.cuda.amp.autocast` can be directly used, but requires `torch` is compiled with `cuda` support for datatype of `torch.bfloat16`. We recommend using -`autocast(torch_xla.device())` on XLA:GPU as it does not require +`autocast(torch.device('xla'))` on XLA:GPU as it does not require `torch.cuda` support for any datatypes, including `torch.bfloat16`. ### AMP for XLA:GPU Best Practices diff --git a/docs/source/perf/ddp.md b/docs/source/perf/ddp.md index efc4071d648d..51067d37044a 100644 --- a/docs/source/perf/ddp.md +++ b/docs/source/perf/ddp.md @@ -105,7 +105,7 @@ def demo_basic(rank): setup(rank, world_size) # create model and move it to XLA device - device = torch_xla.device() + device = torch.device('xla') model = ToyModel().to(device) ddp_model = DDP(model, gradient_as_bucket_view=True) diff --git a/docs/source/perf/dynamo.md b/docs/source/perf/dynamo.md index 090decb77371..2ab3982fe820 100644 --- a/docs/source/perf/dynamo.md +++ b/docs/source/perf/dynamo.md @@ -41,7 +41,7 @@ import torchvision import torch_xla.core.xla_model as xm def eval_model(loader): - device = torch_xla.device() + device = torch.device('xla') xla_resnet18 = torchvision.models.resnet18().to(device) xla_resnet18.eval() dynamo_resnet18 = torch.compile( @@ -129,7 +129,7 @@ def train_model(model, data, target, optimizer): return pred def train_model_main(loader): - device = torch_xla.device() + device = torch.device('xla') xla_resnet18 = torchvision.models.resnet18().to(device) xla_resnet18.train() dynamo_train_model = torch.compile( diff --git a/docs/source/perf/fori_loop.md b/docs/source/perf/fori_loop.md index bfdd2bf318ab..b6ebf57e09a8 100644 --- a/docs/source/perf/fori_loop.md +++ b/docs/source/perf/fori_loop.md @@ -30,7 +30,7 @@ result = while_loop(cond_fn, body_fn, init) >>> from torch._higher_order_ops.while_loop import while_loop >>> import torch_xla.core.xla_model as xm >>> ->>> device = torch_xla.device() +>>> device = torch.device('xla') >>> >>> def cond_fn(iteri, x): ... return iteri > 0 @@ -60,7 +60,7 @@ with similar logic: cumulative plus 1 for ten times: >>> import torch_xla >>> import torch_xla.core.xla_model as xm >>> ->>> device = torch_xla.device() +>>> device = torch.device('xla') >>> >>> init_val = torch.tensor(1, device=device) >>> iteri = torch.tensor(50, device=device) diff --git a/docs/source/perf/quantized_ops.md b/docs/source/perf/quantized_ops.md index 6d44b05e433b..8aa9ed063dc0 100644 --- a/docs/source/perf/quantized_ops.md +++ b/docs/source/perf/quantized_ops.md @@ -48,7 +48,7 @@ scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16) # Call with torch CPU tensor (For debugging purpose) matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler) -device = torch_xla.device() +device = torch.device('xla') x_xla = x.to(device) w_int_xla = w_int.to(device) scaler_xla = scaler.to(device) diff --git a/examples/train_decoder_only_base.py b/examples/train_decoder_only_base.py index b3b3a33590e9..ae6efea7d079 100644 --- a/examples/train_decoder_only_base.py +++ b/examples/train_decoder_only_base.py @@ -35,7 +35,7 @@ def __init__(self, torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64)), sample_count=self.train_dataset_len // self.batch_size) - self.device = torch_xla.device() + self.device = torch.device('xla') self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device) self.model = decoder_cls(self.config).to(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001) diff --git a/examples/train_resnet_amp.py b/examples/train_resnet_amp.py index f5ca308bed75..7b0b68a10da2 100644 --- a/examples/train_resnet_amp.py +++ b/examples/train_resnet_amp.py @@ -19,7 +19,7 @@ def train_loop_fn(self, loader, epoch): for step, (data, target) in enumerate(loader): self.optimizer.zero_grad() # Enables autocasting for the forward pass - with autocast(torch_xla.device()): + with autocast(torch.device('xla')): output = self.model(data) loss = self.loss_fn(output, target) # TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU diff --git a/examples/train_resnet_base.py b/examples/train_resnet_base.py index c4a8890e9be7..59ff180934f1 100644 --- a/examples/train_resnet_base.py +++ b/examples/train_resnet_base.py @@ -28,7 +28,7 @@ def __init__(self): sample_count=self.train_dataset_len // self.batch_size // xr.world_size()) - self.device = torch_xla.device() + self.device = torch.device('xla') self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device) self.model = torchvision.models.resnet50().to(self.device) self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) diff --git a/plugins/cpu/README.md b/plugins/cpu/README.md index 76c9d0b7c88e..d3771094768d 100644 --- a/plugins/cpu/README.md +++ b/plugins/cpu/README.md @@ -38,5 +38,5 @@ plugins.use_dynamic_plugins() plugins.register_plugin('CPU', torch_xla_cpu_plugin.CpuPlugin()) xr.set_device_type('CPU') -print(torch_xla.device()) +print(torch.device('xla')) ``` diff --git a/plugins/cuda/README.md b/plugins/cuda/README.md index 45a002e06f6c..d3760610046c 100644 --- a/plugins/cuda/README.md +++ b/plugins/cuda/README.md @@ -35,5 +35,5 @@ plugins.use_dynamic_plugins() plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin()) xr.set_device_type('CUDA') -print(torch_xla.device()) +print(torch.device('xla')) ``` diff --git a/test/bench.py b/test/bench.py index e5eff86a34d5..bb68dcda052e 100644 --- a/test/bench.py +++ b/test/bench.py @@ -29,7 +29,7 @@ class BaseBench(object): def __init__(self, args): self.args = args - self.device = torch_xla.device() + self.device = torch.device('xla') self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0) torch.manual_seed(42) diff --git a/test/debug_tool/test_mp_pt_xla_debug.py b/test/debug_tool/test_mp_pt_xla_debug.py index 785554657b14..baf58cea6dfd 100644 --- a/test/debug_tool/test_mp_pt_xla_debug.py +++ b/test/debug_tool/test_mp_pt_xla_debug.py @@ -16,7 +16,7 @@ def _mp_fn(index): assert False, "This test should be run with PT_XLA_DEBUG_FILE" if index == 0: open(debug_file_name, 'w').close() - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(10, 10, device=device) t2 = t1 * 100 torch_xla.sync() diff --git a/test/debug_tool/test_pt_xla_debug.py b/test/debug_tool/test_pt_xla_debug.py index 4ebcb2cd1bb9..54abfb98a3b5 100644 --- a/test/debug_tool/test_pt_xla_debug.py +++ b/test/debug_tool/test_pt_xla_debug.py @@ -31,7 +31,7 @@ def setUpClass(cls): def test_eager_sync(self): with torch_xla.experimental.eager_mode_context(True): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 9, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: @@ -41,7 +41,7 @@ def test_eager_sync(self): open(self.debug_file_name, 'w').close() def test_user_sync(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(2, 2, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: @@ -79,7 +79,7 @@ def test_user_sync(self): open(self.debug_file_name, 'w').close() def test_step_trace(self): - device = torch_xla.device() + device = torch.device('xla') with xp.StepTrace('train_pt_xla_debug'): t1 = torch.randn(3, 3, device=device) with open(self.debug_file_name, 'rb') as f: @@ -111,7 +111,7 @@ def test_step_trace(self): open(self.debug_file_name, 'w').close() def test_dynamo(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 4, device=device) def toy_program(t1): @@ -161,7 +161,7 @@ def toy_program(t1): open(self.debug_file_name, 'w').close() def test_torch_xla_compile(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(12, 4, device=device) def toy_program(t1): @@ -209,7 +209,7 @@ def toy_program(t1): open(self.debug_file_name, 'w').close() def test_torch_xla_compile_custom_name(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(18, 4, device=device) def toy_program2(t1): @@ -239,7 +239,7 @@ def toy_program2(t1): open(self.debug_file_name, 'w').close() def test_parallel_loader(self): - device = torch_xla.device() + device = torch.device('xla') train_dataset_len = 100 batch_size = 10 @@ -287,7 +287,7 @@ def test_parallel_loader(self): open(self.debug_file_name, 'w').close() def test_print(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5, device=device) print(t1) with open(self.debug_file_name, 'rb') as f: @@ -315,7 +315,7 @@ def test_print(self): open(self.debug_file_name, 'w').close() def test_frame(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(6, 6, device=device) torch_xla.sync() with open(self.debug_file_name, 'rb') as f: diff --git a/test/distributed_util.py b/test/distributed_util.py index 85069aaabc82..32f04712575e 100644 --- a/test/distributed_util.py +++ b/test/distributed_util.py @@ -101,7 +101,7 @@ def ddp_correctness(init_method: str = 'env://', dist.init_process_group("xla", init_method=init_method) rank, world_size = dist.get_rank(), dist.get_world_size() - device = torch_xla.device() + device = torch.device('xla') # Module initialization is not thread safe. Force threads to initialize one # at a time with the same seed diff --git a/test/ds/test_dynamic_shape_models.py b/test/ds/test_dynamic_shape_models.py index 114c41e5c829..7f5e50a838d8 100644 --- a/test/ds/test_dynamic_shape_models.py +++ b/test/ds/test_dynamic_shape_models.py @@ -17,7 +17,7 @@ # It enables us to run python implementations of CompositeAutogradImplicit ops. # CompositeAutogradImplicit means we don't have an explicit backward formula for an op instead an op is composed of a bunch of ops that do have backward formulas and combines this formulas is equivalent to differentiating the op explicitly. pd = torch._C._EnablePythonDispatcher() -xla_dev = torch_xla.device() +xla_dev = torch.device('xla') class Feedforward(torch.nn.Module): diff --git a/test/ds/test_dynamic_shapes.py b/test/ds/test_dynamic_shapes.py index 46f329de4537..2d9e4d5bb7a2 100644 --- a/test/ds/test_dynamic_shapes.py +++ b/test/ds/test_dynamic_shapes.py @@ -10,7 +10,7 @@ import test_utils pd = torch._C._EnablePythonDispatcher() -dev = torch_xla.device() +dev = torch.device('xla') class TestDynamicShapes(test_utils.XlaTestCase): @@ -192,7 +192,7 @@ def test_nonzero_cast(self): torch_xla.sync() def test_expand_symint_correctness(self): - dev = torch_xla.device() + dev = torch.device('xla') size1 = 5 size2 = 2 t1 = torch.ones([size1, size2]) diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 5aa57abd3575..feb5898d9d80 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -116,7 +116,7 @@ def unwrap(cont): def make_reuse_graph_test(module_class, niter=100): def test_wrapper(self): - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') xla_module = module_class().to(device=xla_dev) inputs = tuple(x.to(device=xla_dev) for x in xla_module.get_random_inputs()) metrics.clear_counters() @@ -187,7 +187,7 @@ def make_training_test(model_cls): def test_wrapper(self): import torch_xla.core.xla_model as xm - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') model = model_cls() inputs = model.get_random_inputs() @@ -240,7 +240,7 @@ class Emb(torch.nn.Embedding): def __init__(self): super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0) - device = torch_xla.device() + device = torch.device('xla') module = Emb() module.to(device) @@ -255,7 +255,7 @@ def test_inputs_not_computed(self): def foo(x): return x * 2 - device = torch_xla.device() + device = torch.device('xla') x = torch.rand(5, device=device) x = x.unsqueeze(dim=-1) self._compile_and_check(foo, (x,)) @@ -265,7 +265,7 @@ def test_factory_copy(self): def foo(device): return torch.arange(5, device="cpu").to(device) - self._compile_and_check(foo, (torch_xla.device(),)) + self._compile_and_check(foo, (torch.device('xla'),)) def test_index_flag_unsupported(self): # The indices of the index operation are represented as @@ -277,7 +277,7 @@ def test_index_flag_unsupported(self): def foo(xt, t): return xt[t] - device = torch_xla.device() + device = torch.device('xla') xt = torch.rand(5, device=device) t = torch.randint(0, 5, (3,)) self._compile_and_check(foo, (xt, t)) @@ -299,7 +299,7 @@ def test_cpu_flag_unsupported(self): def foo(t): return t.cpu() - device = torch_xla.device() + device = torch.device('xla') t = torch.randint(0, 5, (3,), device=device) self._compile_and_check(foo, (t,)) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 572d255514a6..d7c55c1c6405 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -89,7 +89,7 @@ def test_sync_after_dynamo(self): head_dim = 128 running = 16 - device = torch_xla.device() + device = torch.device('xla') cache = torch.rand((cache_len, kv_heads, head_dim)).to(device) update_indices = torch.randint( 0, cache_len, (running,), dtype=torch.long).to(device) @@ -116,7 +116,7 @@ def copy_a_to_b(a): copy = torch.ops.aten.copy_.default(a, res) return copy - device = torch_xla.device() + device = torch.device('xla') compiled_copy = torch.compile(copy_a_to_b, backend=backend) a = torch.randn(2, 9).to(device) res = compiled_copy(a) @@ -150,7 +150,7 @@ def fn_simple(self, x, y): def _choose_proper_device(self, initialize_on_cuda): if not initialize_on_cuda: - return torch_xla.device() + return torch.device('xla') assert initialize_on_cuda if xr.device_type() != "CUDA" or not torch.cuda.is_available(): @@ -164,7 +164,7 @@ def _choose_proper_device(self, initialize_on_cuda): @skipOnNeuron def test_simple_model(self): - device = torch_xla.device() + device = torch.device('xla') x = torch.tensor(100.0) y = torch.tensor(200.0) xla_x = x.to(device) @@ -413,7 +413,7 @@ def test_resnet18(self, initialize_on_cuda, backend): @skipOnNeuron def test_resnet18_lazy_vs_dynamo(self): sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) - device = torch_xla.device() + device = torch.device('xla') loader = self.get_loader(device, sample_count) resnet18_base = torchvision.models.resnet18() resnet18_base.eval() @@ -448,7 +448,7 @@ def fn_fallback(t): torch._dynamo.reset() met.clear_all() - device = torch_xla.device() + device = torch.device('xla') # Initial tracing dynamo_fn = torch.compile(fn_fallback, backend="openxla") @@ -488,7 +488,7 @@ def fn_fallback(t): torch._dynamo.reset() met.clear_all() - device = torch_xla.device() + device = torch.device('xla') # Initial tracing dynamo_fn = torch.compile(fn_fallback, backend="openxla") @@ -541,7 +541,7 @@ def train_model(self, model, data, target): def test_simple_model(self): torch._dynamo.reset() - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(3, 5, requires_grad=True) xla_input = input.detach().to(device) xla_input.requires_grad = True @@ -577,7 +577,7 @@ def test_simple_model(self): def test_resnet18(self): torch._dynamo.reset() met.clear_counters() - device = torch_xla.device() + device = torch.device('xla') batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = xu.SampleGenerator( @@ -650,7 +650,7 @@ def train_model(self, model, data, target, optimizer): def test_simple_model(self): torch._dynamo.reset() - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(3, 5, requires_grad=True) saved_input = input.detach().to(device).cpu() xla_input = input.detach().to(device) @@ -673,7 +673,7 @@ def test_simple_model(self): def test_resnet18(self): torch._dynamo.reset() met.clear_counters() - device = torch_xla.device() + device = torch.device('xla') batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = xu.SampleGenerator( @@ -732,7 +732,7 @@ def test_resnet18(self): class DynamoErrorMessageTest(parameterized.TestCase): def test_mixed_cpu_tensor(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(4, 3, 224, 224) input_xla = input.clone().to(device) resnet18 = torchvision.models.resnet18() diff --git a/test/dynamo/test_dynamo_aliasing.py b/test/dynamo/test_dynamo_aliasing.py index 36bfb5744bd4..709186bec02c 100644 --- a/test/dynamo/test_dynamo_aliasing.py +++ b/test/dynamo/test_dynamo_aliasing.py @@ -11,7 +11,7 @@ class TestBufferDonationUtil(unittest.TestCase): def test_hash_with_buffer_donor(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) res = torch.cos(input) hash_no_donor = torch_xla._XLAC._get_graph_hash([res]) @@ -40,7 +40,7 @@ def dummy_mul(self, input): return input * 1.1 def test_manual_buffer_donation(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_mul_compiled = torch.compile( @@ -55,7 +55,7 @@ def test_manual_buffer_donation(self): torch.allclose(input_cloned.cpu() * 1.1, input.cpu()) def test_manual_buffer_donation_for_non_inplce_op(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_mul_compiled = torch.compile(self.dummy_mul, backend='openxla') @@ -81,7 +81,7 @@ def dummy_inplace(input): torch.ops.xla.dynamo_set_buffer_donor_(input, True) input += (0.5 * torch.sin(input)) - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') @@ -109,7 +109,7 @@ def dummy_add(self, input): return input + 1 def test_manual_buffer_donation(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile( @@ -127,7 +127,7 @@ def test_manual_buffer_donation(self): self.assertFalse(torch_xla._XLAC._get_buffer_donation(input)) def test_manual_buffer_donation_for_non_inplce_op(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_add_compiled = torch.compile(self.dummy_add, backend='openxla') @@ -152,7 +152,7 @@ def test_manual_buffer_donation_for_inplce_op_repeat(self): def dummy_inplace(input): input += (0.3 * torch.cos(input)) - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) input_cloned = input.cpu().to(device) dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') @@ -174,7 +174,7 @@ def dummy_inplace(input): self.assertEqual(met.metric_data('CompileTime')[0], 1) def test_buffer_donation_on_non_data_tensor(self): - device = torch_xla.device() + device = torch.device('xla') input = torch.randn(5, 5).to(device) res = input + 1 diff --git a/test/dynamo/test_dynamo_config.py b/test/dynamo/test_dynamo_config.py index 66f21cc84e91..d67c32acd475 100644 --- a/test/dynamo/test_dynamo_config.py +++ b/test/dynamo/test_dynamo_config.py @@ -10,7 +10,7 @@ def dummy_test(self, a): return a.cos().sin() def test_config_skip_input_data_check(self): - device = torch_xla.device() + device = torch.device('xla') print(config.skip_input_data_check) config.skip_input_data_check = True compiled_dummy = torch.compile(self.dummy_test, backend="openxla") diff --git a/test/dynamo/test_dynamo_dynamic_shape.py b/test/dynamo/test_dynamo_dynamic_shape.py index 1aa6905261f7..b475cc4fa904 100644 --- a/test/dynamo/test_dynamo_dynamic_shape.py +++ b/test/dynamo/test_dynamo_dynamic_shape.py @@ -45,7 +45,7 @@ def _get_linear_and_input(self, in_dim: int, out_dum: int, batch_dim: int, def test_dynamic_shape_basic(self): torch_xla.manual_seed(100) - device = torch_xla.device() + device = torch.device('xla') # model setup dummy_linear, dummy_linear_xla, input, input_xla = self._get_linear_and_input( 10, 20, 20, device) @@ -78,7 +78,7 @@ def test_dynamic_shape_basic(self): def test_dynamic_shape_basic_with_mark_dynamic(self): torch_xla.manual_seed(100) - device = torch_xla.device() + device = torch.device('xla') # model setup dummy_linear, dummy_linear_xla, input, input_xla = self._get_linear_and_input( 10, 40, 40, device) @@ -123,7 +123,7 @@ def test_dynamic_shape_basic_with_mark_dynamic(self): def test_dynamic_shape_multiple_batchs(self): torch_xla.manual_seed(100) - device = torch_xla.device() + device = torch.device('xla') # model setup in_dim = 16 out_dum = 32 @@ -180,7 +180,7 @@ def test_dynamic_shape_multiple_batchs(self): def test_dynamic_shape_mix_with_non_dynamic(self): torch_xla.manual_seed(100) - device = torch_xla.device() + device = torch.device('xla') # model setup in_dim = 15 out_dum = 31 @@ -238,7 +238,7 @@ def test_dynamic_shape_mix_with_non_dynamic(self): self.assertEqual(met.metric_data('ExecuteTime')[0], 1) def test_dynamic_decoder(self): - device = torch_xla.device() + device = torch.device('xla') config = DecoderOnlyConfig() config.num_hidden_layers = 2 config.hidden_size = 512 @@ -257,7 +257,7 @@ def test_dynamic_decoder(self): self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 2) def test_dynamic_shape_decoder_mark_dynamic(self): - device = torch_xla.device() + device = torch.device('xla') config = DecoderOnlyConfig() config.num_hidden_layers = 2 config.hidden_size = 512 @@ -276,7 +276,7 @@ def test_dynamic_shape_decoder_mark_dynamic(self): self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 2) def test_dynamic_shape_no_retracing(self): - device = torch_xla.device() + device = torch.device('xla') # model setup _, dummy_linear_xla, _, input_xla = self._get_linear_and_input( 8, 10, 20, device) @@ -295,7 +295,7 @@ def test_dynamic_shape_no_retracing(self): "Skip right now because with torch._dynamo.config.inline_inbuilt_nn_modules = True, dynamic compiles takes minutes for resnet18." ) def test_dynamic_shape_resnet18(self): - device = torch_xla.device() + device = torch.device('xla') sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) loader = self._get_loader(device, sample_count, batch_size=4) diff --git a/test/dynamo/test_dynamo_graph_dump.py b/test/dynamo/test_dynamo_graph_dump.py index ae0383a47963..5b35221fcea2 100644 --- a/test/dynamo/test_dynamo_graph_dump.py +++ b/test/dynamo/test_dynamo_graph_dump.py @@ -27,7 +27,7 @@ def test_dump_graph_with_dynamo_execution(self): if not save_file: assert False, "This test should be run with XLA_SAVE_TENSORS_FILE" save_file += '.0' - device = torch_xla.device() + device = torch.device('xla') xla_x = torch.tensor(100.0).to(device) xla_y = torch.tensor(200.0).to(device) res_xla_dynamo = self.fn_simple_dynamo(xla_x, xla_y) diff --git a/test/dynamo/test_dynamo_integrations_util.py b/test/dynamo/test_dynamo_integrations_util.py index 293bef17ec05..04d1615817d4 100644 --- a/test/dynamo/test_dynamo_integrations_util.py +++ b/test/dynamo/test_dynamo_integrations_util.py @@ -20,7 +20,7 @@ class PybindTest(unittest.TestCase): def test_get_tensors_xla_device_data_node(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.randn(20, 5).to(xla_device) t2 = torch.randn(20, 5).to(xla_device) t3 = t2 + t1 @@ -42,7 +42,7 @@ def test_get_tensors_xla_device_data_node(self): assert (expected_tensor_ids == sorted(res_pair[0])) def test_get_base_seed_as_tensor(self): - device = torch_xla.device() + device = torch.device('xla') xm.set_rng_state(23, str(device)) base_seed = torch_xla._XLAC._get_base_seed_as_tensor(str(device)).item() self.assertEqual(23, base_seed) @@ -51,7 +51,7 @@ def test_get_seed_info_id(self): self.assertEqual(torch_xla._XLAC._get_seed_info_id(), -127389) def test_check_tensor_need_materialization(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.randn(20, 5) assert (torch_xla._XLAC._check_tensor_need_materialization([t1]) == [False]) t1 = t1.to(xla_device) @@ -67,7 +67,7 @@ def test_check_tensor_need_materialization(self): assert (torch_xla._XLAC._check_tensor_need_materialization([t1]) == [True]) def test_get_graph_hash(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_input = torch.randn(64, 256, 14, 14).to(xla_device) xla_dummy_model = dummy_model.to(xla_device) xla_out = xla_dummy_model(xla_input) @@ -85,7 +85,7 @@ def test_get_graph_hash(self): assert (hash == torch_xla._XLAC._get_graph_hash([xla_out_2])) def test_clear_pending_irs(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') torch_xla.sync() t1 = torch.randn(20, 5).to(xla_device) t2 = torch.randn(20, 5).to(xla_device) @@ -104,7 +104,7 @@ def test_clear_pending_irs(self): self.assertEqual(met.metric_data('ExecuteTime')[0], 1) def test_run_cached_graph(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_input = torch.randn(64, 256, 14, 14).to(xla_device) xla_dummy_model = dummy_model.to(xla_device) xla_out = xla_dummy_model(xla_input) diff --git a/test/dynamo/test_graph_input_matcher.py b/test/dynamo/test_graph_input_matcher.py index 70dd0be73f57..7a03139ce029 100644 --- a/test/dynamo/test_graph_input_matcher.py +++ b/test/dynamo/test_graph_input_matcher.py @@ -24,7 +24,7 @@ def get_example_inputs(self): class TestGraphInputMatcher(unittest.TestCase): def test_no_cache_fx_gragh_inputs(self): - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') model = M().to(device=xla_dev) inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs()) diff --git a/test/dynamo/test_num_output.py b/test/dynamo/test_num_output.py index b540e0691643..77081e3f2c5e 100644 --- a/test/dynamo/test_num_output.py +++ b/test/dynamo/test_num_output.py @@ -59,7 +59,7 @@ def get_example_inputs(self): class TestNumOutput(unittest.TestCase): def do_test(self, model_class, expected_num_output): - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') model = model_class().to(device=xla_dev) inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs()) diff --git a/test/dynamo/test_traceable_collectives.py b/test/dynamo/test_traceable_collectives.py index 45bd89266604..58cdd092cd61 100644 --- a/test/dynamo/test_traceable_collectives.py +++ b/test/dynamo/test_traceable_collectives.py @@ -18,7 +18,7 @@ def collective_broadcast_and_cos(input, src): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): print(f'skip this test for hw {xm.xla_device_hw(device)}') diff --git a/test/eager/test_eager.py b/test/eager/test_eager.py index 552382a2dc39..48acb0958ed4 100644 --- a/test/eager/test_eager.py +++ b/test/eager/test_eager.py @@ -20,7 +20,7 @@ def test_eager_basic(self): xm.wait_device_ops() met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # For some reason randn will also trigger an execution of # size [5, 5] full of 0. @@ -36,7 +36,7 @@ def test_eager_basic(self): def test_eager_recompile(self): self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5, device=device) xm.wait_device_ops() @@ -55,7 +55,7 @@ def test_eager_recompile(self): def test_eager_in_place(self): self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5, device=device) xm.wait_device_ops() @@ -67,7 +67,7 @@ def test_eager_in_place(self): def test_eager_random_seed(self): self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') met.clear_all() t1 = torch.randn(12, 13, device=device) @@ -82,7 +82,7 @@ def test_eager_random_seed(self): def test_eager_set_random_seed(self): self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') old_seed = 1234 xm.set_rng_state(old_seed) @@ -95,7 +95,7 @@ def test_eager_set_random_seed(self): def test_batch_norm_execute_once(self): xm.wait_device_ops() - device = torch_xla.device() + device = torch.device('xla') m = nn.BatchNorm2d(16).to(device) m.train() input = torch.randn(8, 16, 8, 32).to(device) @@ -112,7 +112,7 @@ def test_batch_norm_execute_once(self): torch_xla._XLAC._get_xla_tensor_debug_info(m.running_mean)) def test_svd_execute_once(self): - device = torch_xla.device() + device = torch.device('xla') a = torch.randn(5, 3).to(device) xm.wait_device_ops() met.clear_all() diff --git a/test/eager/test_eager_all_reduce_in_place.py b/test/eager/test_eager_all_reduce_in_place.py index 7ea68b7fb6e4..5349212bea0c 100644 --- a/test/eager/test_eager_all_reduce_in_place.py +++ b/test/eager/test_eager_all_reduce_in_place.py @@ -10,7 +10,7 @@ def _mp_fn(index): import torch_xla torch_xla.experimental.eager_mode(True) - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): return diff --git a/test/eager/test_eager_spmd.py b/test/eager/test_eager_spmd.py index 3b05ba7af652..36a8faa931f5 100644 --- a/test/eager/test_eager_spmd.py +++ b/test/eager/test_eager_spmd.py @@ -39,7 +39,7 @@ def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None): return xs.Mesh(device_ids, mesh_shape, axis_names) def test_eager_spmd_basic(self): - device = torch_xla.device() + device = torch.device('xla') mesh = self._get_mesh((self.n_devices,), axis_names=('data',)) torch.manual_seed(100) linear = torch.nn.Linear(10, 20) @@ -52,7 +52,7 @@ def test_eager_spmd_basic(self): self.assertTrue(torch.allclose(res, res_xla.cpu(), atol=1e-2)) def test_module_to_empty_sharding(self): - device = torch_xla.device() + device = torch.device('xla') mlinear = MultiLinear() mlinear.to(device) torch_xla._XLAC._get_xla_sharding_spec(mlinear.linear1.weight) diff --git a/test/eager/test_eager_with_torch_compile.py b/test/eager/test_eager_with_torch_compile.py index e7604658aa5e..c66fbda1bbc6 100644 --- a/test/eager/test_eager_with_torch_compile.py +++ b/test/eager/test_eager_with_torch_compile.py @@ -19,7 +19,7 @@ def dummy_cos_sin(self, tensor): def test_eager_with_compile_basic(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # this part happens eagerly t1 = torch.randn(5, 5, device=device) @@ -38,7 +38,7 @@ def test_eager_with_compile_basic(self): def test_eager_execute_compiled_multiple_times(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # this part happens eagerly t1 = torch.randn(10, 5, device=device) t1.add_(0.5) diff --git a/test/eager/test_eager_with_xla_compile.py b/test/eager/test_eager_with_xla_compile.py index 5aee35b2a12d..3d4b88ce0dfa 100644 --- a/test/eager/test_eager_with_xla_compile.py +++ b/test/eager/test_eager_with_xla_compile.py @@ -29,7 +29,7 @@ def dummy_graph_break(self, t): def test_eager_with_compile_basic(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # this part happens eagerly t1 = torch.randn(5, 5, device=device) @@ -54,7 +54,7 @@ def test_eager_with_compile_basic(self): def test_eager_execute_compiled_multiple_times(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') # this part happens eagerly t1 = torch.randn(10, 5, device=device) t1.add_(0.5) @@ -69,7 +69,7 @@ def test_eager_execute_compiled_multiple_times(self): def test_eager_with_compile_graph_break(self): met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5, device=device) with self.assertRaisesRegex( diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 28b0b7709060..1c7176df8013 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -17,7 +17,7 @@ class TestXMCollectiveOpsTpu(parameterized.TestCase): @staticmethod def _broadcast(sync): torch.manual_seed(xr.global_ordinal()) - device = torch_xla.device() + device = torch.device('xla') model = nn.Linear(5, 5).to(device) if sync: xm.broadcast_master_param(model) @@ -41,7 +41,7 @@ def test_broadcast_master_param(self, sync): @staticmethod def _all_reduce(pin_layout): - device = torch_xla.device() + device = torch.device('xla') # Prevent 0 and 1 from being converted to constants ordinal = xm.send_cpu_data_to_device( torch.tensor( @@ -63,7 +63,7 @@ def test_all_reduce(self, pin_layout): @staticmethod def _all_gather(pin_layout): - device = torch_xla.device() + device = torch.device('xla') ordinal = torch.tensor([xr.global_ordinal()], device=device) out = xm.all_gather(ordinal, pin_layout=pin_layout) torch_xla.sync() @@ -80,7 +80,7 @@ def test_all_gather(self, pin_layout): @staticmethod def _reduce_scatter(pin_layout): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() tensor = -torch.arange(world_size, dtype=torch.float32).to(device) @@ -105,7 +105,7 @@ def test_reduce_scatter(self, pin_layout): @staticmethod def _all_to_all(pin_layout): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() tensor = torch.cat( @@ -151,7 +151,7 @@ def callable(input): return input dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') input = torch.tensor([xr.global_ordinal()], dtype=torch.float, device=device) @@ -175,7 +175,7 @@ def callable(output, input): return output dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') input = torch.tensor([xr.global_ordinal()], dtype=torch.float, device=device) @@ -200,7 +200,7 @@ def callable(output, input): def _all_gather(use_dynamo: bool): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') def callable(input): output_tensor = [ @@ -229,7 +229,7 @@ def callable(input): def _reduce_scatter(use_dynamo: bool): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') def callable(output, input): dist.reduce_scatter_tensor(output, input) @@ -254,7 +254,7 @@ def callable(output, input): def _all_to_all_single(use_dynamo: bool, split_size: int = 1): met.clear_all() dist.init_process_group("xla", init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') def callable(output, input): dist.all_to_all_single(output, input) diff --git a/test/pjrt/test_ddp.py b/test/pjrt/test_ddp.py index d93bbe45c4d9..62e24b804af2 100644 --- a/test/pjrt/test_ddp.py +++ b/test/pjrt/test_ddp.py @@ -26,7 +26,7 @@ class TestPjRtDistributedDataParallel(parameterized.TestCase): @staticmethod def _ddp_init(index: int = ...): dist.init_process_group('xla', init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') model = nn.Linear(10, 10).to(device) ddp_model = DDP(model) diff --git a/test/pjrt/test_profiler.py b/test/pjrt/test_profiler.py index 17892261119a..f9189aa50342 100644 --- a/test/pjrt/test_profiler.py +++ b/test/pjrt/test_profiler.py @@ -33,12 +33,12 @@ class TestPjRtProfiler(absltest.TestCase): def setUp(self): # HACK: ensure libtpu is loaded if using TPU - torch_xla.device() + torch.device('xla') def test_profiler_output(self): tempdir = self.create_tempdir().full_path - device = torch_xla.device() + device = torch.device('xla') ones = torch.ones([5]) with _profile(tempdir): xones = ones.to(device) diff --git a/test/pjrt/test_runtime_multi_cpu.py b/test/pjrt/test_runtime_multi_cpu.py index 25c3280ce4b5..6090066d05f9 100644 --- a/test/pjrt/test_runtime_multi_cpu.py +++ b/test/pjrt/test_runtime_multi_cpu.py @@ -65,7 +65,7 @@ def forward(ctx, x): def backward(ctx, grad_output): results['forward_ordinal'] = ctx.forward_ordinal results['backward_ordinal'] = xr.global_ordinal() - results['device'] = str(torch_xla.device()) + results['device'] = str(torch.device('xla')) return grad_output x = torch.ones(1, requires_grad=True, device='xla') diff --git a/test/pjrt/test_runtime_multi_gpu.py b/test/pjrt/test_runtime_multi_gpu.py new file mode 100644 index 000000000000..25d967e363fd --- /dev/null +++ b/test/pjrt/test_runtime_multi_gpu.py @@ -0,0 +1,266 @@ +import concurrent.futures +import itertools +import os +import queue +import requests +import unittest + +import numpy as np +import torch +import torch.nn as nn +import torch_xla +import torch_xla.core.xla_env_vars as xenv +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +from torch_xla import runtime as xr +from torch_xla._internal import pjrt +from absl.testing import absltest, parameterized + + +@unittest.skipIf(xr.device_type() != "CUDA", + f"GPU tests should only run on GPU devices.") +class TestExperimentalPjrtMultiGpu(parameterized.TestCase): + + def setUp(self): + xr.set_device_type('CUDA') + + os.environ.update({ + xenv.PJRT_GPU_ASYNC_CLIENT: 'true', + }) + + def test_default_gpu_device(self): + os.environ.pop(xenv.PJRT_GPU_ASYNC_CLIENT, None) + + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = {i: torch.device(f'xla:0') for i in range(num_devices)} + devices_per_process = pjrt.run_multiprocess(xm.xla_device) + self.assertDictEqual(devices_per_process, expected) + + def test_multi_gpu_devices(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = {i: torch.device(f'xla:0') for i in range(num_devices)} + + devices_per_process = pjrt.run_multiprocess(xm.xla_device) + self.assertDictEqual(devices_per_process, expected) + + def test_global_ordinal(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = [i for i in range(num_devices)] + + results = pjrt.run_multiprocess(xr.global_ordinal) + self.assertListEqual(sorted(results.values()), expected) + + def test_local_ordinal(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = [i for i in range(num_devices)] + + results = pjrt.run_multiprocess(xr.local_ordinal) + self.assertListEqual(sorted(results.values()), expected) + + def test_global_device_count(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = {i: num_devices for i in range(num_devices)} + results = pjrt.run_multiprocess(xr.global_device_count) + self.assertEqual(expected, results) + + def test_local_process_count(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = {i: num_devices for i in range(num_devices)} + results = pjrt.run_multiprocess(xr.local_process_count) + self.assertEqual(expected, results) + + def test_world_size(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = {i: num_devices for i in range(num_devices)} + results = pjrt.run_multiprocess(xr.world_size) + self.assertEqual(expected, results) + + def test_addressable_device_count(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = {i: 1 for i in range(num_devices)} + results = pjrt.run_multiprocess(xr.addressable_device_count) + self.assertEqual(expected, results) + + def test_addressable_runtime_device_count(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = {i: 1 for i in range(num_devices)} + results = pjrt.run_multiprocess(xr.addressable_runtime_device_count) + self.assertEqual(expected, results) + + def test_local_device_count(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + # xr.local_process_count() is 2, xr.addressable_device_count() is 1. + expected = {i: num_devices for i in range(num_devices)} + results = pjrt.run_multiprocess(xr.local_device_count) + self.assertEqual(expected, results) + + def test_process_index(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = {i: i for i in range(num_devices)} + results = pjrt.run_multiprocess(xr.process_index) + self.assertEqual(expected, results) + + def test_process_count(self): + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + expected = {i: num_devices for i in range(num_devices)} + results = pjrt.run_multiprocess(xr.process_count) + self.assertEqual(expected, results) + + @staticmethod + def _multi_gpu_backwards(): + results = {} + + class _CustomBackwards(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + ordinal = xr.global_ordinal() + ctx.forward_ordinal = ordinal + return x + + @staticmethod + def backward(ctx, grad_output): + results['forward_ordinal'] = ctx.forward_ordinal + results['backward_ordinal'] = xr.global_ordinal() + results['device'] = str(torch.device('xla')) + return grad_output + + x = torch.ones(1, requires_grad=True, device='xla') + y = _CustomBackwards.apply(x) + y.backward() + torch_xla.sync() + + return results + + def test_multi_gpu_backwards(self): + os.environ.update({ + xenv.PJRT_GPU_ASYNC_CLIENT: 'true', + }) + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + + expected = { + i: { + 'forward_ordinal': i, + 'backward_ordinal': i, + 'device': f'xla:0' + } for i in range(num_devices) + } + results = pjrt.run_multiprocess(self._multi_gpu_backwards) + + self.assertDictEqual(results, expected) + + @staticmethod + def _spawn(index: int, queue: queue.Queue): + queue.put(index) + + @parameterized.named_parameters(('xmp', xmp.spawn), ('pjrt', pjrt.spawn)) + def test_spawn(self, spawn): + manager = torch.multiprocessing.Manager() + num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) + queue = manager.Queue(num_devices) + spawn(self._spawn, args=(queue,)) + + indices = sorted(queue.get(block=False) for _ in range(queue.qsize())) + self.assertListEqual(indices, list(range(num_devices))) + + @staticmethod + def _broadcast(sync): + torch.manual_seed(xr.global_ordinal()) + device = torch.device('xla') + model = nn.Linear(5, 5).to(device) + if sync: + xm.broadcast_master_param(model) + + torch_xla.sync() + return next(model.parameters()).detach().cpu().numpy() + + @parameterized.named_parameters(('synchronized_parameters', True), + ('unsynchronized_parameters', False)) + def test_broadcast_master_param(self, sync): + results = pjrt.run_multiprocess(self._broadcast, sync) + master_params = results[0] + for ordinal, worker_params in results.items(): + if sync: + np.testing.assert_array_equal(master_params, worker_params) + elif ordinal != 0: + np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, + master_params, worker_params) + + @staticmethod + def _all_gather(pin_layout): + device = torch.device('xla') + ordinal = torch.tensor([xr.global_ordinal()], device=device) + out = xm.all_gather(ordinal, pin_layout=pin_layout) + torch_xla.sync() + + return out.cpu().numpy() + + @parameterized.named_parameters(('pinned', True), ('unpinned', False)) + def test_all_gather(self, pin_layout): + results = pjrt.run_multiprocess(self._all_gather, pin_layout) + + expected = list(range(len(results))) + for v in results.values(): + np.testing.assert_array_equal(v, expected) + + @staticmethod + def _reduce_scatter(pin_layout): + device = torch.device('xla') + world_size = xr.world_size() + tensor = -torch.arange(world_size, dtype=torch.float32).to(device) + + out = xm.reduce_scatter( + xm.REDUCE_SUM, + tensor, + scale=1.0 / world_size, + scatter_dim=0, + shard_count=world_size, + pin_layout=pin_layout, + ) + torch_xla.sync() + + return out.cpu().numpy() + + # 2023-08-02 04:16:36.520884: F external/xla/xla/service/layout_assignment.cc:157] Check failed: ShapeUtil::Compatible(shape_layout.shape(), instruction->operand(operand_no)->shape()) f32[1]{0} is not compatible with f32[2]{0} (for operand 0 of instruction %reduce-scatter.10 = f32[1]{0} reduce-scatter(f32[2]{0} %add.5), replica_groups={}, constrain_layout=true, dimensions={0}, to_apply=%AddComputation.6) + @parameterized.named_parameters(('pinned', True), ('unpinned', False)) + def test_reduce_scatter(self, pin_layout): + results = pjrt.run_multiprocess(self._reduce_scatter, pin_layout) + + for ordinal, value in results.items(): + np.testing.assert_array_equal(value, [-ordinal]) + + @staticmethod + def _all_to_all(pin_layout): + device = torch.device('xla') + world_size = xr.world_size() + + tensor = torch.cat( + [ + -torch.arange(world_size, dtype=torch.float32).view(-1, 1, 1), + torch.ones(world_size, 1, 1) * xr.global_ordinal(), + ], + dim=1, + ).to(device) + torch_xla.sync() + + out = xm.all_to_all( + tensor, + split_dimension=0, + concat_dimension=2, + split_count=world_size, + pin_layout=pin_layout, + ) + + return out.cpu().numpy() + + @parameterized.named_parameters(('pinned', True), ('unpinned', False)) + def test_all_to_all(self, pin_layout): + results = pjrt.run_multiprocess(self._all_to_all, pin_layout) + + for ordinal, value in results.items(): + np.testing.assert_array_equal(value, [[[-ordinal] * len(results), + list(range(len(results)))]]) + + +if __name__ == '__main__': + absltest.main() diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 89ad676ca383..aa039166ae67 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -172,7 +172,7 @@ def test_local_ordinal_with_discontiguous_global_ordinal_v4_threaded(self): @staticmethod def _spawn_threads() -> Dict[int, torch.device]: results = {} - pjrt.spawn_threads(lambda i: results.setdefault(i, torch_xla.device())) + pjrt.spawn_threads(lambda i: results.setdefault(i, torch.device('xla'))) return results @@ -187,7 +187,7 @@ def test_spawn_threads(self): @staticmethod def _spawn_error(): # Initialize the client in the parent process - torch_xla.device() + torch.device('xla') torch_xla.launch(xm.xla_device) @@ -199,7 +199,7 @@ def test_spawn_error(self): @staticmethod def _runtime_device_attributes(): - return xr.runtime_device_attributes(str(torch_xla.device())) + return xr.runtime_device_attributes(str(torch.device('xla'))) def test_runtime_device_attributes(self): result = pjrt.run_multiprocess(self._runtime_device_attributes) @@ -226,7 +226,7 @@ def test_global_runtime_device_attributes(self): @staticmethod def _execute_time_metric(): # Initialize the client before starting the timer. - torch_xla.device() + torch.device('xla') begin = time.perf_counter_ns() value = ( diff --git a/test/pjrt/test_train_hf_transformer.py b/test/pjrt/test_train_hf_transformer.py index d484edc0a6ce..93d932bab7a8 100644 --- a/test/pjrt/test_train_hf_transformer.py +++ b/test/pjrt/test_train_hf_transformer.py @@ -55,7 +55,7 @@ def finetune(rank, train_dataset, test_dataset, tokenizer, flags): drop_last=True, generator=rng) - device = torch_xla.device() + device = torch.device('xla') model = AutoModelForSequenceClassification.from_pretrained( 'google-bert/bert-base-cased', num_labels=5) model.to(device) diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 3355f8efba99..72e1f621b54d 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -559,7 +559,7 @@ def _alt_lookup(d, keys, defval): def instantiate_test(cls, name, test, *, generic_cls): test_name = name + '_' + cls.device_type class_name = cls.__name__ - real_device_type = xm.xla_device_hw(str(torch_xla.device())) + real_device_type = xm.xla_device_hw(str(torch.device('xla'))) assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type] @@ -632,7 +632,7 @@ def get_primary_device(cls): @classmethod def setUpClass(cls): # Sets the primary test device to the xla_device (CPU or TPU) - cls.primary_device = str(torch_xla.device()) + cls.primary_device = str(torch.device('xla')) torch_xla._XLAC._xla_set_mat_mul_precision('highest') def setUp(self): diff --git a/test/quantized_ops/test_dot_general.py b/test/quantized_ops/test_dot_general.py index 71a39ff56e96..846da4f0255a 100644 --- a/test/quantized_ops/test_dot_general.py +++ b/test/quantized_ops/test_dot_general.py @@ -5,7 +5,7 @@ import torch_xla import unittest -device = torch_xla.device() +device = torch.device('xla') torch.manual_seed(12345) diff --git a/test/quantized_ops/test_quantized_matmul.py b/test/quantized_ops/test_quantized_matmul.py index 88a34c69a4ae..ace38bfee083 100644 --- a/test/quantized_ops/test_quantized_matmul.py +++ b/test/quantized_ops/test_quantized_matmul.py @@ -12,7 +12,7 @@ torch.manual_seed(123456) -device = torch_xla.device() +device = torch.device('xla') class M(torch.nn.Module): diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index 42c362ee8769..fbf7d5a4cded 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -45,7 +45,7 @@ class TestBase(XlaTestCase): def setUp(self): super().setUp() - self.device = torch_xla.device() + self.device = torch.device('xla') # Clear the scan computation cache before each test to avoid cross-test contamination. scan_module._SCAN_COMPUTATION_CACHE.clear() @@ -288,7 +288,7 @@ def test_scan_external_in_place_mutation(self): giving wrong results. """ # TODO(yifeit): Modify this test when external in-place mutation is eventually supported. - weird_global = torch.tensor([0.0, 0.0], device=torch_xla.device()) + weird_global = torch.tensor([0.0, 0.0], device='xla') def step_fn(carry, x): new_carry = carry + x @@ -296,9 +296,8 @@ def step_fn(carry, x): y = new_carry + weird_global return new_carry, y - init = torch.tensor([0.0, 0.0], device=torch_xla.device()) - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], device='xla') + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla') with self.assertRaisesRegex(AssertionError, "FakeTensor"): scan(step_fn, init, xs) @@ -371,9 +370,8 @@ def step_fn(carry, x): y = new_carry + torch.rand(2, device=torch_xla.device()) return new_carry, y - init = torch.tensor([0.0, 0.0], device=torch_xla.device()) - xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - device=torch_xla.device()) + init = torch.tensor([0.0, 0.0], device='xla') + xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla') _, ys = scan(step_fn, init, xs) # ys should be a 2D tensor with this shape. self.assertEqual(ys.shape, (3, 2)) diff --git a/test/scan/test_scan_layers.py b/test/scan/test_scan_layers.py index ba193ea1eb30..e093e83ccace 100644 --- a/test/scan/test_scan_layers.py +++ b/test/scan/test_scan_layers.py @@ -26,7 +26,7 @@ class ScanLayersTest(XlaTestCase): def setUp(self): super().setUp() - self.device = torch_xla.device() + self.device = torch.device('xla') def assert_different_tensor(self, a: torch.Tensor, b: torch.Tensor): assert a is not b, f"Expected {a} and {b} to be different tensors" diff --git a/test/scan/test_scan_pallas.py b/test/scan/test_scan_pallas.py index a267886cd3f7..6f77d1b52aa7 100644 --- a/test/scan/test_scan_pallas.py +++ b/test/scan/test_scan_pallas.py @@ -72,7 +72,7 @@ def fake_fa_wrapper(self, has_model_weight, use_scan): torch.manual_seed(12) torch_xla.manual_seed(12) hidden_states = torch.randn((8, 4, 256, 256)).requires_grad_().to('xla') - with torch_xla.device(): + with torch.device('xla'): attention_layers = AttentionLayers( has_model_weight, num_layer=3, use_scan=use_scan) hidden_states.retain_grad() diff --git a/test/scan/test_scan_spmd.py b/test/scan/test_scan_spmd.py index 9bf081527c72..2bd1428a842c 100644 --- a/test/scan/test_scan_spmd.py +++ b/test/scan/test_scan_spmd.py @@ -23,7 +23,7 @@ def setUp(self): # Set up a simple SPMD mesh for these tests. self.spmd_mesh = get_1d_mesh(axis_name="model") set_global_mesh(self.spmd_mesh) - self.device = torch_xla.device() + self.device = torch.device('xla') @unittest.skipUnless(xr.global_runtime_device_count() >= 4, "Multiple devices required") diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 518e4203b459..71b109b90d12 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -42,7 +42,7 @@ def setUpClass(cls): super().setUpClass() def test_dynamo_spmd_basic(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -58,7 +58,7 @@ def test_dynamo_spmd_basic(self): # a ExecuteMetric. def test_dynamo_spmd_output_sharding_spec(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -74,7 +74,7 @@ def test_dynamo_spmd_output_sharding_spec(self): ) def test_dynamo_spmd_output_sharding_cache(self): met.clear_all() - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -90,7 +90,7 @@ def test_dynamo_spmd_output_sharding_cache(self): self.assertEqual(met.counter_value('UncachedOutputSharding'), 1) def test_dynamo_sharded_input(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -103,7 +103,7 @@ def test_dynamo_sharded_input(self): torch.allclose(xla_res.cpu(), dynamo_res.cpu()) def test_dynamo_input_sharding_changed(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -142,7 +142,7 @@ def test_dynamo_input_sharding_changed(self): @unittest.skipIf(xr.global_runtime_device_count() == 1, "Multiple devices needed to test the mesh change") def test_dynamo_input_sharding_threashold(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(8, 128, device=device) @@ -183,7 +183,7 @@ def test_dynamo_input_sharding_threashold(self): del os.environ['XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD'] def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self): - device = torch_xla.device() + device = torch.device('xla') linear = SimpleLinear().to(device) linear.eval() xla_x = torch.randn(1, 128, device=device) @@ -202,7 +202,7 @@ def test_dynamo_spmd_basic_with_dynamo_mark_sharding(self): torch.allclose(xla_res.cpu(), dynamo_res.cpu()) def test_dynamo_spmd_activation_sharding_with_dynamo_mark_sharding(self): - device = torch_xla.device() + device = torch.device('xla') mesh = self._get_mesh((1, self.n_devices)) device_ids = mesh.device_ids.tolist() mesh_shape = list(mesh.mesh_shape) diff --git a/test/spmd/test_mp_input_sharding.py b/test/spmd/test_mp_input_sharding.py index dc1e4aba12b0..135215e5b72c 100644 --- a/test/spmd/test_mp_input_sharding.py +++ b/test/spmd/test_mp_input_sharding.py @@ -34,7 +34,7 @@ def __next__(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_multiple_inputs(self): - device = torch_xla.device() + device = torch.device('xla') batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -61,7 +61,7 @@ def test_multiple_inputs(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_single_tensor(self): - device = torch_xla.device() + device = torch.device('xla') batch = torch.randn((16, 128)) train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -78,7 +78,7 @@ def test_single_tensor(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_error_single_tensor_with_input_sharding_dict(self): - device = torch_xla.device() + device = torch.device('xla') batch = torch.randn((16, 128)) train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -95,7 +95,7 @@ def test_error_single_tensor_with_input_sharding_dict(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_input_sharding_none(self): - device = torch_xla.device() + device = torch.device('xla') batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) num_devices = xr.global_runtime_device_count() @@ -112,7 +112,7 @@ def test_input_sharding_none(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_error_missing_keys(self): - device = torch_xla.device() + device = torch.device('xla') batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} train_loader = self.fake_dataloader(batch) mesh = xs.get_1d_mesh('x') @@ -127,7 +127,7 @@ def test_error_missing_keys(self): @unittest.skipUnless(xr.global_runtime_device_count() > 1, "Multiple devices required for tupled partition spec") def test_input_sharding_not_dict(self): - device = torch_xla.device() + device = torch.device('xla') num_devices = xr.global_runtime_device_count() batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))} train_loader = self.fake_dataloader(batch) diff --git a/test/spmd/test_sharding_strategies.py b/test/spmd/test_sharding_strategies.py index 2dd09580a5a6..8f31fc3dde51 100644 --- a/test/spmd/test_sharding_strategies.py +++ b/test/spmd/test_sharding_strategies.py @@ -146,7 +146,7 @@ def training_step(data): torch.manual_seed(42) tries = 5 -device = torch_xla.device() +device = torch.device('xla') if args.profile: print("Profiler server started at port 9012") server = xp.start_server(9012) diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 34221d375e9c..def91adef995 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -209,7 +209,7 @@ def test_single_host_replicated_tpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_debugging_spmd_single_host_tiled_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = torch_xla.device() + device = torch.device('xla') num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) @@ -252,7 +252,7 @@ def test_debugging_spmd_single_host_tiled_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_partial_replication_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = torch_xla.device() + device = torch.device('xla') num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) @@ -295,7 +295,7 @@ def test_single_host_partial_replication_cpu(self): f"Requires PJRT_DEVICE set to `CPU`.") def test_single_host_replicated_cpu(self): from torch_xla.distributed.spmd.debugging import visualize_sharding - device = torch_xla.device() + device = torch.device('xla') num_devices = self.n_devices mesh_shape = (1, num_devices) device_ids = np.array(range(num_devices)) diff --git a/test/spmd/test_spmd_graph_dump.py b/test/spmd/test_spmd_graph_dump.py index 45af3b154934..a0c7011f914f 100644 --- a/test/spmd/test_spmd_graph_dump.py +++ b/test/spmd/test_spmd_graph_dump.py @@ -26,7 +26,7 @@ def test_dump_with_output_sharding(self): assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE" should_dump_output_sharding = (save_format == 'hlo') save_file += '.0' - device = torch_xla.device() + device = torch.device('xla') xla_x = torch.randn(8, 32).to(device) xla_y = torch.randn(8, 32).to(device) # shard one of the input tensor diff --git a/test/spmd/test_spmd_lowering_context.py b/test/spmd/test_spmd_lowering_context.py index 9bc80194318f..6f6307ab0676 100644 --- a/test/spmd/test_spmd_lowering_context.py +++ b/test/spmd/test_spmd_lowering_context.py @@ -38,7 +38,7 @@ def test_basic(self): mesh_shape = (data_axis, model_axis) spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) - device = torch_xla.device() + device = torch.device('xla') a = torch.zeros(2048, device=device, requires_grad=True) xs.mark_sharding(a, spmd_mesh, ('x',)) b = torch.randn([32, 2048], device=device, requires_grad=True) @@ -108,7 +108,7 @@ def test_device_parameter_id_tensor_mapping(self): mesh_shape = (data_axis, model_axis) spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y')) - device = torch_xla.device() + device = torch.device('xla') a = torch.randn([32, 2048]).to(device) xs.mark_sharding(a, spmd_mesh, ('x', 'y')) b = torch.ones(2048).to(device) diff --git a/test/spmd/test_spmd_parameter_wrapping.py b/test/spmd/test_spmd_parameter_wrapping.py index 47f1bab8d33f..94267fab2b32 100644 --- a/test/spmd/test_spmd_parameter_wrapping.py +++ b/test/spmd/test_spmd_parameter_wrapping.py @@ -38,7 +38,7 @@ def setUpClass(cls): super().setUpClass() def test_fsdpv2(self): - device = torch_xla.device() + device = torch.device('xla') one_d_mesh = xs.get_1d_mesh("fsdp") xs.set_global_mesh(one_d_mesh) linears = MultiLinear() @@ -56,7 +56,7 @@ def test_fsdpv2(self): self.assertEqual(output.shape, torch.Size([100, 40])) def basic_spmd_test(self): - device = torch_xla.device() + device = torch.device('xla') one_d_mesh = xs.get_1d_mesh("data") input = torch.randn(8, 128) input2 = torch.randn(8, 128) diff --git a/test/spmd/test_train_spmd_imagenet.py b/test/spmd/test_train_spmd_imagenet.py index 006db02dd46e..935470d82446 100644 --- a/test/spmd/test_train_spmd_imagenet.py +++ b/test/spmd/test_train_spmd_imagenet.py @@ -206,7 +206,7 @@ def train_imagenet(): torch.manual_seed(42) - device = torch_xla.device() + device = torch.device('xla') model = get_model_property('model_fn')().to(device) if FLAGS.use_gradient_checkpointing: diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index d3fa093e8b13..4223b7b82ee5 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -76,7 +76,7 @@ def _assert_same_state_dict(self, sd1, sd2, keypath=""): if isinstance(sd1, torch.Tensor): assert sd1.device == sd2.device, f"Tensors on different devices at {keypath}: {sd1} vs {sd2}" - if sd1.device == torch_xla.device(): + if sd1.device == torch.device('xla'): sharding1 = torch_xla._XLAC._get_xla_sharding_spec(sd1) sharding2 = torch_xla._XLAC._get_xla_sharding_spec(sd2) assert sharding1 == sharding2, f"Different sharding on tensors at {keypath}: {sharding1} vs {sharding2}" diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 7b1be7574a1f..455d9006078d 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -277,7 +277,7 @@ def test_mark_sharding_4d(self): self.assertTrue(torch.allclose(expected, actual)) def test_mark_sharding_not_ordered_sharding_spec_2d(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(8, 16, device='cpu') expected = t1 + t1 @@ -290,7 +290,7 @@ def test_mark_sharding_not_ordered_sharding_spec_2d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_not_ordered_sharding_spec_3d(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 8, 16, device='cpu') expected = t1 + t1 @@ -307,7 +307,7 @@ def test_mark_sharding_not_ordered_sharding_spec_3d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_not_ordered_sharding_spec_4d(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(32, 4, 8, 16, device='cpu') expected = t1 + t1 @@ -326,7 +326,7 @@ def test_mark_sharding_not_ordered_sharding_spec_4d(self): self.assertTrue(torch.allclose(expected, (xt1 + xt1).cpu())) def test_mark_sharding_partial(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 4).to(device) t2 = torch.randn(4, 4).to(device) # Somehow the eager cpu result is different from the xla result. @@ -356,7 +356,7 @@ def test_mark_sharding_partial(self): self.assertTrue(torch.allclose(expected, actual)) def test_propagate_replicated_sharding(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 4).to(device) t2 = torch.randn(4, 4).to(device) t3 = t1 @ t2 @@ -368,7 +368,7 @@ def test_propagate_replicated_sharding(self): self.assertIn("replicated", torch_xla._XLAC._get_xla_sharding_spec(t3)) def test_mark_sharding_partial_unordered(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(4, 3, 4).to(device) t2 = torch.randn(4, 3, 4).to(device) expected = t1 + t2 @@ -467,7 +467,7 @@ def test_3d_tensor_2d_mesh(self): (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) def test_partial_replication_addmm(self): - device = torch_xla.device() + device = torch.device('xla') z_dim = 2 if self.n_devices >= 4 else 1 mesh = self._get_mesh((z_dim, self.n_devices // z_dim)) @@ -657,7 +657,7 @@ def test_send_cpu_data_to_device_with_sharding(self): sharding_spec = xs.ShardingSpec(mesh, (0, 1)) self.assertTrue(sharding_spec.can_apply(tensor)) xtensors = xm.send_cpu_data_to_device([tensor], - torch_xla.device(), + torch.device('xla'), input_sharding=sharding_spec) self.assertEqual(len(xtensors), 1) outbound = met.metric_data("OutboundData")[1] @@ -955,7 +955,7 @@ def test_named_partition_spec(self): self.assertTrue("replicated" in sharding_spec) def test_shard_device_data_ir(self): - device = torch_xla.device() + device = torch.device('xla') xla_x = torch.randn(8, 128, device=device) # xla_x now becomes a device data IR xla_y = xla_x * 5 @@ -967,7 +967,7 @@ def test_shard_device_data_ir(self): self.assertTrue(torch.allclose(xla_y.cpu(), xla_x.cpu() * 5)) def test_shard_device_data_ir_after_sync(self): - device = torch_xla.device() + device = torch.device('xla') xla_x = torch.randn(8, 128, device=device) x = xla_x.cpu() # xla_x now becomes a device data IR without XLAData @@ -1370,7 +1370,7 @@ def test_spmd_all_reduce_scale(self): self.assertTrue(torch.allclose(x.cpu(), expected_x)) def test_get_1d_mesh(self): - device = torch_xla.device() + device = torch.device('xla') mesh = xs.get_1d_mesh("data") t1 = torch.randn(8, 8).to(device) xt = xs.mark_sharding(t1, mesh, ("data", None)) @@ -1387,7 +1387,7 @@ def test_get_1d_mesh(self): xr.global_runtime_device_count() > 1, "Multiple devices required for dataloader sharding test") def test_data_loader_with_sharding(self): - device = torch_xla.device() + device = torch.device('xla') mesh = xs.get_1d_mesh("data") batch_size = 8 train_loader = xu.SampleGenerator( @@ -1410,7 +1410,7 @@ def test_data_loader_with_sharding(self): xr.global_runtime_device_count() > 1, "Multiple devices required for dataloader sharding test") def test_data_loader_with_non_batch_size(self): - device = torch_xla.device() + device = torch.device('xla') mesh = xs.get_1d_mesh("data") batch_size = mesh.size() - 1 train_loader = xu.SampleGenerator( @@ -1433,7 +1433,7 @@ def test_data_loader_with_non_batch_size(self): xr.global_runtime_device_count() > 1, "Multiple devices required for dataloader sharding test") def test_data_loader_with_non_batch_size_and_mini_batch(self): - device = torch_xla.device() + device = torch.device('xla') mesh = xs.get_1d_mesh("data") batch_size = mesh.size() - 1 train_loader = xu.SampleGenerator( @@ -1453,7 +1453,7 @@ def test_data_loader_with_non_batch_size_and_mini_batch(self): data, _ = iter(train_device_loader).__next__() def test_fallback(self): - device = torch_xla.device() + device = torch.device('xla') theta: float = 10000 dim = 16 @@ -1487,7 +1487,7 @@ def test_xla_patched_linear(self): import torch_xla.core.xla_model as xm import torch.nn.functional as F - with torch_xla.device(): + with torch.device('xla'): torch_xla.manual_seed(42) x0 = torch.randn(2, 3, requires_grad=True) w0 = torch.randn(4, 3, requires_grad=True) diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index ba051964a108..e4ce72b23003 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -38,22 +38,22 @@ def test_is_master_ordinal(self): self.assertTrue(xm.is_master_ordinal()) def test_xla_device(self): - device = torch_xla.device() + device = torch.device('xla') self.assertEqual(device, torch.device('xla:0')) def test_xla_real_devices(self): - device = torch_xla.device() + device = torch.device('xla') device_type = os.environ['PJRT_DEVICE'] self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0']) def test_xla_device_hw(self): - device = torch_xla.device() + device = torch.device('xla') device_type = os.environ['PJRT_DEVICE'] replication_devices = xm.xla_replication_devices([device]) self.assertEqual(xm.xla_device_hw(device), device_type) def test_xla_replication_devices(self): - device = torch_xla.device() + device = torch.device('xla') device_type = os.environ['PJRT_DEVICE'] replication_devices = xm.xla_replication_devices([device]) self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0']) @@ -148,7 +148,7 @@ def setUpClass(cls): @unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'], f"TPU/GPU autocast test.") def test_xla_autocast_api(self): - device = torch_xla.device() + device = torch.device('xla') t1 = torch.ones([2, 3], device=device, dtype=torch.float32) t2 = torch.ones([3, 2], device=device, dtype=torch.float32) with autocast(device, dtype=torch.bfloat16): diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 60c8b31a00e9..9c9f2b3a49ec 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -88,7 +88,7 @@ def test_non_tensor_scalar(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], - torch_xla.device(), + torch.device('xla'), input_sharding=sharding_spec)[0] # we will transfer 0.5 as a device_data to the 'SPMD:0' device, need to make sure # that virtual device can handle this case. @@ -101,7 +101,7 @@ def test_sync_on_virtual_device(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], - torch_xla.device(), + torch.device('xla'), input_sharding=sharding_spec)[0] xt2 = xt1 / 0.5 torch_xla.sync(wait=True) @@ -111,7 +111,7 @@ def test_sync_on_virtual_device(self): def test_virtual_device_no_upload(self): met.clear_all() - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(5, 5).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) # t1's upload to device should be deferred @@ -125,7 +125,7 @@ def test_virtual_device_no_upload(self): def test_virtual_device_upload_after_mark_sharding(self): met.clear_all() partition_spec = (0, 1) - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(8, 8).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) @@ -139,7 +139,7 @@ def test_virtual_device_upload_after_mark_sharding(self): def test_virtual_device_upload_after_tracing(self): met.clear_all() - device = torch_xla.device() + device = torch.device('xla') t1 = torch.randn(8, 8).to(device) t1_debug_info = torch_xla._XLAC._get_xla_tensor_debug_info(t1) self.assertIn("Tensor on host: with size [8, 8]", t1_debug_info) @@ -152,7 +152,7 @@ def test_virtual_device_upload_after_tracing(self): def test_virtual_device_upload_for_sharded_dataloader(self): met.clear_counters() - device = torch_xla.device() + device = torch.device('xla') sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) # tensor will have device as `SPMD:0` in c++ t1 = xm.send_cpu_data_to_device([torch.randn(8, 8)], diff --git a/test/stablehlo/test_composite.py b/test/stablehlo/test_composite.py index 8fe211475ba1..8f95f5b73967 100644 --- a/test/stablehlo/test_composite.py +++ b/test/stablehlo/test_composite.py @@ -71,7 +71,7 @@ class XlaMarkPatternTest(unittest.TestCase): def run_func_get_stablehlo(self, f, input_args): - device = torch_xla.device() + device = torch.device('xla') input_args = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), input_args) exported = torch.export.export(AsModule(f), input_args) diff --git a/test/stablehlo/test_implicit_broadcasting.py b/test/stablehlo/test_implicit_broadcasting.py index 10fbe5789981..04c5dd882a7b 100644 --- a/test/stablehlo/test_implicit_broadcasting.py +++ b/test/stablehlo/test_implicit_broadcasting.py @@ -10,7 +10,7 @@ # The following tests cover the implcit-broadcasting for static and bounded # dynamic shapes. -device = torch_xla.device() +device = torch.device('xla') class ImplicitBroadcasting(unittest.TestCase): diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index 34426f978029..a5b32cd7d5bc 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -55,7 +55,7 @@ def count_qdq_ops(g: torch.fx.Graph): class PT2EExportTest(unittest.TestCase): def test_per_tensor_qdq(self): - device = torch_xla.device() + device = torch.device('xla') x = torch.randn(2, 3, 4, 5).to(device) x = torch.ops.quantized_decomposed.quantize_per_tensor( x, 0.4, 2, -128, 127, torch.int8) @@ -69,7 +69,7 @@ def test_per_tensor_qdq(self): self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) def test_per_channel_qdq(self): - device = torch_xla.device() + device = torch.device('xla') x = torch.randn(2, 3, 4, 5).to(device) scale = torch.tensor([3.2, 5.3, 0.1, 10]).to(device) zero_point = torch.tensor([1, 2, -1, -2], dtype=torch.int64).to(device) diff --git a/test/stablehlo/test_stablehlo_compile.py b/test/stablehlo/test_stablehlo_compile.py index a57faf7ff5f2..a5abc0d27498 100644 --- a/test/stablehlo/test_stablehlo_compile.py +++ b/test/stablehlo/test_stablehlo_compile.py @@ -21,7 +21,7 @@ def test_resnet18_stablehlo_compile(self): torch_input = torch.tensor(np_input).float() cpu_output = resnet18(torch_input) # Run ResNet on XLA device. - device = torch_xla.device() + device = torch.device('xla') # materalize the fake data for test purpose torch_xla.sync() xm.wait_device_ops() diff --git a/test/stablehlo/test_stablehlo_custom_call.py b/test/stablehlo/test_stablehlo_custom_call.py index a315bbc230db..3a73c93f1005 100644 --- a/test/stablehlo/test_stablehlo_custom_call.py +++ b/test/stablehlo/test_stablehlo_custom_call.py @@ -118,7 +118,7 @@ def forward(self, x): # self.assertTrue("api_version = 1" in shlo_text) def test_place_to_host_device(self): - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones(10, device=dev) b = place_to_host(a) shlo_text = xm.get_stablehlo([b]) @@ -137,7 +137,7 @@ def test_place_to_host_device(self): def test_place_to_host_device_autograd(self): # Test that gradient can flow through place_to_host and place_to_device ops. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones(10, device=dev, requires_grad=True) b = place_to_host(a) c = b.sum() @@ -155,7 +155,7 @@ def test_place_to_host_device_aot_autograd(self): # specifically `aot_function`. from functorch.compile import aot_function, make_boxed_func # type: ignore - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones(10, device=dev, requires_grad=True) def my_fn(x): diff --git a/test/stablehlo/test_stablehlo_inference.py b/test/stablehlo/test_stablehlo_inference.py index a29b66ebceaa..0e1b0ffdfe25 100644 --- a/test/stablehlo/test_stablehlo_inference.py +++ b/test/stablehlo/test_stablehlo_inference.py @@ -67,7 +67,7 @@ def forward(self, x, y): output = m(*data) exported = export_torch_model(m, data) - device = torch_xla.device() + device = torch.device('xla') data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) output2 = exported(*data).cpu() @@ -91,7 +91,7 @@ def forward(self, inputs): output = m(*data) exported = export_torch_model(m, data) - device = torch_xla.device() + device = torch.device('xla') data = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device), data) output2 = exported(*data) self.assertEqual(len(output2), 2) diff --git a/test/stablehlo/test_stablehlo_save_load.py b/test/stablehlo/test_stablehlo_save_load.py index 71ff463578cb..5d353e25a386 100644 --- a/test/stablehlo/test_stablehlo_save_load.py +++ b/test/stablehlo/test_stablehlo_save_load.py @@ -17,7 +17,7 @@ class StableHloDumpTest(unittest.TestCase): def test_simple(self): - device = torch_xla.device() + device = torch.device('xla') x = torch.tensor([3], device=device) y = torch.tensor([3], device=device) z = x + y @@ -26,7 +26,7 @@ def test_simple(self): self.assertEqual(stablehlo.count("stablehlo.add"), 1) def test_resnet18(self): - device = torch_xla.device() + device = torch.device('xla') xla_resnet18 = torchvision.models.resnet18() xla_resnet18.eval() xla_resnet18 = xla_resnet18.to(device) @@ -66,7 +66,7 @@ class SimpleExportTest(unittest.TestCase): def export_stable_hlo(self, model, args, kwargs=None): if kwargs is None: kwargs = {} - device = torch_xla.device() + device = torch.device('xla') model.eval() model = model.to(device) args = tuple(i.to(device) for i in args if hasattr(i, 'to')) diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index 88fce368b668..5f6a853c87ba 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -19,7 +19,7 @@ compare_exported_program_and_saved_model_result, has_tf_package, load_save_model_and_inference, wrap_func_as_nn_module) -device = torch_xla.device() +device = torch.device('xla') os.environ['EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM'] = '1' diff --git a/test/stablehlo/test_xla_export_interpreter.py b/test/stablehlo/test_xla_export_interpreter.py index 51a73a402703..5c336b3b6bcf 100644 --- a/test/stablehlo/test_xla_export_interpreter.py +++ b/test/stablehlo/test_xla_export_interpreter.py @@ -7,7 +7,7 @@ import torch_xla.core.xla_model as xm from torch_xla.stablehlo import exported_program_to_stablehlo -device = torch_xla.device() +device = torch.device('xla') class XLAExportInterpreterTest(unittest.TestCase): diff --git a/test/test_autocast.py b/test/test_autocast.py index ca1f26c05ec1..468e8b932061 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -282,7 +282,7 @@ def cast(val, to_type): add_kwargs = {} self.assertFalse(self.is_autocast_enabled()) - with autocast(torch_xla.device(), dtype=autocast_dtype): + with autocast(torch.device('xla'), dtype=autocast_dtype): self.assertTrue(self.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type @@ -332,7 +332,7 @@ def compare(first, second): # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method - with autocast(torch_xla.device(), enabled=False): + with autocast(torch.device('xla'), enabled=False): self.assertFalse(self.is_autocast_enabled()) if module is not None and hasattr(module, op): @@ -355,7 +355,7 @@ class TestAutocastTPU(TestAutocastBase): @classmethod def setUpClass(cls): super().setUpClass() - cls.autocast_lists = AutocastTPUTestLists(torch.device(torch_xla.device())) + cls.autocast_lists = AutocastTPUTestLists(torch.device(torch.device('xla'))) def setUp(self): super(TestAutocastTPU, self).setUp() @@ -397,7 +397,7 @@ def test_autocast_methods_expect_builtin_promote(self): op, args, torch.float32, module=None, out_type=out_type) def test_autocast_tpu_check_dtype(self): - with autocast(torch_xla.device(), dtype=torch.float16): + with autocast(torch.device('xla'), dtype=torch.float16): assert not torch.is_autocast_xla_enabled() @@ -408,7 +408,7 @@ class TestOtherOps(unittest.TestCase): xm.xla_device_hw(torch_xla.device()) != 'TPU', "the behavior of batch_norm autocast on TPU is different from others") def test_batch_norm_tpu(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16) batch_norm = torch.nn.BatchNorm2d(16) with autocast(device, dtype=torch.bfloat16): diff --git a/test/test_autocast_xla.py b/test/test_autocast_xla.py index e287cb1bae55..cc8bed32b44c 100644 --- a/test/test_autocast_xla.py +++ b/test/test_autocast_xla.py @@ -6,7 +6,7 @@ import torch_xla.distributed.spmd.xla_sharding as xs -device = torch_xla.device() +device = torch.device('xla') class TestAutocastXla(unittest.TestCase): diff --git a/test/test_compilation_cache_utils.py b/test/test_compilation_cache_utils.py index 0ac8a013d814..5ca4a0bfb396 100644 --- a/test/test_compilation_cache_utils.py +++ b/test/test_compilation_cache_utils.py @@ -31,7 +31,7 @@ def _test_spawn(fn, args): class TestGraphHash(parameterized.TestCase): def _test_num_graph_hash(self, use_dynamo, use_persistent): - xla_dev = torch_xla.device() + xla_dev = torch.device('xla') model = M().to(device=xla_dev) input_shape = (10, 5) if use_dynamo: diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 361c09de7faf..4f099900a83f 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -46,7 +46,7 @@ def run_export_and_compare(testcase, atol=1e-3, rtol=1e-5, equal_nan=True): - device = torch_xla.device() + device = torch.device('xla') with testcase.subTest('torch_eval'): res = func(*args, **kwargs) with testcase.subTest('torch_xla_eval'): diff --git a/test/test_data_type.py b/test/test_data_type.py index da4b7d00681f..5cdbf4aa2c2f 100644 --- a/test/test_data_type.py +++ b/test/test_data_type.py @@ -55,7 +55,7 @@ def test_datatype_use_32bit_long(self): self._test_datatype(torch.uint64, 'u32', torch.add) def test_module_to_dtype(self): - device = torch_xla.device() + device = torch.device('xla') linear = torch.nn.Linear( 5, 10, dtype=torch.float32).to(device).to(torch.bfloat16) input = torch.randn(10, 5).to(device).to(torch.bfloat16) diff --git a/test/test_env_var_mapper.py b/test/test_env_var_mapper.py index e4dcef2ba8cb..95d5c99595af 100644 --- a/test/test_env_var_mapper.py +++ b/test/test_env_var_mapper.py @@ -15,7 +15,7 @@ def check_env_flag(name, default=''): class EnvVarMapperTest(unittest.TestCase): def test_xla_ir_debug_(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') with xp.Trace('test_xla_ir_debug'): t = torch.tensor([2.0, 3.0], dtype=torch.float, device=xla_device) diff --git a/test/test_fp8.py b/test/test_fp8.py index 2dbf534cb5c7..fc00e0a932ac 100644 --- a/test/test_fp8.py +++ b/test/test_fp8.py @@ -6,7 +6,7 @@ import unittest from absl.testing import parameterized -device = torch_xla.device() +device = torch.device('xla') dtype_parameters = [ torch.float8_e5m2, diff --git a/test/test_fsdp_auto_wrap.py b/test/test_fsdp_auto_wrap.py index 019612899697..55ef03881a88 100644 --- a/test/test_fsdp_auto_wrap.py +++ b/test/test_fsdp_auto_wrap.py @@ -35,7 +35,7 @@ def forward(self, x): "This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)" ) def test(self): - dev = torch_xla.device() + dev = torch.device('xla') input = torch.zeros([16, 16], device=dev) model = self.MyModel(input_size=16, hidden_size=4) model = XlaFullyShardedDataParallel( @@ -48,7 +48,7 @@ def test(self): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_grad_checkpoint.py b/test/test_grad_checkpoint.py index e4e318ba8310..42c9e4e93686 100644 --- a/test/test_grad_checkpoint.py +++ b/test/test_grad_checkpoint.py @@ -11,7 +11,7 @@ def run(): - device = torch_xla.device() + device = torch.device('xla') model = torch.nn.ModuleList([ torch.nn.Sequential( torch.nn.Conv2d(1024, 1024, 1), diff --git a/test/test_gradient_accumulation.py b/test/test_gradient_accumulation.py index 62ecfc431132..cfe324ed0f54 100644 --- a/test/test_gradient_accumulation.py +++ b/test/test_gradient_accumulation.py @@ -23,7 +23,7 @@ def forward(self, x): class GradAccumulationTest(XlaTestCase): def setUp(self): - self.device = torch_xla.device() + self.device = torch.device('xla') torch.manual_seed(123) def test_basic(self): diff --git a/test/test_inplace_update.py b/test/test_inplace_update.py index 704888d4f6e7..9e718d29ad17 100644 --- a/test/test_inplace_update.py +++ b/test/test_inplace_update.py @@ -11,7 +11,7 @@ class InplaceUpdateTest(unittest.TestCase): def test_aten_op_after_full_update(self): - device = torch_xla.device() + device = torch.device('xla') t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t.zero_() @@ -21,7 +21,7 @@ def test_aten_op_after_full_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_aten_op_after_partial_update(self): - device = torch_xla.device() + device = torch.device('xla') t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t[0][0] = 0 @@ -31,7 +31,7 @@ def test_aten_op_after_partial_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_non_aten_op_after_full_update(self): - device = torch_xla.device() + device = torch.device('xla') t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t.zero_() @@ -41,7 +41,7 @@ def test_non_aten_op_after_full_update(self): self.assertTrue(torch.all(torch.eq(y, expected))) def test_non_aten_op_after_partial_update(self): - device = torch_xla.device() + device = torch.device('xla') t = torch.ones(2, 1, device=device) w = torch.ones(1, 2, device=device) t[0][0] = 0 @@ -53,7 +53,7 @@ def test_non_aten_op_after_partial_update(self): def test_xm_save(self): with temporary_env( XLA_DISABLE_FUNCTIONALIZATION="0", XLA_ENABLE_PARAM_ALIASING="0"): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor([1], device=xla_device) t2 = t1.detach() torch_xla.sync() diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 3f20f9d25c97..be3789f08785 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -38,7 +38,7 @@ def config_context(value): class InputOutputAliasesTest(parameterized.TestCase): def test_non_view(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.randn(4, 2, 2).to(xla_device) t2 = torch.randn(4, 2, 2).to(xla_device) torch_xla.sync() @@ -53,7 +53,7 @@ def test_non_view(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) def test_aliasing_with_cloned(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.randn(4, 2, 2).to(xla_device) # t1_cloned share the same storage as t1 @@ -66,7 +66,7 @@ def test_aliasing_with_cloned(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) def test_aliasing_across_custom_inplace(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.randn(4, 5).to(xla_device) t1 *= t1 @@ -78,7 +78,7 @@ def test_aliasing_across_custom_inplace(self): self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) def test_aliasing_across_sync(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.randn(4, 5).to(xla_device) t1 += 1 @@ -96,7 +96,7 @@ def test_aliasing_with_multiple_inplace_update(self): BLOCK_SIZE = 16 DTYPE = torch.bfloat16 num_blocks = 1024 - device = torch_xla.device() + device = torch.device('xla') key = torch.randn( BATCH_SIZE * SEQ_LEN, NUM_KV_HEADS, @@ -145,7 +145,7 @@ def try_grad_accum(model, device, train_x, train_label, accum_steps): torch_xla.sync() return [p.grad.to('cpu').numpy() for p in model.parameters()] - dev = torch_xla.device() + dev = torch.device('xla') train_x_sample = torch.rand((1, 28 * 28)) train_label_sample = torch.tensor([5]) c_model = MLP().to('cpu') @@ -171,7 +171,7 @@ def test_separate_graphs(self): """ Test that paramater aliasing differences should produce different graphs. """ - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.tensor([1], device=xla_device) t1 = torch.tensor([2], device=xla_device) torch_xla.sync() @@ -190,7 +190,7 @@ def test_xm_save_no_aliasing(self): """ Test that xm.save() does not perform aliasing. """ - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.tensor([1], device=xla_device) t1 = torch.tensor([2], device=xla_device) torch_xla.sync() @@ -212,7 +212,7 @@ def test_device_data_cache_no_aliasing(self): """ Test that device data in DataCache are not aliased. """ - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.tensor(42, device=xla_device) # drops the read-only bit on t0's device_data @@ -235,7 +235,7 @@ def test_device_data_cache_no_aliasing(self): def test_user_config_donation_with_ltc_donation(self): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) @@ -255,7 +255,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync( self, enable_buffer_donor_config): with alias_with_buffer_donor_config_context(enable_buffer_donor_config): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) @@ -279,7 +279,7 @@ def test_user_config_donation_with_ltc_donation_graph_sync( def test_user_config_donation_with_ltc_donation_overlap(self): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -291,7 +291,7 @@ def test_user_config_donation_with_ltc_donation_overlap(self): def test_user_config_donation(self): with alias_with_buffer_donor_config_context(True): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -308,7 +308,7 @@ def test_user_config_donation(self): def test_user_config_donation_inplace_aliasing(self): with alias_with_buffer_donor_config_context(True): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) @@ -322,7 +322,7 @@ def test_user_config_donation_inplace_aliasing(self): def test_user_config_donation_no_op_sync(self): with alias_with_buffer_donor_config_context(True): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) torch_xla.sync() @@ -331,7 +331,7 @@ def test_user_config_donation_no_op_sync(self): self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0)) def test_no_op_sync_keep_buffer_donation(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') input = torch.randn(5, 5).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) torch_xla.sync() @@ -346,7 +346,7 @@ def test_device_data_node_tracing_aliasing(self): for a given set of unmutated input tensor during its tracing. This helps ensure that aliasings can be retained if using the binding for tracing purposes. """ - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.tensor(10).to(xla_device) t1 = t0 + 5 diff --git a/test/test_jax_interop.py b/test/test_jax_interop.py index 5016462b982e..f852f239d524 100644 --- a/test/test_jax_interop.py +++ b/test/test_jax_interop.py @@ -14,7 +14,7 @@ def setUp(self): def test_call_jax(self): """Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, b): @@ -29,7 +29,7 @@ def f(a, b): def test_call_jax_input_pytree(self): """Test that call_jax works with PyTree inputs.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((2, 2), device=dev) b = torch.ones((2, 2), device=dev) * 2 @@ -55,7 +55,7 @@ def f(inputs): def test_call_jax_output_pytree(self): """Test that call_jax works with PyTree outputs.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((2, 2), device=dev) def f(a): @@ -89,7 +89,7 @@ def f(a): def test_call_jax_some_arg_unused(self): """Test when the jax function doesn't use some input arguments.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.randn((3, 3), device=dev) b = torch.randn((3, 3), device=dev) c = torch.randn((3, 3), device=dev) @@ -106,7 +106,7 @@ def f(a, b, c, d): def test_call_jax_grad(self): """Test calling a simple jax.grad transformed function.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.randn((3, 3), device=dev, requires_grad=True) b = torch.randn((3, 3), device=dev, requires_grad=True) torch_xla.sync() @@ -143,7 +143,7 @@ def f_jax(a, b): def test_call_jax_non_tensor_args(self): """Test that call_jax works with non-tensor arguments.""" - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, num: float, string: str, dictionary: dict, none): @@ -173,7 +173,7 @@ def test_call_jax_cache_hlo(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace two different jax functions a couple of times. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, b): @@ -198,7 +198,7 @@ def test_call_jax_cache_by_shape(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different shapes. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) b = torch.ones((2, 2), device=dev) @@ -217,7 +217,7 @@ def test_call_jax_cache_by_tree_spec(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different tree specs. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) b = torch.ones((3, 2), device=dev) @@ -237,7 +237,7 @@ def test_call_jax_cache_by_static_args(self): starting_cache_misses = xb._jax_to_xla_computation_cache_elements() # Let's trace the same jax function with different static args. - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, num: float): @@ -255,7 +255,7 @@ def test_call_jax_with_different_jax_config(self): import jax starting_cache_misses = xb._jax_to_xla_computation_cache_elements() - dev = torch_xla.device() + dev = torch.device('xla') a = torch.ones((3, 3), device=dev) def f(a, b): diff --git a/test/test_metrics.py b/test/test_metrics.py index 69b9ab20a656..6cbbc9fc3340 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -24,7 +24,7 @@ def check_metrics_file(): class MetricsTest(unittest.TestCase): def test_clear_counters(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(100, device=xla_device) t1 += 2 self.assertIn("xla::add", met.metrics_report()) @@ -39,7 +39,7 @@ def test_clear_counters(self): assert (len(met.counter_names()) > 0) def test_clear_metrics(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(156, device=xla_device) self.assertIn("TensorToData", met.metrics_report()) assert (len(met.metric_names()) > 0) @@ -52,7 +52,7 @@ def test_clear_metrics(self): assert (len(met.metric_names()) > 0) def test_tracing_time_metrics(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.tensor(156, device=xla_device) t2 = t1 + 100 @@ -61,7 +61,7 @@ def test_tracing_time_metrics(self): def test_eager_metrics(self): with torch_xla.experimental.eager_mode_context(True): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.tensor(156, device=xla_device) t2 = t1 + 100 @@ -78,7 +78,7 @@ def test_eager_metrics(self): self.assertNotIn('ExecuteTime', met.metric_names()) def test_short_metrics_report_default_list(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(1456, device=xla_device) t2 = t1 * 2 torch_xla.sync() @@ -100,7 +100,7 @@ def test_short_metrics_report_default_list(self): assert check_metrics_file() def test_short_metrics_report_custom_list(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(100, device=xla_device) t2 = t1 * 2 t1 += 2 @@ -120,7 +120,7 @@ def test_short_metrics_report_custom_list(self): self.assertIn('InputOutputAliasCount', short_report) def test_short_metrics_fallback_counter(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(100, device=xla_device) t2 = t1 * 2 # this will trigger a aten::_local_scalar_dense which is the same as fallback counter @@ -135,7 +135,7 @@ def test_short_metrics_fallback_counter(self): def test_metrics_report(self): # TODO(jwtan): Add test to cover TrimIrGraph, SyncTensorsToData, TransferToDeviceAsync, IrValueTensorToXlaData - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(2077, device=xla_device) t2 = t1 * 2 torch_xla.sync() @@ -207,7 +207,7 @@ def test_metrics_report(self): @unittest.skipIf(xr.device_type() != "CPU", f"This test only works on CPU.") def test_execute_time_metric(self): # Initialize the client before starting the timer. - torch_xla.device() + torch.device('xla') begin = time.perf_counter_ns() value = torch.randn( @@ -226,7 +226,7 @@ def test_execute_time_metric(self): def test_pybind_increment_counter(self): met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.tensor(2077, device=xla_device) self.assertEqual(met.counter_value('CreateXlaTensor'), 1) torch_xla._XLAC._xla_increment_counter('CreateXlaTensor', 3) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 93d64f47ef3e..c364b69875c2 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -11,7 +11,7 @@ def all_gather(tensor, dim): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() input_list_size = 5 if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): diff --git a/test/test_mp_all_to_all.py b/test/test_mp_all_to_all.py index 9761507dea13..5a041463c7cd 100644 --- a/test/test_mp_all_to_all.py +++ b/test/test_mp_all_to_all.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'NEURON'): slots_per_device = 4 size = slots_per_device * xr.world_size() diff --git a/test/test_mp_collective_matmul.py b/test/test_mp_collective_matmul.py index 29f115c986cd..4e18fee0de2c 100644 --- a/test/test_mp_collective_matmul.py +++ b/test/test_mp_collective_matmul.py @@ -8,7 +8,7 @@ def _mp_fn(index): os.environ["ENABLE_COLLECTIVE_MATMUL_IN_MP"] = "1" - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() groups = [[i for i in range(world_size)]] scale = 1 / world_size diff --git a/test/test_mp_collective_permute.py b/test/test_mp_collective_permute.py index 81a1eb771bcd..31f9cc94ae3b 100644 --- a/test/test_mp_collective_permute.py +++ b/test/test_mp_collective_permute.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ['TPU', 'NEURON']: world_size = xr.world_size() ordinal = xr.global_ordinal() diff --git a/test/test_mp_distributed_mm.py b/test/test_mp_distributed_mm.py index 7d6c7982cb2f..c6630e1a0a04 100644 --- a/test/test_mp_distributed_mm.py +++ b/test/test_mp_distributed_mm.py @@ -7,7 +7,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA'): world_size = xr.world_size() diff --git a/test/test_mp_early_exit.py b/test/test_mp_early_exit.py index 89e46722e232..275fb353c8db 100644 --- a/test/test_mp_early_exit.py +++ b/test/test_mp_early_exit.py @@ -12,7 +12,7 @@ def _mp_fn(): dist.init_process_group('xla', init_method='xla://') - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ['TPU', 'CUDA']: train_loader = xu.SampleGenerator( data=torch.zeros(1, 12), sample_count=1024) diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 12fc7fdfe1c8..375b8cc85b17 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -6,7 +6,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() scale = 1 / world_size scatter_dim = 1 diff --git a/test/test_mp_replication.py b/test/test_mp_replication.py index 61a302a65784..c21a4b83629e 100644 --- a/test/test_mp_replication.py +++ b/test/test_mp_replication.py @@ -10,7 +10,7 @@ def all_reduce(tensor): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') world_size = xr.world_size() if world_size > 1: ones = torch.ones((2, 3)) diff --git a/test/test_mp_save.py b/test/test_mp_save.py index ae9f46df120a..4ab45e9d81a7 100644 --- a/test/test_mp_save.py +++ b/test/test_mp_save.py @@ -35,7 +35,7 @@ def _get_data_str(data): def _mp_fn(index, temp_file): - device = torch_xla.device() + device = torch.device('xla') dd = _create_state_dict(device) xm.save(dd, temp_file) # User needs to manually rendezvous since only master process diff --git a/test/test_mp_sync_batch_norm.py b/test/test_mp_sync_batch_norm.py index fa4f18ad00d2..0ac2f720099d 100644 --- a/test/test_mp_sync_batch_norm.py +++ b/test/test_mp_sync_batch_norm.py @@ -47,7 +47,7 @@ def _sync_bn1d_no_channel(rank): t_global = torch.rand((xr.world_size() * bsz, length)) # XLA SyncBatchNorm - device = torch_xla.device() + device = torch.device('xla') t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(length).to(device) result = run_step(sbn_xla, t_xla) @@ -72,7 +72,7 @@ def _sync_bn1d_multi_channel(rank): t_global = torch.rand((xr.world_size() * bsz, features, length)) # XLA SyncBatchNorm - device = torch_xla.device() + device = torch.device('xla') t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) @@ -97,7 +97,7 @@ def _sync_bn2d(rank): t_global = torch.rand((xr.world_size() * bsz, features, h, w)) # XLA SyncBatchNorm - device = torch_xla.device() + device = torch.device('xla') t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) @@ -122,7 +122,7 @@ def _sync_bn3d(rank): t_global = torch.rand((xr.world_size() * bsz, features, d, h, w)) # XLA SyncBatchNorm - device = torch_xla.device() + device = torch.device('xla') t_xla = t_global[bsz * rank:bsz * (rank + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) diff --git a/test/test_operations.py b/test/test_operations.py index 3f6774e87413..164787921ecc 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -179,7 +179,7 @@ def onlyIfPJRTDeviceIsCUDA(fn): class TestToXlaTensorArena(test_utils.XlaTestCase): def test(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)] kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)]) @@ -307,7 +307,7 @@ def loop_fn(model, loader, device, context): class TestLongGraphChain(test_utils.XlaTestCase): def test(self): - device = torch_xla.device() + device = torch.device('xla') orig_x = torch.Tensor([[1, 2], [3, 4]]) orig_y = torch.Tensor([[0.1, 0.2], [0.3, 0.4]]) x = orig_x @@ -440,7 +440,7 @@ def test_nonzero_cast(self): class TestOptimizationBarrier(test_utils.XlaTestCase): def test_optimization_barrier_correctness(self): - device = torch_xla.device() + device = torch.device('xla') # only test optimization_barrier on TPU if xm.xla_device_hw(device) != 'TPU': return @@ -459,7 +459,7 @@ def op_fn(a): return xb.Op.tuple((a, a.cast(xb.Type.BF16))) op = xor.register('test_mixed_dtype_tuple', op_fn) - xla_device = torch_xla.device() + xla_device = torch.device('xla') a_tensor = torch.randn([2, 3]).to(xla_device) a_result, a_cast = op(a_tensor) self.assertEqual(a_result.dtype, torch.float) @@ -530,7 +530,7 @@ def test_amp_foreach_non_finite_check_and_unscale_(self): found_inf_output0 = torch.tensor(0, dtype=torch.float32) found_inf_output1 = torch.tensor(1, dtype=torch.float32) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_grads0 = grads0.to(xla_device) xla_inv_scale = inv_scale.to(xla_device) xla_found_inf = found_inf.to(xla_device) @@ -627,7 +627,7 @@ def test_no_storage(self): def test_slice_copy(self): a = torch.rand(3, 3, 3) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() @@ -638,7 +638,7 @@ def test_slice_copy(self): def test_slice_assign(self): a = torch.rand(3, 3, 3) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() @@ -649,7 +649,7 @@ def test_slice_assign(self): def test_slice_stepped_assign(self): a = torch.ones((10, 4)) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_a = a.to(xla_device) a[:, 0::2] = 2 xla_a[:, 0::2] = 2 @@ -657,14 +657,14 @@ def test_slice_stepped_assign(self): def test_slice_stepped_other_assign(self): a = torch.ones((10, 4)) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_a = a.to(xla_device) a[:, 1::4] = 2 xla_a[:, 1::4] = 2 self.assertEqual(a.data, xla_a.data.cpu()) def test_ailing_slice(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.ones((1000, 324)).to(xla_device) xla_a = a.to(xla_device) w = a[:, 2::4] @@ -674,7 +674,7 @@ def test_ailing_slice(self): self.assertEqual(w.data, xla_w.data.cpu()) def test_slice_rnd_stepped_assign(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') size = 10 for s in range(0, size - 1): for e in range(1, size - s): @@ -691,7 +691,7 @@ def test_arange_nan(self): a = torch.arange(float('nan'), 5, device='xla') def test_empty_advanced_indexing(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') base = torch.randn(2, 3, 4, 5) xla_base = base.to(device=xla_device) result = base[:, torch.empty(0, 6, dtype=torch.int64)] @@ -702,7 +702,7 @@ def test_empty_advanced_indexing(self): "grad_input produces wrong results after functionalization. pytorch/pytorch#91199" ) def test_empty_strided(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') m = nn.Conv1d(4, 6, kernel_size=3, groups=2) a = torch.rand(2, 4, 6, requires_grad=True) xla_m = copy.deepcopy(m).to(xla_device) @@ -736,7 +736,7 @@ def test_clamp(self): self.assertEqual(b.data, xla_b.data.cpu()) def test_rrelu_module(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(1, 2, 2, requires_grad=True) xla_a = a.to(xla_device).detach() xla_a.requires_grad = True @@ -753,7 +753,7 @@ def test_rrelu_module(self): self.assertEqual(a.grad, xla_a.grad.cpu()) def test_max_broadcast(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(3, 1, 2) b = torch.rand(4, 2) c = torch.max(a, b) @@ -763,7 +763,7 @@ def test_max_broadcast(self): self.assertEqual(c.data, xla_c.data.cpu()) def test_sgn(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t = torch.randn(2, 3, dtype=torch.cfloat) # Generate inf+infj t[0][0].real.div_(0) @@ -797,7 +797,7 @@ def test_sgn(self): @skipIfFunctionalizationDisabled("view_as_real unsupported") def test_view_as_real_c64(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(4, dtype=torch.cfloat, device=xla_device) real = torch.view_as_real(x) self.assertEqual(real.dtype, torch.float32) @@ -809,7 +809,7 @@ def test_view_as_real_c64(self): @skipIfFunctionalizationDisabled("view_as_real unsupported") def test_view_as_real_c128(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(4, dtype=torch.cdouble, device=xla_device) real = torch.view_as_real(x) self.assertEqual(real.dtype, torch.float64) @@ -821,7 +821,7 @@ def test_view_as_real_c128(self): @skipIfFunctionalizationDisabled("view_as_real unsupported") def test_view_as_complex_f32(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(4, 2, device=xla_device) complex = torch.view_as_complex(x) self.assertEqual(complex.dtype, torch.complex64) @@ -834,7 +834,7 @@ def test_view_as_complex_f32(self): @skipIfFunctionalizationDisabled("view_as_real unsupported") def test_view_as_complex_f64(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(4, 2, dtype=torch.float64, device=xla_device) complex = torch.view_as_complex(x) self.assertEqual(complex.dtype, torch.complex128) @@ -847,7 +847,7 @@ def test_view_as_complex_f64(self): torch_xla._XLAC._get_xla_tensors_text([complex]).split('\n')[-3]) def test_index_put(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32) b = torch.rand(4) > 0.1 a[b] = 10 @@ -912,7 +912,7 @@ def test_fn(device): return loss, linear.weight.grad cpu_loss, cpu_weight_grad = test_fn('cpu') - xla_loss, xla_weight_grad = test_fn(torch_xla.device()) + xla_loss, xla_weight_grad = test_fn(torch.device('xla')) self.assertEqual(cpu_loss, xla_loss) self.assertEqual(cpu_weight_grad, xla_weight_grad) @@ -985,7 +985,7 @@ def func(root, b): def test_inplace_view_backprop_view(self): # modify view and backprop through view - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.tensor([2., 5.], device=xla_device, requires_grad=False) b = torch.tensor([3.], device=xla_device, requires_grad=True) res = a.narrow(0, 1, 1).mul_(b) @@ -1110,7 +1110,7 @@ def test_replace_xla_tensor(self): self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10))) def test_pred_type(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(4) b = torch.rand(4) xla_a = a.to(xla_device) @@ -1132,7 +1132,7 @@ def test_pred_type(self): self.runAtenTest(c, lambda x: x ^ x.byte()) def test_bitwise_and_not(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.randint(255, (4,), dtype=torch.long) xla_a = a.to(xla_device) @@ -1142,27 +1142,27 @@ def test_fn(a): self.runAtenTest(a, test_fn) def test_s_copy_dtype(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(10).to(xla_device).to(dtype=torch.uint8) b = torch.tensor([0, 1, 2, 3]).to(xla_device) self.assertEqual(a[b].dtype, torch.uint8) def test_slice_zero_sized_dim(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') v = torch.randn(2, 3, 4, 5).to(xla_device) y = v[:, :, :, 1] z = y[:, 1:1, :] self.assertEqual(z.size()[1], 0) def test_byte_dtype(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.ByteTensor([0, 1]).to(xla_device) y = torch.ByteTensor([0, 1]).to(xla_device) z = x + y self.assertEqual(z.dtype, torch.uint8) def test_frac_negative(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.tensor(-3.2) b = a.frac() xla_a = a.to(xla_device) @@ -1170,7 +1170,7 @@ def test_frac_negative(self): self.assertEqual(b, xla_b) def test_flip(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) self.assertEqual( torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0)) @@ -1193,7 +1193,7 @@ def test_flip(self): torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0)) def test_flip_check_throws(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) # not allow flip on the same dim more than once self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) @@ -1205,7 +1205,7 @@ def test_flip_check_throws(self): self.assertRaises(RuntimeError, lambda: data.flip(3)) def test_flip_expand(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2) transposed_data = torch.arange( @@ -1217,7 +1217,7 @@ def test_flip_expand(self): transposed_data.flip(0, 1, 2)) def test_flip_shape(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.randn(2, 3, 4, device=device) size = [2, 3, 4] test_dims = [] @@ -1227,7 +1227,7 @@ def test_flip_shape(self): self.assertEqual(size, list(data.flip(ds).size())) def test_flip_rectangular(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3).to(device) flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]]).to(device) flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]]).to(device) @@ -1236,13 +1236,13 @@ def test_flip_rectangular(self): self.assertEqual(flip1_result, data.flip(1)) def test_flip_empty_tensor(self): - device = torch_xla.device() + device = torch.device('xla') data = torch.tensor([]) self.assertEqual(data, data.flip(0)) def test_norm_p0(self): # p = 0 is equivalent to nonzero - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.randn(3, 2) xla_a = a.to(xla_device) norm = a.norm(p=0) @@ -1288,7 +1288,7 @@ def test_fn(input, src): self.runAtenTest([torch.zeros(3, 3), torch.ones(3)], test_fn) def test_scatter_add_bool(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]]) b = torch.zeros(3, 5, dtype=torch.bool) @@ -1333,7 +1333,7 @@ def test_reduction_0dim(self): self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.mean(x)) self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.prod(x)) # min & max throws - xla_device = torch_xla.device() + xla_device = torch.device('xla') a = torch.rand(2, 0, 4) xla_a = a.to(xla_device) self.assertRaises(IndexError, lambda: torch.max(a, dim=1)) @@ -1469,11 +1469,11 @@ def check(device): d = a xm.check_view_sharing([a, d]) - check(torch_xla.device()) + check(torch.device('xla')) check(torch.device('cpu')) def test_save(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(5, device=xla_device) with tempfile.NamedTemporaryFile() as tf: torch.save(x, tf) @@ -1481,7 +1481,7 @@ def test_save(self): self.assertEqual(x, x_loaded) def test_save_bf16(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(5, dtype=torch.bfloat16, device=xla_device) with tempfile.NamedTemporaryFile() as tf: torch.save(x, tf) @@ -1489,7 +1489,7 @@ def test_save_bf16(self): self.assertEqual(x, x_loaded) def test_save_tuple(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.randn(5, device=xla_device) number = 3 with tempfile.NamedTemporaryFile() as tf: @@ -1499,7 +1499,7 @@ def test_save_tuple(self): self.assertEqual(number, number_loaded) def test_save_api(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') model = XlaMNIST().to(xla_device) with tempfile.NamedTemporaryFile() as tf: xm.save(model.state_dict(), tf) @@ -1512,7 +1512,7 @@ def test_save_api(self): def test_serialization_api(self): with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, 'data.pt') - xla_device = torch_xla.device() + xla_device = torch.device('xla') model = XlaMNIST().to(xla_device) xser.save(model.state_dict(), path) state_dict = xser.load(path) @@ -1522,7 +1522,7 @@ def test_serialization_api(self): self.assertEqual(model.state_dict(), loaded_model.state_dict()) def test_deepcopy(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.rand(5, device=xla_device) x0 = x[0] y = copy.deepcopy(x) @@ -1532,7 +1532,7 @@ def test_deepcopy(self): self.assertEqual(x[0], x0) def test_print(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.tensor([5], device=xla_device) expected_str = 'tensor([5], device=\'' + str(xla_device) + '\')' self.assertEqual(str(x), expected_str) @@ -1727,14 +1727,14 @@ def test_fn(t): self.runAtenTest([torch.tensor(20.0)], test_fn) def test_view_and_copy_(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device='cpu') y = torch.tensor([0, 0, 0, 0, 0, 0], device=xla_device) y[::2].copy_(x[::2]) self.assertEqual(y, [1, 0, 3, 0, 5, 0]) def test_view_and_multi_sync(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t1 = torch.zeros(100, device=xla_device) t1[10] = 113 torch_xla.sync() @@ -1744,7 +1744,7 @@ def test_view_and_multi_sync(self): torch_xla._XLAC._get_xla_tensors_text([t1])) def test_binaryop_order(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = torch.rand(5, device=xla_device) y = torch.rand(5) self.assertEqual(x + y, y + x) @@ -1759,7 +1759,7 @@ def test_pow_constant(self): assert 'xla::device_data' not in const_hlo def test_emb_bf16(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') index = torch.ones(1, dtype=torch.long, device=xla_device) emb = torch.nn.Embedding(1024, 128, device=xla_device) emb = emb.to(torch.bfloat16) @@ -1779,7 +1779,7 @@ def test_on_device(device): return m(index) out = test_on_device("cpu") - out_x = test_on_device(torch_xla.device()) + out_x = test_on_device(torch.device('xla')) self.assertEqual(out, out_x.cpu()) def test_transpose_1d(self): @@ -1798,7 +1798,7 @@ def test_fn(t1): def test_sigmoid_bounds(self): torch.manual_seed(0) - xla_device = torch_xla.device() + xla_device = torch.device('xla') for _ in range(100): x = torch.rand(1000).to(xla_device) lower_bound = torch.sigmoid(x * (-100.0)) @@ -1807,7 +1807,7 @@ def test_sigmoid_bounds(self): assert torch.all(upper_bound <= 1.0) def test_manual_seed(self): - device = torch_xla.device() + device = torch.device('xla') torch_xla.manual_seed(12345) t1 = torch.randn(5, 5, device=device) torch_xla.manual_seed(12345) @@ -1815,7 +1815,7 @@ def test_manual_seed(self): self.assertTrue(torch.allclose(t1.cpu(), t2.cpu())) def test_cached_addcdiv(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') met.clear_all() t1 = torch.randn(1, 3).to(xla_device) @@ -1833,7 +1833,7 @@ def test_cached_addcdiv(self): @skipOnEagerDebug def test_print_execution(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') torch_xla.sync() xm.wait_device_ops() met.clear_all() @@ -1887,7 +1887,7 @@ def test_fn(input): return dropped[1].cpu(), input.grad.cpu() met.clear_all() - xla_device = torch_xla.device() + xla_device = torch.device('xla') input_cpu = torch.randn(7, 7, requires_grad=True) input_xla = torch.randn(7, 7, device=xla_device, requires_grad=True) mask_cpu, grad_cpu = test_fn(input_cpu) @@ -2045,7 +2045,7 @@ def foo(x): x = torch.arange(10).to(dtype) r = foo(x) - device = torch_xla.device() + device = torch.device('xla') Xx = x.to(device) Xr = foo(Xx) @@ -2074,7 +2074,7 @@ def func(input_volume): return F.interpolate( input_volume, size=output_size, mode='trilinear', align_corners=False) - device = torch_xla.device() + device = torch.device('xla') input_volume = torch.randn(1, 3, 16, 32, 32).to(device) met.clear_all() self.runAtenTest((input_volume), func) @@ -2105,7 +2105,7 @@ def foo(t): t.retain_grad() t.grad = torch.rand(10, 10, dtype=torch.bfloat16) xt = t.to('xla') - xt.grad = t.grad.to(torch_xla.device(), dtype=torch.bfloat16) + xt.grad = t.grad.to(torch.device('xla'), dtype=torch.bfloat16) foo(t) foo(xt) @@ -2393,7 +2393,7 @@ def run(device): return runf(*args_) actual = run("cpu") - expected = run(torch_xla.device()) + expected = run(torch.device('xla')) self.assertFalse( met.executed_fallback_ops(), msg="expected no fallback operations.") @@ -2452,7 +2452,7 @@ class TestModelComparator(test_utils.XlaTestCase): def test(self): SEED = 42 - xla_device = torch_xla.device() + xla_device = torch.device('xla') x = _gen_tensor(8, 1, 28, 28) xla_x = x.to(xla_device) @@ -2477,7 +2477,7 @@ def test(self): class TestWaitDeviceOps(test_utils.XlaTestCase): def test_wait_device_ops(self): - torch_xla.device() + torch.device('xla') value = torch.randn(10000, 10000, device='xla') val_list = [] val_mean_list = [] @@ -2496,7 +2496,7 @@ class TestDebuggingUtil(test_utils.XlaTestCase): @skipOnEagerDebug def test_get_xla_tensor_debug_info(self): - device = torch_xla.device() + device = torch.device('xla') # test non xla tensor cpu_t1 = torch.randn(5) cpu_t1_info = torch_xla._XLAC._get_xla_tensor_debug_info(cpu_t1) @@ -2531,7 +2531,7 @@ def runOpBuilderTest(self, kwargs=dict()): op = xor.register(name, opfn) if device is None: - device = torch_xla.device() + device = torch.device('xla') if aten_fn is None: aten_fn = opfn tensors = xu.as_list(tensors) @@ -2653,7 +2653,7 @@ class MpDecoratorTest(test_utils.XlaTestCase): @xtu.mp_test def test_mp_decorator(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') self.assertTrue(xla_device.type == 'xla') @@ -2692,7 +2692,7 @@ class TestLoweringContext(test_utils.XlaTestCase): def test_api(self): met.clear_all() - device = torch_xla.device() + device = torch.device('xla') a = torch.tensor([1.0, 2.0, 3.0], device=device) b = torch.tensor([4.0, 5.0, 6.0], device=device) @@ -2720,7 +2720,7 @@ def test_get_parameters_scalar(self): that appropriately. """ - device = torch_xla.device() + device = torch.device('xla') tensors = [] for i in range(10): # Add three copies of the same value. @@ -2753,13 +2753,13 @@ def test_git_revisons(self): self.assertTrue('torch' in revs) def test_send_to_device_grad(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t = _gen_tensor(2, 2, requires_grad=True) dt = xm.send_cpu_data_to_device([t], xla_device) self.assertTrue(dt[0].requires_grad) def test_send_to_device_single(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t = _gen_tensor(2, 2) dt = xm.send_cpu_data_to_device(t, xla_device) self.assertEqual(dt[0].device, xla_device) @@ -2859,7 +2859,7 @@ def from_tensors(self, tensors): wpack = PackWrapper(pack) - xla_device = torch_xla.device() + xla_device = torch.device('xla') xdata = xm.send_cpu_data_to_device(wpack, xla_device) self.assertTrue(isinstance(xdata, nn.utils.rnn.PackedSequence)) self.assertEqual(xdata.batch_sizes.device, torch.device('cpu')) @@ -2869,7 +2869,7 @@ def from_tensors(self, tensors): "https://github.com/pytorch/xla/pull/7864#issuecomment-2294034008") def test_as_strided_input_larger(self): size = (5, 5) - device = torch_xla.device() + device = torch.device('xla') a = torch.ones(size, device=device) small_a = a[:, ::2] @@ -2899,7 +2899,7 @@ def test_aten_move_scalar_cuda_to_xla(self): self._test_move_tensor_cuda_to_xla(torch.tensor(42)) def test_unsafe_buffer_pointer(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_tensor_0 = torch.tensor(42).to(xla_device) # `torch_xla.sync()` ensures xtensor->CurrentDataHandle() != nullptr torch_xla.sync() @@ -2944,7 +2944,7 @@ def _test_dlpack_capsule_conversion_helper(self, xla_tensor): @onlyIfPJRTDeviceIsCUDA @parameterized.parameters(*all_types_and(torch.half, torch.bfloat16)) def test_dlpack_roundtrip_tensor(self, dtype): - xla_device = torch_xla.device() + xla_device = torch.device('xla') # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr # xla_tensor_2 uses XLANativeFunctions::_to_copy xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) @@ -2961,7 +2961,7 @@ def test_dlpack_roundtrip_tensor(self, dtype): *all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64)) def test_dlpack_roundtrip_scalar(self, dtype): - xla_device = torch_xla.device() + xla_device = torch.device('xla') xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) # `torch_xla.sync()` ensures xtensor->CurrentDataHandle() != nullptr torch_xla.sync() @@ -3118,7 +3118,7 @@ def forward(self, inp): class TestActivationCheckpoint(test_utils.XlaTestCase): def test_dropout(self): - device = torch_xla.device() + device = torch.device('xla') model = SimpleModelWithDropout().to(device) model = checkpoint_module(model) _input = torch.randn(128, 128, requires_grad=True) @@ -3132,7 +3132,7 @@ def test_dropout(self): f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}") def test_opt_barrier(self): - device = torch_xla.device() + device = torch.device('xla') model = SimpleModelWithDropout().to(device) model = checkpoint_module(model) _input = torch.randn(128, 128, requires_grad=True) @@ -3167,7 +3167,7 @@ def _reference_nms(self, boxes, scores, iou_threshold): def _nms(self, boxes, scores, iou_threshold): import torchvision - device = torch_xla.device() + device = torch.device('xla') return torchvision.ops.nms( boxes.to(device), scores.to(device), iou_threshold).cpu() @@ -3240,7 +3240,7 @@ class TestHelperFunction(test_utils.XlaTestCase): def test_repeat_truncated(self): from torch_xla.experimental.custom_kernel import repeat_with_fixed_output_size met.clear_all() - device = torch_xla.device() + device = torch.device('xla') total_repeat_length = 20 input = torch.randn(10).to(device) repeats = torch.tensor([0, 1, 2, 0, 4, 0, 6, 7, 8, 9]).to(device) @@ -3253,7 +3253,7 @@ def test_repeat_truncated(self): def test_repeat_extended(self): from torch_xla.experimental.custom_kernel import repeat_with_fixed_output_size met.clear_all() - device = torch_xla.device() + device = torch.device('xla') total_repeat_length = 100 input = torch.randn(10).to(device) repeats = torch.tensor([0, 5, 2, 0, 4, 9, 6, 7, 8, 0]).to(device) @@ -3271,7 +3271,7 @@ def test_repeat_extended(self): def test_repeat_special(self): from torch_xla.experimental.custom_kernel import repeat_with_fixed_output_size met.clear_all() - device = torch_xla.device() + device = torch.device('xla') total_repeat_length = 135 num_groups = 8 input = torch.arange(num_groups, dtype=torch.int32).to(device) diff --git a/test/test_placeholder.py b/test/test_placeholder.py index d5506bfacd55..5b6c2096a39e 100644 --- a/test/test_placeholder.py +++ b/test/test_placeholder.py @@ -19,7 +19,7 @@ def test_create_placeholder(self): ): p = create_placeholder_tensor(shape, dtype) assert isinstance(p, torch.Tensor) - assert p.device == torch_xla.device() + assert p.device == torch.device('xla') self.assertEqual(p.dtype, dtype) self.assertEqual(p.shape, shape) self.assertTrue(torch_xla._XLAC._is_placecholder(p)) @@ -56,7 +56,7 @@ def test_placeholder_handle_unique(self): self.assertNotEqual(h1, h2) def test_cannot_get_handle_from_deleted_pjrt_buffer(self): - xla_device = torch_xla.device() + xla_device = torch.device('xla') t0 = torch.randn(4, 2, 2).to(xla_device) t1 = torch.randn(4, 2, 2).to(xla_device) self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True)) diff --git a/test/test_profile_mp_mnist.py b/test/test_profile_mp_mnist.py index e23c2f59c223..2dbd67655918 100644 --- a/test/test_profile_mp_mnist.py +++ b/test/test_profile_mp_mnist.py @@ -144,7 +144,7 @@ def train_mnist(flags, # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') model = MNIST().to(device) writer = None if xm.is_master_ordinal(): diff --git a/test/test_python_ops.py b/test/test_python_ops.py index 9dc145947f62..557bf5c4c278 100644 --- a/test/test_python_ops.py +++ b/test/test_python_ops.py @@ -29,8 +29,8 @@ def test_put(self, dtype): raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( str(dtype))) - device = torch_xla.device() - real_device_type = xm.xla_device_hw(str(torch_xla.device())) + device = torch.device('xla') + real_device_type = xm.xla_device_hw(str(torch.device('xla'))) if real_device_type == "TPU": raise unittest.SkipTest("TestPut is too slow on TPU. Skipped") @@ -108,7 +108,7 @@ def test_index_copy(self, dtype): raise unittest.SkipTest("Dtype {0} is unsupported by XLA".format( str(dtype))) - device = torch_xla.device() + device = torch.device('xla') # We just test for num_copy <= num_dest, as otherwise there are repeated indices # and the behavior is undefined diff --git a/test/test_syncfree_optimizers.py b/test/test_syncfree_optimizers.py index 8807271440c6..593ea06b83f1 100644 --- a/test/test_syncfree_optimizers.py +++ b/test/test_syncfree_optimizers.py @@ -53,7 +53,7 @@ def _test_optimizer(self, syncfree_optim_cls, ref_optim_cls, optim_kwargs={'lr': 1e-2}): - device = torch_xla.device() + device = torch.device('xla') loss_fn = nn.NLLLoss() # syncfree model torch.manual_seed(0) diff --git a/test/test_torch_distributed_fsdp_frozen_weight.py b/test/test_torch_distributed_fsdp_frozen_weight.py index 98730dbf7009..898c249625e2 100644 --- a/test/test_torch_distributed_fsdp_frozen_weight.py +++ b/test/test_torch_distributed_fsdp_frozen_weight.py @@ -7,7 +7,7 @@ def _mp_fn(index): - dev = torch_xla.device() + dev = torch.device('xla') if xm.xla_device_hw(dev) not in ('TPU', 'CUDA'): print( 'Default device {} is not a TPU or CUDA device'.format(dev), diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index a3069a6637ec..54626415255f 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -63,7 +63,7 @@ def test_xla_backend_exists(self): self.assertIsNotNone(pg_xla_creator) def test_allreduce(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' dist.all_reduce(tensor) @@ -72,7 +72,7 @@ def test_allreduce(self): @patch_world(rank=3, size=6) def test_allreduce_with_mesh(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_options = {'xla_pg_options': {'spmd': True}} @@ -89,7 +89,7 @@ def test_allreduce_with_mesh(self): @patch_world(rank=3, size=8) def test_allgather(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)] all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' @@ -99,7 +99,7 @@ def test_allgather(self): @patch_world(rank=3, size=8) def test_all_scalar_allgather(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.zeros((), device=device) + 1 + 2 * dist.get_rank() output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)] all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' @@ -109,7 +109,7 @@ def test_all_scalar_allgather(self): @patch_world(rank=3, size=8) def test_allgather_coalesced(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() pg_xla = get_process_group_xla(rank=3, size=8) @@ -127,7 +127,7 @@ def test_allgather_coalesced(self): hlo_matches(hlo, all_gather_pattern) def test_broadcast(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' dist.broadcast(tensor, 0) @@ -136,7 +136,7 @@ def test_broadcast(self): # Needed for ZeRO stage 1 def test_reduce_scatter(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() input_list = [tensor] output = torch.zeros_like(tensor) @@ -148,7 +148,7 @@ def test_reduce_scatter(self): @skipIf(xr.device_type() == 'CPU', "UNIMPLEMENTED: ReduceScatter is not implemented on CPU.") def test_reduce_scatter_coalesced(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() input_tensors_list = [[tensor, tensor], [tensor2, tensor2]] @@ -168,7 +168,7 @@ def test_reduce_scatter_coalesced(self): @patch_world(0, 6) def test_send(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() input_list = [tensor] @@ -185,11 +185,11 @@ def test_send(self): hlo_matches(hlo, senddone_pattern) # Don't try to run Send on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) + torch_xla._XLAC._clear_pending_irs(str(torch.device('xla'))) @patch_world(0, 6) def test_recv(self): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() with mock.patch.object( @@ -205,7 +205,7 @@ def test_recv(self): hlo_matches(hlo, recvdone_pattern) # Don't try to run Recv on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) + torch_xla._XLAC._clear_pending_irs(str(torch.device('xla'))) @patch_world(rank=0, size=12) def test_new_group_no_ranks(self): @@ -365,7 +365,7 @@ def test_barrier(self): 'monitored_barrier', ) def test_unimplemented_op(self, op): - device = torch_xla.device() + device = torch.device('xla') tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_xla = dist.group.WORLD self.assertIsInstance(pg_xla, diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index efb34a2cc3af..0a031c1d0cba 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -250,7 +250,7 @@ def train_imagenet(): torch.manual_seed(42) - device = torch_xla.device() + device = torch.device('xla') model = get_model_property('model_fn')().to(device) # Initialization is nondeterministic with multiple threads in PjRt. diff --git a/test/test_train_mp_imagenet_amp.py b/test/test_train_mp_imagenet_amp.py index 290857281fd7..c5bf26b9e4cf 100644 --- a/test/test_train_mp_imagenet_amp.py +++ b/test/test_train_mp_imagenet_amp.py @@ -194,7 +194,7 @@ def train_imagenet(): torch.manual_seed(42) - device = torch_xla.device() + device = torch.device('xla') device_hw = xm.xla_device_hw(device) model = get_model_property('model_fn')().to(device) writer = None @@ -229,7 +229,7 @@ def train_loop_fn(loader, epoch): for step, (data, target) in enumerate(loader): optimizer.zero_grad() if FLAGS.amp: - with autocast(torch_xla.device()): + with autocast(torch.device('xla')): output = model(data) loss = loss_fn(output, target) if scaler: diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index 1d939d8385b3..3423b3e4df59 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -241,7 +241,7 @@ def train_imagenet(): torch.manual_seed(42) - device = torch_xla.device() + device = torch.device('xla') model = get_model_property('model_fn')() # Automatic wrapping sub-modules with inner FSDP auto_wrap_policy = None diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 0a5e46fdcd1f..4aa328752e89 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -130,7 +130,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') model = MNIST().to(device) # Initialization is nondeterministic with multiple threads in PjRt. diff --git a/test/test_train_mp_mnist_amp.py b/test/test_train_mp_mnist_amp.py index 0bd393b21f2e..d6fac172003a 100644 --- a/test/test_train_mp_mnist_amp.py +++ b/test/test_train_mp_mnist_amp.py @@ -130,7 +130,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') device_hw = xm.xla_device_hw(device) model = MNIST().to(device) diff --git a/test/test_train_mp_mnist_fsdp_with_ckpt.py b/test/test_train_mp_mnist_fsdp_with_ckpt.py index 833612a2be49..c6aa20bc1d68 100644 --- a/test/test_train_mp_mnist_fsdp_with_ckpt.py +++ b/test/test_train_mp_mnist_fsdp_with_ckpt.py @@ -164,7 +164,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') model = MNIST() # Automatic wrapping sub-modules with inner FSDP auto_wrap_policy = None diff --git a/test/test_train_mp_mnist_zero1.py b/test/test_train_mp_mnist_zero1.py index 523bf5fc0a19..11926c273697 100644 --- a/test/test_train_mp_mnist_zero1.py +++ b/test/test_train_mp_mnist_zero1.py @@ -114,7 +114,7 @@ def train_mnist(flags, **kwargs): # Scale learning rate to num cores lr = flags.lr * xr.world_size() - device = torch_xla.device() + device = torch.device('xla') model = MNIST().to(device) writer = None diff --git a/test/test_user_computation_debug_cache.py b/test/test_user_computation_debug_cache.py index f83f856c2cfd..a6fb1cd885ae 100644 --- a/test/test_user_computation_debug_cache.py +++ b/test/test_user_computation_debug_cache.py @@ -40,7 +40,7 @@ def input_scope_0(tensor): def input_scope_1(tensor): return [torch.sin(tensor), torch.cos(tensor)] - device = torch_xla.device() + device = torch.device('xla') init_tensor = torch.tensor(10).to(device) def create_user_computation(fn): diff --git a/test/test_utils.py b/test/test_utils.py index 6a913f932e4d..2bbf7255182c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -384,7 +384,7 @@ def compareResults(self, results, xla_results, rel_err=1e-2, abs_err=1e-5): def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5): if device is None: - device = torch_xla.device() + device = torch.device('xla') tensors = xu.as_list(tensors) xla_tensors = [ x.to(device).detach().requires_grad_(x.requires_grad) for x in tensors diff --git a/test/test_while_loop.py b/test/test_while_loop.py index 4dc0a17a96ea..d58b18eb3e45 100644 --- a/test/test_while_loop.py +++ b/test/test_while_loop.py @@ -26,7 +26,7 @@ def _fake_while_loop(cond_fn, body_fn, operands): class WhileLoopTest(unittest.TestCase): def test_while_loop_addition(self): - device = torch_xla.device() + device = torch.device('xla') def cond_fn(iteri, x): return iteri > 0 @@ -41,7 +41,7 @@ def body_fn(iteri, x): self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) def test_while_loop_addition_nested(self): - device = torch_xla.device() + device = torch.device('xla') def cond_fn(iteri, x): return iteri > 0 @@ -56,7 +56,7 @@ def body_fn(iteri, x): self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) def test_while_loop_simple_linear_inside_loop(self): - device = torch_xla.device() + device = torch.device('xla') torch.set_grad_enabled(False) class SimpleLinear(torch.nn.Module): @@ -94,7 +94,7 @@ def forward_without_while_loop_op(self, iteri, x): # ====== fori_loop ====== @unittest.skip("Fori_loop is not supported now due to unstable result.") def test_fori_loop_addition(self): - device = torch_xla.device() + device = torch.device('xla') lower = torch.tensor(0, device=device) upper = torch.tensor(50, device=device) diff --git a/test/test_zero1.py b/test/test_zero1.py index 8bb2fbc3d822..1a798abc1d9c 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -34,7 +34,7 @@ class XlaZeRO1Test(test_utils.XlaTestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") def test_zero1(self): - device = torch_xla.device() + device = torch.device('xla') model = nn.Linear(32, 32) x = torch.ones((32, 32)) @@ -89,7 +89,7 @@ def test_zero1(self): torch_xla.sync() def test_zero1_load(self): - device = torch_xla.device() + device = torch.device('xla') model = nn.Linear(32, 32) x = torch.ones((32, 32)) @@ -153,7 +153,7 @@ def test_zero1_load(self): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA'): test = unittest.main(exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/torch_distributed/test_ddp.py b/test/torch_distributed/test_ddp.py index 1d91f520d5aa..6e8c01a3f7b9 100644 --- a/test/torch_distributed/test_ddp.py +++ b/test/torch_distributed/test_ddp.py @@ -24,7 +24,7 @@ def _ddp_correctness(rank, gradient_as_bucket_view: bool = False): # We cannot run this guard before XMP, # see API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing. - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) not in ('TPU', 'CUDA'): print( 'Default device {} is not a TPU device'.format(device), diff --git a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py index 7c30b211ad49..125121b8c798 100644 --- a/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py index 2fd71d2ed84e..18bec4fecdc0 100644 --- a/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() dist.init_process_group('xla', init_method='xla://') diff --git a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py index c462f7552800..82eff827fc9f 100644 --- a/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_bucketed_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_fsdp_meta.py b/test/torch_distributed/test_torch_distributed_fsdp_meta.py index 444c47890330..2f382eb86246 100644 --- a/test/torch_distributed/test_torch_distributed_fsdp_meta.py +++ b/test/torch_distributed/test_torch_distributed_fsdp_meta.py @@ -141,7 +141,7 @@ def meta_module_fn(): def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') # This test fails on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) if xm.xla_device_hw(device) in ('TPU', 'NEURON'): dist.init_process_group('xla', init_method='xla://') diff --git a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py index 9089f9d799ff..affc32c4a73d 100644 --- a/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py index 006d3fd33a95..90ccbfb64d0c 100644 --- a/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py +++ b/test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py @@ -9,7 +9,7 @@ def _mp_fn(index): - device = torch_xla.device() + device = torch.device('xla') if xm.xla_device_hw(device) in ('TPU', 'CUDA'): world_size = xr.world_size() rank = xr.global_ordinal() diff --git a/test/utils/train_spmd_linear_model.py b/test/utils/train_spmd_linear_model.py index ac0dd9f86b22..4407bd665f65 100644 --- a/test/utils/train_spmd_linear_model.py +++ b/test/utils/train_spmd_linear_model.py @@ -69,7 +69,7 @@ def forward(self, x): def train(): - device = torch_xla.device() + device = torch.device('xla') torch.manual_seed(42) model = SimpleLinear().to(device) print('===> Preparing data..') diff --git a/test/utils/train_spmd_linear_model_grad_acc.py b/test/utils/train_spmd_linear_model_grad_acc.py index 62f3e79ae4a0..2d6ccfd71a3c 100644 --- a/test/utils/train_spmd_linear_model_grad_acc.py +++ b/test/utils/train_spmd_linear_model_grad_acc.py @@ -77,7 +77,7 @@ def forward(self, x): def train(): - device = torch_xla.device() + device = torch.device('xla') num_devices = xr.global_runtime_device_count() print(f'num_devices: {num_devices}') # Define a mesh with all devices along one axis diff --git a/torch_xla/_dynamo/dynamo_backend2.py b/torch_xla/_dynamo/dynamo_backend2.py index e3fee43f792b..1d515c9cc63f 100644 --- a/torch_xla/_dynamo/dynamo_backend2.py +++ b/torch_xla/_dynamo/dynamo_backend2.py @@ -34,7 +34,7 @@ def _dynamo_backend(model: torch.fx.GraphModule, sample_args: Any): jax.config.update("jax_enable_x64", True) env = torchax.default_env() - xla_device = torch_xla.device() + xla_device = torch.device('xla') def run_jax(*args, initial_rng_key): args_t = torchax.interop.torch_view(args) diff --git a/torch_xla/_dynamo/dynamo_bridge.py b/torch_xla/_dynamo/dynamo_bridge.py index 7cae4f7392e5..1061406746f7 100644 --- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -498,7 +498,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule, # 2. All of the pending IRs are result of our warm up cache tracing and they # should be removed to avoid extra computation executed and in place updates op # mistakenlly update the input tensors. - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) + torch_xla._XLAC._clear_pending_irs(str(torch.device('xla'))) vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, @@ -567,7 +567,7 @@ def optimized_mod(*args: tuple): is_cuda_args = original_device.type == "cuda" if is_cuda_args: - args = _maybe_move_tensors_to_device(args, torch_xla.device()) + args = _maybe_move_tensors_to_device(args, torch.device('xla')) if not config.skip_input_data_check: # `torch_xla.sync()` needs to be blocking since we want to access args's @@ -768,7 +768,7 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args, # UnsupportedNodesCollector might trigger in place ops, need to clear them here. _clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args) - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) + torch_xla._XLAC._clear_pending_irs(str(torch.device('xla'))) class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): @@ -813,7 +813,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args): if _args_on_cuda(xla_args): xla_args = tuple( - _maybe_move_tensors_to_device(xla_args, torch_xla.device())) + _maybe_move_tensors_to_device(xla_args, torch.device('xla'))) # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its # value reference before actually computing it. diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 25a0ee36c36e..9689fb7fb703 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -104,7 +104,7 @@ def initialize_singleprocess(): plugins.default().configure_single_process() elif runtime.device_type() == 'TPU': tpu.configure_one_chip_topology() - xm.set_replication(torch_xla.device(), []) + xm.set_replication(torch.device('xla'), []) def initialize_multiprocess(local_rank: int, local_world_size: int): @@ -119,7 +119,7 @@ def initialize_multiprocess(local_rank: int, local_world_size: int): neuron.initialize_env(local_rank, local_world_size) devices = xm.get_xla_supported_devices() - xm.set_replication(torch_xla.device(), devices) + xm.set_replication(torch.device('xla'), devices) def run_multiprocess(fn: Callable[..., R], diff --git a/torch_xla/core/xla_op_registry.py b/torch_xla/core/xla_op_registry.py index 62943f4c70c5..2e82d49028d4 100644 --- a/torch_xla/core/xla_op_registry.py +++ b/torch_xla/core/xla_op_registry.py @@ -68,7 +68,7 @@ def slice_and_add(a, b, dimno=0): SLICE_AND_ADD = xor.register('slice_and_add', slice_and_add) def user_computation_test(): - device = torch_xla.device() + device = torch.device('xla') x = torch.randn(2, 2).to(device) y = torch.randn(2, 2).to(device) z = SLICE_AND_ADD(x, y, dimno=0) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 68bb7ea7a48e..feda7894c081 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1310,7 +1310,7 @@ void BuildLoweringContextSubmodule(py::module* m) { * import torch_xla * import torch_xla.core.xla_model as xm * - * device = torch_xla.device() + * device = torch.device('xla') * example = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device) * * def network(x): diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index c5605d2b3ed2..fb5e41cc92c5 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -139,7 +139,7 @@ class XlaFullyShardedDataParallel(nn.Module): module (nn.Module): module to be wrapped with FSDP. If the input module's parameters and buffers are not already on XLA device, they will be cast to - ``torch_xla.device()`` (after sharding) during FSDP initialization. + ``torch.device('xla')`` (after sharding) during FSDP initialization. reshard_after_forward (bool, Optional): if ``True``, reshard parameters after the forward pass. This saves memory but slows training. This is only relevant when resharding @@ -527,7 +527,7 @@ def __init__( List[Parameter], self._fsdp_wrapped_module.flat_params) + non_flatten_params - self.xla_device = torch_xla.device() + self.xla_device = torch.device('xla') # Shard module parameters in place self._shard_parameters_(params_to_shard) # Cast the module buffers to the specified buffer_dtype @@ -1646,7 +1646,7 @@ def _print_r0(self, msg: str, restart: bool = False) -> None: if restart: self._tstart = time.time() if self.rank == 0: - memory_info = xm.get_memory_info(torch_xla.device()) + memory_info = xm.get_memory_info(torch.device('xla')) gb_free = memory_info["kb_free"] / 1024 / 1024 gb_total = memory_info["kb_total"] / 1024 / 1024 logging.info( diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 05a37fe9b411..f84b71d32f9d 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -265,7 +265,7 @@ class MpDeviceLoader(object): Example: - >>> device = torch_xla.device() + >>> device = torch.device('xla') >>> train_device_loader = MpDeviceLoader(train_loader, device) """ diff --git a/torch_xla/distributed/spmd/api.py b/torch_xla/distributed/spmd/api.py index 567fba1ad015..29c930af5d9a 100644 --- a/torch_xla/distributed/spmd/api.py +++ b/torch_xla/distributed/spmd/api.py @@ -216,7 +216,7 @@ def xla_distribute_module( if partition_fn: if getattr(partition_fn, '__name__', 'unknown') == "auto_policy": # TODO(yeounoh) allow pre-loading to xla device in the future. - assert next(module.parameters()).device != torch_xla.device(), \ + assert next(module.parameters()).device != torch.device('xla'), \ f"Currently requires module to be on cpu, before xla_distribute_module." xr.use_spmd(auto=True) module = module.to('xla') diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 49229b17cffe..239a2bce1043 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -334,7 +334,7 @@ def _get_physical_tpu_mesh(self, devices: np.ndarray) -> np.ndarray: A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On v2 and v3, global_z is instead cores_per_chip (i.e., 2). """ - assert xm.xla_device_hw(torch_xla.device()) == 'TPU' + assert xm.xla_device_hw(torch.device('xla')) == 'TPU' # coords is a 3-dims tuple representing the device in physical mesh device_coords = [self.device_attributes[d]['coords'] for d in devices] dims = tuple(d + 1 for d in max(device_coords)) @@ -826,7 +826,7 @@ def can_apply(self, t: torch.Tensor) -> bool: def apply(self, t: torch.Tensor): # TODO(yeounoh) use virtual device interface when available. - assert (t.device == torch_xla.device()) + assert (t.device == torch.device('xla')) mark_sharding(t, self.mesh, self.partition_spec) diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index e3b349a4b7fb..b14fde5bb1a8 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -56,7 +56,7 @@ class MpModelWrapper(object): WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork()) def _mp_fn(index, ...): - device = torch_xla.device() + device = torch.device('xla') model = WRAPPED_MODEL.to(device) ... diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index ea4c8d54c1a2..f0100dec87bd 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -16,7 +16,7 @@ def fori_loop(lower, upper, body_fun, *input_value): - device = torch_xla.device() + device = torch.device('xla') if (upper < lower): print("ERROR: upper should be a larger number than lower") iteri = upper - lower diff --git a/torch_xla/experimental/gradient_accumulation.py b/torch_xla/experimental/gradient_accumulation.py index 4e3f8682e68e..0855fffbd62a 100644 --- a/torch_xla/experimental/gradient_accumulation.py +++ b/torch_xla/experimental/gradient_accumulation.py @@ -154,7 +154,7 @@ def _make_init_grad(param): def _gradient_accumulation_impl(context, body_fn, iterable_tensors, params, carried_tensors): builder = XlaBuildHelper('grad_acc') - device = torch_xla.device() + device = torch.device('xla') init_iterator = torch.tensor(0, dtype=torch.int32, device=device) init_loss = torch.tensor(0, dtype=torch.float32, device=device) diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 565b569ed726..ee8fabe5a3ac 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -143,7 +143,7 @@ def scan(fn, init, xs): >>> y = new_carry >>> return new_carry, y >>> - >>> with torch_xla.device(): + >>> with torch.device('xla'): >>> init = torch.tensor([0.0, 0.0], requires_grad=True) >>> xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], >>> requires_grad=True) @@ -650,7 +650,7 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor: t = xb.create_placeholder_tensor(v.shape, v.dtype) return t.requires_grad_(v.requires_grad) - device = torch_xla.device() + device = torch.device('xla') fake_carry = tree_map(make_fake_tensor, init) fake_x = tree_map(lambda v: make_fake_tensor(v[0]), xs) diff --git a/torch_xla/experimental/scan_layers.py b/torch_xla/experimental/scan_layers.py index 4e55111caeec..3bbd78196fd5 100644 --- a/torch_xla/experimental/scan_layers.py +++ b/torch_xla/experimental/scan_layers.py @@ -50,7 +50,7 @@ def scan_layers(layers: Iterable[torch.nn.Module], >>> import torch >>> import torch.nn as nn >>> from torch_xla.experimental.scan_layers import scan_layers - >>> with torch_xla.device(): + >>> with torch.device('xla'): >>> layers = [nn.Linear(16, 16) for i in range(10)] >>> input = torch.randn(16) >>> output = scan_layers(layers, input) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 2e274190db75..371a9005d510 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -156,7 +156,7 @@ def local_ordinal() -> int: Local ordinal is in range [0, local_device_count).""" local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0) devices_per_process = addressable_device_count() - return local_rank * devices_per_process + torch_xla.device().index + return local_rank * devices_per_process + torch.device('xla').index def process_index() -> int: diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index b88a8131b2d8..1a70f7972af6 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -341,7 +341,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, assert len(kwargs) == 0, "Export to stablehlo doesnt support kwargs yet." - device = torch_xla.device() + device = torch.device('xla') _flat_input_args = exported_model._graph_module_flat_inputs(args, {}) _flat_input_args = pytree.tree_map_only(torch.Tensor, @@ -352,7 +352,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, torch_xla.sync() xm.wait_device_ops() metrics.clear_counters() - device = torch_xla.device() + device = torch.device('xla') # Run the fx graph tracing using lazy tensor if options.inline_all_constant: From c5c7bfd051fa53c46b78c3fff21cb7aff04cec33 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 27 May 2025 22:01:10 +0000 Subject: [PATCH 2/3] Deprecate API --- torch_xla/core/xla_model.py | 2 +- torch_xla/torch_xla.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6b68e656d333..6d437ecab5ee 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -118,7 +118,7 @@ def master_print(*args: Any, print(*args, file=fd, flush=flush) -@deprecated("Use torch_xla.device instead") +@deprecated("Use torch.device('xla') instead") def xla_device(n: Optional[int] = None, devkind: Optional[str] = None) -> torch.device: """Returns a given instance of an XLA device. diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 9062d6a9ef21..739b1147b28a 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -4,6 +4,7 @@ import functools import uuid from typing import Any, Callable, List, Optional, Tuple +from typing_extensions import deprecated import weakref import torch @@ -16,6 +17,7 @@ import torch_xla.utils.utils as xu +@deprecated("Use torch.device('xla') instead") def device(index: int = None) -> torch.device: """Returns a given instance of an XLA device. From e578dfe1a2948e3a9b1684e02856adf089cbedd4 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Fri, 6 Jun 2025 18:39:48 +0000 Subject: [PATCH 3/3] Fix errors --- API_GUIDE.md | 2 +- test/pjrt/test_runtime_multi_cpu.py | 6 +- test/pjrt/test_runtime_multi_gpu.py | 266 ------------------ test/pytorch_test_base.py | 6 +- .../test_xla_spmd_python_api_interaction.py | 2 +- test/test_operations.py | 8 +- torch_xla/_internal/pjrt.py | 4 +- torch_xla/runtime.py | 3 +- 8 files changed, 16 insertions(+), 281 deletions(-) delete mode 100644 test/pjrt/test_runtime_multi_gpu.py diff --git a/API_GUIDE.md b/API_GUIDE.md index cd2e1f2fd5f3..f13ff86ab10b 100644 --- a/API_GUIDE.md +++ b/API_GUIDE.md @@ -197,7 +197,7 @@ single device snippet. Let's go over then one by one. - `torch_xla.launch()` - Creates the processes that each run an XLA device. - This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device. - - Note that if you print the `torch.device('xla')` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details). + - Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details). - `MpDeviceLoader` - Loads the training data onto each device. - `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance. diff --git a/test/pjrt/test_runtime_multi_cpu.py b/test/pjrt/test_runtime_multi_cpu.py index 6090066d05f9..71f667765637 100644 --- a/test/pjrt/test_runtime_multi_cpu.py +++ b/test/pjrt/test_runtime_multi_cpu.py @@ -27,7 +27,7 @@ def test_default_cpu_device(self): os.environ.pop(xenv.PJRT_CPU_ASYNC_CLIENT, None) expected = {0: torch.device('xla:0')} - devices_per_process = pjrt.run_multiprocess(xm.xla_device) + devices_per_process = pjrt.run_multiprocess(torch_xla.device) self.assertDictEqual(devices_per_process, expected) def test_multi_cpu_devices(self): @@ -38,7 +38,7 @@ def test_multi_cpu_devices(self): 3: torch.device('xla:3'), } - devices_per_process = pjrt.run_multiprocess(xm.xla_device) + devices_per_process = pjrt.run_multiprocess(torch_xla.device) self.assertDictEqual(devices_per_process, expected) def test_global_ordinal(self): @@ -65,7 +65,7 @@ def forward(ctx, x): def backward(ctx, grad_output): results['forward_ordinal'] = ctx.forward_ordinal results['backward_ordinal'] = xr.global_ordinal() - results['device'] = str(torch.device('xla')) + results['device'] = str(torch_xla.device()) return grad_output x = torch.ones(1, requires_grad=True, device='xla') diff --git a/test/pjrt/test_runtime_multi_gpu.py b/test/pjrt/test_runtime_multi_gpu.py deleted file mode 100644 index 25d967e363fd..000000000000 --- a/test/pjrt/test_runtime_multi_gpu.py +++ /dev/null @@ -1,266 +0,0 @@ -import concurrent.futures -import itertools -import os -import queue -import requests -import unittest - -import numpy as np -import torch -import torch.nn as nn -import torch_xla -import torch_xla.core.xla_env_vars as xenv -import torch_xla.core.xla_model as xm -import torch_xla.distributed.xla_multiprocessing as xmp -from torch_xla import runtime as xr -from torch_xla._internal import pjrt -from absl.testing import absltest, parameterized - - -@unittest.skipIf(xr.device_type() != "CUDA", - f"GPU tests should only run on GPU devices.") -class TestExperimentalPjrtMultiGpu(parameterized.TestCase): - - def setUp(self): - xr.set_device_type('CUDA') - - os.environ.update({ - xenv.PJRT_GPU_ASYNC_CLIENT: 'true', - }) - - def test_default_gpu_device(self): - os.environ.pop(xenv.PJRT_GPU_ASYNC_CLIENT, None) - - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = {i: torch.device(f'xla:0') for i in range(num_devices)} - devices_per_process = pjrt.run_multiprocess(xm.xla_device) - self.assertDictEqual(devices_per_process, expected) - - def test_multi_gpu_devices(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = {i: torch.device(f'xla:0') for i in range(num_devices)} - - devices_per_process = pjrt.run_multiprocess(xm.xla_device) - self.assertDictEqual(devices_per_process, expected) - - def test_global_ordinal(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = [i for i in range(num_devices)] - - results = pjrt.run_multiprocess(xr.global_ordinal) - self.assertListEqual(sorted(results.values()), expected) - - def test_local_ordinal(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = [i for i in range(num_devices)] - - results = pjrt.run_multiprocess(xr.local_ordinal) - self.assertListEqual(sorted(results.values()), expected) - - def test_global_device_count(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = {i: num_devices for i in range(num_devices)} - results = pjrt.run_multiprocess(xr.global_device_count) - self.assertEqual(expected, results) - - def test_local_process_count(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = {i: num_devices for i in range(num_devices)} - results = pjrt.run_multiprocess(xr.local_process_count) - self.assertEqual(expected, results) - - def test_world_size(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = {i: num_devices for i in range(num_devices)} - results = pjrt.run_multiprocess(xr.world_size) - self.assertEqual(expected, results) - - def test_addressable_device_count(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = {i: 1 for i in range(num_devices)} - results = pjrt.run_multiprocess(xr.addressable_device_count) - self.assertEqual(expected, results) - - def test_addressable_runtime_device_count(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = {i: 1 for i in range(num_devices)} - results = pjrt.run_multiprocess(xr.addressable_runtime_device_count) - self.assertEqual(expected, results) - - def test_local_device_count(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - # xr.local_process_count() is 2, xr.addressable_device_count() is 1. - expected = {i: num_devices for i in range(num_devices)} - results = pjrt.run_multiprocess(xr.local_device_count) - self.assertEqual(expected, results) - - def test_process_index(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = {i: i for i in range(num_devices)} - results = pjrt.run_multiprocess(xr.process_index) - self.assertEqual(expected, results) - - def test_process_count(self): - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - expected = {i: num_devices for i in range(num_devices)} - results = pjrt.run_multiprocess(xr.process_count) - self.assertEqual(expected, results) - - @staticmethod - def _multi_gpu_backwards(): - results = {} - - class _CustomBackwards(torch.autograd.Function): - - @staticmethod - def forward(ctx, x): - ordinal = xr.global_ordinal() - ctx.forward_ordinal = ordinal - return x - - @staticmethod - def backward(ctx, grad_output): - results['forward_ordinal'] = ctx.forward_ordinal - results['backward_ordinal'] = xr.global_ordinal() - results['device'] = str(torch.device('xla')) - return grad_output - - x = torch.ones(1, requires_grad=True, device='xla') - y = _CustomBackwards.apply(x) - y.backward() - torch_xla.sync() - - return results - - def test_multi_gpu_backwards(self): - os.environ.update({ - xenv.PJRT_GPU_ASYNC_CLIENT: 'true', - }) - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - - expected = { - i: { - 'forward_ordinal': i, - 'backward_ordinal': i, - 'device': f'xla:0' - } for i in range(num_devices) - } - results = pjrt.run_multiprocess(self._multi_gpu_backwards) - - self.assertDictEqual(results, expected) - - @staticmethod - def _spawn(index: int, queue: queue.Queue): - queue.put(index) - - @parameterized.named_parameters(('xmp', xmp.spawn), ('pjrt', pjrt.spawn)) - def test_spawn(self, spawn): - manager = torch.multiprocessing.Manager() - num_devices = int(os.environ[xenv.GPU_NUM_DEVICES]) - queue = manager.Queue(num_devices) - spawn(self._spawn, args=(queue,)) - - indices = sorted(queue.get(block=False) for _ in range(queue.qsize())) - self.assertListEqual(indices, list(range(num_devices))) - - @staticmethod - def _broadcast(sync): - torch.manual_seed(xr.global_ordinal()) - device = torch.device('xla') - model = nn.Linear(5, 5).to(device) - if sync: - xm.broadcast_master_param(model) - - torch_xla.sync() - return next(model.parameters()).detach().cpu().numpy() - - @parameterized.named_parameters(('synchronized_parameters', True), - ('unsynchronized_parameters', False)) - def test_broadcast_master_param(self, sync): - results = pjrt.run_multiprocess(self._broadcast, sync) - master_params = results[0] - for ordinal, worker_params in results.items(): - if sync: - np.testing.assert_array_equal(master_params, worker_params) - elif ordinal != 0: - np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, - master_params, worker_params) - - @staticmethod - def _all_gather(pin_layout): - device = torch.device('xla') - ordinal = torch.tensor([xr.global_ordinal()], device=device) - out = xm.all_gather(ordinal, pin_layout=pin_layout) - torch_xla.sync() - - return out.cpu().numpy() - - @parameterized.named_parameters(('pinned', True), ('unpinned', False)) - def test_all_gather(self, pin_layout): - results = pjrt.run_multiprocess(self._all_gather, pin_layout) - - expected = list(range(len(results))) - for v in results.values(): - np.testing.assert_array_equal(v, expected) - - @staticmethod - def _reduce_scatter(pin_layout): - device = torch.device('xla') - world_size = xr.world_size() - tensor = -torch.arange(world_size, dtype=torch.float32).to(device) - - out = xm.reduce_scatter( - xm.REDUCE_SUM, - tensor, - scale=1.0 / world_size, - scatter_dim=0, - shard_count=world_size, - pin_layout=pin_layout, - ) - torch_xla.sync() - - return out.cpu().numpy() - - # 2023-08-02 04:16:36.520884: F external/xla/xla/service/layout_assignment.cc:157] Check failed: ShapeUtil::Compatible(shape_layout.shape(), instruction->operand(operand_no)->shape()) f32[1]{0} is not compatible with f32[2]{0} (for operand 0 of instruction %reduce-scatter.10 = f32[1]{0} reduce-scatter(f32[2]{0} %add.5), replica_groups={}, constrain_layout=true, dimensions={0}, to_apply=%AddComputation.6) - @parameterized.named_parameters(('pinned', True), ('unpinned', False)) - def test_reduce_scatter(self, pin_layout): - results = pjrt.run_multiprocess(self._reduce_scatter, pin_layout) - - for ordinal, value in results.items(): - np.testing.assert_array_equal(value, [-ordinal]) - - @staticmethod - def _all_to_all(pin_layout): - device = torch.device('xla') - world_size = xr.world_size() - - tensor = torch.cat( - [ - -torch.arange(world_size, dtype=torch.float32).view(-1, 1, 1), - torch.ones(world_size, 1, 1) * xr.global_ordinal(), - ], - dim=1, - ).to(device) - torch_xla.sync() - - out = xm.all_to_all( - tensor, - split_dimension=0, - concat_dimension=2, - split_count=world_size, - pin_layout=pin_layout, - ) - - return out.cpu().numpy() - - @parameterized.named_parameters(('pinned', True), ('unpinned', False)) - def test_all_to_all(self, pin_layout): - results = pjrt.run_multiprocess(self._all_to_all, pin_layout) - - for ordinal, value in results.items(): - np.testing.assert_array_equal(value, [[[-ordinal] * len(results), - list(range(len(results)))]]) - - -if __name__ == '__main__': - absltest.main() diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 72e1f621b54d..bb3b7b8114c4 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -559,7 +559,7 @@ def _alt_lookup(d, keys, defval): def instantiate_test(cls, name, test, *, generic_cls): test_name = name + '_' + cls.device_type class_name = cls.__name__ - real_device_type = xm.xla_device_hw(str(torch.device('xla'))) + real_device_type = xm.xla_device_hw(str(torch.device('xla:0'))) assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type] @@ -631,8 +631,8 @@ def get_primary_device(cls): @classmethod def setUpClass(cls): - # Sets the primary test device to the xla_device (CPU or TPU) - cls.primary_device = str(torch.device('xla')) + # Sets the primary test device to the torch_xla.device (CPU or TPU) + cls.primary_device = str(torch_xla.device()) torch_xla._XLAC._xla_set_mat_mul_precision('highest') def setUp(self): diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index e4ce72b23003..741392f89562 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -38,7 +38,7 @@ def test_is_master_ordinal(self): self.assertTrue(xm.is_master_ordinal()) def test_xla_device(self): - device = torch.device('xla') + device = torch_xla.device() self.assertEqual(device, torch.device('xla:0')) def test_xla_real_devices(self): diff --git a/test/test_operations.py b/test/test_operations.py index 164787921ecc..6b0721684264 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -442,7 +442,7 @@ class TestOptimizationBarrier(test_utils.XlaTestCase): def test_optimization_barrier_correctness(self): device = torch.device('xla') # only test optimization_barrier on TPU - if xm.xla_device_hw(device) != 'TPU': + if xr.device_type() != 'TPU': return x = torch.randn(5, 5, device=device) y = torch.randn(5, 5, device=device) @@ -1532,7 +1532,7 @@ def test_deepcopy(self): self.assertEqual(x[0], x0) def test_print(self): - xla_device = torch.device('xla') + xla_device = torch.device('xla:0') x = torch.tensor([5], device=xla_device) expected_str = 'tensor([5], device=\'' + str(xla_device) + '\')' self.assertEqual(str(x), expected_str) @@ -2759,7 +2759,7 @@ def test_send_to_device_grad(self): self.assertTrue(dt[0].requires_grad) def test_send_to_device_single(self): - xla_device = torch.device('xla') + xla_device = torch.device('xla:0') t = _gen_tensor(2, 2) dt = xm.send_cpu_data_to_device(t, xla_device) self.assertEqual(dt[0].device, xla_device) @@ -2859,7 +2859,7 @@ def from_tensors(self, tensors): wpack = PackWrapper(pack) - xla_device = torch.device('xla') + xla_device = torch.device('xla:0') xdata = xm.send_cpu_data_to_device(wpack, xla_device) self.assertTrue(isinstance(xdata, nn.utils.rnn.PackedSequence)) self.assertEqual(xdata.batch_sizes.device, torch.device('cpu')) diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 9689fb7fb703..25a0ee36c36e 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -104,7 +104,7 @@ def initialize_singleprocess(): plugins.default().configure_single_process() elif runtime.device_type() == 'TPU': tpu.configure_one_chip_topology() - xm.set_replication(torch.device('xla'), []) + xm.set_replication(torch_xla.device(), []) def initialize_multiprocess(local_rank: int, local_world_size: int): @@ -119,7 +119,7 @@ def initialize_multiprocess(local_rank: int, local_world_size: int): neuron.initialize_env(local_rank, local_world_size) devices = xm.get_xla_supported_devices() - xm.set_replication(torch.device('xla'), devices) + xm.set_replication(torch_xla.device(), devices) def run_multiprocess(fn: Callable[..., R], diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 371a9005d510..e5aef103a17a 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -156,7 +156,8 @@ def local_ordinal() -> int: Local ordinal is in range [0, local_device_count).""" local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0) devices_per_process = addressable_device_count() - return local_rank * devices_per_process + torch.device('xla').index + return local_rank * devices_per_process + torch.device( + torch_xla._XLAC._xla_get_default_device()).index def process_index() -> int: