From a98e18fee2f20ba67c8c019059f207de163d8dfb Mon Sep 17 00:00:00 2001 From: Isuru Janith Ranawaka Date: Thu, 31 Jul 2025 11:01:58 -0700 Subject: [PATCH 1/2] Reshard API Performance Benchmarking Summary: - Identify baseline performance with and without reshard API - Identify different baselines for different sharding strategies under different data sets --- .../distributed/benchmark/benchmark_train.py | 5 + .../distributed/benchmark/benchmark_utils.py | 107 ++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py index 15ea780f2..55cdff9b6 100644 --- a/torchrec/distributed/benchmark/benchmark_train.py +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -54,6 +54,7 @@ def training_func_to_benchmark( bench_inputs: List[KeyedJaggedTensor], optimizer: Optional[torch.optim.Optimizer], ) -> None: + for bench_input in bench_inputs: pooled_embeddings = model(bench_input) vals = [] @@ -120,6 +121,7 @@ def benchmark_ebc( def main() -> None: + # torch.cuda.cudart().cudaProfilerStart() args: argparse.Namespace = init_argparse_and_args() num_requests = args.bench_iters * args.batch_size * args.num_benchmarks @@ -203,6 +205,8 @@ def main() -> None: for i, write_report_func in enumerate(write_report_funcs_per_module): write_report_func(benchmark_results_per_module[i]) + # torch.cuda.cudart().cudaProfilerStop() + def invoke_main() -> None: logging.basicConfig() @@ -212,4 +216,5 @@ def invoke_main() -> None: if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index a9f7a3864..2bb2fa422 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -18,6 +18,7 @@ import json import logging import os +import random import resource import time import timeit @@ -55,6 +56,10 @@ EmbeddingStorageEstimator, ) from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + get_sharding_constructor_from_type, +) from torchrec.distributed.test_utils.multi_process import MultiProcessContext from torchrec.distributed.test_utils.test_model import ModelInput @@ -359,6 +364,28 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: T = TypeVar("T", bound=torch.nn.Module) +def _generate_rank_placements( + world_size: int, + num_tables: int, + ranks_per_tables: List[int], + random_seed: int = None, +) -> List[List[int]]: + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size + if random_seed is None: + # Generate a random seed to ensure that the output rank placements can be different each time + random_seed = random.randint(0, 10000) + placements = [] + max_rank = world_size - 1 + random.seed(random_seed) + if ranks_per_tables == [0]: + ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)] + for i in range(num_tables): + ranks_per_table = ranks_per_tables[i] + placement = sorted(random.sample(range(world_size), ranks_per_table)) + placements.append(placement) + return placements + + def default_func_to_benchmark( model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor] ) -> None: @@ -679,6 +706,8 @@ def init_argparse_and_args() -> argparse.Namespace: parser.add_argument("--num_benchmarks", type=int, default=5) parser.add_argument("--embedding_config_json", type=str, default="") parser.add_argument("--device_type", type=str, default="cuda") + parser.add_argument("--enable_resharding", type=bool, default=False) + parser.add_argument("--resharding_interval", type=int, default=1000) args = parser.parse_args() @@ -728,6 +757,7 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: # Don't want to modify the module outright # Since module is on cpu, won't cause cuda oom. copied_module = copy.deepcopy(module) + # pyre-ignore [6] plan = planner.plan(copied_module, [sharder]) @@ -780,6 +810,7 @@ def _run_benchmark_core( pre_gpu_load: int = 0, export_stacks: bool = False, reset_accumulated_memory_stats: bool = False, + new_ranks_per_plan: Optional[List[int]] = None, ) -> BenchmarkResult: """Internal helper that contains the core benchmarking logic shared by ``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory @@ -824,6 +855,11 @@ def _run_benchmark_core( # Timings start_events, end_events, times = [], [], [] +<<<<<<< dest: af95f723afd1 - noreply+1244265887488347: [AutoAccept][Codemod... +======= + times = [] + +>>>>>>> source: ea51915e8d98 - isuru: Reshard API Performance Benchmarking if device_type == "cuda": start_events = [ torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks) @@ -1071,6 +1107,7 @@ def init_module_and_run_benchmark( queue: Optional[mp.Queue] = None, pooling_configs: Optional[List[int]] = None, benchmark_unsharded_module: bool = False, + new_ranks_per_plan: Optional[List[List[int]]] = None, ) -> BenchmarkResult: """ There are a couple of caveats here as to why the module has to be initialized @@ -1136,6 +1173,40 @@ def init_module_and_run_benchmark( else: name = benchmark_type_name(compile_mode, sharding_type) + resharding_plans = [] + + import fbvscode + + fbvscode.set_trace() + + if new_ranks_per_plan is not None and len(new_ranks_per_plan) > 0: + sharding_type_constructor = get_sharding_constructor_from_type( + sharding_type + ) + for i, new_ranks in enumerate(new_ranks_per_plan): + new_per_param_sharding = {} + for table in tables: + if sharding_type == ShardingType.TABLE_WISE: + new_per_param_sharding[table.name] = sharding_type_constructor( + rank=new_ranks, compute_kernel=sharder._kernel_type + ) + elif sharding_type == ShardingType.COLUMN_WISE: + new_per_param_sharding[table.name] = sharding_type_constructor( + ranks=new_ranks + ) + + new_module_sharding_plan = construct_module_sharding_plan( + module=module.module, + # Pyre-ignore + sharder=sharder, + per_param_sharding=new_per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + resharding_plans.append(new_module_sharding_plan) + benchmark_func_kwargs["resharding_plans"] = resharding_plans + res = benchmark( name, module, @@ -1244,6 +1315,8 @@ def benchmark_module( pooling_configs: Optional[List[int]] = None, variable_batch_embeddings: bool = False, device_type: str = "cuda", + enable_resharding: bool = False, + resharding_interval: int = 1000, ) -> List[BenchmarkResult]: """ Args: @@ -1325,6 +1398,39 @@ def benchmark_module( ) if train: + total_plans_per_benchmark = bench_iters // resharding_interval + total_plans_per_benchmark = max(1, total_plans_per_benchmark) + new_ranks_per_plan = [] + if enable_resharding: + num_tables = len(tables) + new_ranks_count_per_plan = [ + [] for _ in range(total_plans_per_benchmark) + ] + if sharding_type == ShardingType.TABLE_WISE: + ranks_per_tables = [1 for _ in range(num_tables)] + new_ranks_per_plan = [ + _generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + for _ in range(total_plans_per_benchmark) + ] + + elif sharding_type == ShardingType.COLUMN_WISE: + valid_candidates = [ + i + for i in range(1, world_size + 1) + if EMBEDDING_DIM % i == 0 + ] + ranks_per_tables = [ + random.choice(valid_candidates) for _ in range(num_tables) + ] + new_ranks_per_plan = [ + _generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + for ranks_per_tables in (new_ranks_count_per_plan) + ] + res = multi_process_benchmark( # pyre-ignore[6] callable=init_module_and_run_benchmark, @@ -1344,6 +1450,7 @@ def benchmark_module( func_to_benchmark=func_to_benchmark, benchmark_func_kwargs=benchmark_func_kwargs, pooling_configs=pooling_configs, + new_ranks_per_plan=new_ranks_per_plan, ) else: res = init_module_and_run_benchmark( From d633bd5eb40776a2f98348b1762d4af4fc955e20 Mon Sep 17 00:00:00 2001 From: Isuru Janith Ranawaka Date: Sun, 3 Aug 2025 09:15:10 -0700 Subject: [PATCH 2/2] Reshard API Performance Benchmarking (#3218) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3218 - Identify baseline performance with and without reshard API - Identify different baselines for different sharding strategies under different data sets Differential Revision: D78672730 --- .../distributed/benchmark/benchmark_train.py | 23 +++- .../distributed/benchmark/benchmark_utils.py | 100 +++++++++--------- torchrec/distributed/model_parallel.py | 5 +- 3 files changed, 76 insertions(+), 52 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py index 55cdff9b6..66d2c0042 100644 --- a/torchrec/distributed/benchmark/benchmark_train.py +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -28,8 +28,9 @@ write_report, ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType +from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta from torchrec.distributed.test_utils.test_model import TestEBCSharder -from torchrec.distributed.types import DataType +from torchrec.distributed.types import DataType, EmbeddingModuleShardingPlan from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -53,9 +54,27 @@ def training_func_to_benchmark( model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor], optimizer: Optional[torch.optim.Optimizer], + resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None, ) -> None: - for bench_input in bench_inputs: + reshard_idx = 0 + + for i, bench_input in enumerate(bench_inputs): + if resharding_plan_diffs is not None: + if ( + i > 0 + and len(resharding_plan_diffs) > 0 + and i % (len(bench_inputs) / len(resharding_plan_diffs)) == 0 + ): + + plan_difference = output_sharding_plan_delta( + # Pyre-ignore + model.plan.plan["_module"], + resharding_plan_diffs[reshard_idx], + ) + # Pyre-ignore + model.reshard("_module", plan_difference) + reshard_idx += 1 pooled_embeddings = model(bench_input) vals = [] for _name, param in pooled_embeddings.to_dict().items(): diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 2bb2fa422..e5e8b124c 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -56,6 +56,8 @@ EmbeddingStorageEstimator, ) from torchrec.distributed.shard import _shard_modules + +from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta from torchrec.distributed.sharding_plan import ( construct_module_sharding_plan, get_sharding_constructor_from_type, @@ -63,7 +65,12 @@ from torchrec.distributed.test_utils.multi_process import MultiProcessContext from torchrec.distributed.test_utils.test_model import ModelInput -from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv +from torchrec.distributed.types import ( + DataType, + EmbeddingModuleShardingPlan, + ModuleSharder, + ShardingEnv, +) from torchrec.fx import symbolic_trace from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig from torchrec.quant.embedding_modules import ( @@ -368,7 +375,7 @@ def _generate_rank_placements( world_size: int, num_tables: int, ranks_per_tables: List[int], - random_seed: int = None, + random_seed: Optional[int] = None, ) -> List[List[int]]: # Cannot include old/new rank generation with hypothesis library due to depedency on world_size if random_seed is None: @@ -387,7 +394,9 @@ def _generate_rank_placements( def default_func_to_benchmark( - model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor] + model: torch.nn.Module, + bench_inputs: List[KeyedJaggedTensor], + resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None, ) -> None: with torch.inference_mode(): for bench_input in bench_inputs: @@ -855,11 +864,6 @@ def _run_benchmark_core( # Timings start_events, end_events, times = [], [], [] -<<<<<<< dest: af95f723afd1 - noreply+1244265887488347: [AutoAccept][Codemod... -======= - times = [] - ->>>>>>> source: ea51915e8d98 - isuru: Reshard API Performance Benchmarking if device_type == "cuda": start_events = [ torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks) @@ -1154,58 +1158,59 @@ def init_module_and_run_benchmark( if rank != -1 else contextlib.nullcontext() ) as ctx: - module = transform_module( - module=module, - device=device, - inputs=warmup_inputs_cuda, - sharder=sharder, - sharding_type=sharding_type, - compile_mode=compile_mode, - world_size=world_size, - batch_size=batch_size, - # pyre-ignore[6] - ctx=ctx, - benchmark_unsharded_module=benchmark_unsharded_module, - ) - - if benchmark_unsharded_module: - name = "unsharded" + compile_mode.name - else: - name = benchmark_type_name(compile_mode, sharding_type) resharding_plans = [] - import fbvscode - - fbvscode.set_trace() - if new_ranks_per_plan is not None and len(new_ranks_per_plan) > 0: sharding_type_constructor = get_sharding_constructor_from_type( sharding_type ) - for i, new_ranks in enumerate(new_ranks_per_plan): + for new_ranks_per_table in new_ranks_per_plan: new_per_param_sharding = {} - for table in tables: + for table_id, table in enumerate(tables): if sharding_type == ShardingType.TABLE_WISE: new_per_param_sharding[table.name] = sharding_type_constructor( - rank=new_ranks, compute_kernel=sharder._kernel_type + rank=new_ranks_per_table[table_id][0], + compute_kernel=sharder._kernel_type, ) elif sharding_type == ShardingType.COLUMN_WISE: new_per_param_sharding[table.name] = sharding_type_constructor( - ranks=new_ranks + ranks=new_ranks_per_table[table_id] ) new_module_sharding_plan = construct_module_sharding_plan( - module=module.module, + module=module._module, # Pyre-ignore # Pyre-ignore sharder=sharder, per_param_sharding=new_per_param_sharding, local_size=world_size, world_size=world_size, - device_type="cuda" if torch.cuda.is_available() else "cpu", + device_type=device.type, ) resharding_plans.append(new_module_sharding_plan) - benchmark_func_kwargs["resharding_plans"] = resharding_plans + + module = transform_module( + module=module, + device=device, + inputs=warmup_inputs_cuda, + sharder=sharder, + sharding_type=sharding_type, + compile_mode=compile_mode, + world_size=world_size, + batch_size=batch_size, + # pyre-ignore[6] + ctx=ctx, + benchmark_unsharded_module=benchmark_unsharded_module, + ) + + if benchmark_unsharded_module: + name = "unsharded" + compile_mode.name + else: + name = benchmark_type_name(compile_mode, sharding_type) + + if benchmark_func_kwargs is None: + benchmark_func_kwargs = {} + benchmark_func_kwargs["resharding_plan_diffs"] = resharding_plans res = benchmark( name, @@ -1398,22 +1403,18 @@ def benchmark_module( ) if train: - total_plans_per_benchmark = bench_iters // resharding_interval - total_plans_per_benchmark = max(1, total_plans_per_benchmark) + new_ranks_per_plan = [] + if enable_resharding: + total_plans_per_benchmark = bench_iters // resharding_interval + total_plans_per_benchmark = max(1, total_plans_per_benchmark) + num_tables = len(tables) - new_ranks_count_per_plan = [ - [] for _ in range(total_plans_per_benchmark) - ] + ranks_per_tables = [] + if sharding_type == ShardingType.TABLE_WISE: ranks_per_tables = [1 for _ in range(num_tables)] - new_ranks_per_plan = [ - _generate_rank_placements( - world_size, num_tables, ranks_per_tables - ) - for _ in range(total_plans_per_benchmark) - ] elif sharding_type == ShardingType.COLUMN_WISE: valid_candidates = [ @@ -1424,11 +1425,12 @@ def benchmark_module( ranks_per_tables = [ random.choice(valid_candidates) for _ in range(num_tables) ] + new_ranks_per_plan = [ _generate_rank_placements( world_size, num_tables, ranks_per_tables ) - for ranks_per_tables in (new_ranks_count_per_plan) + for _ in range(total_plans_per_benchmark) ] res = multi_process_benchmark( diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index d09f30781..936d0d2fc 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -742,7 +742,10 @@ def reshard( # Need to use .module to maintain FQN consistency self._optim: CombinedOptimizer = self._init_optim( - self._dmp_wrapped_module.module # pyre-ignore + # pyre-ignore + self._dmp_wrapped_module.module + if hasattr(self._dmp_wrapped_module, "module") + else self._dmp_wrapped_module._module ) self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan return sharded_module