diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py index 15ea780f2..66d2c0042 100644 --- a/torchrec/distributed/benchmark/benchmark_train.py +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -28,8 +28,9 @@ write_report, ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType +from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta from torchrec.distributed.test_utils.test_model import TestEBCSharder -from torchrec.distributed.types import DataType +from torchrec.distributed.types import DataType, EmbeddingModuleShardingPlan from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -53,8 +54,27 @@ def training_func_to_benchmark( model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor], optimizer: Optional[torch.optim.Optimizer], + resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None, ) -> None: - for bench_input in bench_inputs: + + reshard_idx = 0 + + for i, bench_input in enumerate(bench_inputs): + if resharding_plan_diffs is not None: + if ( + i > 0 + and len(resharding_plan_diffs) > 0 + and i % (len(bench_inputs) / len(resharding_plan_diffs)) == 0 + ): + + plan_difference = output_sharding_plan_delta( + # Pyre-ignore + model.plan.plan["_module"], + resharding_plan_diffs[reshard_idx], + ) + # Pyre-ignore + model.reshard("_module", plan_difference) + reshard_idx += 1 pooled_embeddings = model(bench_input) vals = [] for _name, param in pooled_embeddings.to_dict().items(): @@ -120,6 +140,7 @@ def benchmark_ebc( def main() -> None: + # torch.cuda.cudart().cudaProfilerStart() args: argparse.Namespace = init_argparse_and_args() num_requests = args.bench_iters * args.batch_size * args.num_benchmarks @@ -203,6 +224,8 @@ def main() -> None: for i, write_report_func in enumerate(write_report_funcs_per_module): write_report_func(benchmark_results_per_module[i]) + # torch.cuda.cudart().cudaProfilerStop() + def invoke_main() -> None: logging.basicConfig() @@ -212,4 +235,5 @@ def invoke_main() -> None: if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index a9f7a3864..e5e8b124c 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -18,6 +18,7 @@ import json import logging import os +import random import resource import time import timeit @@ -55,10 +56,21 @@ EmbeddingStorageEstimator, ) from torchrec.distributed.shard import _shard_modules + +from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + get_sharding_constructor_from_type, +) from torchrec.distributed.test_utils.multi_process import MultiProcessContext from torchrec.distributed.test_utils.test_model import ModelInput -from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv +from torchrec.distributed.types import ( + DataType, + EmbeddingModuleShardingPlan, + ModuleSharder, + ShardingEnv, +) from torchrec.fx import symbolic_trace from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig from torchrec.quant.embedding_modules import ( @@ -359,8 +371,32 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: T = TypeVar("T", bound=torch.nn.Module) +def _generate_rank_placements( + world_size: int, + num_tables: int, + ranks_per_tables: List[int], + random_seed: Optional[int] = None, +) -> List[List[int]]: + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size + if random_seed is None: + # Generate a random seed to ensure that the output rank placements can be different each time + random_seed = random.randint(0, 10000) + placements = [] + max_rank = world_size - 1 + random.seed(random_seed) + if ranks_per_tables == [0]: + ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)] + for i in range(num_tables): + ranks_per_table = ranks_per_tables[i] + placement = sorted(random.sample(range(world_size), ranks_per_table)) + placements.append(placement) + return placements + + def default_func_to_benchmark( - model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor] + model: torch.nn.Module, + bench_inputs: List[KeyedJaggedTensor], + resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None, ) -> None: with torch.inference_mode(): for bench_input in bench_inputs: @@ -679,6 +715,8 @@ def init_argparse_and_args() -> argparse.Namespace: parser.add_argument("--num_benchmarks", type=int, default=5) parser.add_argument("--embedding_config_json", type=str, default="") parser.add_argument("--device_type", type=str, default="cuda") + parser.add_argument("--enable_resharding", type=bool, default=False) + parser.add_argument("--resharding_interval", type=int, default=1000) args = parser.parse_args() @@ -728,6 +766,7 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: # Don't want to modify the module outright # Since module is on cpu, won't cause cuda oom. copied_module = copy.deepcopy(module) + # pyre-ignore [6] plan = planner.plan(copied_module, [sharder]) @@ -780,6 +819,7 @@ def _run_benchmark_core( pre_gpu_load: int = 0, export_stacks: bool = False, reset_accumulated_memory_stats: bool = False, + new_ranks_per_plan: Optional[List[int]] = None, ) -> BenchmarkResult: """Internal helper that contains the core benchmarking logic shared by ``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory @@ -1071,6 +1111,7 @@ def init_module_and_run_benchmark( queue: Optional[mp.Queue] = None, pooling_configs: Optional[List[int]] = None, benchmark_unsharded_module: bool = False, + new_ranks_per_plan: Optional[List[List[int]]] = None, ) -> BenchmarkResult: """ There are a couple of caveats here as to why the module has to be initialized @@ -1117,6 +1158,37 @@ def init_module_and_run_benchmark( if rank != -1 else contextlib.nullcontext() ) as ctx: + + resharding_plans = [] + + if new_ranks_per_plan is not None and len(new_ranks_per_plan) > 0: + sharding_type_constructor = get_sharding_constructor_from_type( + sharding_type + ) + for new_ranks_per_table in new_ranks_per_plan: + new_per_param_sharding = {} + for table_id, table in enumerate(tables): + if sharding_type == ShardingType.TABLE_WISE: + new_per_param_sharding[table.name] = sharding_type_constructor( + rank=new_ranks_per_table[table_id][0], + compute_kernel=sharder._kernel_type, + ) + elif sharding_type == ShardingType.COLUMN_WISE: + new_per_param_sharding[table.name] = sharding_type_constructor( + ranks=new_ranks_per_table[table_id] + ) + + new_module_sharding_plan = construct_module_sharding_plan( + module=module._module, # Pyre-ignore + # Pyre-ignore + sharder=sharder, + per_param_sharding=new_per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type=device.type, + ) + resharding_plans.append(new_module_sharding_plan) + module = transform_module( module=module, device=device, @@ -1136,6 +1208,10 @@ def init_module_and_run_benchmark( else: name = benchmark_type_name(compile_mode, sharding_type) + if benchmark_func_kwargs is None: + benchmark_func_kwargs = {} + benchmark_func_kwargs["resharding_plan_diffs"] = resharding_plans + res = benchmark( name, module, @@ -1244,6 +1320,8 @@ def benchmark_module( pooling_configs: Optional[List[int]] = None, variable_batch_embeddings: bool = False, device_type: str = "cuda", + enable_resharding: bool = False, + resharding_interval: int = 1000, ) -> List[BenchmarkResult]: """ Args: @@ -1325,6 +1403,36 @@ def benchmark_module( ) if train: + + new_ranks_per_plan = [] + + if enable_resharding: + total_plans_per_benchmark = bench_iters // resharding_interval + total_plans_per_benchmark = max(1, total_plans_per_benchmark) + + num_tables = len(tables) + ranks_per_tables = [] + + if sharding_type == ShardingType.TABLE_WISE: + ranks_per_tables = [1 for _ in range(num_tables)] + + elif sharding_type == ShardingType.COLUMN_WISE: + valid_candidates = [ + i + for i in range(1, world_size + 1) + if EMBEDDING_DIM % i == 0 + ] + ranks_per_tables = [ + random.choice(valid_candidates) for _ in range(num_tables) + ] + + new_ranks_per_plan = [ + _generate_rank_placements( + world_size, num_tables, ranks_per_tables + ) + for _ in range(total_plans_per_benchmark) + ] + res = multi_process_benchmark( # pyre-ignore[6] callable=init_module_and_run_benchmark, @@ -1344,6 +1452,7 @@ def benchmark_module( func_to_benchmark=func_to_benchmark, benchmark_func_kwargs=benchmark_func_kwargs, pooling_configs=pooling_configs, + new_ranks_per_plan=new_ranks_per_plan, ) else: res = init_module_and_run_benchmark( diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 2a7f2fa39..fe66a87fe 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1459,6 +1459,19 @@ def _create_inverse_indices_permute_indices( inverse_indices[1].device, ) + def _is_optimizer_enabled( + self, + has_local_optimizer: bool, + env: ShardingEnv, + device: Optional[torch.device], + ) -> bool: + flag = torch.tensor( + [has_local_optimizer], dtype=torch.uint8, device=device + ) # example: True + # Reduce with MAX to check if any process has True + dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=env.process_group) + return bool(flag.item()) + # pyre-ignore [14] def input_dist( self, @@ -1698,10 +1711,17 @@ def update_shards( return current_state = self.state_dict() - has_optimizer = len(self._optim._optims) > 0 and all( + has_local_optimizer = len(self._optim._optims) > 0 and all( len(i) > 0 for i in self._optim.state_dict()["state"].values() ) + # communicate optimizer state across all ranks, because if one rank owns all tables + # and other ranks does not own any table, and later transfer the weights to empty rank + # creates inconsistent state, because initally empty rank does not have optimizer state + # hence, incorrectly computes the tensor splits + + has_optimizer = self._is_optimizer_enabled(has_local_optimizer, env, device) + # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again # TODO: Ensure lookup tensors are actually being deleted for _, lookup in enumerate(self._lookups): @@ -1715,7 +1735,7 @@ def update_shards( max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates( changed_sharding_params ) - old_optimizer_state = self._optim.state_dict() if has_optimizer else None + old_optimizer_state = self._optim.state_dict() if has_local_optimizer else None local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all( module=self, @@ -1727,6 +1747,7 @@ def update_shards( max_dim_0=max_dim_0, max_dim_1=max_dim_1, optimizer_state=old_optimizer_state, + has_optimizer=has_optimizer, ) for name, param in changed_sharding_params.items(): @@ -1791,30 +1812,25 @@ def update_shards( self._optim: CombinedOptimizer = CombinedOptimizer(optims) if has_optimizer: - split_index = len(local_output_tensor) // 2 - local_weight_tensors = local_output_tensor[:split_index] - local_optimizer_tensors = local_output_tensor[split_index:] - # Modifies new_opt_state in place and returns it optimizer_state = update_optimizer_state_post_resharding( old_opt_state=old_optimizer_state, # pyre-ignore new_opt_state=copy.deepcopy(self._optim.state_dict()), ordered_shard_names_and_lengths=local_shard_names_by_src_rank, - output_tensor=local_optimizer_tensors, + output_tensor=local_output_tensor, max_dim_0=max_dim_0, + extend_shard_name=self.extend_shard_name, ) - self._optim.load_state_dict(optimizer_state) - else: - local_weight_tensors = local_output_tensor current_state = update_state_dict_post_resharding( state_dict=current_state, ordered_shard_names_and_lengths=local_shard_names_by_src_rank, - output_tensor=local_weight_tensors, + output_tensor=local_output_tensor, new_sharding_params=changed_sharding_params, curr_rank=dist.get_rank(), extend_shard_name=self.extend_shard_name, max_dim_0=max_dim_0, + has_optimizer=has_optimizer, ) self.load_state_dict(current_state) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index d09f30781..936d0d2fc 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -742,7 +742,10 @@ def reshard( # Need to use .module to maintain FQN consistency self._optim: CombinedOptimizer = self._init_optim( - self._dmp_wrapped_module.module # pyre-ignore + # pyre-ignore + self._dmp_wrapped_module.module + if hasattr(self._dmp_wrapped_module, "module") + else self._dmp_wrapped_module._module ) self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan return sharded_module diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index caaa9752b..6c5e12e83 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -75,12 +75,34 @@ ] """ +ShardToRankMapping = Dict[int, List[Tuple[int, int, int, List[int]]]] +""" +ShardToRankMapping is a type alias for a dictionary that maps source ranks to a list of shard metadata tuples. +Each key in the dictionary is an integer representing a source rank. +The value associated with each key is a list of tuples, where each tuple contains metadata about a shard. +The structure of each tuple is as follows: +1. First Element (int): Represents the position of the shard metadata in the original module sharding plan. + This indicates the order or index of the shard within the sharding plan. +2. Second Element (int): Represents the position of the shard tensor in the original state dictionary. + This helps in identifying the specific shard tensor within the state dictionary. +3. Third Element (int): Represents the destination rank. + This indicates the rank to which the shard is being redistributed. +4. Fourth Element (List[int]): A list representing the shard size. + This provides the dimensions or size of the shard being handled. + + E.g [0,(1,0,2,[10,4])] means that the source rank 0 has a shard with size (10,4) and it's in the first position of the + source modulesharding plan. and the data can be accessed through rank=1, first local tensor and this is being sent to rank 2, + the new shard size is (10,4) +""" + def _generate_shard_allocation_metadata( shard_name: str, source_params: ParameterSharding, destination_params: ParameterSharding, -) -> Dict[int, List[Tuple[int, List[int]]]]: + rank: int, + world_size: int, +) -> Dict[int, List[Tuple[int, int, int, List[int]]]]: """ Generates a mapping of shards to ranks for redistribution of data. @@ -98,11 +120,12 @@ def _generate_shard_allocation_metadata( Dict[int, List[Tuple[int, List[int]]]]: A dictionary mapping source ranks to a list of tuples, where each tuple contains a destination rank and the corresponding shard offsets. """ - shard_to_rank_mapping: Dict[int, List[Tuple[int, List[int]]]] = {} + shard_to_rank_mapping: Dict[int, List[Tuple[int, int, int, List[int]]]] = {} src_rank_index = 0 dst_rank_index = 0 curr_source_offset = 0 curr_dst_offset = 0 + local_shard_indices = [0 for _ in range(world_size)] assert source_params.ranks is not None assert destination_params.ranks is not None @@ -136,6 +159,8 @@ def _generate_shard_allocation_metadata( # Pyre-ignore shard_to_rank_mapping[source_params.ranks[src_rank_index]].append( ( + src_rank_index, + local_shard_indices[source_params.ranks[src_rank_index]], destination_params.ranks[dst_rank_index], [curr_source_offset, next_source_offset], ) @@ -144,6 +169,7 @@ def _generate_shard_allocation_metadata( curr_dst_offset = next_dst_offset if next_source_offset >= src_shard_size[1]: + local_shard_indices[source_params.ranks[src_rank_index]] += 1 src_rank_index += 1 curr_source_offset = 0 @@ -158,7 +184,7 @@ def _process_shard_redistribution_metadata( shard_name: str, max_dim_0: int, max_dim_1: int, - shard_to_rank_mapping: Dict[int, List[Tuple[int, List[int]]]], + shard_to_rank_mapping: ShardToRankMapping, sharded_tensor: ShardedTensor, input_splits_per_rank: List[List[int]], output_splits_per_rank: List[List[int]], @@ -167,28 +193,32 @@ def _process_shard_redistribution_metadata( local_table_to_opt_by_dst_rank: List[List[torch.Tensor]], optimizer_state: Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]] = None, extend_shard_name: Callable[[str], str] = lambda x: x, + has_optimizer: bool = False, ) -> Tuple[int, int]: """ - calculates shard redistribution metadata across ranks and processes optimizer state if present. + Processes the metadata for shard redistribution across ranks. - This function handles the redistribution of tensor shards from source ranks to destination ranks - based on the provided shard-to-rank mapping. It also processes optimizer state if available, - ensuring that the data is correctly padded and split for communication between ranks. + This function handles the redistribution of shards from source ranks to destination ranks + based on the provided shard-to-rank mapping. It updates the input and output splits for + each rank and manages the optimizer state if present. Args: rank (int): The current rank of the process. shard_name (str): The name of the shard being processed. max_dim_0 (int): The maximum dimension size of dim 0 for padding. max_dim_1 (int): The maximum dimension size of dim 1 for padding. - shard_to_rank_mapping (Dict[int, List[Tuple[int, List[int]]]]): Mapping of source ranks to destination ranks and split offsets. - sharded_tensor (ShardedTensor): The sharded tensor to be redistributed. + shard_to_rank_mapping ShardToRankMapping: Mapping of source ranks to destination ranks and shard offsets. + sharded_tensor (ShardedTensor): The sharded tensor being processed. input_splits_per_rank (List[List[int]]): Input split sizes for each rank. output_splits_per_rank (List[List[int]]): Output split sizes for each rank. shard_names_to_lengths_by_src_rank (List[List[Tuple[str, List[int]]]]): List of shard names and sizes by source rank. - local_table_to_input_tensor_by_dst_rank (List[List[torch.Tensor]]): Local input tensors by destination rank. - local_table_to_opt_by_dst_rank (List[List[torch.Tensor]]): Local optimizer tensors by destination rank. - optimizer_state (Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]]): Optimizer state if available. + local_table_to_input_tensor_by_dst_rank (List[List[torch.Tensor]]): Local input tensors for each destination rank. + local_table_to_opt_by_dst_rank (List[List[torch.Tensor]]): Local optimizer tensors for each destination rank. + optimizer_state (Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]]): The optimizer state, if present. extend_shard_name (Callable[[str], str]): Function to extend shard names. + has_optimizer (bool): Flag indicating whether any rank has an optimizer state for this shard. Destination ranks + initally may not have an optimizer state, but if any source rank has an optimizer state, then all destinations + should recreate that optimizer state. Returns: Tuple[int, int]: Counts of output tensors and optimizer tensors processed. @@ -196,59 +226,65 @@ def _process_shard_redistribution_metadata( output_tensor_count = 0 output_optimizer_count = 0 - has_optimizer = optimizer_state is not None + has_local_optimizer = ( + optimizer_state is not None + ) # local optimizer represents wether sharding table has optimizer state locally. # Process each shard mapping from source to destination for src_rank, dsts in shard_to_rank_mapping.items(): - for dst_rank, split_offsets in dsts: + for shard_id, local_shard_id, dst_rank, split_offsets in dsts: # Get shard metadata - shard_metadata = sharded_tensor.metadata().shards_metadata[0] + shard_metadata = sharded_tensor.metadata().shards_metadata[shard_id] shard_size = shard_metadata.shard_sizes assert split_offsets[0] >= 0 assert split_offsets[1] <= shard_size[1] # Update the shard size with new size - shard_size = [shard_size[0], split_offsets[1] - split_offsets[0]] + shard_size = [shard_size[0], split_offsets[1] - split_offsets[0], shard_id] # Update split sizes for communication input_splits_per_rank[src_rank][dst_rank] += max_dim_0 output_splits_per_rank[dst_rank][src_rank] += max_dim_0 - if has_optimizer: - input_splits_per_rank[src_rank][dst_rank] += max_dim_0 - output_splits_per_rank[dst_rank][src_rank] += max_dim_0 # Process data being sent from current rank if src_rank == rank: # Handle optimizer state if present - if has_optimizer and optimizer_state is not None: - + if has_local_optimizer: + # Pyre-ignore local_optimizer_shards = optimizer_state["state"][ extend_shard_name(shard_name) ][tmp_momentum_extender(shard_name)].local_shards() - assert ( - len(local_optimizer_shards) == 1 - ), "Expected exactly one local optimizer shard" + # assert ( + # len(local_optimizer_shards) == 1 + # ), "Expected exactly one local optimizer shard" + + local_optimizer_tensor = local_optimizer_shards[ + local_shard_id + ].tensor - local_optimizer_tensor = local_optimizer_shards[0].tensor if len(local_optimizer_tensor.size()) == 1: # 1D Optimizer Tensor # Convert to 2D Tensor, transpose, for AllToAll local_optimizer_tensor = local_optimizer_tensor.view( local_optimizer_tensor.size(0), 1 ) + else: + local_optimizer_tensor = local_optimizer_tensor[ + :, split_offsets[0] : split_offsets[1] + ] padded_optimizer_tensor = pad_tensor_to_max_dims( local_optimizer_tensor, max_dim_0, max_dim_1 ) local_table_to_opt_by_dst_rank[dst_rank].append( padded_optimizer_tensor ) + input_splits_per_rank[src_rank][dst_rank] += max_dim_0 # Handle main tensor data - local_shards = sharded_tensor.local_shards() - assert len(local_shards) == 1, "Expected exactly one local shard" + local_shard = sharded_tensor.local_shards()[local_shard_id] # cut the tensor based on split points - dst_t = local_shards[0].tensor[:, split_offsets[0] : split_offsets[1]] + dst_t = local_shard.tensor[:, split_offsets[0] : split_offsets[1]] padded_tensor = pad_tensor_to_max_dims(dst_t, max_dim_0, max_dim_1) local_table_to_input_tensor_by_dst_rank[dst_rank].append(padded_tensor) @@ -261,6 +297,7 @@ def _process_shard_redistribution_metadata( output_tensor_count += max_dim_0 if has_optimizer: output_optimizer_count += max_dim_0 + output_splits_per_rank[dst_rank][src_rank] += max_dim_0 return output_tensor_count, output_optimizer_count @@ -269,7 +306,11 @@ def _create_local_shard_tensors( ordered_shard_names_and_lengths: OrderedShardNamesWithSizes, output_tensor: torch.Tensor, max_dim_0: int, -) -> Dict[str, torch.Tensor]: + has_optimizer: bool = False, + optimizer_mode: bool = False, + new_opt_state_state: Optional[Dict[str, Dict[str, ShardedTensor]]] = None, + extend_shard_name: Optional[Callable[[str], str]] = None, +) -> Dict[str, List[torch.Tensor]]: """ Creates local shard tensors from the output tensor based on the ordered shard names and lengths. @@ -285,21 +326,73 @@ def _create_local_shard_tensors( Returns: Dict[str, torch.Tensor]: A dictionary mapping shard names to their corresponding local output tensors. + has_optimizer (bool): Flag indicating whether optimizer is enabled and optimizer weights are present. + It is helpful, to determine the split indexes of the output_tensor. + + e.g output_tensor_format when has_optimizer enabled = [ST1,OPW1,ST2,OPW2,ST3,OPW3,ST4,OPW4] + e.g output_tensor_format when has_optimizer disabled = [ST1,ST2,ST3,ST4] """ - slice_index = 0 - shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} + + shard_name_to_local_output_tensor: Dict[str, List[torch.Tensor]] = {} + + slice_index = 0 if not optimizer_mode else max_dim_0 + step_size = max_dim_0 + + splitted_shards_with_names: Dict[str, List[Tuple[int, torch.Tensor]]] = {} + for shard_name, shard_size in ordered_shard_names_and_lengths: - end_slice_index = slice_index + max_dim_0 + + shard_id = shard_size[2] + end_slice_index = slice_index + step_size cur_t = output_tensor[slice_index:end_slice_index] cur_t = pad_tensor_to_max_dims(cur_t, shard_size[0], shard_size[1]) - if shard_name not in shard_name_to_local_output_tensor.keys(): - shard_name_to_local_output_tensor[shard_name] = cur_t + + extended_shard_name = ( + extend_shard_name(shard_name) if extend_shard_name else shard_name + ) + new_opt_state_state = new_opt_state_state if new_opt_state_state else {} + + momentum_name = tmp_momentum_extender(shard_name) + + if ( + optimizer_mode + and new_opt_state_state is not None + and extended_shard_name in new_opt_state_state.keys() + ): + sharded_t = new_opt_state_state[extended_shard_name][momentum_name] + assert len(sharded_t._local_shards) == 1 + + if len(sharded_t._local_shards[0].tensor.size()) == 1: + cur_t = cur_t * shard_size[1] # Supporting RowWise Adagrad operation + + if shard_name not in splitted_shards_with_names: + splitted_shards_with_names[shard_name] = [(shard_id, cur_t)] else: - # CW sharding may have multiple shards per rank in many to one case, so we need to concatenate them - shard_name_to_local_output_tensor[shard_name] = torch.cat( - (shard_name_to_local_output_tensor[shard_name], cur_t), dim=1 - ) - slice_index = end_slice_index + splitted_shards_with_names[shard_name].append((shard_id, cur_t)) + slice_index = ( + end_slice_index if not has_optimizer else end_slice_index + max_dim_0 + ) + + # Assuming splitted_shards_with_names is already populated + for shard_name, shards in splitted_shards_with_names.items(): + # Sort shards by shard_id if needed, since, CW sharding can have multiple shards for the same table + shards.sort(key=lambda x: x[0]) + + for _, curr_t in shards: + # Initialize shard_name_to_local_output_tensor[shard_name] if it doesn't exist + if shard_name not in shard_name_to_local_output_tensor: + # Initialize with a list containing the first tensor + shard_name_to_local_output_tensor[shard_name] = [curr_t] + else: + # Since we always assume one tensor in the list, concatenate with it + # TODO: Extend this for multiple shards per table for same rank for new state. + # TODO: Although original plan supports min_partition, we assume changing plan has only one shard per table IN CW sharding + concatenated_tensor = torch.cat( + (shard_name_to_local_output_tensor[shard_name][0], curr_t), dim=1 + ) + # Replace the existing tensor with the concatenated one + shard_name_to_local_output_tensor[shard_name][0] = concatenated_tensor + return shard_name_to_local_output_tensor @@ -313,6 +406,7 @@ def shards_all_to_all( max_dim_1: int, extend_shard_name: Callable[[str], str] = lambda x: x, optimizer_state: Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]] = None, + has_optimizer: bool = False, ) -> Tuple[OrderedShardNamesWithSizes, torch.Tensor]: """ Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters. @@ -337,6 +431,12 @@ def shards_all_to_all( max_dim_1 (int): The maximum dimension size of dim 1 across all tables in the module. + optimizer_state (Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]]): The optimizer state, if present. + A dictionary mapping shard names to their optimizer states, which are themselves + dictionaries mapping parameter names to their optimizer states. + + has_optimizer (bool): Flag indicating if optimizer state is present. + Returns: Tuple[List[Tuple[str, List[int]]], torch.Tensor]: Two outputs containing: - A list of shard name and the corresponding shard_size in dim 0 & 1 that were sent to the current rank. @@ -350,7 +450,7 @@ def shards_all_to_all( # Module sharding plan is used to get the source ranks for each shard assert hasattr(module, "module_sharding_plan") - has_optimizer = optimizer_state is not None + has_local_optimizer = has_optimizer and optimizer_state is not None world_size = env.world_size rank = dist.get_rank() @@ -379,6 +479,8 @@ def shards_all_to_all( shard_name=shard_name, source_params=src_params, destination_params=param, + rank=rank, + world_size=world_size, ) tensor_count, optimizer_count = _process_shard_redistribution_metadata( @@ -395,6 +497,7 @@ def shards_all_to_all( local_table_to_opt_by_dst_rank=local_table_to_opt_by_dst_rank, optimizer_state=optimizer_state, extend_shard_name=extend_shard_name, + has_optimizer=has_optimizer, ) output_tensor_tensor_count += tensor_count @@ -405,18 +508,9 @@ def shards_all_to_all( local_output_splits = output_splits_per_rank[rank] local_input_tensor = torch.empty([0, max_dim_1], device=device) - for sub_l in local_table_to_input_tensor_by_dst_rank: - for shard_info in sub_l: - local_input_tensor = torch.cat( - ( - local_input_tensor, - shard_info, - ), - dim=0, - ) - for sub_l in local_table_to_opt_by_dst_rank: - for shard_info in sub_l: + for i, sub_l in enumerate(local_table_to_input_tensor_by_dst_rank): + for j, shard_info in enumerate(sub_l): local_input_tensor = torch.cat( ( local_input_tensor, @@ -425,6 +519,16 @@ def shards_all_to_all( dim=0, ) + if has_local_optimizer: + shard_info = local_table_to_opt_by_dst_rank[i][j] + local_input_tensor = torch.cat( + ( + local_input_tensor, + shard_info, + ), + dim=0, + ) + receive_count = output_tensor_tensor_count + output_optimizer_tensor_count max_embedding_size = max_dim_1 local_output_tensor = torch.empty( @@ -463,6 +567,7 @@ def update_state_dict_post_resharding( curr_rank: int, max_dim_0: int, extend_shard_name: Callable[[str], str] = lambda x: x, + has_optimizer: bool = False, ) -> Dict[str, ShardedTensor]: """ Updates and returns the given state_dict with new placements and @@ -492,35 +597,36 @@ def update_state_dict_post_resharding( Dict[str, ShardedTensor]: The updated state dictionary with new shard placements and local shards. """ - shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = ( + shard_name_to_local_output_tensor: Dict[str, List[torch.Tensor]] = ( _create_local_shard_tensors( - ordered_shard_names_and_lengths, output_tensor, max_dim_0 + ordered_shard_names_and_lengths, + output_tensor, + max_dim_0, + has_optimizer=has_optimizer, + optimizer_mode=False, ) ) for shard_name, param in new_sharding_params.items(): extended_name = extend_shard_name(shard_name) + sharded_t = state_dict[extended_name] + sharded_t.metadata().shards_metadata.clear() + # pyre-ignore for i in range(len(param.ranks)): # pyre-ignore r = param.ranks[i] - sharded_t = state_dict[extended_name] - # Update placements - if len(sharded_t.metadata().shards_metadata) > i: - # pyre-ignore - sharded_t.metadata().shards_metadata[i] = param.sharding_spec.shards[i] - else: - sharded_t.metadata().shards_metadata.append( - param.sharding_spec.shards[i] - ) + # Update placements + # pyre-ignore + sharded_t.metadata().shards_metadata.append(param.sharding_spec.shards[i]) # Update local shards if r == curr_rank: assert len(output_tensor) > 0 # slice output tensor for correct size. sharded_t._local_shards = [ Shard( - tensor=shard_name_to_local_output_tensor[shard_name], + tensor=shard_name_to_local_output_tensor[shard_name][0], metadata=param.sharding_spec.shards[i], ) ] @@ -537,33 +643,51 @@ def update_optimizer_state_post_resharding( ordered_shard_names_and_lengths: OrderedShardNamesWithSizes, output_tensor: torch.Tensor, max_dim_0: int, + extend_shard_name: Callable[[str], str] = lambda x: x, ) -> Dict[str, Dict[str, Dict[str, ShardedTensor]]]: - new_opt_state_state = new_opt_state["state"] - old_opt_state_state = old_opt_state["state"] + new_opt_state_state = new_opt_state["state"] if new_opt_state else None + old_opt_state_state = old_opt_state["state"] if old_opt_state else None # Remove padding and store tensors by shard name - - shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = ( + shard_name_to_local_output_tensor: Dict[str, List[torch.Tensor]] = ( _create_local_shard_tensors( - ordered_shard_names_and_lengths, output_tensor, max_dim_0 + ordered_shard_names_and_lengths, + output_tensor, + max_dim_0, + has_optimizer=True, + optimizer_mode=True, + new_opt_state_state=new_opt_state_state, + extend_shard_name=extend_shard_name, ) ) + if new_opt_state_state is None or len(new_opt_state_state) == 0: + return new_opt_state for extended_shard_name, item in new_opt_state_state.items(): - if extended_shard_name in old_opt_state_state: + shard_name = extract_shard_name(extended_shard_name) + + if ( + old_opt_state_state is not None + and extended_shard_name in old_opt_state_state + and shard_name not in shard_name_to_local_output_tensor.keys() + ): + new_opt_state_state[extended_shard_name] = old_opt_state_state[ extended_shard_name ] else: - shard_name = extract_shard_name(extended_shard_name) momentum_name = tmp_momentum_extender(shard_name) sharded_t = item[momentum_name] assert len(sharded_t._local_shards) == 1 # local_tensor is updated in-pace for CW sharding - local_tensor = shard_name_to_local_output_tensor[shard_name] + local_tensor = shard_name_to_local_output_tensor[shard_name][0] if len(sharded_t._local_shards[0].tensor.size()) == 1: # Need to transpose 1D optimizer tensor, due to previous conversion - local_tensor = local_tensor.T[0] + + local_tensor_dim = local_tensor.size()[1] + squared_sum_t = torch.sum(local_tensor, dim=1, keepdim=True) + mean_squared_sum_t = torch.div(squared_sum_t, local_tensor_dim) + local_tensor = mean_squared_sum_t.T[0] sharded_t._local_shards = [ Shard( tensor=local_tensor, @@ -571,6 +695,8 @@ def update_optimizer_state_post_resharding( ) for shard in sharded_t._local_shards ] + item[momentum_name] = sharded_t + new_opt_state_state[extended_shard_name] = item return new_opt_state diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 8db53fe8a..f0c8a847e 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -625,6 +625,7 @@ def _parameter_sharding_generator( def column_wise( ranks: Optional[List[int]] = None, size_per_rank: Optional[List[int]] = None, + compute_kernel: Optional[str] = None, ) -> ParameterShardingGenerator: """ Returns a generator of ParameterShardingPlan for `ShardingType::COLUMN_WISE` for construct_module_sharding_plan. @@ -694,6 +695,7 @@ def _parameter_sharding_generator( local_size, device_type, sharder, + compute_kernel=compute_kernel, ) return _parameter_sharding_generator diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index 66c45776c..107b927af 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -263,6 +263,11 @@ def _test_dynamic_sharding( constraints[name] = ParameterConstraints( sharding_types=[sharding_type.value], ) + if sharding_type == ShardingType.COLUMN_WISE: + constraints[name] = ParameterConstraints( + sharding_types=[sharding_type.value], + min_partition=4, + ) self._run_multi_process_test( callable=dynamic_sharding_test, @@ -291,6 +296,7 @@ def _test_dynamic_sharding( offsets_dtype=offsets_dtype, lengths_dtype=lengths_dtype, random_seed=random_seed, + sharding_type=sharding_type, ) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 9d55360f4..ff72f9fa2 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -364,7 +364,7 @@ def dynamic_sharding_test( random.randint(0, batch_size) if allow_zero_batch_size else batch_size ) num_steps = 2 - embedding_dim = 16 + num_float_features = 16 # Generate model & inputs. (global_model, inputs) = gen_model_and_input( model_class=model_class, @@ -382,7 +382,7 @@ def dynamic_sharding_test( # weighted_tables=weighted_tables, embedding_groups=embedding_groups, world_size=world_size, - num_float_features=embedding_dim, + num_float_features=num_float_features, variable_batch_size=variable_batch_size, batch_size=batch_size, feature_processor_modules=feature_processor_modules, @@ -409,7 +409,7 @@ def dynamic_sharding_test( embedding_groups=embedding_groups, dense_device=ctx.device, sparse_device=torch.device("meta"), - num_float_features=embedding_dim, + num_float_features=num_float_features, feature_processor_modules=feature_processor_modules, ) @@ -420,7 +420,7 @@ def dynamic_sharding_test( embedding_groups=embedding_groups, dense_device=ctx.device, sparse_device=torch.device("meta"), - num_float_features=embedding_dim, + num_float_features=num_float_features, feature_processor_modules=feature_processor_modules, ) @@ -495,19 +495,19 @@ def dynamic_sharding_test( num_tables = len(tables) ranks_per_tables = [1 for _ in range(num_tables)] - - # CW sharding - valid_candidates = [ - i for i in range(1, world_size + 1) if embedding_dim % i == 0 - ] - ranks_per_tables_for_CW = [ - random.choice(valid_candidates) for _ in range(num_tables) - ] - new_ranks = generate_rank_placements( world_size, num_tables, ranks_per_tables, random_seed ) + ranks_per_tables_for_CW = [] + for table in tables: + + # CW sharding + valid_candidates = [ + i for i in range(1, world_size + 1) if table.embedding_dim % i == 0 + ] + ranks_per_tables_for_CW.append(random.choice(valid_candidates)) + new_ranks_cw = generate_rank_placements( world_size, num_tables, ranks_per_tables_for_CW, random_seed ) @@ -536,7 +536,11 @@ def dynamic_sharding_test( ) elif sharding_type == ShardingType.COLUMN_WISE: new_per_param_sharding[table_name] = sharding_type_constructor( - ranks=new_ranks_cw[i] + ranks=new_ranks_cw[i], compute_kernel=kernel_type + ) + else: + raise NotImplementedError( + f"Dynamic Sharding currently does not support {sharding_type}" ) new_module_sharding_plan = construct_module_sharding_plan( @@ -594,6 +598,7 @@ def dynamic_sharding_test( dict(in_backward_optimizer_filter(local_m1_dmp.named_parameters())), lambda params: torch.optim.SGD(params, lr=0.1), ) + local_m1_opt = CombinedOptimizer([local_m1_dmp.fused_optimizer, dense_m1_optim]) # Run a single training step of the sharded model. @@ -611,7 +616,9 @@ def dynamic_sharding_test( ) local_m1_dmp.reshard("sparse.ebc", new_module_sharding_plan_delta) + # Must recreate local_m1_opt, because current local_m1_opt is a copy of underlying fused_opt + local_m1_opt = CombinedOptimizer([local_m1_dmp.fused_optimizer, dense_m1_optim]) local_m1_pred = gen_full_pred_after_one_step( diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index eddf15faa..92a7db0ce 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -513,10 +513,9 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared): torch.cuda.device_count() <= 1, "Not enough GPUs, this test requires at least two GPUs", ) - @given( # pyre-ignore + @given( # Pyre-ignore sharder_type=st.sampled_from( [ - # SharderType.EMBEDDING_BAG.value, SharderType.EMBEDDING_BAG_COLLECTION.value, ] ), @@ -543,28 +542,9 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared): ), ] ), - apply_optimizer_in_backward_config=st.sampled_from( - [ - None, - { - "embedding_bags": (optim.Adagrad, {"lr": 0.04}), - }, - { - "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), - }, - { - "embedding_bags": ( - trec_optim.RowWiseAdagrad, - {"lr": 0.01}, - ), - }, - ] - ), - variable_batch_size=st.sampled_from( - [False] - ), # TODO: Enable variable batch size st.booleans(), data_type=st.sampled_from([DataType.FP16, DataType.FP32]), random_seed=st.integers(0, 1000), + world_size=st.sampled_from([2, 4, 8]), ) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) def test_sharding( @@ -573,15 +553,14 @@ def test_sharding( sharding_type: str, kernel_type: str, qcomms_config: Optional[QCommsConfig], - apply_optimizer_in_backward_config: Optional[ - Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] - ], - variable_batch_size: bool, data_type: DataType, - random_seed: int, # Random seed value for deterministically generating sharding plan for resharding + random_seed: int, + world_size: int, ) -> None: """ - Tests resharding from DMP module interface, rather than EBC level. + Tests resharding from DMP module interface with conditional optimizer selection: + - For Table-Wise sharding: All optimizers including RowWiseAdagrad + - For Column-Wise sharding: Only Adagrad and SGD (no RowWiseAdagrad) """ if ( self.device == torch.device("cpu") @@ -589,15 +568,38 @@ def test_sharding( ): self.skipTest("CPU does not support uvm.") + # Fixed to False as variable batch size with CW is more complex + variable_batch_size = False + assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value or not variable_batch_size ) + # Define optimizer options based on sharding type + if sharding_type == ShardingType.COLUMN_WISE.value: + # For Column-Wise sharding, exclude RowWiseAdagrad + # TODO: fix rowise adagrad logic similar to 2D sharding + optimizer_options = [ + None, + {"embedding_bags": (optim.Adagrad, {"lr": 0.04})}, + {"embedding_bags": (torch.optim.SGD, {"lr": 0.01})}, + ] + else: + # For Table-Wise sharding, include all optimizers + optimizer_options = [ + None, + {"embedding_bags": (optim.Adagrad, {"lr": 0.04})}, + {"embedding_bags": (torch.optim.SGD, {"lr": 0.01})}, + {"embedding_bags": (trec_optim.RowWiseAdagrad, {"lr": 0.01})}, + ] + + # Randomly select one optimizer from the appropriate options + apply_optimizer_in_backward_config = random.choice(optimizer_options) + sharding_type_e = ShardingType(sharding_type) self._test_dynamic_sharding( - # pyre-ignore[6] - sharders=[ + sharders=[ # Pyre-ignore create_test_sharder( sharder_type, sharding_type, @@ -608,11 +610,13 @@ def test_sharding( ], backend=self.backend, qcomms_config=qcomms_config, + # Pyre-ignore apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, variable_batch_size=variable_batch_size, data_type=data_type, sharding_type=sharding_type_e, random_seed=random_seed, + world_size=world_size, ) diff --git a/torchrec/optim/keyed.py b/torchrec/optim/keyed.py index 2f6b75d5f..65c3d69fd 100644 --- a/torchrec/optim/keyed.py +++ b/torchrec/optim/keyed.py @@ -135,6 +135,10 @@ def _update_param_state_dict_object( f"Different number of shards {num_shards} vs {num_new_shards} for the path of {json.dumps(parent_keys)}" ) for shard, new_shard in zip(v.local_shards(), new_v.local_shards()): + # CW sharding can create different size tensors + if shard.tensor.shape != new_shard.tensor.shape: + # Resize the tensor to match the new shape + shard.tensor.resize_(new_shard.tensor.shape) shard.tensor.detach().copy_(new_shard.tensor) elif isinstance(v, DTensor): assert isinstance(new_v, DTensor)