diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py index 15ea780f2..66d2c0042 100644 --- a/torchrec/distributed/benchmark/benchmark_train.py +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -28,8 +28,9 @@ write_report, ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType +from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta from torchrec.distributed.test_utils.test_model import TestEBCSharder -from torchrec.distributed.types import DataType +from torchrec.distributed.types import DataType, EmbeddingModuleShardingPlan from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -53,8 +54,27 @@ def training_func_to_benchmark( model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor], optimizer: Optional[torch.optim.Optimizer], + resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None, ) -> None: - for bench_input in bench_inputs: + + reshard_idx = 0 + + for i, bench_input in enumerate(bench_inputs): + if resharding_plan_diffs is not None: + if ( + i > 0 + and len(resharding_plan_diffs) > 0 + and i % (len(bench_inputs) / len(resharding_plan_diffs)) == 0 + ): + + plan_difference = output_sharding_plan_delta( + # Pyre-ignore + model.plan.plan["_module"], + resharding_plan_diffs[reshard_idx], + ) + # Pyre-ignore + model.reshard("_module", plan_difference) + reshard_idx += 1 pooled_embeddings = model(bench_input) vals = [] for _name, param in pooled_embeddings.to_dict().items(): @@ -120,6 +140,7 @@ def benchmark_ebc( def main() -> None: + # torch.cuda.cudart().cudaProfilerStart() args: argparse.Namespace = init_argparse_and_args() num_requests = args.bench_iters * args.batch_size * args.num_benchmarks @@ -203,6 +224,8 @@ def main() -> None: for i, write_report_func in enumerate(write_report_funcs_per_module): write_report_func(benchmark_results_per_module[i]) + # torch.cuda.cudart().cudaProfilerStop() + def invoke_main() -> None: logging.basicConfig() @@ -212,4 +235,5 @@ def invoke_main() -> None: if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index a9f7a3864..e5e8b124c 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -18,6 +18,7 @@ import json import logging import os +import random import resource import time import timeit @@ -55,10 +56,21 @@ EmbeddingStorageEstimator, ) from torchrec.distributed.shard import _shard_modules + +from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + get_sharding_constructor_from_type, +) from torchrec.distributed.test_utils.multi_process import MultiProcessContext from torchrec.distributed.test_utils.test_model import ModelInput -from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv +from torchrec.distributed.types import ( + DataType, + EmbeddingModuleShardingPlan, + ModuleSharder, + ShardingEnv, +) from torchrec.fx import symbolic_trace from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig from torchrec.quant.embedding_modules import ( @@ -359,8 +371,32 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: T = TypeVar("T", bound=torch.nn.Module) +def _generate_rank_placements( + world_size: int, + num_tables: int, + ranks_per_tables: List[int], + random_seed: Optional[int] = None, +) -> List[List[int]]: + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size + if random_seed is None: + # Generate a random seed to ensure that the output rank placements can be different each time + random_seed = random.randint(0, 10000) + placements = [] + max_rank = world_size - 1 + random.seed(random_seed) + if ranks_per_tables == [0]: + ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)] + for i in range(num_tables): + ranks_per_table = ranks_per_tables[i] + placement = sorted(random.sample(range(world_size), ranks_per_table)) + placements.append(placement) + return placements + + def default_func_to_benchmark( - model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor] + model: torch.nn.Module, + bench_inputs: List[KeyedJaggedTensor], + resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None, ) -> None: with torch.inference_mode(): for bench_input in bench_inputs: @@ -679,6 +715,8 @@ def init_argparse_and_args() -> argparse.Namespace: parser.add_argument("--num_benchmarks", type=int, default=5) parser.add_argument("--embedding_config_json", type=str, default="") parser.add_argument("--device_type", type=str, default="cuda") + parser.add_argument("--enable_resharding", type=bool, default=False) + parser.add_argument("--resharding_interval", type=int, default=1000) args = parser.parse_args() @@ -728,6 +766,7 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: # Don't want to modify the module outright # Since module is on cpu, won't cause cuda oom. copied_module = copy.deepcopy(module) + # pyre-ignore [6] plan = planner.plan(copied_module, [sharder]) @@ -780,6 +819,7 @@ def _run_benchmark_core( pre_gpu_load: int = 0, export_stacks: bool = False, reset_accumulated_memory_stats: bool = False, + new_ranks_per_plan: Optional[List[int]] = None, ) -> BenchmarkResult: """Internal helper that contains the core benchmarking logic shared by ``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory @@ -1071,6 +1111,7 @@ def init_module_and_run_benchmark( queue: Optional[mp.Queue] = None, pooling_configs: Optional[List[int]] = None, benchmark_unsharded_module: bool = False, + new_ranks_per_plan: Optional[List[List[int]]] = None, ) -> BenchmarkResult: """ There are a couple of caveats here as to why the module has to be initialized @@ -1117,6 +1158,37 @@ def init_module_and_run_benchmark( if rank != -1 else contextlib.nullcontext() ) as ctx: + + resharding_plans = [] + + if new_ranks_per_plan is not None and len(new_ranks_per_plan) > 0: + sharding_type_constructor = get_sharding_constructor_from_type( + sharding_type + ) + for new_ranks_per_table in new_ranks_per_plan: + new_per_param_sharding = {} + for table_id, table in enumerate(tables): + if sharding_type == ShardingType.TABLE_WISE: + new_per_param_sharding[table.name] = sharding_type_constructor( + rank=new_ranks_per_table[table_id][0], + compute_kernel=sharder._kernel_type, + ) + elif sharding_type == ShardingType.COLUMN_WISE: + new_per_param_sharding[table.name] = sharding_type_constructor( + ranks=new_ranks_per_table[table_id] + ) + + new_module_sharding_plan = construct_module_sharding_plan( + module=module._module, # Pyre-ignore + # Pyre-ignore + sharder=sharder, + per_param_sharding=new_per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type=device.type, + ) + resharding_plans.append(new_module_sharding_plan) + module = transform_module( module=module, device=device, @@ -1136,6 +1208,10 @@ def init_module_and_run_benchmark( else: name = benchmark_type_name(compile_mode, sharding_type) + if benchmark_func_kwargs is None: + benchmark_func_kwargs = {} + benchmark_func_kwargs["resharding_plan_diffs"] = resharding_plans + res = benchmark( name, module, @@ -1244,6 +1320,8 @@ def benchmark_module( pooling_configs: Optional[List[int]] = None, variable_batch_embeddings: bool = False, device_type: str = "cuda", + enable_resharding: bool = False, + resharding_interval: int = 1000, ) -> List[BenchmarkResult]: """ Args: @@ -1325,6 +1403,36 @@ def benchmark_module( ) if train: + + new_ranks_per_plan = [] + + if enable_resharding: + total_plans_per_benchmark = bench_iters // resharding_interval + total_plans_per_benchmark = max(1, total_plans_per_benchmark) + + num_tables = len(tables) + ranks_per_tables = [] + + if sharding_type == ShardingType.TABLE_WISE: + ranks_per_tables = [1 for _ in range(num_tables)] + + elif sharding_type == ShardingType.COLUMN_WISE: + valid_candidates = [ + i + for i in range(1, world_size + 1) + if EMBEDDING_DIM % i == 0 + ] + ranks_per_tables = [ + random.choice(valid_candidates) for _ in range(num_tables) + ] + + new_ranks_per_plan = [ + _generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + for _ in range(total_plans_per_benchmark) + ] + res = multi_process_benchmark( # pyre-ignore[6] callable=init_module_and_run_benchmark, @@ -1344,6 +1452,7 @@ def benchmark_module( func_to_benchmark=func_to_benchmark, benchmark_func_kwargs=benchmark_func_kwargs, pooling_configs=pooling_configs, + new_ranks_per_plan=new_ranks_per_plan, ) else: res = init_module_and_run_benchmark( diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index d09f30781..936d0d2fc 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -742,7 +742,10 @@ def reshard( # Need to use .module to maintain FQN consistency self._optim: CombinedOptimizer = self._init_optim( - self._dmp_wrapped_module.module # pyre-ignore + # pyre-ignore + self._dmp_wrapped_module.module + if hasattr(self._dmp_wrapped_module, "module") + else self._dmp_wrapped_module._module ) self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan return sharded_module