Skip to content

Commit a76ef1c

Browse files
committed
Replace device=torch_xla.device() with device="xla"
1 parent ca3dc52 commit a76ef1c

39 files changed

+142
-165
lines changed

API_GUIDE.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import torch
1515
import torch_xla
1616
import torch_xla.core.xla_model as xm
1717

18-
t = torch.randn(2, 2, device=torch_xla.device())
18+
t = torch.randn(2, 2, device="xla")
1919
print(t.device)
2020
print(t)
2121
```
@@ -32,8 +32,8 @@ PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors
3232
For example, XLA tensors can be added together:
3333

3434
```python
35-
t0 = torch.randn(2, 2, device=torch_xla.device())
36-
t1 = torch.randn(2, 2, device=torch_xla.device())
35+
t0 = torch.randn(2, 2, device="xla")
36+
t1 = torch.randn(2, 2, device="xla")
3737
print(t0 + t1)
3838
```
3939

@@ -46,7 +46,7 @@ print(t0.mm(t1))
4646
Or used with neural network modules:
4747

4848
```python
49-
l_in = torch.randn(10, device=torch_xla.device())
49+
l_in = torch.randn(10, device="xla")
5050
linear = torch.nn.Linear(10, 20).to(torch_xla.device())
5151
l_out = linear(l_in)
5252
print(l_out)
@@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on the
5656
same device. So code like
5757

5858
```python
59-
l_in = torch.randn(10, device=torch_xla.device())
59+
l_in = torch.randn(10, device="xla")
6060
linear = torch.nn.Linear(10, 20)
6161
l_out = linear(l_in)
6262
print(l_out)

benchmarks/matmul_bench.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ def main():
3939
"""
4040

4141
xla_bench_fn = lambda fn: do_bench(
42-
fn,
43-
return_mode='min',
44-
sync_fn=lambda: xm.wait_device_ops(),
45-
device=torch_xla.device())
42+
fn, return_mode='min', sync_fn=lambda: xm.wait_device_ops(), device="xla")
4643
ind_bench_fn = lambda fn: do_bench(
4744
fn,
4845
return_mode='min',
@@ -53,7 +50,7 @@ def main():
5350
for dtype in dtypes:
5451
for inductor_matmul, xla_matmul in zip(
5552
get_matmuls(device='cuda', dtype=dtype, backend='inductor'),
56-
get_matmuls(device=torch_xla.device(), dtype=dtype, backend='openxla')):
53+
get_matmuls(device="xla", dtype=dtype, backend='openxla')):
5754
ind_lhs_shape, ind_rhs_shape, ind_fn = inductor_matmul
5855
xla_lhs_shape, xla_rhs_shape, xla_fn = xla_matmul
5956
assert ind_lhs_shape == xla_lhs_shape, f"Expect matmul shapes to match for benchmarking. Mismatch lhs: {ind_lhs_shape}, rhs: {xla_rhs_shape}"

contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@
273273
},
274274
{
275275
"cell_type": "code",
276-
"execution_count": 9,
276+
"execution_count": null,
277277
"metadata": {
278278
"execution": {
279279
"iopub.execute_input": "2024-01-10T19:30:33.219878Z",
@@ -318,12 +318,12 @@
318318
],
319319
"source": [
320320
"def add_ones(i, lock):\n",
321-
" x = torch.ones((3, 3), device=torch_xla.device())\n",
321+
" x = torch.ones((3, 3), device=\"xla\")\n",
322322
" y = x + x\n",
323-
" \n",
323+
"\n",
324324
" # Run graph to compute `y` before printing\n",
325325
" torch_xla.sync()\n",
326-
" \n",
326+
"\n",
327327
" with lock:\n",
328328
" print(i, y)\n",
329329
"\n",
@@ -340,7 +340,7 @@
340340
},
341341
{
342342
"cell_type": "code",
343-
"execution_count": 10,
343+
"execution_count": null,
344344
"metadata": {
345345
"execution": {
346346
"iopub.execute_input": "2024-01-10T19:30:35.656796Z",
@@ -378,10 +378,10 @@
378378
"source": [
379379
"def gather_ids(i, lock):\n",
380380
" # Create a tensor on each device with the device ID\n",
381-
" t = torch.tensor([i], device=torch_xla.device())\n",
381+
" t = torch.tensor([i], device=\"xla\")\n",
382382
" with lock:\n",
383383
" print(i, t)\n",
384-
" \n",
384+
"\n",
385385
" # Collect and concatenate the IDs\n",
386386
" ts = xm.all_gather(t)\n",
387387
" torch_xla.sync()\n",
@@ -402,7 +402,7 @@
402402
},
403403
{
404404
"cell_type": "code",
405-
"execution_count": 11,
405+
"execution_count": null,
406406
"metadata": {
407407
"execution": {
408408
"iopub.execute_input": "2024-01-10T19:30:38.315927Z",
@@ -479,7 +479,7 @@
479479
" loss.backward()\n",
480480
"\n",
481481
" optimizer.step()\n",
482-
" \n",
482+
"\n",
483483
" # Run the pending graph\n",
484484
" torch_xla.sync()\n",
485485
"\n",

docs/source/learn/_pjrt.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def _all_gather(index: int):
377377
# No need to pass in `rank` or `world_size`
378378
dist.init_process_group('xla', init_method='xla://')
379379

380-
t = torch.tensor([index], dtype=torch.int32, device=torch_xla.device())
380+
t = torch.tensor([index], dtype=torch.int32, device="xla")
381381
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
382382
dist.all_gather(output, t)
383383

docs/source/learn/pytorch-on-xla-devices.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import torch
1414
import torch_xla
1515
import torch_xla.core.xla_model as xm
1616

17-
t = torch.randn(2, 2, device=torch_xla.device())
17+
t = torch.randn(2, 2, device="xla")
1818
print(t.device)
1919
print(t)
2020
```
@@ -32,8 +32,8 @@ tensors.
3232
For example, XLA tensors can be added together:
3333

3434
``` python
35-
t0 = torch.randn(2, 2, device=torch_xla.device())
36-
t1 = torch.randn(2, 2, device=torch_xla.device())
35+
t0 = torch.randn(2, 2, device="xla")
36+
t1 = torch.randn(2, 2, device="xla")
3737
print(t0 + t1)
3838
```
3939

@@ -46,7 +46,7 @@ print(t0.mm(t1))
4646
Or used with neural network modules:
4747

4848
``` python
49-
l_in = torch.randn(10, device=torch_xla.device())
49+
l_in = torch.randn(10, device="xla")
5050
linear = torch.nn.Linear(10, 20).to(torch_xla.device())
5151
l_out = linear(l_in)
5252
print(l_out)
@@ -56,7 +56,7 @@ Like other device types, XLA tensors only work with other XLA tensors on
5656
the same device. So code like
5757

5858
``` python
59-
l_in = torch.randn(10, device=torch_xla.device())
59+
l_in = torch.randn(10, device="xla")
6060
linear = torch.nn.Linear(10, 20)
6161
l_out = linear(l_in)
6262
print(l_out)

docs/source/learn/troubleshoot.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ vm:~$ export PJRT_DEVICE=TPU
3232
vm:~$ python3
3333
>>> import torch
3434
>>> import torch_xla.core.xla_model as xm
35-
>>> t1 = torch.tensor(100, device=torch_xla.device())
36-
>>> t2 = torch.tensor(200, device=torch_xla.device())
35+
>>> t1 = torch.tensor(100, device="xla")
36+
>>> t2 = torch.tensor(200, device="xla")
3737
>>> print(t1 + t2)
3838
tensor(300, device='xla:0')
3939
```

examples/scan/scan_examples.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def cumsum(accumulated, element):
1818
return accumulated, accumulated
1919

2020
# 2) Define an initial carry and the input tensor.
21-
init_sum = torch.tensor([0.0], device=torch_xla.device())
22-
xs = torch.tensor([1.0, 2.0, 3.0], device=torch_xla.device())
21+
init_sum = torch.tensor([0.0], device="xla")
22+
xs = torch.tensor([1.0, 2.0, 3.0], device="xla")
2323
torch_xla.sync()
2424

2525
# 3) Call `scan` with our combine function, initial carry, and input tensor.
@@ -40,15 +40,15 @@ def scan_example_pytree():
4040
# - 'sum' to accumulate the sum of all seen values
4141
# - 'count' to count how many values have been seen
4242
carry = {
43-
'sum': torch.tensor([0.0], device=torch_xla.device()),
44-
'count': torch.tensor([0.0], device=torch_xla.device())
43+
'sum': torch.tensor([0.0], device="xla"),
44+
'count': torch.tensor([0.0], device="xla")
4545
}
4646

4747
# 2) Define our input PyTree, which in this case is just a dictionary with one leaf:
4848
# - 'values' is a 1D tensor representing data points we want to scan over.
4949
xs = {
5050
'values':
51-
torch.arange(1, 6, dtype=torch.float32, device=torch_xla.device())
51+
torch.arange(1, 6, dtype=torch.float32, device="xla")
5252
}
5353

5454
# Here, xs['values'] has shape [5]. The `scan` function will automatically slice

test/ds/test_dynamic_shapes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_t_copy(self):
163163
self.assertEqual(t2_t.shape[1], 7)
164164

165165
def test_nonzero_shape(self):
166-
x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device())
166+
x = torch.tensor((0, 1, 2, 0, 3, 4), device="xla")
167167
x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(
168168
torch.nonzero(x, as_tuple=False), 0)
169169
self.assertEqual(x_dim0_shape.item(), 4)
@@ -176,14 +176,14 @@ def test_nonzero_correctness(self):
176176
self.assertEqual(t2.cpu(), t2_aten)
177177

178178
def test_masked_select_shape(self):
179-
x = torch.tensor((0, 1, 2, 0, 3, 4), device=torch_xla.device())
179+
x = torch.tensor((0, 1, 2, 0, 3, 4), device="xla")
180180
mask = x.ge(2)
181181
x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(
182182
torch.masked_select(x, mask), 0)
183183
self.assertEqual(x_dim0_shape.item(), 3)
184184

185185
def test_nonzero_cast(self):
186-
t1 = torch.ones(5, 2, device=torch_xla.device())
186+
t1 = torch.ones(5, 2, device="xla")
187187
# Result of the nonzero should be the index type. Currently
188188
# index type is s64 on cpu and gpu, but s32 on TPU. We should be
189189
# able to cast it to any other type without error.

test/dynamo/test_dynamo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def inplace_update(self, a):
4949
def test_inplace_update_correctness(self, backend):
5050
dynamo_inplace = torch.compile(
5151
self.inplace_update, backend=backend, fullgraph=True)
52-
t = torch.tensor([0, 1, 2], device=torch_xla.device())
52+
t = torch.tensor([0, 1, 2], device="xla")
5353
for i in range(10):
5454
t = dynamo_inplace(t)
5555
self.assertTrue(torch.all(torch.eq(t.cpu(), torch.tensor([10, 11, 12]))))
@@ -131,7 +131,7 @@ def dummy_fn(self, a):
131131
def test_dynamo_with_trace(self):
132132
dynamo_dummy = torch.compile(
133133
self.dummy_fn, backend="openxla", fullgraph=True)
134-
t = torch.randn(2, 3, 4, device=torch_xla.device())
134+
t = torch.randn(2, 3, 4, device="xla")
135135
for i in range(10):
136136
with xp.Trace('build_graph'):
137137
t = dynamo_dummy(t)

test/metrics_compare_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def test_compare_metrics_reports_new_counters(self):
275275
def test_parse_real_metrics(self):
276276
print(
277277
'Testing against TPU. If this hangs, check that $XRT_TPU_CONFIG is set')
278-
x = torch.rand(3, 5, device=torch_xla.device())
278+
x = torch.rand(3, 5, device="xla")
279279
x = torch.flatten(x, 1)
280280
x = torch.roll(x, 1, 0)
281281
x = torch.flip(x, [0, 1])

0 commit comments

Comments
 (0)