Skip to content

Commit 9633587

Browse files
isururanawakafacebook-github-bot
authored andcommitted
Fixing CW E2E test case with multiple ranks
Differential Revision: D79023990
1 parent 5c8e5e2 commit 9633587

File tree

7 files changed

+198
-100
lines changed

7 files changed

+198
-100
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,16 @@ def _create_inverse_indices_permute_indices(
14591459
inverse_indices[1].device,
14601460
)
14611461

1462+
def _is_optimizer_enabled(
1463+
self, has_local_optimizer: bool, env: ShardingEnv, device: torch.device
1464+
) -> bool:
1465+
flag = torch.tensor(
1466+
[has_local_optimizer], dtype=torch.uint8, device=device
1467+
) # example: True
1468+
# Reduce with MAX to check if any process has True
1469+
dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=env.process_group)
1470+
return bool(flag.item())
1471+
14621472
# pyre-ignore [14]
14631473
def input_dist(
14641474
self,
@@ -1698,10 +1708,19 @@ def update_shards(
16981708
return
16991709

17001710
current_state = self.state_dict()
1701-
has_optimizer = len(self._optim._optims) > 0 and all(
1711+
has_local_optimizer = len(self._optim._optims) > 0 and all(
17021712
len(i) > 0 for i in self._optim.state_dict()["state"].values()
17031713
)
17041714

1715+
# communicate optimizer state across all ranks, because if one rank owns all tables
1716+
# and other ranks does not own any table, and later transfer the weights to empty rank
1717+
# creates inconsistent state, because initally empty rank does not have optimizer state
1718+
# hence, incorrectly computes the tensor splits
1719+
1720+
# Pyre-ignore
1721+
has_optimizer = self._is_optimizer_enabled(has_local_optimizer, env, device)
1722+
1723+
17051724
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
17061725
# TODO: Ensure lookup tensors are actually being deleted
17071726
for _, lookup in enumerate(self._lookups):
@@ -1715,7 +1734,7 @@ def update_shards(
17151734
max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates(
17161735
changed_sharding_params
17171736
)
1718-
old_optimizer_state = self._optim.state_dict() if has_optimizer else None
1737+
old_optimizer_state = self._optim.state_dict() if has_local_optimizer else None
17191738

17201739
local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
17211740
module=self,
@@ -1727,6 +1746,7 @@ def update_shards(
17271746
max_dim_0=max_dim_0,
17281747
max_dim_1=max_dim_1,
17291748
optimizer_state=old_optimizer_state,
1749+
has_optimizer=has_optimizer,
17301750
)
17311751

17321752
for name, param in changed_sharding_params.items():
@@ -1791,30 +1811,24 @@ def update_shards(
17911811
self._optim: CombinedOptimizer = CombinedOptimizer(optims)
17921812

17931813
if has_optimizer:
1794-
split_index = len(local_output_tensor) // 2
1795-
local_weight_tensors = local_output_tensor[:split_index]
1796-
local_optimizer_tensors = local_output_tensor[split_index:]
1797-
# Modifies new_opt_state in place and returns it
17981814
optimizer_state = update_optimizer_state_post_resharding(
17991815
old_opt_state=old_optimizer_state, # pyre-ignore
18001816
new_opt_state=copy.deepcopy(self._optim.state_dict()),
18011817
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1802-
output_tensor=local_optimizer_tensors,
1818+
output_tensor=local_output_tensor,
18031819
max_dim_0=max_dim_0,
18041820
)
1805-
18061821
self._optim.load_state_dict(optimizer_state)
1807-
else:
1808-
local_weight_tensors = local_output_tensor
18091822

18101823
current_state = update_state_dict_post_resharding(
18111824
state_dict=current_state,
18121825
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1813-
output_tensor=local_weight_tensors,
1826+
output_tensor=local_output_tensor,
18141827
new_sharding_params=changed_sharding_params,
18151828
curr_rank=dist.get_rank(),
18161829
extend_shard_name=self.extend_shard_name,
18171830
max_dim_0=max_dim_0,
1831+
has_optimizer=has_optimizer,
18181832
)
18191833

18201834
self.load_state_dict(current_state)

0 commit comments

Comments
 (0)