18
18
import json
19
19
import logging
20
20
import os
21
+ import random
21
22
import time
22
23
import timeit
23
24
from dataclasses import dataclass , fields , is_dataclass , MISSING
54
55
EmbeddingStorageEstimator ,
55
56
)
56
57
from torchrec .distributed .shard import _shard_modules
58
+ from torchrec .distributed .sharding_plan import (
59
+ construct_module_sharding_plan ,
60
+ get_sharding_constructor_from_type ,
61
+ )
57
62
from torchrec .distributed .test_utils .multi_process import MultiProcessContext
58
63
from torchrec .distributed .test_utils .test_model import ModelInput
59
64
@@ -308,6 +313,28 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor:
308
313
T = TypeVar ("T" , bound = torch .nn .Module )
309
314
310
315
316
+ def _generate_rank_placements (
317
+ world_size : int ,
318
+ num_tables : int ,
319
+ ranks_per_tables : List [int ],
320
+ random_seed : int = None ,
321
+ ) -> List [List [int ]]:
322
+ # Cannot include old/new rank generation with hypothesis library due to depedency on world_size
323
+ if random_seed is None :
324
+ # Generate a random seed to ensure that the output rank placements can be different each time
325
+ random_seed = random .randint (0 , 10000 )
326
+ placements = []
327
+ max_rank = world_size - 1
328
+ random .seed (random_seed )
329
+ if ranks_per_tables == [0 ]:
330
+ ranks_per_tables = [random .randint (1 , max_rank ) for _ in range (num_tables )]
331
+ for i in range (num_tables ):
332
+ ranks_per_table = ranks_per_tables [i ]
333
+ placement = sorted (random .sample (range (world_size ), ranks_per_table ))
334
+ placements .append (placement )
335
+ return placements
336
+
337
+
311
338
def default_func_to_benchmark (
312
339
model : torch .nn .Module , bench_inputs : List [KeyedJaggedTensor ]
313
340
) -> None :
@@ -595,6 +622,8 @@ def init_argparse_and_args() -> argparse.Namespace:
595
622
parser .add_argument ("--num_benchmarks" , type = int , default = 5 )
596
623
parser .add_argument ("--embedding_config_json" , type = str , default = "" )
597
624
parser .add_argument ("--device_type" , type = str , default = "cuda" )
625
+ parser .add_argument ("--enable_resharding" , type = bool , default = False )
626
+ parser .add_argument ("--resharding_interval" , type = int , default = 1000 )
598
627
599
628
args = parser .parse_args ()
600
629
@@ -644,6 +673,7 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
644
673
# Don't want to modify the module outright
645
674
# Since module is on cpu, won't cause cuda oom.
646
675
copied_module = copy .deepcopy (module )
676
+
647
677
# pyre-ignore [6]
648
678
plan = planner .plan (copied_module , [sharder ])
649
679
@@ -700,6 +730,7 @@ def benchmark(
700
730
enable_logging : bool = True ,
701
731
device_type : str = "cuda" ,
702
732
benchmark_unsharded_module : bool = False ,
733
+ new_ranks_per_plan : Optional [List [int ]] = None ,
703
734
) -> BenchmarkResult :
704
735
memory_stats : List [MemoryStats ] = []
705
736
if enable_logging :
@@ -728,6 +759,7 @@ def benchmark(
728
759
benchmark_func_kwargs = {}
729
760
730
761
times = []
762
+
731
763
if device_type == "cuda" :
732
764
for i in range (num_benchmarks ):
733
765
start [i ].record ()
@@ -998,6 +1030,7 @@ def init_module_and_run_benchmark(
998
1030
queue : Optional [mp .Queue ] = None ,
999
1031
pooling_configs : Optional [List [int ]] = None ,
1000
1032
benchmark_unsharded_module : bool = False ,
1033
+ new_ranks_per_plan : Optional [List [List [int ]]] = None ,
1001
1034
) -> BenchmarkResult :
1002
1035
"""
1003
1036
There are a couple of caveats here as to why the module has to be initialized
@@ -1063,6 +1096,40 @@ def init_module_and_run_benchmark(
1063
1096
else :
1064
1097
name = benchmark_type_name (compile_mode , sharding_type )
1065
1098
1099
+ resharding_plans = []
1100
+
1101
+ import fbvscode
1102
+
1103
+ fbvscode .set_trace ()
1104
+
1105
+ if new_ranks_per_plan is not None and len (new_ranks_per_plan ) > 0 :
1106
+ sharding_type_constructor = get_sharding_constructor_from_type (
1107
+ sharding_type
1108
+ )
1109
+ for i , new_ranks in enumerate (new_ranks_per_plan ):
1110
+ new_per_param_sharding = {}
1111
+ for table in tables :
1112
+ if sharding_type == ShardingType .TABLE_WISE :
1113
+ new_per_param_sharding [table .name ] = sharding_type_constructor (
1114
+ rank = new_ranks , compute_kernel = sharder ._kernel_type
1115
+ )
1116
+ elif sharding_type == ShardingType .COLUMN_WISE :
1117
+ new_per_param_sharding [table .name ] = sharding_type_constructor (
1118
+ ranks = new_ranks
1119
+ )
1120
+
1121
+ new_module_sharding_plan = construct_module_sharding_plan (
1122
+ module = module .module ,
1123
+ # Pyre-ignore
1124
+ sharder = sharder ,
1125
+ per_param_sharding = new_per_param_sharding ,
1126
+ local_size = world_size ,
1127
+ world_size = world_size ,
1128
+ device_type = "cuda" if torch .cuda .is_available () else "cpu" ,
1129
+ )
1130
+ resharding_plans .append (new_module_sharding_plan )
1131
+ benchmark_func_kwargs ["resharding_plans" ] = resharding_plans
1132
+
1066
1133
res = benchmark (
1067
1134
name ,
1068
1135
module ,
@@ -1167,6 +1234,8 @@ def benchmark_module(
1167
1234
pooling_configs : Optional [List [int ]] = None ,
1168
1235
variable_batch_embeddings : bool = False ,
1169
1236
device_type : str = "cuda" ,
1237
+ enable_resharding : bool = False ,
1238
+ resharding_interval : int = 1000 ,
1170
1239
) -> List [BenchmarkResult ]:
1171
1240
"""
1172
1241
Args:
@@ -1248,6 +1317,39 @@ def benchmark_module(
1248
1317
)
1249
1318
1250
1319
if train :
1320
+ total_plans_per_benchmark = bench_iters // resharding_interval
1321
+ total_plans_per_benchmark = max (1 , total_plans_per_benchmark )
1322
+ new_ranks_per_plan = []
1323
+ if enable_resharding :
1324
+ num_tables = len (tables )
1325
+ new_ranks_count_per_plan = [
1326
+ [] for _ in range (total_plans_per_benchmark )
1327
+ ]
1328
+ if sharding_type == ShardingType .TABLE_WISE :
1329
+ ranks_per_tables = [1 for _ in range (num_tables )]
1330
+ new_ranks_per_plan = [
1331
+ _generate_rank_placements (
1332
+ world_size , num_tables , ranks_per_tables
1333
+ )
1334
+ for _ in range (total_plans_per_benchmark )
1335
+ ]
1336
+
1337
+ elif sharding_type == ShardingType .COLUMN_WISE :
1338
+ valid_candidates = [
1339
+ i
1340
+ for i in range (1 , world_size + 1 )
1341
+ if EMBEDDING_DIM % i == 0
1342
+ ]
1343
+ ranks_per_tables = [
1344
+ random .choice (valid_candidates ) for _ in range (num_tables )
1345
+ ]
1346
+ new_ranks_per_plan = [
1347
+ _generate_rank_placements (
1348
+ world_size , num_tables , ranks_per_tables
1349
+ )
1350
+ for ranks_per_tables in (new_ranks_count_per_plan )
1351
+ ]
1352
+
1251
1353
res = multi_process_benchmark (
1252
1354
# pyre-ignore[6]
1253
1355
callable = init_module_and_run_benchmark ,
@@ -1267,6 +1369,7 @@ def benchmark_module(
1267
1369
func_to_benchmark = func_to_benchmark ,
1268
1370
benchmark_func_kwargs = benchmark_func_kwargs ,
1269
1371
pooling_configs = pooling_configs ,
1372
+ new_ranks_per_plan = new_ranks_per_plan ,
1270
1373
)
1271
1374
else :
1272
1375
res = init_module_and_run_benchmark (
0 commit comments