@@ -273,16 +273,17 @@ def test_scan_external_in_place_mutation(self):
273
273
giving wrong results.
274
274
"""
275
275
# TODO(yifeit): Modify this test when external in-place mutation is eventually supported.
276
- weird_global = torch .tensor ([0.0 , 0.0 ], device = 'xla' )
276
+ weird_global = torch .tensor ([0.0 , 0.0 ], device = torch_xla . device () )
277
277
278
278
def step_fn (carry , x ):
279
279
new_carry = carry + x
280
280
weird_global .add_ (1.0 )
281
281
y = new_carry + weird_global
282
282
return new_carry , y
283
283
284
- init = torch .tensor ([0.0 , 0.0 ], device = 'xla' )
285
- xs = torch .tensor ([[0.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]], device = 'xla' )
284
+ init = torch .tensor ([0.0 , 0.0 ], device = torch_xla .device ())
285
+ xs = torch .tensor ([[0.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]],
286
+ device = torch_xla .device ())
286
287
287
288
with self .assertRaisesRegex (AssertionError , "FakeTensor" ):
288
289
scan (step_fn , init , xs )
@@ -350,11 +351,13 @@ def test_scan_rand_in_fn(self):
350
351
351
352
def step_fn (carry , x ):
352
353
new_carry = carry + x
353
- y = new_carry + torch .rand (2 , device = 'xla' )
354
+ # TODO: figure out why device='xla' doesn't work
355
+ y = new_carry + torch .rand (2 , device = torch_xla .device ())
354
356
return new_carry , y
355
357
356
- init = torch .tensor ([0.0 , 0.0 ], device = 'xla' )
357
- xs = torch .tensor ([[0.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]], device = 'xla' )
358
+ init = torch .tensor ([0.0 , 0.0 ], device = torch_xla .device ())
359
+ xs = torch .tensor ([[0.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]],
360
+ device = torch_xla .device ())
358
361
_ , ys = scan (step_fn , init , xs )
359
362
# ys should be a 2D tensor with this shape.
360
363
self .assertEqual (ys .shape , (3 , 2 ))
0 commit comments