5656 EmbeddingStorageEstimator ,
5757)
5858from torchrec .distributed .shard import _shard_modules
59+
60+ from torchrec .distributed .sharding .dynamic_sharding import output_sharding_plan_delta
5961from torchrec .distributed .sharding_plan import (
6062 construct_module_sharding_plan ,
6163 get_sharding_constructor_from_type ,
6264)
6365from torchrec .distributed .test_utils .multi_process import MultiProcessContext
6466from torchrec .distributed .test_utils .test_model import ModelInput
6567
66- from torchrec .distributed .types import DataType , ModuleSharder , ShardingEnv
68+ from torchrec .distributed .types import (
69+ DataType ,
70+ EmbeddingModuleShardingPlan ,
71+ ModuleSharder ,
72+ ShardingEnv ,
73+ )
6774from torchrec .fx import symbolic_trace
6875from torchrec .modules .embedding_configs import EmbeddingBagConfig , EmbeddingConfig
6976from torchrec .quant .embedding_modules import (
@@ -368,7 +375,7 @@ def _generate_rank_placements(
368375 world_size : int ,
369376 num_tables : int ,
370377 ranks_per_tables : List [int ],
371- random_seed : int = None ,
378+ random_seed : Optional [ int ] = None ,
372379) -> List [List [int ]]:
373380 # Cannot include old/new rank generation with hypothesis library due to depedency on world_size
374381 if random_seed is None :
@@ -387,7 +394,9 @@ def _generate_rank_placements(
387394
388395
389396def default_func_to_benchmark (
390- model : torch .nn .Module , bench_inputs : List [KeyedJaggedTensor ]
397+ model : torch .nn .Module ,
398+ bench_inputs : List [KeyedJaggedTensor ],
399+ resharding_plan_diffs : Optional [List [EmbeddingModuleShardingPlan ]] = None ,
391400) -> None :
392401 with torch .inference_mode ():
393402 for bench_input in bench_inputs :
@@ -855,11 +864,6 @@ def _run_benchmark_core(
855864 # Timings
856865 start_events , end_events , times = [], [], []
857866
858- < << << << dest : af95f723afd1 - noreply + 1244265887488347 : [AutoAccept ][Codemod ...
859- == == == =
860- times = []
861-
862- > >> >> >> source : ea51915e8d98 - isuru : Reshard API Performance Benchmarking
863867 if device_type == "cuda" :
864868 start_events = [
865869 torch .cuda .Event (enable_timing = True ) for _ in range (num_benchmarks )
@@ -1154,58 +1158,59 @@ def init_module_and_run_benchmark(
11541158 if rank != - 1
11551159 else contextlib .nullcontext ()
11561160 ) as ctx :
1157- module = transform_module (
1158- module = module ,
1159- device = device ,
1160- inputs = warmup_inputs_cuda ,
1161- sharder = sharder ,
1162- sharding_type = sharding_type ,
1163- compile_mode = compile_mode ,
1164- world_size = world_size ,
1165- batch_size = batch_size ,
1166- # pyre-ignore[6]
1167- ctx = ctx ,
1168- benchmark_unsharded_module = benchmark_unsharded_module ,
1169- )
1170-
1171- if benchmark_unsharded_module :
1172- name = "unsharded" + compile_mode .name
1173- else :
1174- name = benchmark_type_name (compile_mode , sharding_type )
11751161
11761162 resharding_plans = []
11771163
1178- import fbvscode
1179-
1180- fbvscode .set_trace ()
1181-
11821164 if new_ranks_per_plan is not None and len (new_ranks_per_plan ) > 0 :
11831165 sharding_type_constructor = get_sharding_constructor_from_type (
11841166 sharding_type
11851167 )
1186- for i , new_ranks in enumerate ( new_ranks_per_plan ) :
1168+ for new_ranks_per_table in new_ranks_per_plan :
11871169 new_per_param_sharding = {}
1188- for table in tables :
1170+ for table_id , table in enumerate ( tables ) :
11891171 if sharding_type == ShardingType .TABLE_WISE :
11901172 new_per_param_sharding [table .name ] = sharding_type_constructor (
1191- rank = new_ranks , compute_kernel = sharder ._kernel_type
1173+ rank = new_ranks_per_table [table_id ][0 ],
1174+ compute_kernel = sharder ._kernel_type ,
11921175 )
11931176 elif sharding_type == ShardingType .COLUMN_WISE :
11941177 new_per_param_sharding [table .name ] = sharding_type_constructor (
1195- ranks = new_ranks
1178+ ranks = new_ranks_per_table [ table_id ]
11961179 )
11971180
11981181 new_module_sharding_plan = construct_module_sharding_plan (
1199- module = module .module ,
1182+ module = module ._module , # Pyre-ignore
12001183 # Pyre-ignore
12011184 sharder = sharder ,
12021185 per_param_sharding = new_per_param_sharding ,
12031186 local_size = world_size ,
12041187 world_size = world_size ,
1205- device_type = "cuda" if torch . cuda . is_available () else "cpu" ,
1188+ device_type = device . type ,
12061189 )
12071190 resharding_plans .append (new_module_sharding_plan )
1208- benchmark_func_kwargs ["resharding_plans" ] = resharding_plans
1191+
1192+ module = transform_module (
1193+ module = module ,
1194+ device = device ,
1195+ inputs = warmup_inputs_cuda ,
1196+ sharder = sharder ,
1197+ sharding_type = sharding_type ,
1198+ compile_mode = compile_mode ,
1199+ world_size = world_size ,
1200+ batch_size = batch_size ,
1201+ # pyre-ignore[6]
1202+ ctx = ctx ,
1203+ benchmark_unsharded_module = benchmark_unsharded_module ,
1204+ )
1205+
1206+ if benchmark_unsharded_module :
1207+ name = "unsharded" + compile_mode .name
1208+ else :
1209+ name = benchmark_type_name (compile_mode , sharding_type )
1210+
1211+ if benchmark_func_kwargs is None :
1212+ benchmark_func_kwargs = {}
1213+ benchmark_func_kwargs ["resharding_plan_diffs" ] = resharding_plans
12091214
12101215 res = benchmark (
12111216 name ,
@@ -1398,22 +1403,18 @@ def benchmark_module(
13981403 )
13991404
14001405 if train :
1401- total_plans_per_benchmark = bench_iters // resharding_interval
1402- total_plans_per_benchmark = max (1 , total_plans_per_benchmark )
1406+
14031407 new_ranks_per_plan = []
1408+
14041409 if enable_resharding :
1410+ total_plans_per_benchmark = bench_iters // resharding_interval
1411+ total_plans_per_benchmark = max (1 , total_plans_per_benchmark )
1412+
14051413 num_tables = len (tables )
1406- new_ranks_count_per_plan = [
1407- [] for _ in range (total_plans_per_benchmark )
1408- ]
1414+ ranks_per_tables = []
1415+
14091416 if sharding_type == ShardingType .TABLE_WISE :
14101417 ranks_per_tables = [1 for _ in range (num_tables )]
1411- new_ranks_per_plan = [
1412- _generate_rank_placements (
1413- world_size , num_tables , ranks_per_tables
1414- )
1415- for _ in range (total_plans_per_benchmark )
1416- ]
14171418
14181419 elif sharding_type == ShardingType .COLUMN_WISE :
14191420 valid_candidates = [
@@ -1424,11 +1425,12 @@ def benchmark_module(
14241425 ranks_per_tables = [
14251426 random .choice (valid_candidates ) for _ in range (num_tables )
14261427 ]
1428+
14271429 new_ranks_per_plan = [
14281430 _generate_rank_placements (
14291431 world_size , num_tables , ranks_per_tables
14301432 )
1431- for ranks_per_tables in ( new_ranks_count_per_plan )
1433+ for _ in range ( total_plans_per_benchmark )
14321434 ]
14331435
14341436 res = multi_process_benchmark (
0 commit comments