Skip to content

Commit 022ecde

Browse files
authored
Migrate runtime.xla_device in favor of core.xla_model.xla_device (#9200)
1 parent a159c9c commit 022ecde

File tree

165 files changed

+985
-1008
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

165 files changed

+985
-1008
lines changed

API_GUIDE.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ import torch
1515
import torch_xla
1616
import torch_xla.core.xla_model as xm
1717

18-
t = torch.randn(2, 2, device=xm.xla_device())
18+
t = torch.randn(2, 2, device='xla')
1919
print(t.device)
2020
print(t)
2121
```
2222

2323
This code should look familiar. PyTorch/XLA uses the same interface as regular
2424
PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and
25-
`xm.xla_device()` returns the current XLA device. This may be a CPU or TPU
25+
`torch_xla.device()` returns the current XLA device. This may be a CPU or TPU
2626
depending on your environment.
2727

2828
## XLA Tensors are PyTorch Tensors
@@ -32,8 +32,8 @@ PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors
3232
For example, XLA tensors can be added together:
3333

3434
```python
35-
t0 = torch.randn(2, 2, device=xm.xla_device())
36-
t1 = torch.randn(2, 2, device=xm.xla_device())
35+
t0 = torch.randn(2, 2, device='xla')
36+
t1 = torch.randn(2, 2, device='xla')
3737
print(t0 + t1)
3838
```
3939

@@ -46,8 +46,8 @@ print(t0.mm(t1))
4646
Or used with neural network modules:
4747

4848
```python
49-
l_in = torch.randn(10, device=xm.xla_device())
50-
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
49+
l_in = torch.randn(10, device='xla')
50+
linear = torch.nn.Linear(10, 20).to(torch_xla.device())
5151
l_out = linear(l_in)
5252
print(l_out)
5353
```
@@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on the
5656
same device. So code like
5757

5858
```python
59-
l_in = torch.randn(10, device=xm.xla_device())
59+
l_in = torch.randn(10, device='xla')
6060
linear = torch.nn.Linear(10, 20)
6161
l_out = linear(l_in)
6262
print(l_out)
@@ -109,10 +109,10 @@ class MNIST(nn.Module):
109109
batch_size = 128
110110
train_loader = xu.SampleGenerator(
111111
data=(torch.zeros(batch_size, 1, 28, 28),
112-
torch.zeros(batch_size, dtype=torch.int64)),
112+
torch.zeros(batch_size, dtype=torch.int64)),
113113
sample_count=60000 // batch_size // xr.world_size())
114114

115-
device = xm.xla_device() # Get the XLA device (TPU).
115+
device = torch_xla.device() # Get the XLA device (TPU).
116116
model = MNIST().train().to(device) # Create a model and move it to the device.
117117
loss_fn = nn.NLLLoss()
118118
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
@@ -169,7 +169,7 @@ def _mp_fn(index):
169169
index: Index of the process.
170170
"""
171171

172-
device = xm.xla_device() # Get the device assigned to this process.
172+
device = torch_xla.device() # Get the device assigned to this process.
173173
# Wrap the loader for multi-device.
174174
mp_device_loader = pl.MpDeviceLoader(train_loader, device)
175175

@@ -197,7 +197,7 @@ single device snippet. Let's go over then one by one.
197197
- `torch_xla.launch()`
198198
- Creates the processes that each run an XLA device.
199199
- This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.
200-
- Note that if you print the `xm.xla_device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
200+
- Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
201201
- `MpDeviceLoader`
202202
- Loads the training data onto each device.
203203
- `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.
@@ -290,7 +290,7 @@ import torch
290290
import torch_xla
291291
import torch_xla.core.xla_model as xm
292292

293-
device = xm.xla_device()
293+
device = torch_xla.device()
294294

295295
t0 = torch.randn(2, 2, device=device)
296296
t1 = torch.randn(2, 2, device=device)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ If you're using `DistributedDataParallel`, make the following changes:
196196
+ # Rank and world size are inferred from the XLA device runtime
197197
+ dist.init_process_group("xla", init_method='xla://')
198198
+
199-
+ model.to(xm.xla_device())
199+
+ model.to(torch_xla.device())
200200
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
201201

202202
- model = model.to(rank)

benchmarks/benchmark_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def update_process_env(self, process_env: Dict[str, str]):
208208
def get_device(self):
209209
if self.torch_xla2:
210210
# Initiate the model in CPU first for xla2. We will move the model to jax device later.
211-
# This is because we don't have xm.xla_device() function in torch_xla2.
211+
# This is because we don't have torch_xla.device() function in torch_xla2.
212212
return torch.device("cpu")
213213
if self.xla:
214214
return xm.xla_device(devkind=self.accelerator.upper())

benchmarks/experiment_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _default_iter_fn(self, benchmark_experiment: BenchmarkExperiment,
255255

256256
def _pure_wall_time_iter_fn(self, benchmark_experiment: BenchmarkExperiment,
257257
benchmark_model: BenchmarkModel, input_tensor):
258-
device = xm.xla_device() if benchmark_experiment.xla else 'cuda'
258+
device = torch_xla.device() if benchmark_experiment.xla else 'cuda'
259259
sync_fn = xm.wait_device_ops if benchmark_experiment.xla else torch.cuda.synchronize
260260
timing, output = bench.do_bench(
261261
lambda: benchmark_model.model_iter_fn(

benchmarks/matmul_bench.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ def main():
3939
"""
4040

4141
xla_bench_fn = lambda fn: do_bench(
42-
fn,
43-
return_mode='min',
44-
sync_fn=lambda: xm.wait_device_ops(),
45-
device=xm.xla_device())
42+
fn, return_mode='min', sync_fn=lambda: xm.wait_device_ops(), device='xla')
4643
ind_bench_fn = lambda fn: do_bench(
4744
fn,
4845
return_mode='min',
@@ -53,7 +50,7 @@ def main():
5350
for dtype in dtypes:
5451
for inductor_matmul, xla_matmul in zip(
5552
get_matmuls(device='cuda', dtype=dtype, backend='inductor'),
56-
get_matmuls(device=xm.xla_device(), dtype=dtype, backend='openxla')):
53+
get_matmuls(device='xla', dtype=dtype, backend='openxla')):
5754
ind_lhs_shape, ind_rhs_shape, ind_fn = inductor_matmul
5855
xla_lhs_shape, xla_rhs_shape, xla_fn = xla_matmul
5956
assert ind_lhs_shape == xla_lhs_shape, f"Expect matmul shapes to match for benchmarking. Mismatch lhs: {ind_lhs_shape}, rhs: {xla_rhs_shape}"

contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@
188188
"cell_type": "markdown",
189189
"metadata": {},
190190
"source": [
191-
"To get the current process/thread's default XLA device, use `xm.xla_device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`."
191+
"To get the current process/thread's default XLA device, use `torch_xla.device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`."
192192
]
193193
},
194194
{
@@ -210,7 +210,7 @@
210210
"lock = mp.Manager().Lock()\n",
211211
"\n",
212212
"def print_device(i, lock):\n",
213-
" device = xm.xla_device()\n",
213+
" device = torch_xla.device()\n",
214214
" with lock:\n",
215215
" print('process', i, device)"
216216
]
@@ -273,7 +273,7 @@
273273
},
274274
{
275275
"cell_type": "code",
276-
"execution_count": 9,
276+
"execution_count": null,
277277
"metadata": {
278278
"execution": {
279279
"iopub.execute_input": "2024-01-10T19:30:33.219878Z",
@@ -318,12 +318,12 @@
318318
],
319319
"source": [
320320
"def add_ones(i, lock):\n",
321-
" x = torch.ones((3, 3), device=xm.xla_device())\n",
321+
" x = torch.ones((3, 3), device='xla')\n",
322322
" y = x + x\n",
323-
" \n",
323+
"\n",
324324
" # Run graph to compute `y` before printing\n",
325325
" torch_xla.sync()\n",
326-
" \n",
326+
"\n",
327327
" with lock:\n",
328328
" print(i, y)\n",
329329
"\n",
@@ -340,7 +340,7 @@
340340
},
341341
{
342342
"cell_type": "code",
343-
"execution_count": 10,
343+
"execution_count": null,
344344
"metadata": {
345345
"execution": {
346346
"iopub.execute_input": "2024-01-10T19:30:35.656796Z",
@@ -378,10 +378,10 @@
378378
"source": [
379379
"def gather_ids(i, lock):\n",
380380
" # Create a tensor on each device with the device ID\n",
381-
" t = torch.tensor([i], device=xm.xla_device())\n",
381+
" t = torch.tensor([i], device='xla')\n",
382382
" with lock:\n",
383383
" print(i, t)\n",
384-
" \n",
384+
"\n",
385385
" # Collect and concatenate the IDs\n",
386386
" ts = xm.all_gather(t)\n",
387387
" torch_xla.sync()\n",
@@ -402,7 +402,7 @@
402402
},
403403
{
404404
"cell_type": "code",
405-
"execution_count": 11,
405+
"execution_count": null,
406406
"metadata": {
407407
"execution": {
408408
"iopub.execute_input": "2024-01-10T19:30:38.315927Z",
@@ -454,7 +454,7 @@
454454
"import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n",
455455
"\n",
456456
"def toy_model(index, lock):\n",
457-
" device = xm.xla_device()\n",
457+
" device = torch_xla.device()\n",
458458
" dist.init_process_group('xla', init_method='xla://')\n",
459459
"\n",
460460
" # Initialize a basic toy model\n",
@@ -479,7 +479,7 @@
479479
" loss.backward()\n",
480480
"\n",
481481
" optimizer.step()\n",
482-
" \n",
482+
"\n",
483483
" # Run the pending graph\n",
484484
" torch_xla.sync()\n",
485485
"\n",

contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@
172172
"\n",
173173
"pipeline = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\")\n",
174174
"# Move the model to the first TPU core\n",
175-
"pipeline = pipeline.to(xm.xla_device())"
175+
"pipeline = pipeline.to(torch_xla.device())"
176176
]
177177
},
178178
{

docs/source/features/pallas.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ jax will lock the TPU and torch-xla cannot access it.
4040
Example usage:
4141

4242
``` python3
43-
q = torch.randn(3, 2, 128, 4).to("xla")
44-
k = torch.randn(3, 2, 128, 4).to("xla")
45-
v = torch.randn(3, 2, 128, 4).to("xla")
43+
q = torch.randn(3, 2, 128, 4).to('xla')
44+
k = torch.randn(3, 2, 128, 4).to('xla')
45+
v = torch.randn(3, 2, 128, 4).to('xla')
4646

4747
# Adopts any Pallas kernel
4848
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas

docs/source/features/triton.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ import triton
4242
import triton.language as tl
4343

4444
size = 16
45-
x = torch.arange(size, dtype=torch.int64).to("xla")
46-
y = torch.arange(size, dtype=torch.int64).to("xla")
45+
x = torch.arange(size, dtype=torch.int64).to('xla')
46+
y = torch.arange(size, dtype=torch.int64).to('xla')
4747
output = torch.empty_like(x)
4848
block_size = 8
4949
grid = (triton.cdiv(size, block_size),)

docs/source/learn/_pjrt.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ import torch_xla.distributed.xla_backend
7373

7474

7575
def _mp_fn(index):
76-
device = xm.xla_device()
76+
device = torch_xla.device()
7777
- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
7878
+ dist.init_process_group('xla', init_method='xla://')
7979

@@ -377,7 +377,7 @@ def _all_gather(index: int):
377377
# No need to pass in `rank` or `world_size`
378378
dist.init_process_group('xla', init_method='xla://')
379379

380-
t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
380+
t = torch.tensor([index], dtype=torch.int32, device='xla')
381381
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
382382
dist.all_gather(output, t)
383383

0 commit comments

Comments
 (0)