@@ -1459,6 +1459,19 @@ def _create_inverse_indices_permute_indices(
14591459 inverse_indices [1 ].device ,
14601460 )
14611461
1462+ def _is_optimizer_enabled (
1463+ self ,
1464+ has_local_optimizer : bool ,
1465+ env : ShardingEnv ,
1466+ device : Optional [torch .device ],
1467+ ) -> bool :
1468+ flag = torch .tensor (
1469+ [has_local_optimizer ], dtype = torch .uint8 , device = device
1470+ ) # example: True
1471+ # Reduce with MAX to check if any process has True
1472+ dist .all_reduce (flag , op = dist .ReduceOp .MAX , group = env .process_group )
1473+ return bool (flag .item ())
1474+
14621475 # pyre-ignore [14]
14631476 def input_dist (
14641477 self ,
@@ -1698,10 +1711,17 @@ def update_shards(
16981711 return
16991712
17001713 current_state = self .state_dict ()
1701- has_optimizer = len (self ._optim ._optims ) > 0 and all (
1714+ has_local_optimizer = len (self ._optim ._optims ) > 0 and all (
17021715 len (i ) > 0 for i in self ._optim .state_dict ()["state" ].values ()
17031716 )
17041717
1718+ # communicate optimizer state across all ranks, because if one rank owns all tables
1719+ # and other ranks does not own any table, and later transfer the weights to empty rank
1720+ # creates inconsistent state, because initally empty rank does not have optimizer state
1721+ # hence, incorrectly computes the tensor splits
1722+
1723+ has_optimizer = self ._is_optimizer_enabled (has_local_optimizer , env , device )
1724+
17051725 # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
17061726 # TODO: Ensure lookup tensors are actually being deleted
17071727 for _ , lookup in enumerate (self ._lookups ):
@@ -1715,7 +1735,7 @@ def update_shards(
17151735 max_dim_0 , max_dim_1 = get_largest_dims_from_sharding_plan_updates (
17161736 changed_sharding_params
17171737 )
1718- old_optimizer_state = self ._optim .state_dict () if has_optimizer else None
1738+ old_optimizer_state = self ._optim .state_dict () if has_local_optimizer else None
17191739
17201740 local_shard_names_by_src_rank , local_output_tensor = shards_all_to_all (
17211741 module = self ,
@@ -1727,6 +1747,7 @@ def update_shards(
17271747 max_dim_0 = max_dim_0 ,
17281748 max_dim_1 = max_dim_1 ,
17291749 optimizer_state = old_optimizer_state ,
1750+ has_optimizer = has_optimizer ,
17301751 )
17311752
17321753 for name , param in changed_sharding_params .items ():
@@ -1791,30 +1812,25 @@ def update_shards(
17911812 self ._optim : CombinedOptimizer = CombinedOptimizer (optims )
17921813
17931814 if has_optimizer :
1794- split_index = len (local_output_tensor ) // 2
1795- local_weight_tensors = local_output_tensor [:split_index ]
1796- local_optimizer_tensors = local_output_tensor [split_index :]
1797- # Modifies new_opt_state in place and returns it
17981815 optimizer_state = update_optimizer_state_post_resharding (
17991816 old_opt_state = old_optimizer_state , # pyre-ignore
18001817 new_opt_state = copy .deepcopy (self ._optim .state_dict ()),
18011818 ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1802- output_tensor = local_optimizer_tensors ,
1819+ output_tensor = local_output_tensor ,
18031820 max_dim_0 = max_dim_0 ,
1821+ extend_shard_name = self .extend_shard_name ,
18041822 )
1805-
18061823 self ._optim .load_state_dict (optimizer_state )
1807- else :
1808- local_weight_tensors = local_output_tensor
18091824
18101825 current_state = update_state_dict_post_resharding (
18111826 state_dict = current_state ,
18121827 ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1813- output_tensor = local_weight_tensors ,
1828+ output_tensor = local_output_tensor ,
18141829 new_sharding_params = changed_sharding_params ,
18151830 curr_rank = dist .get_rank (),
18161831 extend_shard_name = self .extend_shard_name ,
18171832 max_dim_0 = max_dim_0 ,
1833+ has_optimizer = has_optimizer ,
18181834 )
18191835
18201836 self .load_state_dict (current_state )
0 commit comments