Skip to content

Commit a98e18f

Browse files
Isuru Janith Ranawakafacebook-github-bot
authored andcommitted
Reshard API Performance Benchmarking
Summary: - Identify baseline performance with and without reshard API - Identify different baselines for different sharding strategies under different data sets
1 parent 143d088 commit a98e18f

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def training_func_to_benchmark(
5454
bench_inputs: List[KeyedJaggedTensor],
5555
optimizer: Optional[torch.optim.Optimizer],
5656
) -> None:
57+
5758
for bench_input in bench_inputs:
5859
pooled_embeddings = model(bench_input)
5960
vals = []
@@ -120,6 +121,7 @@ def benchmark_ebc(
120121

121122

122123
def main() -> None:
124+
# torch.cuda.cudart().cudaProfilerStart()
123125
args: argparse.Namespace = init_argparse_and_args()
124126

125127
num_requests = args.bench_iters * args.batch_size * args.num_benchmarks
@@ -203,6 +205,8 @@ def main() -> None:
203205
for i, write_report_func in enumerate(write_report_funcs_per_module):
204206
write_report_func(benchmark_results_per_module[i])
205207

208+
# torch.cuda.cudart().cudaProfilerStop()
209+
206210

207211
def invoke_main() -> None:
208212
logging.basicConfig()
@@ -212,4 +216,5 @@ def invoke_main() -> None:
212216

213217

214218
if __name__ == "__main__":
219+
215220
invoke_main() # pragma: no cover

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import logging
2020
import os
21+
import random
2122
import resource
2223
import time
2324
import timeit
@@ -55,6 +56,10 @@
5556
EmbeddingStorageEstimator,
5657
)
5758
from torchrec.distributed.shard import _shard_modules
59+
from torchrec.distributed.sharding_plan import (
60+
construct_module_sharding_plan,
61+
get_sharding_constructor_from_type,
62+
)
5863
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
5964
from torchrec.distributed.test_utils.test_model import ModelInput
6065

@@ -359,6 +364,28 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor:
359364
T = TypeVar("T", bound=torch.nn.Module)
360365

361366

367+
def _generate_rank_placements(
368+
world_size: int,
369+
num_tables: int,
370+
ranks_per_tables: List[int],
371+
random_seed: int = None,
372+
) -> List[List[int]]:
373+
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
374+
if random_seed is None:
375+
# Generate a random seed to ensure that the output rank placements can be different each time
376+
random_seed = random.randint(0, 10000)
377+
placements = []
378+
max_rank = world_size - 1
379+
random.seed(random_seed)
380+
if ranks_per_tables == [0]:
381+
ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)]
382+
for i in range(num_tables):
383+
ranks_per_table = ranks_per_tables[i]
384+
placement = sorted(random.sample(range(world_size), ranks_per_table))
385+
placements.append(placement)
386+
return placements
387+
388+
362389
def default_func_to_benchmark(
363390
model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor]
364391
) -> None:
@@ -679,6 +706,8 @@ def init_argparse_and_args() -> argparse.Namespace:
679706
parser.add_argument("--num_benchmarks", type=int, default=5)
680707
parser.add_argument("--embedding_config_json", type=str, default="")
681708
parser.add_argument("--device_type", type=str, default="cuda")
709+
parser.add_argument("--enable_resharding", type=bool, default=False)
710+
parser.add_argument("--resharding_interval", type=int, default=1000)
682711

683712
args = parser.parse_args()
684713

@@ -728,6 +757,7 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
728757
# Don't want to modify the module outright
729758
# Since module is on cpu, won't cause cuda oom.
730759
copied_module = copy.deepcopy(module)
760+
731761
# pyre-ignore [6]
732762
plan = planner.plan(copied_module, [sharder])
733763

@@ -780,6 +810,7 @@ def _run_benchmark_core(
780810
pre_gpu_load: int = 0,
781811
export_stacks: bool = False,
782812
reset_accumulated_memory_stats: bool = False,
813+
new_ranks_per_plan: Optional[List[int]] = None,
783814
) -> BenchmarkResult:
784815
"""Internal helper that contains the core benchmarking logic shared by
785816
``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory
@@ -824,6 +855,11 @@ def _run_benchmark_core(
824855
# Timings
825856
start_events, end_events, times = [], [], []
826857

858+
<<<<<<< dest: af95f723afd1 - noreply+1244265887488347: [AutoAccept][Codemod...
859+
=======
860+
times = []
861+
862+
>>>>>>> source: ea51915e8d98 - isuru: Reshard API Performance Benchmarking
827863
if device_type == "cuda":
828864
start_events = [
829865
torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)
@@ -1071,6 +1107,7 @@ def init_module_and_run_benchmark(
10711107
queue: Optional[mp.Queue] = None,
10721108
pooling_configs: Optional[List[int]] = None,
10731109
benchmark_unsharded_module: bool = False,
1110+
new_ranks_per_plan: Optional[List[List[int]]] = None,
10741111
) -> BenchmarkResult:
10751112
"""
10761113
There are a couple of caveats here as to why the module has to be initialized
@@ -1136,6 +1173,40 @@ def init_module_and_run_benchmark(
11361173
else:
11371174
name = benchmark_type_name(compile_mode, sharding_type)
11381175

1176+
resharding_plans = []
1177+
1178+
import fbvscode
1179+
1180+
fbvscode.set_trace()
1181+
1182+
if new_ranks_per_plan is not None and len(new_ranks_per_plan) > 0:
1183+
sharding_type_constructor = get_sharding_constructor_from_type(
1184+
sharding_type
1185+
)
1186+
for i, new_ranks in enumerate(new_ranks_per_plan):
1187+
new_per_param_sharding = {}
1188+
for table in tables:
1189+
if sharding_type == ShardingType.TABLE_WISE:
1190+
new_per_param_sharding[table.name] = sharding_type_constructor(
1191+
rank=new_ranks, compute_kernel=sharder._kernel_type
1192+
)
1193+
elif sharding_type == ShardingType.COLUMN_WISE:
1194+
new_per_param_sharding[table.name] = sharding_type_constructor(
1195+
ranks=new_ranks
1196+
)
1197+
1198+
new_module_sharding_plan = construct_module_sharding_plan(
1199+
module=module.module,
1200+
# Pyre-ignore
1201+
sharder=sharder,
1202+
per_param_sharding=new_per_param_sharding,
1203+
local_size=world_size,
1204+
world_size=world_size,
1205+
device_type="cuda" if torch.cuda.is_available() else "cpu",
1206+
)
1207+
resharding_plans.append(new_module_sharding_plan)
1208+
benchmark_func_kwargs["resharding_plans"] = resharding_plans
1209+
11391210
res = benchmark(
11401211
name,
11411212
module,
@@ -1244,6 +1315,8 @@ def benchmark_module(
12441315
pooling_configs: Optional[List[int]] = None,
12451316
variable_batch_embeddings: bool = False,
12461317
device_type: str = "cuda",
1318+
enable_resharding: bool = False,
1319+
resharding_interval: int = 1000,
12471320
) -> List[BenchmarkResult]:
12481321
"""
12491322
Args:
@@ -1325,6 +1398,39 @@ def benchmark_module(
13251398
)
13261399

13271400
if train:
1401+
total_plans_per_benchmark = bench_iters // resharding_interval
1402+
total_plans_per_benchmark = max(1, total_plans_per_benchmark)
1403+
new_ranks_per_plan = []
1404+
if enable_resharding:
1405+
num_tables = len(tables)
1406+
new_ranks_count_per_plan = [
1407+
[] for _ in range(total_plans_per_benchmark)
1408+
]
1409+
if sharding_type == ShardingType.TABLE_WISE:
1410+
ranks_per_tables = [1 for _ in range(num_tables)]
1411+
new_ranks_per_plan = [
1412+
_generate_rank_placements(
1413+
world_size, num_tables, ranks_per_tables
1414+
)
1415+
for _ in range(total_plans_per_benchmark)
1416+
]
1417+
1418+
elif sharding_type == ShardingType.COLUMN_WISE:
1419+
valid_candidates = [
1420+
i
1421+
for i in range(1, world_size + 1)
1422+
if EMBEDDING_DIM % i == 0
1423+
]
1424+
ranks_per_tables = [
1425+
random.choice(valid_candidates) for _ in range(num_tables)
1426+
]
1427+
new_ranks_per_plan = [
1428+
_generate_rank_placements(
1429+
world_size, num_tables, ranks_per_tables
1430+
)
1431+
for ranks_per_tables in (new_ranks_count_per_plan)
1432+
]
1433+
13281434
res = multi_process_benchmark(
13291435
# pyre-ignore[6]
13301436
callable=init_module_and_run_benchmark,
@@ -1344,6 +1450,7 @@ def benchmark_module(
13441450
func_to_benchmark=func_to_benchmark,
13451451
benchmark_func_kwargs=benchmark_func_kwargs,
13461452
pooling_configs=pooling_configs,
1453+
new_ranks_per_plan=new_ranks_per_plan,
13471454
)
13481455
else:
13491456
res = init_module_and_run_benchmark(

0 commit comments

Comments
 (0)