55
55
EmbeddingStorageEstimator ,
56
56
)
57
57
from torchrec .distributed .shard import _shard_modules
58
+
59
+ from torchrec .distributed .sharding .dynamic_sharding import output_sharding_plan_delta
58
60
from torchrec .distributed .sharding_plan import (
59
61
construct_module_sharding_plan ,
60
62
get_sharding_constructor_from_type ,
61
63
)
62
64
from torchrec .distributed .test_utils .multi_process import MultiProcessContext
63
65
from torchrec .distributed .test_utils .test_model import ModelInput
64
66
65
- from torchrec .distributed .types import DataType , ModuleSharder , ShardingEnv
67
+ from torchrec .distributed .types import (
68
+ DataType ,
69
+ EmbeddingModuleShardingPlan ,
70
+ ModuleSharder ,
71
+ ShardingEnv ,
72
+ )
66
73
from torchrec .fx import symbolic_trace
67
74
from torchrec .modules .embedding_configs import EmbeddingBagConfig , EmbeddingConfig
68
75
from torchrec .quant .embedding_modules import (
@@ -317,7 +324,7 @@ def _generate_rank_placements(
317
324
world_size : int ,
318
325
num_tables : int ,
319
326
ranks_per_tables : List [int ],
320
- random_seed : int = None ,
327
+ random_seed : Optional [ int ] = None ,
321
328
) -> List [List [int ]]:
322
329
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
323
330
if random_seed is None :
@@ -1077,58 +1084,61 @@ def init_module_and_run_benchmark(
1077
1084
if rank != - 1
1078
1085
else contextlib .nullcontext ()
1079
1086
) as ctx :
1080
- module = transform_module (
1081
- module = module ,
1082
- device = device ,
1083
- inputs = warmup_inputs_cuda ,
1084
- sharder = sharder ,
1085
- sharding_type = sharding_type ,
1086
- compile_mode = compile_mode ,
1087
- world_size = world_size ,
1088
- batch_size = batch_size ,
1089
- # pyre-ignore[6]
1090
- ctx = ctx ,
1091
- benchmark_unsharded_module = benchmark_unsharded_module ,
1092
- )
1093
-
1094
- if benchmark_unsharded_module :
1095
- name = "unsharded" + compile_mode .name
1096
- else :
1097
- name = benchmark_type_name (compile_mode , sharding_type )
1098
1087
1099
1088
resharding_plans = []
1100
1089
1101
- import fbvscode
1102
-
1103
- fbvscode .set_trace ()
1104
-
1105
1090
if new_ranks_per_plan is not None and len (new_ranks_per_plan ) > 0 :
1106
1091
sharding_type_constructor = get_sharding_constructor_from_type (
1107
1092
sharding_type
1108
1093
)
1109
- for i , new_ranks in enumerate ( new_ranks_per_plan ) :
1094
+ for new_ranks_per_table in new_ranks_per_plan :
1110
1095
new_per_param_sharding = {}
1111
- for table in tables :
1096
+ for table_id , table in enumerate ( tables ) :
1112
1097
if sharding_type == ShardingType .TABLE_WISE :
1113
1098
new_per_param_sharding [table .name ] = sharding_type_constructor (
1114
- rank = new_ranks , compute_kernel = sharder ._kernel_type
1099
+ rank = new_ranks_per_table [table_id ][0 ],
1100
+ compute_kernel = sharder ._kernel_type ,
1115
1101
)
1116
1102
elif sharding_type == ShardingType .COLUMN_WISE :
1117
1103
new_per_param_sharding [table .name ] = sharding_type_constructor (
1118
- ranks = new_ranks
1104
+ ranks = new_ranks_per_table [ table_id ]
1119
1105
)
1120
1106
1121
1107
new_module_sharding_plan = construct_module_sharding_plan (
1122
- module = module .module ,
1108
+ module = module ._module , # Pyre-ignore
1123
1109
# Pyre-ignore
1124
1110
sharder = sharder ,
1125
1111
per_param_sharding = new_per_param_sharding ,
1126
1112
local_size = world_size ,
1127
1113
world_size = world_size ,
1128
- device_type = "cuda" if torch . cuda . is_available () else "cpu" ,
1114
+ device_type = device . type ,
1129
1115
)
1130
1116
resharding_plans .append (new_module_sharding_plan )
1131
- benchmark_func_kwargs ["resharding_plans" ] = resharding_plans
1117
+
1118
+ module = transform_module (
1119
+ module = module ,
1120
+ device = device ,
1121
+ inputs = warmup_inputs_cuda ,
1122
+ sharder = sharder ,
1123
+ sharding_type = sharding_type ,
1124
+ compile_mode = compile_mode ,
1125
+ world_size = world_size ,
1126
+ batch_size = batch_size ,
1127
+ # pyre-ignore[6]
1128
+ ctx = ctx ,
1129
+ benchmark_unsharded_module = benchmark_unsharded_module ,
1130
+ )
1131
+
1132
+ if benchmark_unsharded_module :
1133
+ name = "unsharded" + compile_mode .name
1134
+ else :
1135
+ name = benchmark_type_name (compile_mode , sharding_type )
1136
+
1137
+ # plan_difference = [
1138
+ # output_sharding_plan_delta(module.plan.plan["_module"], reshard_plan)
1139
+ # for reshard_plan in resharding_plans
1140
+ # ]
1141
+ benchmark_func_kwargs ["resharding_plan_diffs" ] = resharding_plans
1132
1142
1133
1143
res = benchmark (
1134
1144
name ,
@@ -1317,22 +1327,18 @@ def benchmark_module(
1317
1327
)
1318
1328
1319
1329
if train :
1320
- total_plans_per_benchmark = bench_iters // resharding_interval
1321
- total_plans_per_benchmark = max (1 , total_plans_per_benchmark )
1330
+
1322
1331
new_ranks_per_plan = []
1332
+
1323
1333
if enable_resharding :
1334
+ total_plans_per_benchmark = bench_iters // resharding_interval
1335
+ total_plans_per_benchmark = max (1 , total_plans_per_benchmark )
1336
+
1324
1337
num_tables = len (tables )
1325
- new_ranks_count_per_plan = [
1326
- [] for _ in range (total_plans_per_benchmark )
1327
- ]
1338
+ ranks_per_tables = []
1339
+
1328
1340
if sharding_type == ShardingType .TABLE_WISE :
1329
1341
ranks_per_tables = [1 for _ in range (num_tables )]
1330
- new_ranks_per_plan = [
1331
- _generate_rank_placements (
1332
- world_size , num_tables , ranks_per_tables
1333
- )
1334
- for _ in range (total_plans_per_benchmark )
1335
- ]
1336
1342
1337
1343
elif sharding_type == ShardingType .COLUMN_WISE :
1338
1344
valid_candidates = [
@@ -1343,11 +1349,12 @@ def benchmark_module(
1343
1349
ranks_per_tables = [
1344
1350
random .choice (valid_candidates ) for _ in range (num_tables )
1345
1351
]
1352
+
1346
1353
new_ranks_per_plan = [
1347
1354
_generate_rank_placements (
1348
1355
world_size , num_tables , ranks_per_tables
1349
1356
)
1350
- for ranks_per_tables in ( new_ranks_count_per_plan )
1357
+ for _ in range ( total_plans_per_benchmark )
1351
1358
]
1352
1359
1353
1360
res = multi_process_benchmark (
0 commit comments