56
56
EmbeddingStorageEstimator ,
57
57
)
58
58
from torchrec .distributed .shard import _shard_modules
59
+
60
+ from torchrec .distributed .sharding .dynamic_sharding import output_sharding_plan_delta
59
61
from torchrec .distributed .sharding_plan import (
60
62
construct_module_sharding_plan ,
61
63
get_sharding_constructor_from_type ,
62
64
)
63
65
from torchrec .distributed .test_utils .multi_process import MultiProcessContext
64
66
from torchrec .distributed .test_utils .test_model import ModelInput
65
67
66
- from torchrec .distributed .types import DataType , ModuleSharder , ShardingEnv
68
+ from torchrec .distributed .types import (
69
+ DataType ,
70
+ EmbeddingModuleShardingPlan ,
71
+ ModuleSharder ,
72
+ ShardingEnv ,
73
+ )
67
74
from torchrec .fx import symbolic_trace
68
75
from torchrec .modules .embedding_configs import EmbeddingBagConfig , EmbeddingConfig
69
76
from torchrec .quant .embedding_modules import (
@@ -368,7 +375,7 @@ def _generate_rank_placements(
368
375
world_size : int ,
369
376
num_tables : int ,
370
377
ranks_per_tables : List [int ],
371
- random_seed : int = None ,
378
+ random_seed : Optional [ int ] = None ,
372
379
) -> List [List [int ]]:
373
380
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
374
381
if random_seed is None :
@@ -387,7 +394,9 @@ def _generate_rank_placements(
387
394
388
395
389
396
def default_func_to_benchmark (
390
- model : torch .nn .Module , bench_inputs : List [KeyedJaggedTensor ]
397
+ model : torch .nn .Module ,
398
+ bench_inputs : List [KeyedJaggedTensor ],
399
+ resharding_plan_diffs : Optional [List [EmbeddingModuleShardingPlan ]] = None ,
391
400
) -> None :
392
401
with torch .inference_mode ():
393
402
for bench_input in bench_inputs :
@@ -855,11 +864,6 @@ def _run_benchmark_core(
855
864
# Timings
856
865
start_events , end_events , times = [], [], []
857
866
858
- < << << << dest : af95f723afd1 - noreply + 1244265887488347 : [AutoAccept ][Codemod ...
859
- == == == =
860
- times = []
861
-
862
- > >> >> >> source : ea51915e8d98 - isuru : Reshard API Performance Benchmarking
863
867
if device_type == "cuda" :
864
868
start_events = [
865
869
torch .cuda .Event (enable_timing = True ) for _ in range (num_benchmarks )
@@ -1154,58 +1158,59 @@ def init_module_and_run_benchmark(
1154
1158
if rank != - 1
1155
1159
else contextlib .nullcontext ()
1156
1160
) as ctx :
1157
- module = transform_module (
1158
- module = module ,
1159
- device = device ,
1160
- inputs = warmup_inputs_cuda ,
1161
- sharder = sharder ,
1162
- sharding_type = sharding_type ,
1163
- compile_mode = compile_mode ,
1164
- world_size = world_size ,
1165
- batch_size = batch_size ,
1166
- # pyre-ignore[6]
1167
- ctx = ctx ,
1168
- benchmark_unsharded_module = benchmark_unsharded_module ,
1169
- )
1170
-
1171
- if benchmark_unsharded_module :
1172
- name = "unsharded" + compile_mode .name
1173
- else :
1174
- name = benchmark_type_name (compile_mode , sharding_type )
1175
1161
1176
1162
resharding_plans = []
1177
1163
1178
- import fbvscode
1179
-
1180
- fbvscode .set_trace ()
1181
-
1182
1164
if new_ranks_per_plan is not None and len (new_ranks_per_plan ) > 0 :
1183
1165
sharding_type_constructor = get_sharding_constructor_from_type (
1184
1166
sharding_type
1185
1167
)
1186
- for i , new_ranks in enumerate ( new_ranks_per_plan ) :
1168
+ for new_ranks_per_table in new_ranks_per_plan :
1187
1169
new_per_param_sharding = {}
1188
- for table in tables :
1170
+ for table_id , table in enumerate ( tables ) :
1189
1171
if sharding_type == ShardingType .TABLE_WISE :
1190
1172
new_per_param_sharding [table .name ] = sharding_type_constructor (
1191
- rank = new_ranks , compute_kernel = sharder ._kernel_type
1173
+ rank = new_ranks_per_table [table_id ][0 ],
1174
+ compute_kernel = sharder ._kernel_type ,
1192
1175
)
1193
1176
elif sharding_type == ShardingType .COLUMN_WISE :
1194
1177
new_per_param_sharding [table .name ] = sharding_type_constructor (
1195
- ranks = new_ranks
1178
+ ranks = new_ranks_per_table [ table_id ]
1196
1179
)
1197
1180
1198
1181
new_module_sharding_plan = construct_module_sharding_plan (
1199
- module = module .module ,
1182
+ module = module ._module , # Pyre-ignore
1200
1183
# Pyre-ignore
1201
1184
sharder = sharder ,
1202
1185
per_param_sharding = new_per_param_sharding ,
1203
1186
local_size = world_size ,
1204
1187
world_size = world_size ,
1205
- device_type = "cuda" if torch . cuda . is_available () else "cpu" ,
1188
+ device_type = device . type ,
1206
1189
)
1207
1190
resharding_plans .append (new_module_sharding_plan )
1208
- benchmark_func_kwargs ["resharding_plans" ] = resharding_plans
1191
+
1192
+ module = transform_module (
1193
+ module = module ,
1194
+ device = device ,
1195
+ inputs = warmup_inputs_cuda ,
1196
+ sharder = sharder ,
1197
+ sharding_type = sharding_type ,
1198
+ compile_mode = compile_mode ,
1199
+ world_size = world_size ,
1200
+ batch_size = batch_size ,
1201
+ # pyre-ignore[6]
1202
+ ctx = ctx ,
1203
+ benchmark_unsharded_module = benchmark_unsharded_module ,
1204
+ )
1205
+
1206
+ if benchmark_unsharded_module :
1207
+ name = "unsharded" + compile_mode .name
1208
+ else :
1209
+ name = benchmark_type_name (compile_mode , sharding_type )
1210
+
1211
+ if benchmark_func_kwargs is None :
1212
+ benchmark_func_kwargs = {}
1213
+ benchmark_func_kwargs ["resharding_plan_diffs" ] = resharding_plans
1209
1214
1210
1215
res = benchmark (
1211
1216
name ,
@@ -1398,22 +1403,18 @@ def benchmark_module(
1398
1403
)
1399
1404
1400
1405
if train :
1401
- total_plans_per_benchmark = bench_iters // resharding_interval
1402
- total_plans_per_benchmark = max (1 , total_plans_per_benchmark )
1406
+
1403
1407
new_ranks_per_plan = []
1408
+
1404
1409
if enable_resharding :
1410
+ total_plans_per_benchmark = bench_iters // resharding_interval
1411
+ total_plans_per_benchmark = max (1 , total_plans_per_benchmark )
1412
+
1405
1413
num_tables = len (tables )
1406
- new_ranks_count_per_plan = [
1407
- [] for _ in range (total_plans_per_benchmark )
1408
- ]
1414
+ ranks_per_tables = []
1415
+
1409
1416
if sharding_type == ShardingType .TABLE_WISE :
1410
1417
ranks_per_tables = [1 for _ in range (num_tables )]
1411
- new_ranks_per_plan = [
1412
- _generate_rank_placements (
1413
- world_size , num_tables , ranks_per_tables
1414
- )
1415
- for _ in range (total_plans_per_benchmark )
1416
- ]
1417
1418
1418
1419
elif sharding_type == ShardingType .COLUMN_WISE :
1419
1420
valid_candidates = [
@@ -1424,11 +1425,12 @@ def benchmark_module(
1424
1425
ranks_per_tables = [
1425
1426
random .choice (valid_candidates ) for _ in range (num_tables )
1426
1427
]
1428
+
1427
1429
new_ranks_per_plan = [
1428
1430
_generate_rank_placements (
1429
1431
world_size , num_tables , ranks_per_tables
1430
1432
)
1431
- for ranks_per_tables in ( new_ranks_count_per_plan )
1433
+ for _ in range ( total_plans_per_benchmark )
1432
1434
]
1433
1435
1434
1436
res = multi_process_benchmark (
0 commit comments