Skip to content

Commit 3d5a9f6

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add test coverage for loading state dict from cpu for EBC sharding
Summary: # context * OSS CI [test is failing](https://github.com/pytorch/torchrec/actions/runs/17565037476/job/49890092618) due to incorrect sharded EBC weights, which further comes from a recent change D80547120, where the eager-mode EBC is on CPU * the `load_state_dict` is skipped from [day-1](https://www.internalfb.com/diff/D46653763?dst_version_fbid=803682924606242&transaction_fbid=6359829970773670) * actually this feature was added in T156003280 by D46915880 Differential Revision: D82002643
1 parent 6e93e77 commit 3d5a9f6

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -597,10 +597,7 @@ def __init__(
597597
if env.process_group and dist.get_backend(env.process_group) != "fake":
598598
self._initialize_torch_state()
599599

600-
if module.device not in ["meta", "cpu"] and module.device.type not in [
601-
"meta",
602-
"cpu",
603-
]:
600+
if module.device != "meta" and module.device.type != "meta":
604601
self.load_state_dict(module.state_dict(), strict=False)
605602

606603
@classmethod

0 commit comments

Comments
 (0)