Skip to content

Commit 0c5d591

Browse files
Zheng Yanfacebook-github-bot
authored andcommitted
add ModuleCopyMixin to allow inference allocation overrides (#314)
Summary: Pull Request resolved: #314 need to override .copy for some modules to allow it to be placed on CPU during inference time Reviewed By: yinghai Differential Revision: D36235459 fbshipit-source-id: 19786560ae51960deb574895604c168f81373715
1 parent d31f430 commit 0c5d591

File tree

3 files changed

+84
-7
lines changed

3 files changed

+84
-7
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ShardedModule,
3030
ShardingEnv,
3131
ShardingPlan,
32+
ModuleCopyMixin,
3233
)
3334
from torchrec.distributed.utils import (
3435
add_prefix_to_state_dict,
@@ -284,14 +285,17 @@ def _copy_if_device_match(tensor: torch.Tensor) -> torch.Tensor:
284285
return tensor
285286

286287
# if this is a sharded module, customize the copy
287-
if isinstance(module, ShardedModule):
288+
if isinstance(module, ModuleCopyMixin):
288289
return module.copy(device)
289290
# this could be dense or a compound module
290291
for name, child in module.named_children():
291292
# potential DFS cache or bottom-up can save runtime
292293
# search immediate submodules
293294
if not any(
294-
[isinstance(submodule, ShardedModule) for submodule in child.modules()]
295+
[
296+
isinstance(submodule, ModuleCopyMixin)
297+
for submodule in child.modules()
298+
]
295299
):
296300
# if not containing ShardedModule down this submodule (this is a dense module)
297301
# copy it.

torchrec/distributed/tests/test_quant_model_parallel.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
_get_default_rtol_and_atol,
1818
TestSparseNN,
1919
)
20-
from torchrec.distributed.types import ShardedModule, ShardingEnv, ShardingType
20+
from torchrec.distributed.types import (
21+
ShardedModule,
22+
ShardingType,
23+
ShardingEnv,
24+
ModuleCopyMixin,
25+
)
2126
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2227
from torchrec.modules.embedding_modules import EmbeddingBagCollection
2328
from torchrec.quant.embedding_modules import (
@@ -61,6 +66,25 @@ def _quantize(module: nn.Module, inplace: bool) -> nn.Module:
6166
)
6267

6368

69+
class CopyModule(nn.Module, ModuleCopyMixin):
70+
def __init__(self) -> None:
71+
super().__init__()
72+
self.tensor: torch.Tensor = torch.empty((10), device="cpu")
73+
74+
def copy(self, device: torch.device) -> nn.Module:
75+
self.tensor = self.tensor.to(device)
76+
return self
77+
78+
79+
class NoCopyModule(nn.Module, ModuleCopyMixin):
80+
def __init__(self) -> None:
81+
super().__init__()
82+
self.tensor: torch.Tensor = torch.empty((10), device="cpu")
83+
84+
def copy(self, device: torch.device) -> nn.Module:
85+
return self
86+
87+
6488
class QuantModelParallelModelCopyTest(unittest.TestCase):
6589
def setUp(self) -> None:
6690
num_features = 4
@@ -173,3 +197,45 @@ def test_quant_pred(self) -> None:
173197
)
174198
dmp_1 = dmp.copy(device_1)
175199
self._recursive_device_check(dmp.module, dmp_1.module, device, device_1)
200+
201+
# pyre-fixme[56]
202+
@unittest.skipIf(
203+
torch.cuda.device_count() <= 1,
204+
"Not enough GPUs available",
205+
)
206+
def test_copy_mixin(self) -> None:
207+
device = torch.device("cuda:0")
208+
device_1 = torch.device("cuda:1")
209+
model = TestSparseNN(
210+
tables=self.tables,
211+
weighted_tables=self.weighted_tables,
212+
num_float_features=10,
213+
dense_device=device,
214+
sparse_device=torch.device("meta"),
215+
)
216+
# pyre-ignore [16]
217+
model.copy = CopyModule()
218+
# pyre-ignore [16]
219+
model.no_copy = NoCopyModule()
220+
quant_model = _quantize(model, inplace=True)
221+
dmp = DistributedModelParallel(
222+
quant_model,
223+
sharders=[
224+
cast(
225+
ModuleSharder[torch.nn.Module],
226+
TestQuantEBCSharder(
227+
sharding_type=ShardingType.TABLE_WISE.value,
228+
kernel_type=EmbeddingComputeKernel.BATCHED_QUANT.value,
229+
),
230+
)
231+
],
232+
device=None,
233+
env=ShardingEnv.from_local(world_size=2, rank=0),
234+
init_data_parallel=False,
235+
)
236+
237+
dmp_1 = dmp.copy(device_1)
238+
# pyre-ignore [16]
239+
self.assertEqual(dmp_1.module.copy.tensor.device, device_1)
240+
# pyre-ignore [16]
241+
self.assertEqual(dmp_1.module.no_copy.tensor.device, torch.device("cpu"))

torchrec/distributed/types.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,17 @@ def from_local(cls, world_size: int, rank: int) -> "ShardingEnv":
356356
return cls(world_size, rank, None)
357357

358358

359-
class ShardedModule(abc.ABC, nn.Module, Generic[CompIn, DistOut, Out]):
359+
class ModuleCopyMixin:
360+
"""
361+
A mixin to allow modules to override copy behaviros in DMP.
362+
"""
363+
364+
def copy(self, device: torch.device) -> nn.Module:
365+
# pyre-ignore [16]
366+
return self.to(device)
367+
368+
369+
class ShardedModule(abc.ABC, nn.Module, Generic[CompIn, DistOut, Out], ModuleCopyMixin):
360370
"""
361371
All model-parallel modules implement this interface.
362372
Inputs and outputs are data-parallel.
@@ -423,9 +433,6 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
423433
for key, _ in self.named_parameters(prefix):
424434
yield key
425435

426-
def copy(self, device: torch.device) -> nn.Module:
427-
return self.to(device)
428-
429436

430437
class ModuleSharder(abc.ABC, Generic[M]):
431438
"""

0 commit comments

Comments
 (0)