|
17 | 17 | _get_default_rtol_and_atol,
|
18 | 18 | TestSparseNN,
|
19 | 19 | )
|
20 |
| -from torchrec.distributed.types import ShardedModule, ShardingEnv, ShardingType |
| 20 | +from torchrec.distributed.types import ( |
| 21 | + ShardedModule, |
| 22 | + ShardingType, |
| 23 | + ShardingEnv, |
| 24 | + ModuleCopyMixin, |
| 25 | +) |
21 | 26 | from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
22 | 27 | from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
23 | 28 | from torchrec.quant.embedding_modules import (
|
@@ -61,6 +66,25 @@ def _quantize(module: nn.Module, inplace: bool) -> nn.Module:
|
61 | 66 | )
|
62 | 67 |
|
63 | 68 |
|
| 69 | +class CopyModule(nn.Module, ModuleCopyMixin): |
| 70 | + def __init__(self) -> None: |
| 71 | + super().__init__() |
| 72 | + self.tensor: torch.Tensor = torch.empty((10), device="cpu") |
| 73 | + |
| 74 | + def copy(self, device: torch.device) -> nn.Module: |
| 75 | + self.tensor = self.tensor.to(device) |
| 76 | + return self |
| 77 | + |
| 78 | + |
| 79 | +class NoCopyModule(nn.Module, ModuleCopyMixin): |
| 80 | + def __init__(self) -> None: |
| 81 | + super().__init__() |
| 82 | + self.tensor: torch.Tensor = torch.empty((10), device="cpu") |
| 83 | + |
| 84 | + def copy(self, device: torch.device) -> nn.Module: |
| 85 | + return self |
| 86 | + |
| 87 | + |
64 | 88 | class QuantModelParallelModelCopyTest(unittest.TestCase):
|
65 | 89 | def setUp(self) -> None:
|
66 | 90 | num_features = 4
|
@@ -173,3 +197,45 @@ def test_quant_pred(self) -> None:
|
173 | 197 | )
|
174 | 198 | dmp_1 = dmp.copy(device_1)
|
175 | 199 | self._recursive_device_check(dmp.module, dmp_1.module, device, device_1)
|
| 200 | + |
| 201 | + # pyre-fixme[56] |
| 202 | + @unittest.skipIf( |
| 203 | + torch.cuda.device_count() <= 1, |
| 204 | + "Not enough GPUs available", |
| 205 | + ) |
| 206 | + def test_copy_mixin(self) -> None: |
| 207 | + device = torch.device("cuda:0") |
| 208 | + device_1 = torch.device("cuda:1") |
| 209 | + model = TestSparseNN( |
| 210 | + tables=self.tables, |
| 211 | + weighted_tables=self.weighted_tables, |
| 212 | + num_float_features=10, |
| 213 | + dense_device=device, |
| 214 | + sparse_device=torch.device("meta"), |
| 215 | + ) |
| 216 | + # pyre-ignore [16] |
| 217 | + model.copy = CopyModule() |
| 218 | + # pyre-ignore [16] |
| 219 | + model.no_copy = NoCopyModule() |
| 220 | + quant_model = _quantize(model, inplace=True) |
| 221 | + dmp = DistributedModelParallel( |
| 222 | + quant_model, |
| 223 | + sharders=[ |
| 224 | + cast( |
| 225 | + ModuleSharder[torch.nn.Module], |
| 226 | + TestQuantEBCSharder( |
| 227 | + sharding_type=ShardingType.TABLE_WISE.value, |
| 228 | + kernel_type=EmbeddingComputeKernel.BATCHED_QUANT.value, |
| 229 | + ), |
| 230 | + ) |
| 231 | + ], |
| 232 | + device=None, |
| 233 | + env=ShardingEnv.from_local(world_size=2, rank=0), |
| 234 | + init_data_parallel=False, |
| 235 | + ) |
| 236 | + |
| 237 | + dmp_1 = dmp.copy(device_1) |
| 238 | + # pyre-ignore [16] |
| 239 | + self.assertEqual(dmp_1.module.copy.tensor.device, device_1) |
| 240 | + # pyre-ignore [16] |
| 241 | + self.assertEqual(dmp_1.module.no_copy.tensor.device, torch.device("cpu")) |
0 commit comments