@@ -142,12 +142,12 @@ def xla_device(n: Optional[int] = None,
142
142
Args:
143
143
n (int, optional): The specific instance (ordinal) to be returned. If
144
144
specified, the specific XLA device instance will be returned. Otherwise
145
- the first device of `devkind` will be returned.
145
+ the first device (default 0) will be returned.
146
146
devkind (string..., optional): If specified, device type such as `TPU`,
147
147
`CUDA`, `CPU`, or custom PJRT device. Deprecated.
148
148
149
149
Returns:
150
- A `torch.device` with the requested instance.
150
+ A `torch.device` with the requested instance of an XLA device .
151
151
"""
152
152
# When SPMD is enabled, we always return `xla:0` to the user, and
153
153
# under the hood we use virtual device logic for every xla tensor
@@ -156,7 +156,16 @@ def xla_device(n: Optional[int] = None,
156
156
torch_xla ._XLAC ._xla_set_default_device (device )
157
157
return torch .device (device )
158
158
159
- return runtime .xla_device (n , devkind )
159
+ if n is None :
160
+ return torch .device (torch_xla ._XLAC ._xla_get_default_device ())
161
+
162
+ devices = xm .get_xla_supported_devices (devkind = devkind )
163
+ if n > len (devices ):
164
+ raise IndexError ('Device index {} out of range in {}' .format (n , devices ))
165
+
166
+ device = devices [n ]
167
+ torch_xla ._XLAC ._xla_set_default_device (device )
168
+ return torch .device (device )
160
169
161
170
162
171
def _xla_real_device (device : torch .device ) -> Any :
0 commit comments