Skip to content

Commit 71e160f

Browse files
isururanawakafacebook-github-bot
authored andcommitted
Reshard API Performance Benchmarking
Summary: - Identify baseline performance with and without reshard API - Identify different baselines for different sharding strategies under different data sets Differential Revision: D78672730
1 parent 8b6c525 commit 71e160f

File tree

3 files changed

+72
-46
lines changed

3 files changed

+72
-46
lines changed

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
write_report,
2929
)
3030
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
31+
from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta
3132
from torchrec.distributed.test_utils.test_model import TestEBCSharder
32-
from torchrec.distributed.types import DataType
33+
from torchrec.distributed.types import DataType, EmbeddingModuleShardingPlan
3334
from torchrec.modules.embedding_modules import EmbeddingBagCollection
3435
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3536

@@ -53,9 +54,27 @@ def training_func_to_benchmark(
5354
model: torch.nn.Module,
5455
bench_inputs: List[KeyedJaggedTensor],
5556
optimizer: Optional[torch.optim.Optimizer],
57+
resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None,
5658
) -> None:
5759

58-
for bench_input in bench_inputs:
60+
reshard_idx = 0
61+
62+
for i, bench_input in enumerate(bench_inputs):
63+
if resharding_plan_diffs is not None:
64+
if (
65+
i > 0
66+
and len(resharding_plan_diffs) > 0
67+
and i % (len(bench_inputs) / len(resharding_plan_diffs)) == 0
68+
):
69+
70+
plan_difference = output_sharding_plan_delta(
71+
# Pyre-ignore
72+
model.plan.plan["_module"],
73+
resharding_plan_diffs[reshard_idx],
74+
)
75+
# Pyre-ignore
76+
model.reshard("_module", plan_difference)
77+
reshard_idx += 1
5978
pooled_embeddings = model(bench_input)
6079
vals = []
6180
for _name, param in pooled_embeddings.to_dict().items():

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,21 @@
5555
EmbeddingStorageEstimator,
5656
)
5757
from torchrec.distributed.shard import _shard_modules
58+
59+
from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta
5860
from torchrec.distributed.sharding_plan import (
5961
construct_module_sharding_plan,
6062
get_sharding_constructor_from_type,
6163
)
6264
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
6365
from torchrec.distributed.test_utils.test_model import ModelInput
6466

65-
from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv
67+
from torchrec.distributed.types import (
68+
DataType,
69+
EmbeddingModuleShardingPlan,
70+
ModuleSharder,
71+
ShardingEnv,
72+
)
6673
from torchrec.fx import symbolic_trace
6774
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
6875
from torchrec.quant.embedding_modules import (
@@ -317,7 +324,7 @@ def _generate_rank_placements(
317324
world_size: int,
318325
num_tables: int,
319326
ranks_per_tables: List[int],
320-
random_seed: int = None,
327+
random_seed: Optional[int] = None,
321328
) -> List[List[int]]:
322329
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
323330
if random_seed is None:
@@ -1077,58 +1084,61 @@ def init_module_and_run_benchmark(
10771084
if rank != -1
10781085
else contextlib.nullcontext()
10791086
) as ctx:
1080-
module = transform_module(
1081-
module=module,
1082-
device=device,
1083-
inputs=warmup_inputs_cuda,
1084-
sharder=sharder,
1085-
sharding_type=sharding_type,
1086-
compile_mode=compile_mode,
1087-
world_size=world_size,
1088-
batch_size=batch_size,
1089-
# pyre-ignore[6]
1090-
ctx=ctx,
1091-
benchmark_unsharded_module=benchmark_unsharded_module,
1092-
)
1093-
1094-
if benchmark_unsharded_module:
1095-
name = "unsharded" + compile_mode.name
1096-
else:
1097-
name = benchmark_type_name(compile_mode, sharding_type)
10981087

10991088
resharding_plans = []
11001089

1101-
import fbvscode
1102-
1103-
fbvscode.set_trace()
1104-
11051090
if new_ranks_per_plan is not None and len(new_ranks_per_plan) > 0:
11061091
sharding_type_constructor = get_sharding_constructor_from_type(
11071092
sharding_type
11081093
)
1109-
for i, new_ranks in enumerate(new_ranks_per_plan):
1094+
for new_ranks_per_table in new_ranks_per_plan:
11101095
new_per_param_sharding = {}
1111-
for table in tables:
1096+
for table_id, table in enumerate(tables):
11121097
if sharding_type == ShardingType.TABLE_WISE:
11131098
new_per_param_sharding[table.name] = sharding_type_constructor(
1114-
rank=new_ranks, compute_kernel=sharder._kernel_type
1099+
rank=new_ranks_per_table[table_id][0],
1100+
compute_kernel=sharder._kernel_type,
11151101
)
11161102
elif sharding_type == ShardingType.COLUMN_WISE:
11171103
new_per_param_sharding[table.name] = sharding_type_constructor(
1118-
ranks=new_ranks
1104+
ranks=new_ranks_per_table[table_id]
11191105
)
11201106

11211107
new_module_sharding_plan = construct_module_sharding_plan(
1122-
module=module.module,
1108+
module=module._module, # Pyre-ignore
11231109
# Pyre-ignore
11241110
sharder=sharder,
11251111
per_param_sharding=new_per_param_sharding,
11261112
local_size=world_size,
11271113
world_size=world_size,
1128-
device_type="cuda" if torch.cuda.is_available() else "cpu",
1114+
device_type=device.type,
11291115
)
11301116
resharding_plans.append(new_module_sharding_plan)
1131-
benchmark_func_kwargs["resharding_plans"] = resharding_plans
1117+
1118+
module = transform_module(
1119+
module=module,
1120+
device=device,
1121+
inputs=warmup_inputs_cuda,
1122+
sharder=sharder,
1123+
sharding_type=sharding_type,
1124+
compile_mode=compile_mode,
1125+
world_size=world_size,
1126+
batch_size=batch_size,
1127+
# pyre-ignore[6]
1128+
ctx=ctx,
1129+
benchmark_unsharded_module=benchmark_unsharded_module,
1130+
)
1131+
1132+
if benchmark_unsharded_module:
1133+
name = "unsharded" + compile_mode.name
1134+
else:
1135+
name = benchmark_type_name(compile_mode, sharding_type)
1136+
1137+
# plan_difference = [
1138+
# output_sharding_plan_delta(module.plan.plan["_module"], reshard_plan)
1139+
# for reshard_plan in resharding_plans
1140+
# ]
1141+
benchmark_func_kwargs["resharding_plan_diffs"] = resharding_plans
11321142

11331143
res = benchmark(
11341144
name,
@@ -1317,22 +1327,18 @@ def benchmark_module(
13171327
)
13181328

13191329
if train:
1320-
total_plans_per_benchmark = bench_iters // resharding_interval
1321-
total_plans_per_benchmark = max(1, total_plans_per_benchmark)
1330+
13221331
new_ranks_per_plan = []
1332+
13231333
if enable_resharding:
1334+
total_plans_per_benchmark = bench_iters // resharding_interval
1335+
total_plans_per_benchmark = max(1, total_plans_per_benchmark)
1336+
13241337
num_tables = len(tables)
1325-
new_ranks_count_per_plan = [
1326-
[] for _ in range(total_plans_per_benchmark)
1327-
]
1338+
ranks_per_tables = []
1339+
13281340
if sharding_type == ShardingType.TABLE_WISE:
13291341
ranks_per_tables = [1 for _ in range(num_tables)]
1330-
new_ranks_per_plan = [
1331-
_generate_rank_placements(
1332-
world_size, num_tables, ranks_per_tables
1333-
)
1334-
for _ in range(total_plans_per_benchmark)
1335-
]
13361342

13371343
elif sharding_type == ShardingType.COLUMN_WISE:
13381344
valid_candidates = [
@@ -1343,11 +1349,12 @@ def benchmark_module(
13431349
ranks_per_tables = [
13441350
random.choice(valid_candidates) for _ in range(num_tables)
13451351
]
1352+
13461353
new_ranks_per_plan = [
13471354
_generate_rank_placements(
13481355
world_size, num_tables, ranks_per_tables
13491356
)
1350-
for ranks_per_tables in (new_ranks_count_per_plan)
1357+
for _ in range(total_plans_per_benchmark)
13511358
]
13521359

13531360
res = multi_process_benchmark(

torchrec/distributed/model_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def reshard(
742742

743743
# Need to use .module to maintain FQN consistency
744744
self._optim: CombinedOptimizer = self._init_optim(
745-
self._dmp_wrapped_module.module # pyre-ignore
745+
self._dmp_wrapped_module.module if hasattr(self._dmp_wrapped_module, "module") else self._dmp_wrapped_module._module # pyre-ignore
746746
)
747747
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
748748
return sharded_module

0 commit comments

Comments
 (0)