Skip to content

Commit 38f9abf

Browse files
nipung90facebook-github-bot
authored andcommitted
Allow the ability for uneven row wise sharding based on number of buckets for zch
Summary: This diff enables the use of num_buckets ParameterConstraint in the planner. The presence of this planner will indicate the use of ZCH bucketing as part of rowwise sharding plans. ## Without num_buckets present: The current row-wise sharding strategy will be used. ## With num_buckets present: * When devices have the same amount of memory available: We will divide the buckets to be evenly distributed across hosts and distribute an additional bucket to the required number of hosts to handle the remainders. For eg. if Test case 2: hash_size = 100, num_devices = 4, num_buckets = 10 Each bucket has 10 rows, buckets distributed as [3,3,2,2] So rows are distributed as [30,30,20,20] * When devices have uneven amount of memory We will distribute the buckets in the proportion of the memory of the device to the total memory of all devices and all the remaining buckets left are stored on the last device in the case where buckets do not completely fit based on the memory ratios. for eg hash_size = 45, num_buckets = 9, bucket_size = 5 With memory ratio 2:1:1, buckets should be distributed as [4,2,3] So rows are distributed as [20,10,15] Differential Revision: D79659949
1 parent 60ef897 commit 38f9abf

File tree

4 files changed

+392
-21
lines changed

4 files changed

+392
-21
lines changed

torchrec/distributed/planner/enumerators.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
ShardingType,
3939
)
4040
from torchrec.modules.embedding_configs import DataType
41+
from torchrec.modules.embedding_modules import (
42+
EmbeddingBagCollection,
43+
EmbeddingCollection,
44+
)
4145
from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection
4246

4347

@@ -178,7 +182,7 @@ def enumerate(
178182
# skip for other device groups
179183
if device_group and device_group != self._compute_device:
180184
continue
181-
185+
num_buckets = self._get_num_buckets(name, child_module)
182186
sharding_options_per_table: List[ShardingOption] = []
183187

184188
for sharding_type in self._filter_sharding_types(
@@ -200,6 +204,7 @@ def enumerate(
200204
sharding_type=sharding_type,
201205
col_wise_shard_dim=col_wise_shard_dim,
202206
device_memory_sizes=self._device_memory_sizes,
207+
num_buckets=num_buckets,
203208
)
204209
except ZeroDivisionError as e:
205210
# Re-raise with additional context about the table and module
@@ -264,6 +269,33 @@ def enumerate(
264269
self._last_stored_search_space = copy.deepcopy(sharding_options)
265270
return sharding_options
266271

272+
def _get_num_buckets(self, parameter: str, module: nn.Module) -> Optional[int]:
273+
"""
274+
Get the number of buckets for each embedding table.
275+
276+
Args:
277+
parameter (str): name of the embedding table.
278+
module (nn.Module): module to be sharded.
279+
280+
Returns:
281+
Optional[int]: Number of buckets for the table, or None if module is not EmbeddingBagCollection or table not found.
282+
"""
283+
# If module is not of type EmbeddingBagCollection, return None
284+
if isinstance(module, EmbeddingBagCollection):
285+
embedding_configs = module.embedding_bag_configs()
286+
elif isinstance(module, EmbeddingCollection):
287+
embedding_configs = module.embedding_configs()
288+
else:
289+
return None
290+
291+
# Find the embedding config for the table with the same name as parameter input
292+
for config in embedding_configs:
293+
if config.name == parameter and config.use_virtual_table:
294+
return config.total_num_buckets
295+
296+
# If table with matching name not found, return None
297+
return None
298+
267299
@property
268300
def last_stored_search_space(self) -> Optional[List[ShardingOption]]:
269301
# NOTE: This is the last search space stored by enumerate(...), do not use

torchrec/distributed/planner/tests/test_enumerators.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
EmbeddingTowerSharder,
1919
)
2020
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
21-
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
21+
from torchrec.distributed.embeddingbag import (
22+
EmbeddingBagCollection,
23+
EmbeddingBagCollectionSharder,
24+
)
2225
from torchrec.distributed.mc_embeddingbag import (
2326
ManagedCollisionEmbeddingBagCollectionSharder,
2427
)
@@ -45,13 +48,27 @@
4548
[[17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [11, 80]],
4649
]
4750

51+
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [
52+
[[20, 20], [20, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20]],
53+
[[22, 40], [22, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40]],
54+
[[24, 60], [24, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60]],
55+
[[26, 80], [26, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80]],
56+
]
57+
4858
EXPECTED_RW_SHARD_OFFSETS = [
4959
[[0, 0], [13, 0], [26, 0], [39, 0], [52, 0], [65, 0], [78, 0], [91, 0]],
5060
[[0, 0], [14, 0], [28, 0], [42, 0], [56, 0], [70, 0], [84, 0], [98, 0]],
5161
[[0, 0], [15, 0], [30, 0], [45, 0], [60, 0], [75, 0], [90, 0], [105, 0]],
5262
[[0, 0], [17, 0], [34, 0], [51, 0], [68, 0], [85, 0], [102, 0], [119, 0]],
5363
]
5464

65+
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [
66+
[[0, 0], [20, 0], [40, 0], [50, 0], [60, 0], [70, 0], [80, 0], [90, 0]],
67+
[[0, 0], [22, 0], [44, 0], [55, 0], [66, 0], [77, 0], [88, 0], [99, 0]],
68+
[[0, 0], [24, 0], [48, 0], [60, 0], [72, 0], [84, 0], [96, 0], [108, 0]],
69+
[[0, 0], [26, 0], [52, 0], [65, 0], [78, 0], [91, 0], [104, 0], [117, 0]],
70+
]
71+
5572

5673
def get_expected_cache_aux_size(rows: int) -> int:
5774
# 0.2 is the hardcoded cache load factor assumed in this test
@@ -101,6 +118,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
101118
],
102119
]
103120

121+
EXPECTED_RW_SHARD_STORAGE_WITH_BUCKETS = [
122+
[
123+
Storage(hbm=167488, ddr=0),
124+
Storage(hbm=167488, ddr=0),
125+
Storage(hbm=166688, ddr=0),
126+
Storage(hbm=166688, ddr=0),
127+
Storage(hbm=166688, ddr=0),
128+
Storage(hbm=166688, ddr=0),
129+
Storage(hbm=166688, ddr=0),
130+
Storage(hbm=166688, ddr=0),
131+
],
132+
[
133+
Storage(hbm=1004992, ddr=0),
134+
Storage(hbm=1004992, ddr=0),
135+
Storage(hbm=1003232, ddr=0),
136+
Storage(hbm=1003232, ddr=0),
137+
Storage(hbm=1003232, ddr=0),
138+
Storage(hbm=1003232, ddr=0),
139+
Storage(hbm=1003232, ddr=0),
140+
Storage(hbm=1003232, ddr=0),
141+
],
142+
[
143+
Storage(hbm=1009280, ddr=0),
144+
Storage(hbm=1009280, ddr=0),
145+
Storage(hbm=1006400, ddr=0),
146+
Storage(hbm=1006400, ddr=0),
147+
Storage(hbm=1006400, ddr=0),
148+
Storage(hbm=1006400, ddr=0),
149+
Storage(hbm=1006400, ddr=0),
150+
Storage(hbm=1006400, ddr=0),
151+
],
152+
[
153+
Storage(hbm=2656384, ddr=0),
154+
Storage(hbm=2656384, ddr=0),
155+
Storage(hbm=2652224, ddr=0),
156+
Storage(hbm=2652224, ddr=0),
157+
Storage(hbm=2652224, ddr=0),
158+
Storage(hbm=2652224, ddr=0),
159+
Storage(hbm=2652224, ddr=0),
160+
Storage(hbm=2652224, ddr=0),
161+
],
162+
]
104163

105164
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
106165
[
@@ -145,6 +204,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
145204
],
146205
]
147206

207+
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [
208+
[
209+
Storage(hbm=166352, ddr=1600),
210+
Storage(hbm=166352, ddr=1600),
211+
Storage(hbm=166120, ddr=800),
212+
Storage(hbm=166120, ddr=800),
213+
Storage(hbm=166120, ddr=800),
214+
Storage(hbm=166120, ddr=800),
215+
Storage(hbm=166120, ddr=800),
216+
Storage(hbm=166120, ddr=800),
217+
],
218+
[
219+
Storage(hbm=1002335, ddr=3520),
220+
Storage(hbm=1002335, ddr=3520),
221+
Storage(hbm=1001904, ddr=1760),
222+
Storage(hbm=1001904, ddr=1760),
223+
Storage(hbm=1001904, ddr=1760),
224+
Storage(hbm=1001904, ddr=1760),
225+
Storage(hbm=1001904, ddr=1760),
226+
Storage(hbm=1001904, ddr=1760),
227+
],
228+
[
229+
Storage(hbm=1004845, ddr=5760),
230+
Storage(hbm=1004845, ddr=5760),
231+
Storage(hbm=1004183, ddr=2880),
232+
Storage(hbm=1004183, ddr=2880),
233+
Storage(hbm=1004183, ddr=2880),
234+
Storage(hbm=1004183, ddr=2880),
235+
Storage(hbm=1004183, ddr=2880),
236+
Storage(hbm=1004183, ddr=2880),
237+
],
238+
[
239+
Storage(hbm=2649916, ddr=8320),
240+
Storage(hbm=2649916, ddr=8320),
241+
Storage(hbm=2648990, ddr=4160),
242+
Storage(hbm=2648990, ddr=4160),
243+
Storage(hbm=2648990, ddr=4160),
244+
Storage(hbm=2648990, ddr=4160),
245+
Storage(hbm=2648990, ddr=4160),
246+
Storage(hbm=2648990, ddr=4160),
247+
],
248+
]
148249

149250
EXPECTED_TWRW_SHARD_SIZES = [
150251
[[25, 20], [25, 20], [25, 20], [25, 20]],
@@ -367,6 +468,17 @@ def setUp(self) -> None:
367468
)
368469
for i in range(self.num_tables)
369470
]
471+
tables_with_buckets = [
472+
EmbeddingBagConfig(
473+
num_embeddings=100 + i * 10,
474+
embedding_dim=20 + i * 20,
475+
name="table_" + str(i),
476+
feature_names=["feature_" + str(i)],
477+
total_num_buckets=10,
478+
use_virtual_table=True,
479+
)
480+
for i in range(self.num_tables)
481+
]
370482
weighted_tables = [
371483
EmbeddingBagConfig(
372484
num_embeddings=(i + 1) * 10,
@@ -377,6 +489,9 @@ def setUp(self) -> None:
377489
for i in range(4)
378490
]
379491
self.model = TestSparseNN(tables=tables, weighted_tables=[])
492+
self.model_with_buckets = EmbeddingBagCollection(
493+
tables=tables_with_buckets,
494+
)
380495
self.enumerator = EmbeddingEnumerator(
381496
topology=Topology(
382497
world_size=self.world_size,
@@ -514,6 +629,25 @@ def test_rw_sharding(self) -> None:
514629
EXPECTED_RW_SHARD_STORAGE[i],
515630
)
516631

632+
def test_rw_sharding_with_buckets(self) -> None:
633+
sharding_options = self.enumerator.enumerate(
634+
self.model_with_buckets, [cast(ModuleSharder[torch.nn.Module], RWSharder())]
635+
)
636+
for i, sharding_option in enumerate(sharding_options):
637+
self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value)
638+
self.assertEqual(
639+
[shard.size for shard in sharding_option.shards],
640+
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i],
641+
)
642+
self.assertEqual(
643+
[shard.offset for shard in sharding_option.shards],
644+
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i],
645+
)
646+
self.assertEqual(
647+
[shard.storage for shard in sharding_option.shards],
648+
EXPECTED_RW_SHARD_STORAGE_WITH_BUCKETS[i],
649+
)
650+
517651
def test_uvm_caching_rw_sharding(self) -> None:
518652
sharding_options = self.enumerator.enumerate(
519653
self.model,
@@ -535,6 +669,26 @@ def test_uvm_caching_rw_sharding(self) -> None:
535669
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE[i],
536670
)
537671

672+
def test_uvm_caching_rw_sharding_with_buckets(self) -> None:
673+
sharding_options = self.enumerator.enumerate(
674+
self.model_with_buckets,
675+
[cast(ModuleSharder[torch.nn.Module], UVMCachingRWSharder())],
676+
)
677+
for i, sharding_option in enumerate(sharding_options):
678+
self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value)
679+
self.assertEqual(
680+
[shard.size for shard in sharding_option.shards],
681+
EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i],
682+
)
683+
self.assertEqual(
684+
[shard.offset for shard in sharding_option.shards],
685+
EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i],
686+
)
687+
self.assertEqual(
688+
[shard.storage for shard in sharding_option.shards],
689+
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS[i],
690+
)
691+
538692
def test_twrw_sharding(self) -> None:
539693
sharding_options = self.enumerator.enumerate(
540694
self.model, [cast(ModuleSharder[torch.nn.Module], TWRWSharder())]

0 commit comments

Comments
 (0)