Skip to content

Commit 9a0b1ed

Browse files
committed
Fix test errors
1 parent 637ead6 commit 9a0b1ed

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

test/scan/test_scan.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,16 +273,17 @@ def test_scan_external_in_place_mutation(self):
273273
giving wrong results.
274274
"""
275275
# TODO(yifeit): Modify this test when external in-place mutation is eventually supported.
276-
weird_global = torch.tensor([0.0, 0.0], device='xla')
276+
weird_global = torch.tensor([0.0, 0.0], device=torch_xla.device())
277277

278278
def step_fn(carry, x):
279279
new_carry = carry + x
280280
weird_global.add_(1.0)
281281
y = new_carry + weird_global
282282
return new_carry, y
283283

284-
init = torch.tensor([0.0, 0.0], device='xla')
285-
xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla')
284+
init = torch.tensor([0.0, 0.0], device=torch_xla.device())
285+
xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
286+
device=torch_xla.device())
286287

287288
with self.assertRaisesRegex(AssertionError, "FakeTensor"):
288289
scan(step_fn, init, xs)
@@ -350,11 +351,13 @@ def test_scan_rand_in_fn(self):
350351

351352
def step_fn(carry, x):
352353
new_carry = carry + x
353-
y = new_carry + torch.rand(2, device='xla')
354+
# TODO: figure out why device='xla' doesn't work
355+
y = new_carry + torch.rand(2, device=torch_xla.device())
354356
return new_carry, y
355357

356-
init = torch.tensor([0.0, 0.0], device='xla')
357-
xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], device='xla')
358+
init = torch.tensor([0.0, 0.0], device=torch_xla.device())
359+
xs = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
360+
device=torch_xla.device())
358361
_, ys = scan(step_fn, init, xs)
359362
# ys should be a 2D tensor with this shape.
360363
self.assertEqual(ys.shape, (3, 2))

0 commit comments

Comments
 (0)