Skip to content

Commit cc68b1b

Browse files
committed
Migrate torch_xla.device() to torch.device('xla')
1 parent 78cff03 commit cc68b1b

File tree

140 files changed

+701
-437
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

140 files changed

+701
-437
lines changed

API_GUIDE.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ print(t)
2222

2323
This code should look familiar. PyTorch/XLA uses the same interface as regular
2424
PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and
25-
`torch_xla.device()` returns the current XLA device. This may be a CPU or TPU
25+
`torch.device('xla')` returns the current XLA device. This may be a CPU or TPU
2626
depending on your environment.
2727

2828
## XLA Tensors are PyTorch Tensors
@@ -112,7 +112,7 @@ train_loader = xu.SampleGenerator(
112112
torch.zeros(batch_size, dtype=torch.int64)),
113113
sample_count=60000 // batch_size // xr.world_size())
114114

115-
device = torch_xla.device() # Get the XLA device (TPU).
115+
device = torch.device('xla') # Get the XLA device (TPU).
116116
model = MNIST().train().to(device) # Create a model and move it to the device.
117117
loss_fn = nn.NLLLoss()
118118
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
@@ -169,7 +169,7 @@ def _mp_fn(index):
169169
index: Index of the process.
170170
"""
171171

172-
device = torch_xla.device() # Get the device assigned to this process.
172+
device = torch.device('xla') # Get the device assigned to this process.
173173
# Wrap the loader for multi-device.
174174
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
175175

@@ -197,7 +197,7 @@ single device snippet. Let's go over then one by one.
197197
- `torch_xla.launch()`
198198
- Creates the processes that each run an XLA device.
199199
- This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.
200-
- Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
200+
- Note that if you print the `torch.device('xla')` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
201201
- `MpDeviceLoader`
202202
- Loads the training data onto each device.
203203
- `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.
@@ -290,7 +290,7 @@ import torch
290290
import torch_xla
291291
import torch_xla.core.xla_model as xm
292292

293-
device = torch_xla.device()
293+
device = torch.device('xla')
294294

295295
t0 = torch.randn(2, 2, device=device)
296296
t1 = torch.randn(2, 2, device=device)

benchmarks/experiment_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _default_iter_fn(self, benchmark_experiment: BenchmarkExperiment,
255255

256256
def _pure_wall_time_iter_fn(self, benchmark_experiment: BenchmarkExperiment,
257257
benchmark_model: BenchmarkModel, input_tensor):
258-
device = torch_xla.device() if benchmark_experiment.xla else 'cuda'
258+
device = torch.device('xla') if benchmark_experiment.xla else 'cuda'
259259
sync_fn = xm.wait_device_ops if benchmark_experiment.xla else torch.cuda.synchronize
260260
timing, output = bench.do_bench(
261261
lambda: benchmark_model.model_iter_fn(

contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@
193193
},
194194
{
195195
"cell_type": "code",
196-
"execution_count": 7,
196+
"execution_count": null,
197197
"metadata": {
198198
"execution": {
199199
"iopub.execute_input": "2024-01-10T19:30:28.607393Z",
@@ -210,7 +210,7 @@
210210
"lock = mp.Manager().Lock()\n",
211211
"\n",
212212
"def print_device(i, lock):\n",
213-
" device = torch_xla.device()\n",
213+
" device = torch.device('xla')\n",
214214
" with lock:\n",
215215
" print('process', i, device)"
216216
]
@@ -454,7 +454,7 @@
454454
"import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n",
455455
"\n",
456456
"def toy_model(index, lock):\n",
457-
" device = torch_xla.device()\n",
457+
" device = torch.device('xla')\n",
458458
" dist.init_process_group('xla', init_method='xla://')\n",
459459
"\n",
460460
" # Initialize a basic toy model\n",

docs/source/learn/_pjrt.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ import torch_xla.distributed.xla_backend
7373

7474

7575
def _mp_fn(index):
76-
device = torch_xla.device()
76+
device = torch.device('xla')
7777
- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
7878
+ dist.init_process_group('xla', init_method='xla://')
7979

docs/source/learn/eager.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import torch
1313
import torch_xla
1414
import torchvision
1515

16-
device = torch_xla.device()
16+
device = torch.device('xla')
1717
model = torchvision.models.resnet18().to(device)
1818
input = torch.randn(64, 3, 224, 224).to(device)
1919

@@ -71,7 +71,7 @@ import torchvision
7171
# Run ops eagerly by default
7272
torch_xla.experimental.eager_mode(True)
7373

74-
device = torch_xla.device()
74+
device = torch.device('xla')
7575
model = torchvision.models.resnet18().to(device)
7676

7777
# Mark the function to be compiled

docs/source/learn/pytorch-on-xla-devices.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ print(t)
2121

2222
This code should look familiar. PyTorch/XLA uses the same interface as
2323
regular PyTorch with a few additions. Importing `torch_xla` initializes
24-
PyTorch/XLA, and `torch_xla.device()` returns the current XLA device. This
24+
PyTorch/XLA, and `torch.device('xla')` returns the current XLA device. This
2525
may be a CPU or TPU depending on your environment.
2626

2727
## XLA Tensors are PyTorch Tensors
@@ -81,7 +81,7 @@ The following snippet shows a network training on a single XLA device:
8181
``` python
8282
import torch_xla.core.xla_model as xm
8383

84-
device = torch_xla.device()
84+
device = torch.device('xla')
8585
model = MNIST().train().to(device)
8686
loss_fn = nn.NLLLoss()
8787
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
@@ -120,7 +120,7 @@ import torch_xla.core.xla_model as xm
120120
import torch_xla.distributed.parallel_loader as pl
121121

122122
def _mp_fn(index):
123-
device = torch_xla.device()
123+
device = torch.device('xla')
124124
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
125125

126126
model = MNIST().train().to(device)
@@ -148,7 +148,7 @@ previous single device snippet. Let's go over then one by one.
148148
will only be able to access the device assigned to the current
149149
process. For example on a TPU v4-8, there will be 4 processes
150150
being spawn up and each process will own a TPU device.
151-
- Note that if you print the `torch_xla.device()` on each process you
151+
- Note that if you print the `torch.device('xla')` on each process you
152152
will see `xla:0` on all devices. This is because each process
153153
can only see one device. This does not mean multi-process is not
154154
functioning. The only execution is with PJRT runtime on TPU v2
@@ -283,7 +283,7 @@ import torch
283283
import torch_xla
284284
import torch_xla.core.xla_model as xm
285285

286-
device = torch_xla.device()
286+
device = torch.device('xla')
287287

288288
t0 = torch.randn(2, 2, device=device)
289289
t1 = torch.randn(2, 2, device=device)

docs/source/learn/xla-overview.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ repo. contains examples for training and serving many LLM and diffusion models.
184184

185185
General guidelines to modify your code:
186186

187-
- Replace `cuda` with `torch_xla.device()`
187+
- Replace `cuda` with `torch.device('xla')`
188188
- Remove progress bar, printing that would access the XLA tensor
189189
values
190190
- Reduce logging and callbacks that would access the XLA tensor values
@@ -227,7 +227,7 @@ tutorial, but you can pass the `device` value to the function as well.
227227

228228
``` python
229229
import torch_xla.core.xla_model as xm
230-
self.device = torch_xla.device()
230+
self.device = torch.device('xla')
231231
```
232232

233233
Another place in the code that has cuda specific code is DDIM scheduler.
@@ -244,7 +244,7 @@ if attr.device != torch.device("cuda"):
244244
with
245245

246246
``` python
247-
device = torch_xla.device()
247+
device = torch.device('xla')
248248
attr = attr.to(torch.device(device))
249249
```
250250

@@ -339,7 +339,7 @@ with the following lines:
339339

340340
``` python
341341
import torch_xla.core.xla_model as xm
342-
device = torch_xla.device()
342+
device = torch.device('xla')
343343
pipe.to(device)
344344
```
345345

docs/source/perf/amp.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ for input, target in data:
2727
optimizer.zero_grad()
2828

2929
# Enables autocasting for the forward pass
30-
with autocast(torch_xla.device()):
30+
with autocast(torch.device('xla')):
3131
output = model(input)
3232
loss = loss_fn(output, target)
3333

@@ -36,7 +36,7 @@ for input, target in data:
3636
xm.optimizer_step.(optimizer)
3737
```
3838

39-
`autocast(torch_xla.device())` aliases `torch.autocast('xla')` when the XLA
39+
`autocast(torch.device('xla'))` aliases `torch.autocast('xla')` when the XLA
4040
Device is a TPU. Alternatively, if a script is only used with TPUs, then
4141
`torch.autocast('xla', dtype=torch.bfloat16)` can be directly used.
4242

@@ -115,7 +115,7 @@ for input, target in data:
115115
optimizer.zero_grad()
116116

117117
# Enables autocasting for the forward pass
118-
with autocast(torch_xla.device()):
118+
with autocast(torch.device('xla')):
119119
output = model(input)
120120
loss = loss_fn(output, target)
121121

@@ -127,12 +127,12 @@ for input, target in data:
127127
scaler.update()
128128
```
129129

130-
`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the
130+
`autocast(torch.device('xla'))` aliases `torch.cuda.amp.autocast()` when the
131131
XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is
132132
only used with CUDA devices, then `torch.cuda.amp.autocast` can be
133133
directly used, but requires `torch` is compiled with `cuda` support for
134134
datatype of `torch.bfloat16`. We recommend using
135-
`autocast(torch_xla.device())` on XLA:GPU as it does not require
135+
`autocast(torch.device('xla'))` on XLA:GPU as it does not require
136136
`torch.cuda` support for any datatypes, including `torch.bfloat16`.
137137

138138
### AMP for XLA:GPU Best Practices

docs/source/perf/ddp.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def demo_basic(rank):
105105
setup(rank, world_size)
106106

107107
# create model and move it to XLA device
108-
device = torch_xla.device()
108+
device = torch.device('xla')
109109
model = ToyModel().to(device)
110110
ddp_model = DDP(model, gradient_as_bucket_view=True)
111111

docs/source/perf/dynamo.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import torchvision
4141
import torch_xla.core.xla_model as xm
4242

4343
def eval_model(loader):
44-
device = torch_xla.device()
44+
device = torch.device('xla')
4545
xla_resnet18 = torchvision.models.resnet18().to(device)
4646
xla_resnet18.eval()
4747
dynamo_resnet18 = torch.compile(
@@ -129,7 +129,7 @@ def train_model(model, data, target, optimizer):
129129
return pred
130130

131131
def train_model_main(loader):
132-
device = torch_xla.device()
132+
device = torch.device('xla')
133133
xla_resnet18 = torchvision.models.resnet18().to(device)
134134
xla_resnet18.train()
135135
dynamo_train_model = torch.compile(

0 commit comments

Comments
 (0)