18
18
EmbeddingTowerSharder ,
19
19
)
20
20
from torchrec .distributed .embedding_types import EmbeddingComputeKernel
21
- from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
21
+ from torchrec .distributed .embeddingbag import (
22
+ EmbeddingBagCollection ,
23
+ EmbeddingBagCollectionSharder ,
24
+ )
22
25
from torchrec .distributed .mc_embeddingbag import (
23
26
ManagedCollisionEmbeddingBagCollectionSharder ,
24
27
)
45
48
[[17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [17 , 80 ], [11 , 80 ]],
46
49
]
47
50
51
+ EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [
52
+ [[20 , 20 ], [20 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ], [10 , 20 ]],
53
+ [[22 , 40 ], [22 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ], [11 , 40 ]],
54
+ [[24 , 60 ], [24 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ], [12 , 60 ]],
55
+ [[26 , 80 ], [26 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ], [13 , 80 ]],
56
+ ]
57
+
48
58
EXPECTED_RW_SHARD_OFFSETS = [
49
59
[[0 , 0 ], [13 , 0 ], [26 , 0 ], [39 , 0 ], [52 , 0 ], [65 , 0 ], [78 , 0 ], [91 , 0 ]],
50
60
[[0 , 0 ], [14 , 0 ], [28 , 0 ], [42 , 0 ], [56 , 0 ], [70 , 0 ], [84 , 0 ], [98 , 0 ]],
51
61
[[0 , 0 ], [15 , 0 ], [30 , 0 ], [45 , 0 ], [60 , 0 ], [75 , 0 ], [90 , 0 ], [105 , 0 ]],
52
62
[[0 , 0 ], [17 , 0 ], [34 , 0 ], [51 , 0 ], [68 , 0 ], [85 , 0 ], [102 , 0 ], [119 , 0 ]],
53
63
]
54
64
65
+ EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [
66
+ [[0 , 0 ], [20 , 0 ], [40 , 0 ], [50 , 0 ], [60 , 0 ], [70 , 0 ], [80 , 0 ], [90 , 0 ]],
67
+ [[0 , 0 ], [22 , 0 ], [44 , 0 ], [55 , 0 ], [66 , 0 ], [77 , 0 ], [88 , 0 ], [99 , 0 ]],
68
+ [[0 , 0 ], [24 , 0 ], [48 , 0 ], [60 , 0 ], [72 , 0 ], [84 , 0 ], [96 , 0 ], [108 , 0 ]],
69
+ [[0 , 0 ], [26 , 0 ], [52 , 0 ], [65 , 0 ], [78 , 0 ], [91 , 0 ], [104 , 0 ], [117 , 0 ]],
70
+ ]
71
+
55
72
56
73
def get_expected_cache_aux_size (rows : int ) -> int :
57
74
# 0.2 is the hardcoded cache load factor assumed in this test
@@ -101,6 +118,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
101
118
],
102
119
]
103
120
121
+ EXPECTED_RW_SHARD_STORAGE_WITH_BUCKETS = [
122
+ [
123
+ Storage (hbm = 167488 , ddr = 0 ),
124
+ Storage (hbm = 167488 , ddr = 0 ),
125
+ Storage (hbm = 166688 , ddr = 0 ),
126
+ Storage (hbm = 166688 , ddr = 0 ),
127
+ Storage (hbm = 166688 , ddr = 0 ),
128
+ Storage (hbm = 166688 , ddr = 0 ),
129
+ Storage (hbm = 166688 , ddr = 0 ),
130
+ Storage (hbm = 166688 , ddr = 0 ),
131
+ ],
132
+ [
133
+ Storage (hbm = 1004992 , ddr = 0 ),
134
+ Storage (hbm = 1004992 , ddr = 0 ),
135
+ Storage (hbm = 1003232 , ddr = 0 ),
136
+ Storage (hbm = 1003232 , ddr = 0 ),
137
+ Storage (hbm = 1003232 , ddr = 0 ),
138
+ Storage (hbm = 1003232 , ddr = 0 ),
139
+ Storage (hbm = 1003232 , ddr = 0 ),
140
+ Storage (hbm = 1003232 , ddr = 0 ),
141
+ ],
142
+ [
143
+ Storage (hbm = 1009280 , ddr = 0 ),
144
+ Storage (hbm = 1009280 , ddr = 0 ),
145
+ Storage (hbm = 1006400 , ddr = 0 ),
146
+ Storage (hbm = 1006400 , ddr = 0 ),
147
+ Storage (hbm = 1006400 , ddr = 0 ),
148
+ Storage (hbm = 1006400 , ddr = 0 ),
149
+ Storage (hbm = 1006400 , ddr = 0 ),
150
+ Storage (hbm = 1006400 , ddr = 0 ),
151
+ ],
152
+ [
153
+ Storage (hbm = 2656384 , ddr = 0 ),
154
+ Storage (hbm = 2656384 , ddr = 0 ),
155
+ Storage (hbm = 2652224 , ddr = 0 ),
156
+ Storage (hbm = 2652224 , ddr = 0 ),
157
+ Storage (hbm = 2652224 , ddr = 0 ),
158
+ Storage (hbm = 2652224 , ddr = 0 ),
159
+ Storage (hbm = 2652224 , ddr = 0 ),
160
+ Storage (hbm = 2652224 , ddr = 0 ),
161
+ ],
162
+ ]
104
163
105
164
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
106
165
[
@@ -145,6 +204,48 @@ def get_expected_cache_aux_size(rows: int) -> int:
145
204
],
146
205
]
147
206
207
+ EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [
208
+ [
209
+ Storage (hbm = 166352 , ddr = 1600 ),
210
+ Storage (hbm = 166352 , ddr = 1600 ),
211
+ Storage (hbm = 166120 , ddr = 800 ),
212
+ Storage (hbm = 166120 , ddr = 800 ),
213
+ Storage (hbm = 166120 , ddr = 800 ),
214
+ Storage (hbm = 166120 , ddr = 800 ),
215
+ Storage (hbm = 166120 , ddr = 800 ),
216
+ Storage (hbm = 166120 , ddr = 800 ),
217
+ ],
218
+ [
219
+ Storage (hbm = 1002335 , ddr = 3520 ),
220
+ Storage (hbm = 1002335 , ddr = 3520 ),
221
+ Storage (hbm = 1001904 , ddr = 1760 ),
222
+ Storage (hbm = 1001904 , ddr = 1760 ),
223
+ Storage (hbm = 1001904 , ddr = 1760 ),
224
+ Storage (hbm = 1001904 , ddr = 1760 ),
225
+ Storage (hbm = 1001904 , ddr = 1760 ),
226
+ Storage (hbm = 1001904 , ddr = 1760 ),
227
+ ],
228
+ [
229
+ Storage (hbm = 1004845 , ddr = 5760 ),
230
+ Storage (hbm = 1004845 , ddr = 5760 ),
231
+ Storage (hbm = 1004183 , ddr = 2880 ),
232
+ Storage (hbm = 1004183 , ddr = 2880 ),
233
+ Storage (hbm = 1004183 , ddr = 2880 ),
234
+ Storage (hbm = 1004183 , ddr = 2880 ),
235
+ Storage (hbm = 1004183 , ddr = 2880 ),
236
+ Storage (hbm = 1004183 , ddr = 2880 ),
237
+ ],
238
+ [
239
+ Storage (hbm = 2649916 , ddr = 8320 ),
240
+ Storage (hbm = 2649916 , ddr = 8320 ),
241
+ Storage (hbm = 2648990 , ddr = 4160 ),
242
+ Storage (hbm = 2648990 , ddr = 4160 ),
243
+ Storage (hbm = 2648990 , ddr = 4160 ),
244
+ Storage (hbm = 2648990 , ddr = 4160 ),
245
+ Storage (hbm = 2648990 , ddr = 4160 ),
246
+ Storage (hbm = 2648990 , ddr = 4160 ),
247
+ ],
248
+ ]
148
249
149
250
EXPECTED_TWRW_SHARD_SIZES = [
150
251
[[25 , 20 ], [25 , 20 ], [25 , 20 ], [25 , 20 ]],
@@ -367,6 +468,17 @@ def setUp(self) -> None:
367
468
)
368
469
for i in range (self .num_tables )
369
470
]
471
+ tables_with_buckets = [
472
+ EmbeddingBagConfig (
473
+ num_embeddings = 100 + i * 10 ,
474
+ embedding_dim = 20 + i * 20 ,
475
+ name = "table_" + str (i ),
476
+ feature_names = ["feature_" + str (i )],
477
+ total_num_buckets = 10 ,
478
+ use_virtual_table = True ,
479
+ )
480
+ for i in range (self .num_tables )
481
+ ]
370
482
weighted_tables = [
371
483
EmbeddingBagConfig (
372
484
num_embeddings = (i + 1 ) * 10 ,
@@ -377,6 +489,9 @@ def setUp(self) -> None:
377
489
for i in range (4 )
378
490
]
379
491
self .model = TestSparseNN (tables = tables , weighted_tables = [])
492
+ self .model_with_buckets = EmbeddingBagCollection (
493
+ tables = tables_with_buckets ,
494
+ )
380
495
self .enumerator = EmbeddingEnumerator (
381
496
topology = Topology (
382
497
world_size = self .world_size ,
@@ -514,6 +629,25 @@ def test_rw_sharding(self) -> None:
514
629
EXPECTED_RW_SHARD_STORAGE [i ],
515
630
)
516
631
632
+ def test_rw_sharding_with_buckets (self ) -> None :
633
+ sharding_options = self .enumerator .enumerate (
634
+ self .model_with_buckets , [cast (ModuleSharder [torch .nn .Module ], RWSharder ())]
635
+ )
636
+ for i , sharding_option in enumerate (sharding_options ):
637
+ self .assertEqual (sharding_option .sharding_type , ShardingType .ROW_WISE .value )
638
+ self .assertEqual (
639
+ [shard .size for shard in sharding_option .shards ],
640
+ EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS [i ],
641
+ )
642
+ self .assertEqual (
643
+ [shard .offset for shard in sharding_option .shards ],
644
+ EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS [i ],
645
+ )
646
+ self .assertEqual (
647
+ [shard .storage for shard in sharding_option .shards ],
648
+ EXPECTED_RW_SHARD_STORAGE_WITH_BUCKETS [i ],
649
+ )
650
+
517
651
def test_uvm_caching_rw_sharding (self ) -> None :
518
652
sharding_options = self .enumerator .enumerate (
519
653
self .model ,
@@ -535,6 +669,26 @@ def test_uvm_caching_rw_sharding(self) -> None:
535
669
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE [i ],
536
670
)
537
671
672
+ def test_uvm_caching_rw_sharding_with_buckets (self ) -> None :
673
+ sharding_options = self .enumerator .enumerate (
674
+ self .model_with_buckets ,
675
+ [cast (ModuleSharder [torch .nn .Module ], UVMCachingRWSharder ())],
676
+ )
677
+ for i , sharding_option in enumerate (sharding_options ):
678
+ self .assertEqual (sharding_option .sharding_type , ShardingType .ROW_WISE .value )
679
+ self .assertEqual (
680
+ [shard .size for shard in sharding_option .shards ],
681
+ EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS [i ],
682
+ )
683
+ self .assertEqual (
684
+ [shard .offset for shard in sharding_option .shards ],
685
+ EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS [i ],
686
+ )
687
+ self .assertEqual (
688
+ [shard .storage for shard in sharding_option .shards ],
689
+ EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS [i ],
690
+ )
691
+
538
692
def test_twrw_sharding (self ) -> None :
539
693
sharding_options = self .enumerator .enumerate (
540
694
self .model , [cast (ModuleSharder [torch .nn .Module ], TWRWSharder ())]
0 commit comments