@@ -57,24 +57,25 @@ def visualize_sharding(sharding: str,
57
57
# eg: '{devices=[2,2]0,1,2,3}'
58
58
# eg: '{replicated}'
59
59
# eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}'
60
+ print (f"Visualizing { sharding } (showing up to the first two dimensions)" )
60
61
if sharding == '{replicated}' or len (sharding ) == 0 :
61
62
heights = 1
62
63
widths = 1
63
64
num_devices = xr .global_runtime_device_count ()
64
65
device_ids = list (range (num_devices ))
65
66
slices .setdefault ((0 , 0 ), device_ids )
66
67
else :
67
- sharding_spac = sharding [sharding .index ('[' ):sharding .index (']' ) + 1 ]
68
+ sharding_spac = sharding [sharding .index ('[' ) + 1 :sharding .index (']' )]. split ( "," )
68
69
device_list_original = sharding .split (' last_tile_dim_replicate' )
69
70
if len (device_list_original ) == 2 and device_list_original [1 ] == '}' :
70
71
try :
71
72
device_list_original_first = device_list_original [0 ]
72
73
device_list = device_list_original_first [device_list_original_first .
73
74
index (']' ) + 1 :]
74
75
device_indices_map = [int (s ) for s in device_list .split (',' )]
75
- heights = int (sharding_spac [1 ])
76
- widths = int (sharding_spac [3 ])
77
- last_dim_depth = int (sharding_spac [5 ])
76
+ heights = int (sharding_spac [0 ])
77
+ widths = int (sharding_spac [1 ])
78
+ last_dim_depth = int (sharding_spac [- 1 ])
78
79
devices_len = len (device_indices_map )
79
80
len_after_dim_down = devices_len // last_dim_depth
80
81
for i in range (len_after_dim_down ):
@@ -96,8 +97,8 @@ def visualize_sharding(sharding: str,
96
97
device_list = device_list_original_first [device_list_original_first .
97
98
index (']' ) + 1 :- 1 ]
98
99
device_indices_map = [int (i ) for i in device_list .split (',' )]
99
- heights = int (sharding_spac [1 ])
100
- widths = int (sharding_spac [3 ])
100
+ heights = int (sharding_spac [0 ])
101
+ widths = int (sharding_spac [1 ])
101
102
devices_len = len (device_indices_map )
102
103
for i in range (devices_len ):
103
104
slices .setdefault ((i // widths , i % widths ), device_indices_map [i ])
0 commit comments