Skip to content

Commit 5d388f3

Browse files
Isuru Janith Ranawakafacebook-github-bot
authored andcommitted
Reshard API Performance Benchmarking (#3218)
Summary: Pull Request resolved: #3218 - Identify baseline performance with and without reshard API - Identify different baselines for different sharding strategies under different data sets Differential Revision: D78672730
1 parent 51dde0f commit 5d388f3

File tree

3 files changed

+76
-52
lines changed

3 files changed

+76
-52
lines changed

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
write_report,
2929
)
3030
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
31+
from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta
3132
from torchrec.distributed.test_utils.test_model import TestEBCSharder
32-
from torchrec.distributed.types import DataType
33+
from torchrec.distributed.types import DataType, EmbeddingModuleShardingPlan
3334
from torchrec.modules.embedding_modules import EmbeddingBagCollection
3435
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3536

@@ -53,9 +54,27 @@ def training_func_to_benchmark(
5354
model: torch.nn.Module,
5455
bench_inputs: List[KeyedJaggedTensor],
5556
optimizer: Optional[torch.optim.Optimizer],
57+
resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None,
5658
) -> None:
5759

58-
for bench_input in bench_inputs:
60+
reshard_idx = 0
61+
62+
for i, bench_input in enumerate(bench_inputs):
63+
if resharding_plan_diffs is not None:
64+
if (
65+
i > 0
66+
and len(resharding_plan_diffs) > 0
67+
and i % (len(bench_inputs) / len(resharding_plan_diffs)) == 0
68+
):
69+
70+
plan_difference = output_sharding_plan_delta(
71+
# Pyre-ignore
72+
model.plan.plan["_module"],
73+
resharding_plan_diffs[reshard_idx],
74+
)
75+
# Pyre-ignore
76+
model.reshard("_module", plan_difference)
77+
reshard_idx += 1
5978
pooled_embeddings = model(bench_input)
6079
vals = []
6180
for _name, param in pooled_embeddings.to_dict().items():

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,21 @@
5656
EmbeddingStorageEstimator,
5757
)
5858
from torchrec.distributed.shard import _shard_modules
59+
60+
from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta
5961
from torchrec.distributed.sharding_plan import (
6062
construct_module_sharding_plan,
6163
get_sharding_constructor_from_type,
6264
)
6365
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
6466
from torchrec.distributed.test_utils.test_model import ModelInput
6567

66-
from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv
68+
from torchrec.distributed.types import (
69+
DataType,
70+
EmbeddingModuleShardingPlan,
71+
ModuleSharder,
72+
ShardingEnv,
73+
)
6774
from torchrec.fx import symbolic_trace
6875
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
6976
from torchrec.quant.embedding_modules import (
@@ -368,7 +375,7 @@ def _generate_rank_placements(
368375
world_size: int,
369376
num_tables: int,
370377
ranks_per_tables: List[int],
371-
random_seed: int = None,
378+
random_seed: Optional[int] = None,
372379
) -> List[List[int]]:
373380
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
374381
if random_seed is None:
@@ -387,7 +394,9 @@ def _generate_rank_placements(
387394

388395

389396
def default_func_to_benchmark(
390-
model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor]
397+
model: torch.nn.Module,
398+
bench_inputs: List[KeyedJaggedTensor],
399+
resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None,
391400
) -> None:
392401
with torch.inference_mode():
393402
for bench_input in bench_inputs:
@@ -855,11 +864,6 @@ def _run_benchmark_core(
855864
# Timings
856865
start_events, end_events, times = [], [], []
857866

858-
<<<<<<< dest: af95f723afd1 - noreply+1244265887488347: [AutoAccept][Codemod...
859-
=======
860-
times = []
861-
862-
>>>>>>> source: ea51915e8d98 - isuru: Reshard API Performance Benchmarking
863867
if device_type == "cuda":
864868
start_events = [
865869
torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)
@@ -1154,58 +1158,59 @@ def init_module_and_run_benchmark(
11541158
if rank != -1
11551159
else contextlib.nullcontext()
11561160
) as ctx:
1157-
module = transform_module(
1158-
module=module,
1159-
device=device,
1160-
inputs=warmup_inputs_cuda,
1161-
sharder=sharder,
1162-
sharding_type=sharding_type,
1163-
compile_mode=compile_mode,
1164-
world_size=world_size,
1165-
batch_size=batch_size,
1166-
# pyre-ignore[6]
1167-
ctx=ctx,
1168-
benchmark_unsharded_module=benchmark_unsharded_module,
1169-
)
1170-
1171-
if benchmark_unsharded_module:
1172-
name = "unsharded" + compile_mode.name
1173-
else:
1174-
name = benchmark_type_name(compile_mode, sharding_type)
11751161

11761162
resharding_plans = []
11771163

1178-
import fbvscode
1179-
1180-
fbvscode.set_trace()
1181-
11821164
if new_ranks_per_plan is not None and len(new_ranks_per_plan) > 0:
11831165
sharding_type_constructor = get_sharding_constructor_from_type(
11841166
sharding_type
11851167
)
1186-
for i, new_ranks in enumerate(new_ranks_per_plan):
1168+
for new_ranks_per_table in new_ranks_per_plan:
11871169
new_per_param_sharding = {}
1188-
for table in tables:
1170+
for table_id, table in enumerate(tables):
11891171
if sharding_type == ShardingType.TABLE_WISE:
11901172
new_per_param_sharding[table.name] = sharding_type_constructor(
1191-
rank=new_ranks, compute_kernel=sharder._kernel_type
1173+
rank=new_ranks_per_table[table_id][0],
1174+
compute_kernel=sharder._kernel_type,
11921175
)
11931176
elif sharding_type == ShardingType.COLUMN_WISE:
11941177
new_per_param_sharding[table.name] = sharding_type_constructor(
1195-
ranks=new_ranks
1178+
ranks=new_ranks_per_table[table_id]
11961179
)
11971180

11981181
new_module_sharding_plan = construct_module_sharding_plan(
1199-
module=module.module,
1182+
module=module._module, # Pyre-ignore
12001183
# Pyre-ignore
12011184
sharder=sharder,
12021185
per_param_sharding=new_per_param_sharding,
12031186
local_size=world_size,
12041187
world_size=world_size,
1205-
device_type="cuda" if torch.cuda.is_available() else "cpu",
1188+
device_type=device.type,
12061189
)
12071190
resharding_plans.append(new_module_sharding_plan)
1208-
benchmark_func_kwargs["resharding_plans"] = resharding_plans
1191+
1192+
module = transform_module(
1193+
module=module,
1194+
device=device,
1195+
inputs=warmup_inputs_cuda,
1196+
sharder=sharder,
1197+
sharding_type=sharding_type,
1198+
compile_mode=compile_mode,
1199+
world_size=world_size,
1200+
batch_size=batch_size,
1201+
# pyre-ignore[6]
1202+
ctx=ctx,
1203+
benchmark_unsharded_module=benchmark_unsharded_module,
1204+
)
1205+
1206+
if benchmark_unsharded_module:
1207+
name = "unsharded" + compile_mode.name
1208+
else:
1209+
name = benchmark_type_name(compile_mode, sharding_type)
1210+
1211+
if benchmark_func_kwargs is None:
1212+
benchmark_func_kwargs = {}
1213+
benchmark_func_kwargs["resharding_plan_diffs"] = resharding_plans
12091214

12101215
res = benchmark(
12111216
name,
@@ -1398,22 +1403,18 @@ def benchmark_module(
13981403
)
13991404

14001405
if train:
1401-
total_plans_per_benchmark = bench_iters // resharding_interval
1402-
total_plans_per_benchmark = max(1, total_plans_per_benchmark)
1406+
14031407
new_ranks_per_plan = []
1408+
14041409
if enable_resharding:
1410+
total_plans_per_benchmark = bench_iters // resharding_interval
1411+
total_plans_per_benchmark = max(1, total_plans_per_benchmark)
1412+
14051413
num_tables = len(tables)
1406-
new_ranks_count_per_plan = [
1407-
[] for _ in range(total_plans_per_benchmark)
1408-
]
1414+
ranks_per_tables = []
1415+
14091416
if sharding_type == ShardingType.TABLE_WISE:
14101417
ranks_per_tables = [1 for _ in range(num_tables)]
1411-
new_ranks_per_plan = [
1412-
_generate_rank_placements(
1413-
world_size, num_tables, ranks_per_tables
1414-
)
1415-
for _ in range(total_plans_per_benchmark)
1416-
]
14171418

14181419
elif sharding_type == ShardingType.COLUMN_WISE:
14191420
valid_candidates = [
@@ -1424,11 +1425,12 @@ def benchmark_module(
14241425
ranks_per_tables = [
14251426
random.choice(valid_candidates) for _ in range(num_tables)
14261427
]
1428+
14271429
new_ranks_per_plan = [
14281430
_generate_rank_placements(
14291431
world_size, num_tables, ranks_per_tables
14301432
)
1431-
for ranks_per_tables in (new_ranks_count_per_plan)
1433+
for _ in range(total_plans_per_benchmark)
14321434
]
14331435

14341436
res = multi_process_benchmark(

torchrec/distributed/model_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,10 @@ def reshard(
742742

743743
# Need to use .module to maintain FQN consistency
744744
self._optim: CombinedOptimizer = self._init_optim(
745-
self._dmp_wrapped_module.module # pyre-ignore
745+
# pyre-ignore
746+
self._dmp_wrapped_module.module
747+
if hasattr(self._dmp_wrapped_module, "module")
748+
else self._dmp_wrapped_module._module
746749
)
747750
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
748751
return sharded_module

0 commit comments

Comments
 (0)