Skip to content

Commit fa3a464

Browse files
committed
Fix spmd sharding visualization when device index is >= 10
1 parent cc15111 commit fa3a464

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

torch_xla/distributed/spmd/debugging.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,25 @@ def visualize_sharding(sharding: str,
5757
# eg: '{devices=[2,2]0,1,2,3}'
5858
# eg: '{replicated}'
5959
# eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}'
60+
print(f"Visualizing {sharding} (showing up to the first two dimensions)")
6061
if sharding == '{replicated}' or len(sharding) == 0:
6162
heights = 1
6263
widths = 1
6364
num_devices = xr.global_runtime_device_count()
6465
device_ids = list(range(num_devices))
6566
slices.setdefault((0, 0), device_ids)
6667
else:
67-
sharding_spac = sharding[sharding.index('['):sharding.index(']') + 1]
68+
sharding_spac = sharding[sharding.index('[') + 1:sharding.index(']')].split(",")
6869
device_list_original = sharding.split(' last_tile_dim_replicate')
6970
if len(device_list_original) == 2 and device_list_original[1] == '}':
7071
try:
7172
device_list_original_first = device_list_original[0]
7273
device_list = device_list_original_first[device_list_original_first.
7374
index(']') + 1:]
7475
device_indices_map = [int(s) for s in device_list.split(',')]
75-
heights = int(sharding_spac[1])
76-
widths = int(sharding_spac[3])
77-
last_dim_depth = int(sharding_spac[5])
76+
heights = int(sharding_spac[0])
77+
widths = int(sharding_spac[1])
78+
last_dim_depth = int(sharding_spac[-1])
7879
devices_len = len(device_indices_map)
7980
len_after_dim_down = devices_len // last_dim_depth
8081
for i in range(len_after_dim_down):
@@ -96,8 +97,8 @@ def visualize_sharding(sharding: str,
9697
device_list = device_list_original_first[device_list_original_first.
9798
index(']') + 1:-1]
9899
device_indices_map = [int(i) for i in device_list.split(',')]
99-
heights = int(sharding_spac[1])
100-
widths = int(sharding_spac[3])
100+
heights = int(sharding_spac[0])
101+
widths = int(sharding_spac[1])
101102
devices_len = len(device_indices_map)
102103
for i in range(devices_len):
103104
slices.setdefault((i // widths, i % widths), device_indices_map[i])

0 commit comments

Comments
 (0)