Skip to content

Commit 85b2080

Browse files
committed
Manual Resharding given Manifold Paths
Differential Revision: D82241141
1 parent eac316e commit 85b2080

File tree

2 files changed

+6
-174
lines changed

2 files changed

+6
-174
lines changed

torchrec/distributed/benchmark/benchmark_resharding_handler.py

Lines changed: 0 additions & 169 deletions
This file was deleted.

torchrec/distributed/model_parallel.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,12 @@ def __init__(
258258
device = torch.device("cpu")
259259
self.device: torch.device = device
260260

261-
if sharders is None:
262-
sharders = get_default_sharders()
261+
self.sharders: List[ModuleSharder[nn.modules.module.Module]] = (
262+
get_default_sharders() if sharders is None else sharders
263+
)
263264

264265
self._sharder_map: Dict[Type[nn.Module], ModuleSharder[nn.Module]] = {
265-
sharder.module_type: sharder for sharder in sharders
266+
sharder.module_type: sharder for sharder in self.sharders
266267
}
267268

268269
if data_parallel_wrapper is None:
@@ -279,9 +280,9 @@ def __init__(
279280
)
280281
pg = self._env.process_group
281282
if pg is not None:
282-
plan = planner.collective_plan(module, sharders, pg)
283+
plan = planner.collective_plan(module, self.sharders, pg)
283284
else:
284-
plan = planner.plan(module, sharders)
285+
plan = planner.plan(module, self.sharders)
285286
self._plan: ShardingPlan = plan
286287
self._dmp_wrapped_module: nn.Module = self._init_dmp(module)
287288
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)

0 commit comments

Comments
 (0)