@@ -1459,6 +1459,16 @@ def _create_inverse_indices_permute_indices(
1459
1459
inverse_indices [1 ].device ,
1460
1460
)
1461
1461
1462
+ def _is_optimizer_enabled (
1463
+ self , has_local_optimizer : bool , env : ShardingEnv , device : torch .device
1464
+ ) -> bool :
1465
+ flag = torch .tensor (
1466
+ [has_local_optimizer ], dtype = torch .uint8 , device = device
1467
+ ) # example: True
1468
+ # Reduce with MAX to check if any process has True
1469
+ dist .all_reduce (flag , op = dist .ReduceOp .MAX , group = env .process_group )
1470
+ return bool (flag .item ())
1471
+
1462
1472
# pyre-ignore [14]
1463
1473
def input_dist (
1464
1474
self ,
@@ -1698,10 +1708,19 @@ def update_shards(
1698
1708
return
1699
1709
1700
1710
current_state = self .state_dict ()
1701
- has_optimizer = len (self ._optim ._optims ) > 0 and all (
1711
+ has_local_optimizer = len (self ._optim ._optims ) > 0 and all (
1702
1712
len (i ) > 0 for i in self ._optim .state_dict ()["state" ].values ()
1703
1713
)
1704
1714
1715
+ # communicate optimizer state across all ranks, because if one rank owns all tables
1716
+ # and other ranks does not own any table, and later transfer the weights to empty rank
1717
+ # creates inconsistent state, because initally empty rank does not have optimizer state
1718
+ # hence, incorrectly computes the tensor splits
1719
+
1720
+ # Pyre-ignore
1721
+ has_optimizer = self ._is_optimizer_enabled (has_local_optimizer , env , device )
1722
+
1723
+
1705
1724
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1706
1725
# TODO: Ensure lookup tensors are actually being deleted
1707
1726
for _ , lookup in enumerate (self ._lookups ):
@@ -1715,7 +1734,7 @@ def update_shards(
1715
1734
max_dim_0 , max_dim_1 = get_largest_dims_from_sharding_plan_updates (
1716
1735
changed_sharding_params
1717
1736
)
1718
- old_optimizer_state = self ._optim .state_dict () if has_optimizer else None
1737
+ old_optimizer_state = self ._optim .state_dict () if has_local_optimizer else None
1719
1738
1720
1739
local_shard_names_by_src_rank , local_output_tensor = shards_all_to_all (
1721
1740
module = self ,
@@ -1727,6 +1746,7 @@ def update_shards(
1727
1746
max_dim_0 = max_dim_0 ,
1728
1747
max_dim_1 = max_dim_1 ,
1729
1748
optimizer_state = old_optimizer_state ,
1749
+ has_optimizer = has_optimizer ,
1730
1750
)
1731
1751
1732
1752
for name , param in changed_sharding_params .items ():
@@ -1791,30 +1811,24 @@ def update_shards(
1791
1811
self ._optim : CombinedOptimizer = CombinedOptimizer (optims )
1792
1812
1793
1813
if has_optimizer :
1794
- split_index = len (local_output_tensor ) // 2
1795
- local_weight_tensors = local_output_tensor [:split_index ]
1796
- local_optimizer_tensors = local_output_tensor [split_index :]
1797
- # Modifies new_opt_state in place and returns it
1798
1814
optimizer_state = update_optimizer_state_post_resharding (
1799
1815
old_opt_state = old_optimizer_state , # pyre-ignore
1800
1816
new_opt_state = copy .deepcopy (self ._optim .state_dict ()),
1801
1817
ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1802
- output_tensor = local_optimizer_tensors ,
1818
+ output_tensor = local_output_tensor ,
1803
1819
max_dim_0 = max_dim_0 ,
1804
1820
)
1805
-
1806
1821
self ._optim .load_state_dict (optimizer_state )
1807
- else :
1808
- local_weight_tensors = local_output_tensor
1809
1822
1810
1823
current_state = update_state_dict_post_resharding (
1811
1824
state_dict = current_state ,
1812
1825
ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1813
- output_tensor = local_weight_tensors ,
1826
+ output_tensor = local_output_tensor ,
1814
1827
new_sharding_params = changed_sharding_params ,
1815
1828
curr_rank = dist .get_rank (),
1816
1829
extend_shard_name = self .extend_shard_name ,
1817
1830
max_dim_0 = max_dim_0 ,
1831
+ has_optimizer = has_optimizer ,
1818
1832
)
1819
1833
1820
1834
self .load_state_dict (current_state )
0 commit comments