Skip to content

Commit 8b6c525

Browse files
Isuru Janith Ranawakafacebook-github-bot
authored andcommitted
Reshard API Performance Benchmarking
Summary: - Identify baseline performance with and without reshard API - Identify different baselines for different sharding strategies under different data sets
1 parent 1b8004a commit 8b6c525

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def training_func_to_benchmark(
5454
bench_inputs: List[KeyedJaggedTensor],
5555
optimizer: Optional[torch.optim.Optimizer],
5656
) -> None:
57+
5758
for bench_input in bench_inputs:
5859
pooled_embeddings = model(bench_input)
5960
vals = []
@@ -120,6 +121,7 @@ def benchmark_ebc(
120121

121122

122123
def main() -> None:
124+
# torch.cuda.cudart().cudaProfilerStart()
123125
args: argparse.Namespace = init_argparse_and_args()
124126

125127
num_requests = args.bench_iters * args.batch_size * args.num_benchmarks
@@ -203,6 +205,8 @@ def main() -> None:
203205
for i, write_report_func in enumerate(write_report_funcs_per_module):
204206
write_report_func(benchmark_results_per_module[i])
205207

208+
# torch.cuda.cudart().cudaProfilerStop()
209+
206210

207211
def invoke_main() -> None:
208212
logging.basicConfig()
@@ -212,4 +216,5 @@ def invoke_main() -> None:
212216

213217

214218
if __name__ == "__main__":
219+
215220
invoke_main() # pragma: no cover

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import logging
2020
import os
21+
import random
2122
import time
2223
import timeit
2324
from dataclasses import dataclass, fields, is_dataclass, MISSING
@@ -54,6 +55,10 @@
5455
EmbeddingStorageEstimator,
5556
)
5657
from torchrec.distributed.shard import _shard_modules
58+
from torchrec.distributed.sharding_plan import (
59+
construct_module_sharding_plan,
60+
get_sharding_constructor_from_type,
61+
)
5762
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
5863
from torchrec.distributed.test_utils.test_model import ModelInput
5964

@@ -308,6 +313,28 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor:
308313
T = TypeVar("T", bound=torch.nn.Module)
309314

310315

316+
def _generate_rank_placements(
317+
world_size: int,
318+
num_tables: int,
319+
ranks_per_tables: List[int],
320+
random_seed: int = None,
321+
) -> List[List[int]]:
322+
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
323+
if random_seed is None:
324+
# Generate a random seed to ensure that the output rank placements can be different each time
325+
random_seed = random.randint(0, 10000)
326+
placements = []
327+
max_rank = world_size - 1
328+
random.seed(random_seed)
329+
if ranks_per_tables == [0]:
330+
ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)]
331+
for i in range(num_tables):
332+
ranks_per_table = ranks_per_tables[i]
333+
placement = sorted(random.sample(range(world_size), ranks_per_table))
334+
placements.append(placement)
335+
return placements
336+
337+
311338
def default_func_to_benchmark(
312339
model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor]
313340
) -> None:
@@ -595,6 +622,8 @@ def init_argparse_and_args() -> argparse.Namespace:
595622
parser.add_argument("--num_benchmarks", type=int, default=5)
596623
parser.add_argument("--embedding_config_json", type=str, default="")
597624
parser.add_argument("--device_type", type=str, default="cuda")
625+
parser.add_argument("--enable_resharding", type=bool, default=False)
626+
parser.add_argument("--resharding_interval", type=int, default=1000)
598627

599628
args = parser.parse_args()
600629

@@ -644,6 +673,7 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
644673
# Don't want to modify the module outright
645674
# Since module is on cpu, won't cause cuda oom.
646675
copied_module = copy.deepcopy(module)
676+
647677
# pyre-ignore [6]
648678
plan = planner.plan(copied_module, [sharder])
649679

@@ -700,6 +730,7 @@ def benchmark(
700730
enable_logging: bool = True,
701731
device_type: str = "cuda",
702732
benchmark_unsharded_module: bool = False,
733+
new_ranks_per_plan: Optional[List[int]] = None,
703734
) -> BenchmarkResult:
704735
memory_stats: List[MemoryStats] = []
705736
if enable_logging:
@@ -728,6 +759,7 @@ def benchmark(
728759
benchmark_func_kwargs = {}
729760

730761
times = []
762+
731763
if device_type == "cuda":
732764
for i in range(num_benchmarks):
733765
start[i].record()
@@ -998,6 +1030,7 @@ def init_module_and_run_benchmark(
9981030
queue: Optional[mp.Queue] = None,
9991031
pooling_configs: Optional[List[int]] = None,
10001032
benchmark_unsharded_module: bool = False,
1033+
new_ranks_per_plan: Optional[List[List[int]]] = None,
10011034
) -> BenchmarkResult:
10021035
"""
10031036
There are a couple of caveats here as to why the module has to be initialized
@@ -1063,6 +1096,40 @@ def init_module_and_run_benchmark(
10631096
else:
10641097
name = benchmark_type_name(compile_mode, sharding_type)
10651098

1099+
resharding_plans = []
1100+
1101+
import fbvscode
1102+
1103+
fbvscode.set_trace()
1104+
1105+
if new_ranks_per_plan is not None and len(new_ranks_per_plan) > 0:
1106+
sharding_type_constructor = get_sharding_constructor_from_type(
1107+
sharding_type
1108+
)
1109+
for i, new_ranks in enumerate(new_ranks_per_plan):
1110+
new_per_param_sharding = {}
1111+
for table in tables:
1112+
if sharding_type == ShardingType.TABLE_WISE:
1113+
new_per_param_sharding[table.name] = sharding_type_constructor(
1114+
rank=new_ranks, compute_kernel=sharder._kernel_type
1115+
)
1116+
elif sharding_type == ShardingType.COLUMN_WISE:
1117+
new_per_param_sharding[table.name] = sharding_type_constructor(
1118+
ranks=new_ranks
1119+
)
1120+
1121+
new_module_sharding_plan = construct_module_sharding_plan(
1122+
module=module.module,
1123+
# Pyre-ignore
1124+
sharder=sharder,
1125+
per_param_sharding=new_per_param_sharding,
1126+
local_size=world_size,
1127+
world_size=world_size,
1128+
device_type="cuda" if torch.cuda.is_available() else "cpu",
1129+
)
1130+
resharding_plans.append(new_module_sharding_plan)
1131+
benchmark_func_kwargs["resharding_plans"] = resharding_plans
1132+
10661133
res = benchmark(
10671134
name,
10681135
module,
@@ -1167,6 +1234,8 @@ def benchmark_module(
11671234
pooling_configs: Optional[List[int]] = None,
11681235
variable_batch_embeddings: bool = False,
11691236
device_type: str = "cuda",
1237+
enable_resharding: bool = False,
1238+
resharding_interval: int = 1000,
11701239
) -> List[BenchmarkResult]:
11711240
"""
11721241
Args:
@@ -1248,6 +1317,39 @@ def benchmark_module(
12481317
)
12491318

12501319
if train:
1320+
total_plans_per_benchmark = bench_iters // resharding_interval
1321+
total_plans_per_benchmark = max(1, total_plans_per_benchmark)
1322+
new_ranks_per_plan = []
1323+
if enable_resharding:
1324+
num_tables = len(tables)
1325+
new_ranks_count_per_plan = [
1326+
[] for _ in range(total_plans_per_benchmark)
1327+
]
1328+
if sharding_type == ShardingType.TABLE_WISE:
1329+
ranks_per_tables = [1 for _ in range(num_tables)]
1330+
new_ranks_per_plan = [
1331+
_generate_rank_placements(
1332+
world_size, num_tables, ranks_per_tables
1333+
)
1334+
for _ in range(total_plans_per_benchmark)
1335+
]
1336+
1337+
elif sharding_type == ShardingType.COLUMN_WISE:
1338+
valid_candidates = [
1339+
i
1340+
for i in range(1, world_size + 1)
1341+
if EMBEDDING_DIM % i == 0
1342+
]
1343+
ranks_per_tables = [
1344+
random.choice(valid_candidates) for _ in range(num_tables)
1345+
]
1346+
new_ranks_per_plan = [
1347+
_generate_rank_placements(
1348+
world_size, num_tables, ranks_per_tables
1349+
)
1350+
for ranks_per_tables in (new_ranks_count_per_plan)
1351+
]
1352+
12511353
res = multi_process_benchmark(
12521354
# pyre-ignore[6]
12531355
callable=init_module_and_run_benchmark,
@@ -1267,6 +1369,7 @@ def benchmark_module(
12671369
func_to_benchmark=func_to_benchmark,
12681370
benchmark_func_kwargs=benchmark_func_kwargs,
12691371
pooling_configs=pooling_configs,
1372+
new_ranks_per_plan=new_ranks_per_plan,
12701373
)
12711374
else:
12721375
res = init_module_and_run_benchmark(

0 commit comments

Comments
 (0)