Skip to content

Commit 2a8f9e9

Browse files
committed
Fix errors
1 parent fb18d8a commit 2a8f9e9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch_xla/runtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def local_ordinal() -> int:
156156
Local ordinal is in range [0, local_device_count)."""
157157
local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0)
158158
devices_per_process = addressable_device_count()
159-
return local_rank * devices_per_process + torch.device('xla').index
159+
return local_rank * devices_per_process + torch.device(
160+
torch_xla._XLAC._xla_get_default_device()).index
160161

161162

162163
def process_index() -> int:

0 commit comments

Comments
 (0)