18
18
import json
19
19
import logging
20
20
import os
21
+ import random
21
22
import resource
22
23
import time
23
24
import timeit
55
56
EmbeddingStorageEstimator ,
56
57
)
57
58
from torchrec .distributed .shard import _shard_modules
59
+ from torchrec .distributed .sharding_plan import (
60
+ construct_module_sharding_plan ,
61
+ get_sharding_constructor_from_type ,
62
+ )
58
63
from torchrec .distributed .test_utils .multi_process import MultiProcessContext
59
64
from torchrec .distributed .test_utils .test_model import ModelInput
60
65
@@ -359,6 +364,28 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor:
359
364
T = TypeVar ("T" , bound = torch .nn .Module )
360
365
361
366
367
+ def _generate_rank_placements (
368
+ world_size : int ,
369
+ num_tables : int ,
370
+ ranks_per_tables : List [int ],
371
+ random_seed : int = None ,
372
+ ) -> List [List [int ]]:
373
+ # Cannot include old/new rank generation with hypothesis library due to depedency on world_size
374
+ if random_seed is None :
375
+ # Generate a random seed to ensure that the output rank placements can be different each time
376
+ random_seed = random .randint (0 , 10000 )
377
+ placements = []
378
+ max_rank = world_size - 1
379
+ random .seed (random_seed )
380
+ if ranks_per_tables == [0 ]:
381
+ ranks_per_tables = [random .randint (1 , max_rank ) for _ in range (num_tables )]
382
+ for i in range (num_tables ):
383
+ ranks_per_table = ranks_per_tables [i ]
384
+ placement = sorted (random .sample (range (world_size ), ranks_per_table ))
385
+ placements .append (placement )
386
+ return placements
387
+
388
+
362
389
def default_func_to_benchmark (
363
390
model : torch .nn .Module , bench_inputs : List [KeyedJaggedTensor ]
364
391
) -> None :
@@ -679,6 +706,8 @@ def init_argparse_and_args() -> argparse.Namespace:
679
706
parser .add_argument ("--num_benchmarks" , type = int , default = 5 )
680
707
parser .add_argument ("--embedding_config_json" , type = str , default = "" )
681
708
parser .add_argument ("--device_type" , type = str , default = "cuda" )
709
+ parser .add_argument ("--enable_resharding" , type = bool , default = False )
710
+ parser .add_argument ("--resharding_interval" , type = int , default = 1000 )
682
711
683
712
args = parser .parse_args ()
684
713
@@ -728,6 +757,7 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
728
757
# Don't want to modify the module outright
729
758
# Since module is on cpu, won't cause cuda oom.
730
759
copied_module = copy .deepcopy (module )
760
+
731
761
# pyre-ignore [6]
732
762
plan = planner .plan (copied_module , [sharder ])
733
763
@@ -780,6 +810,7 @@ def _run_benchmark_core(
780
810
pre_gpu_load : int = 0 ,
781
811
export_stacks : bool = False ,
782
812
reset_accumulated_memory_stats : bool = False ,
813
+ new_ranks_per_plan : Optional [List [int ]] = None ,
783
814
) -> BenchmarkResult :
784
815
"""Internal helper that contains the core benchmarking logic shared by
785
816
``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory
@@ -824,6 +855,11 @@ def _run_benchmark_core(
824
855
# Timings
825
856
start_events , end_events , times = [], [], []
826
857
858
+ < << << << dest : af95f723afd1 - noreply + 1244265887488347 : [AutoAccept ][Codemod ...
859
+ == == == =
860
+ times = []
861
+
862
+ > >> >> >> source : ea51915e8d98 - isuru : Reshard API Performance Benchmarking
827
863
if device_type == "cuda" :
828
864
start_events = [
829
865
torch .cuda .Event (enable_timing = True ) for _ in range (num_benchmarks )
@@ -1071,6 +1107,7 @@ def init_module_and_run_benchmark(
1071
1107
queue : Optional [mp .Queue ] = None ,
1072
1108
pooling_configs : Optional [List [int ]] = None ,
1073
1109
benchmark_unsharded_module : bool = False ,
1110
+ new_ranks_per_plan : Optional [List [List [int ]]] = None ,
1074
1111
) -> BenchmarkResult :
1075
1112
"""
1076
1113
There are a couple of caveats here as to why the module has to be initialized
@@ -1136,6 +1173,40 @@ def init_module_and_run_benchmark(
1136
1173
else :
1137
1174
name = benchmark_type_name (compile_mode , sharding_type )
1138
1175
1176
+ resharding_plans = []
1177
+
1178
+ import fbvscode
1179
+
1180
+ fbvscode .set_trace ()
1181
+
1182
+ if new_ranks_per_plan is not None and len (new_ranks_per_plan ) > 0 :
1183
+ sharding_type_constructor = get_sharding_constructor_from_type (
1184
+ sharding_type
1185
+ )
1186
+ for i , new_ranks in enumerate (new_ranks_per_plan ):
1187
+ new_per_param_sharding = {}
1188
+ for table in tables :
1189
+ if sharding_type == ShardingType .TABLE_WISE :
1190
+ new_per_param_sharding [table .name ] = sharding_type_constructor (
1191
+ rank = new_ranks , compute_kernel = sharder ._kernel_type
1192
+ )
1193
+ elif sharding_type == ShardingType .COLUMN_WISE :
1194
+ new_per_param_sharding [table .name ] = sharding_type_constructor (
1195
+ ranks = new_ranks
1196
+ )
1197
+
1198
+ new_module_sharding_plan = construct_module_sharding_plan (
1199
+ module = module .module ,
1200
+ # Pyre-ignore
1201
+ sharder = sharder ,
1202
+ per_param_sharding = new_per_param_sharding ,
1203
+ local_size = world_size ,
1204
+ world_size = world_size ,
1205
+ device_type = "cuda" if torch .cuda .is_available () else "cpu" ,
1206
+ )
1207
+ resharding_plans .append (new_module_sharding_plan )
1208
+ benchmark_func_kwargs ["resharding_plans" ] = resharding_plans
1209
+
1139
1210
res = benchmark (
1140
1211
name ,
1141
1212
module ,
@@ -1244,6 +1315,8 @@ def benchmark_module(
1244
1315
pooling_configs : Optional [List [int ]] = None ,
1245
1316
variable_batch_embeddings : bool = False ,
1246
1317
device_type : str = "cuda" ,
1318
+ enable_resharding : bool = False ,
1319
+ resharding_interval : int = 1000 ,
1247
1320
) -> List [BenchmarkResult ]:
1248
1321
"""
1249
1322
Args:
@@ -1325,6 +1398,39 @@ def benchmark_module(
1325
1398
)
1326
1399
1327
1400
if train :
1401
+ total_plans_per_benchmark = bench_iters // resharding_interval
1402
+ total_plans_per_benchmark = max (1 , total_plans_per_benchmark )
1403
+ new_ranks_per_plan = []
1404
+ if enable_resharding :
1405
+ num_tables = len (tables )
1406
+ new_ranks_count_per_plan = [
1407
+ [] for _ in range (total_plans_per_benchmark )
1408
+ ]
1409
+ if sharding_type == ShardingType .TABLE_WISE :
1410
+ ranks_per_tables = [1 for _ in range (num_tables )]
1411
+ new_ranks_per_plan = [
1412
+ _generate_rank_placements (
1413
+ world_size , num_tables , ranks_per_tables
1414
+ )
1415
+ for _ in range (total_plans_per_benchmark )
1416
+ ]
1417
+
1418
+ elif sharding_type == ShardingType .COLUMN_WISE :
1419
+ valid_candidates = [
1420
+ i
1421
+ for i in range (1 , world_size + 1 )
1422
+ if EMBEDDING_DIM % i == 0
1423
+ ]
1424
+ ranks_per_tables = [
1425
+ random .choice (valid_candidates ) for _ in range (num_tables )
1426
+ ]
1427
+ new_ranks_per_plan = [
1428
+ _generate_rank_placements (
1429
+ world_size , num_tables , ranks_per_tables
1430
+ )
1431
+ for ranks_per_tables in (new_ranks_count_per_plan )
1432
+ ]
1433
+
1328
1434
res = multi_process_benchmark (
1329
1435
# pyre-ignore[6]
1330
1436
callable = init_module_and_run_benchmark ,
@@ -1344,6 +1450,7 @@ def benchmark_module(
1344
1450
func_to_benchmark = func_to_benchmark ,
1345
1451
benchmark_func_kwargs = benchmark_func_kwargs ,
1346
1452
pooling_configs = pooling_configs ,
1453
+ new_ranks_per_plan = new_ranks_per_plan ,
1347
1454
)
1348
1455
else :
1349
1456
res = init_module_and_run_benchmark (
0 commit comments