From 8c89d2aa4ddf3ade1c0bfe2bc2777fe5f96a181b Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Sun, 14 Sep 2025 13:33:26 -0700 Subject: [PATCH 1/5] Add ConfigeratorStats to store sharding plan in config store (#3327) Summary: internal Context: This change is part of the effort in improving planners overall UX and reliability. This Diff: 1. Add ConfigeratorStats to upload sharding plan to config store. **How is a sharding plan stored in Configerator?** The Thrift definition of a sharding plan includes two fields: Topology and Dict[int, ShardingOption]. 1. Topology: The Topology field contains the information mentioned in this diff D79142495. 2. Dict[int, ShardingOption]: This field represents a dictionary where each key is a 64-bit hash of a sharding option, and the value is the corresponding Thrift-converted sharding option. The hash is calculated using the storage_hash function within the ShardingOption object, which takes into account factors such as the fqn, sharding type, and compute kernel. **How can a loaded plan be merged with an enumerated search space?** **Background:** When a plan is preserved during the logging stage, a hash is generated to ensure that the same plan can be loaded and validated later. The [hash is calculated](https://www.internalfb.com/code/fbsource/[fdf90ff2be9041f867bc6c9e4aec6ee94862fa11]/fbcode/torchrec/distributed/planner/types.py?lines=1010-1026) using input fields such as topology, batch size, constraints, storage reservation, and storage reservation policy, as well as fields from the sharding options like fqn, sharding type, kernel type, shards, and cache parameters. Once the plan is loaded and validated, we can safely assume that all loaded sharding options are a 1:1 map of enumerated sharded options. During the loading process, we traverse the enumerated search space, calculate the storage hash for each sharding option, look up the corresponding sharding option from the loaded plan, and replace the Shards of the enumerated sharding option with those of the loaded sharding option. This approach enables us to generate precise sharding options that can be seamlessly converted into a sharing plan as done by the planner and this also ensures consistent logging while also facilitating plan replay. Differential Revision: D81185992 --- torchrec/distributed/planner/types.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 1dba9f36c..862d453b9 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -638,6 +638,22 @@ def __hash__(self) -> int: ) ) + def storage_hash(self) -> int: + """ + Hash needed to preserve sharding option uniquely based on input before + planning. This is needed to restore sharding option from the loaded plan. + Hash is computed based on the following attributes: + - fqn + - sharding_type + - compute_kernel + - column_wise_shard_dim + """ + # Use BLAKE2b for deterministic hashing, constrained to 64-bit signed int range + hash_str = f"{self.fqn}|{self.sharding_type}|{self.compute_kernel}|{self.cache_load_factor}|{self.num_shards}" + hash_bytes = hashlib.blake2b(hash_str.encode("utf-8"), digest_size=7).digest() + hash_int = int.from_bytes(hash_bytes, byteorder="big") + return hash_int + def __deepcopy__( self, memo: Optional[Dict[int, "ShardingOption"]] ) -> "ShardingOption": From dfb35a4830aaa5bc6db97b4ed491c5b56d43983c Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Sun, 14 Sep 2025 13:33:26 -0700 Subject: [PATCH 2/5] Integrating planner stats db with ConfigeratorStats (#3331) Summary: internal Context: Planner stats db is introduced in this diff to track metadata and perf metrics associated with sharding plan. This Diff: 1. Added methods to insert, select and delete planner stats db row. 2. UTs for planner stats db 3. Integration of planner stats db with ConfigeratorStats Differential Revision: D81216987 --- torchrec/distributed/planner/planners.py | 5 ++ torchrec/distributed/planner/stats.py | 3 ++ torchrec/distributed/planner/types.py | 63 ++++++++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index b2a2d1f9d..6f0b3da2f 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -42,6 +42,7 @@ ParameterConstraints, Partitioner, PerfModel, + PlanDebugStats, PlannerError, PlannerErrorType, Proposer, @@ -528,6 +529,10 @@ def plan( enumerator=self._enumerator, sharders=sharders, debug=self._debug, + debug_stats=PlanDebugStats( + planner_type=self.__class__.__name__, + timeout_seconds=self._timeout_seconds, + ), ) return sharding_plan else: diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index 82d2e367b..431c336a8 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -40,6 +40,7 @@ Enumerator, ParameterConstraints, Perf, + PlanDebugStats, ShardingOption, Stats, Storage, @@ -160,6 +161,7 @@ def log( sharders: Optional[List[ModuleSharder[nn.Module]]] = None, enumerator: Optional[Enumerator] = None, debug: bool = True, + debug_stats: Optional[PlanDebugStats] = None, ) -> None: """ Logs stats for a given sharding plan. @@ -1138,5 +1140,6 @@ def log( sharders: Optional[List[ModuleSharder[nn.Module]]] = None, enumerator: Optional[Enumerator] = None, debug: bool = True, + debug_stats: Optional[PlanDebugStats] = None, ) -> None: pass diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 862d453b9..c421ac0fc 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -960,6 +960,16 @@ def partition( ... +@dataclass +class PlanDebugStats: + """ + Representation of debug stats associated with a sharding plan, used for logging. + """ + + planner_type: str + timeout_seconds: Optional[int] + + class Stats(abc.ABC): """ Logs statistics related to the sharding plan. @@ -980,6 +990,7 @@ def log( sharders: Optional[List[ModuleSharder[nn.Module]]] = None, enumerator: Optional[Enumerator] = None, debug: bool = False, + debug_stats: Optional[PlanDebugStats] = None, ) -> None: """ See class description @@ -1007,6 +1018,16 @@ def hash_sha256_to_int(hashable_list: List[Any]) -> int: # pyre-ignore return int(hash_digest, 16) +def hash_sha256_str(hashable_list: List[Any]) -> str: # pyre-ignore + """ + Hashes the given data using SHA256 and returns the hash as an string + """ + serialized_list = str(hashable_list).encode("utf-8") + hash_object = hashlib.sha256(serialized_list) + hash_digest = hash_object.hexdigest() + return hash_digest + + def hash_planner_context_inputs( topology: Topology, batch_size: int, @@ -1047,3 +1068,45 @@ def hash_planner_context_inputs( constraints.items() if constraints else None, ] return hash_function(hashable_list) + + +def hash_planner_context_inputs_str( + topology: Topology, + batch_size: int, + enumerator: Enumerator, + storage_reservation: StorageReservation, + constraints: Optional[Dict[str, ParameterConstraints]], + # pyre-ignore + hash_function: Callable[[List[Any]], str] = hash_sha256_str, +) -> str: + assert hasattr( + enumerator, "last_stored_search_space" + ), "This enumerator is not compatible with hashing" + assert ( + enumerator.last_stored_search_space is not None # pyre-ignore + ), "Unable to hash planner context without an enumerator that has a precomputed search space" + search_space = enumerator.last_stored_search_space + storage_reservation_policy = type(storage_reservation).__name__ + + assert ( + storage_reservation._last_reserved_topology is not None # pyre-ignore + ), "Unable to hash planner context without a storage reservation that has a precomputed topology" + + hashable_list = [ + topology, + batch_size, + [ + [ + shard_option.fqn, + shard_option.sharding_type, + shard_option.compute_kernel, + tuple(shard_option.shards), + shard_option.cache_params, + ] + for shard_option in search_space + ], + storage_reservation_policy, + storage_reservation._last_reserved_topology, + constraints.items() if constraints else None, + ] + return hash_function(hashable_list) From d24729b969a08f97884019a7509e8d84d1ceb757 Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Sun, 14 Sep 2025 13:33:26 -0700 Subject: [PATCH 3/5] PlanLoader addition into planner (#3355) Summary: **Summary:** * Added PlanLoader abstract base class to enable loading pre-computed sharding plans from stored locations within planner. * Supports two key scenarios: 1. Reusing previously computed and stored sharding plans to avoid regeneration costs 2. Using sharding plans from previous runs as starting points for iterative improvements * Defines two abstract methods: * `load()`: Returns a dictionary mapping sharding option hashes to ShardingOption objects * `plan_validation_str()`: Provides validation string for plan integrity checks * Part of the broader effort to improve planner UX and reliability by enabling plan persistence and reuse across training runs Reviewed By: mserturk Differential Revision: D81571293 --- torchrec/distributed/planner/types.py | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index c421ac0fc..157fe48fe 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -998,6 +998,40 @@ def log( ... +class PlanLoader(abc.ABC): + """ + Retrieves a pre-computed sharding plan from its stored location. This is useful in two scenarios: + 1. To utilize a specific sharding plan that was previously computed and stored, saving the cost of re-generating the plan + 2. To use a sharding plan from previous runs as a starting point for the next run, allowing for improvement over time. + """ + + @abc.abstractmethod + def load( + self, + ) -> Optional[Dict[int, ShardingOption]]: + """ + Load sharding plan from its stored location. + + Returns: + Dict[int, ShardingOption]: loaded sharding plan. key is hash of sharding option to map to sharding option with enumerated sharding option. + """ + ... + + @abc.abstractmethod + def plan_context_hash( + self, + ) -> Optional[str]: + """ + Input context hash of a sharding plan. + + Returns: + str: hash of sharding plan context. + """ + ... + + ... + + @dataclass class CriticalPathEstimate: comms_estimate: float From 9f5b47f602bd32003e727a5c3b216547bc7dc1eb Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Sun, 14 Sep 2025 13:33:26 -0700 Subject: [PATCH 4/5] Configerator based PlanLoader implementation (#3356) Summary: Add ConfigeratorPlanLoader an implementation of the PlanLoader interface to enable: **Key Features:** 1. Plan Retrieval: Loads compressed sharding plans from Configerator using plan_id 2. Database Integration: Queries PlannerStatsDB to get storage location and context hash 3. Decompression: Uses zstd to decompress stored plan data 4. Thrift Conversion: Deserializes Thrift structures and converts back to Python ShardingOption objects 5. Error Handling: Failure scenarios with configurable fallback behavior **Error Handling & Fallback Scenarios:** The implementation supports two distinct error handling modes controlled by `enable_fallback`: **Normal Mode (enable_fallback=False - Default):** - Raises `PlannerError` with `PLAN_LOADING_FAILED` type for any failure - Error scenarios include: - Network connectivity issues (Configerator service unavailable) - Invalid plan id or config path - Data decompression failures - Thrift deserialization errors - Thrift-to-Python conversion failures **Fallback Mode (enable_fallback=True):** - Returns `None` instead of raising exceptions on loading failures - Logs detailed warning messages with plan_id, config_path, and error details - Enables graceful degradation where system can fall back to alternative planning strategies - Suitable for development, experimentation, or scenarios prioritizing availability over strict error handling - Warning logs include full context for debugging: plan ID, Configerator path, and original error Differential Revision: D81573577 --- torchrec/distributed/planner/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 157fe48fe..fe676cf9f 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -796,6 +796,7 @@ class PlannerErrorType(Enum): PARTITION = "partition" OTHER = "other" PLANNER_INPUT_CONTEXT_MISMATCH = "planner_input_context_mismatch" + PLAN_LOADING_FAILED = "plan_loading_failed" class PlannerError(Exception): From 482ba49769a833d73c3823003cb63c5d8462603b Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Sun, 14 Sep 2025 13:33:26 -0700 Subject: [PATCH 5/5] Integrating PlanLoader within OSS Planner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Integrate PlanLoader functionality within the EmbeddingShardingPlanner to enable loading and reusing pre-computed sharding plans. This integration extends the OSS planner with plan loading capabilities. This diff includes: * PlanLoader Integration in EmbeddingShardingPlanner: - Added optional `plan_loader` parameter to EmbeddingShardingPlanner constructor - Integrated plan validation using context hash comparison to ensure loaded plans are compatible with current planner configuration - Fallback to normal planning when plan loader returns null * Plan Loading Workflow:Check if loaded plan context hash matches current planner context * If mismatch detected → raise PlannerError * If validation passes → load sharding options from storage * Map loaded sharding options to current search space using storage_hash * Skip planning phase and use pre-computed plan if available * Search Space Reconstruction: * Mapping of loaded sharding options to enumerated search space * Preserving all original ShardingOption metadata while replacing shard assignments Differential Revision: D81279558 --- torchrec/distributed/planner/planners.py | 274 ++++++++++++++++------- 1 file changed, 197 insertions(+), 77 deletions(-) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 6f0b3da2f..1214357c6 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -39,10 +39,12 @@ from torchrec.distributed.planner.types import ( Enumerator, hash_planner_context_inputs, + hash_planner_context_inputs_str, ParameterConstraints, Partitioner, PerfModel, PlanDebugStats, + PlanLoader, PlannerError, PlannerErrorType, Proposer, @@ -118,6 +120,50 @@ def to_sharding_plan( return ShardingPlan(plan) +def _update_search_space( + search_space: List[ShardingOption], + loaded_sharding_options: Dict[int, ShardingOption], +) -> List[ShardingOption]: + new_search_space: List[ShardingOption] = [] + for so in search_space: + loaded_so = loaded_sharding_options.get(so.storage_hash()) + if loaded_so is not None: + new_search_space.append( + ShardingOption( + name=so.name, + tensor=so.tensor, + module=so.module, + input_lengths=so.input_lengths, + batch_size=so.batch_size, + compute_kernel=so.compute_kernel, + sharding_type=so.sharding_type, + partition_by=so.partition_by, + # We only need to update the shards from the loaded plan + shards=loaded_so.shards, + cache_params=so.cache_params, + enforce_hbm=so.enforce_hbm, + stochastic_rounding=so.stochastic_rounding, + bounds_check_mode=so.bounds_check_mode, + dependency=so.dependency, + is_pooled=so.is_pooled, + feature_names=so.feature_names, + output_dtype=so.output_dtype, + key_value_params=so.key_value_params, + ) + ) + else: + logger.info( + f"Loaded sharding options from Storage, but not all search space is covered. " + f"Loaded {len(new_search_space)} out of {len(search_space)} search space." + ) + raise PlannerError( + error_type=PlannerErrorType.PLAN_LOADING_FAILED, + message="Unable to create merge loaded plan with enumerated space due to sharded option key mismatch. \n", + ) + + return new_search_space + + def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan: if len(best_plans) == 1: return best_plans[0] @@ -269,6 +315,22 @@ def hash_planner_context_inputs(self) -> int: self._constraints, ) + def hash_planner_context_inputs_str(self) -> str: + """ + Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats. + These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context. + + Returns: + Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints. + """ + return hash_planner_context_inputs_str( + self._topology, + self._batch_size, + self._enumerator, + self._storage_reservation, + self._constraints, + ) + class EmbeddingShardingPlanner(EmbeddingPlannerBase): """ @@ -315,6 +377,7 @@ def __init__( List[Callable[[List[ShardingOption]], List[ShardingOption]]] ] = None, timeout_seconds: Optional[int] = None, + plan_loader: Optional[PlanLoader] = None, ) -> None: super().__init__( topology=topology, @@ -347,6 +410,8 @@ def __init__( else NoopPerfModel(topology=self._topology) ) + self.plan_loader = plan_loader + self._num_proposals: int = 0 self._num_plans: int = 0 self._best_plan: Optional[List[ShardingOption]] = None @@ -427,86 +492,115 @@ def plan( # No shardable parameters return ShardingPlan({}) - proposal_cache: Dict[ - Tuple[int, ...], - Tuple[bool, Optional[List[ShardingOption]], Optional[float]], - ] = {} - - for proposer in self._proposers: - proposer.load(search_space=search_space, enumerator=self._enumerator) - - start = time.time() - for proposer in self._proposers: - proposal = proposer.propose() - - while proposal: - end = time.time() - elapsed = end - start - if self._timeout_seconds: - if elapsed > self._timeout_seconds: - logger.info( - f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s" - ) - break - proposal_key = tuple(sorted(map(hash, proposal))) - if proposal_key in proposal_cache: - partitionable, plan, perf_rating = proposal_cache[proposal_key] - proposer.feedback( - partitionable=partitionable, - plan=plan, - perf_rating=perf_rating, - storage_constraint=storage_constraint, - ) - proposal = proposer.propose() - continue - - self._num_proposals += 1 - try: - # plan is just proposal where shard.rank is populated - plan = self._partitioner.partition( - proposal=proposal, - storage_constraint=storage_constraint, - ) - self._num_plans += 1 - perf_rating = self._perf_model.rate(plan=plan) - if perf_rating < best_perf_rating: - best_perf_rating = perf_rating - best_plan = copy.deepcopy(plan) - proposal_cache[proposal_key] = (True, plan, perf_rating) - proposer.feedback( - partitionable=True, - plan=plan, - perf_rating=perf_rating, - storage_constraint=storage_constraint, - ) - except PlannerError as planner_error: - last_planner_error = planner_error - # shallow copy of the proposal - last_proposal: List[ShardingOption] = copy.copy(proposal) - current_storage = cast( - Storage, - reduce( - lambda x, y: x + y, - [ - shard.storage - for option in proposal - for shard in option.shards - ], - ), - ) - if current_storage < lowest_storage: - lowest_storage = current_storage - proposal_cache[proposal_key] = (False, proposal, None) - proposer.feedback( - partitionable=False, - plan=proposal, - storage_constraint=storage_constraint, - ) + loaded_sharding_options = None + new_search_space: List[ShardingOption] = [] + # import fbvscode + + # fbvscode.set_trace() + if self.plan_loader is not None: + # validate plan before loading + self._loader_plan_validation( + current_planner_hash=self.hash_planner_context_inputs_str(), + # pyre-fixme[16]: `Optional` has no attribute `plan_context_hash`. + loaded_plan_hash=self.plan_loader.plan_context_hash(), + ) + # pyre-ignore + loaded_sharding_options = self.plan_loader.load() + if loaded_sharding_options is not None: + # Merging sharding options from loaded plan with enumerated search space + new_search_space = _update_search_space( + search_space=search_space, + loaded_sharding_options=loaded_sharding_options, + ) + + # Loaded plan is validated successfully and can be used for generate the sharding plan, skipping new plan generation. + if new_search_space: + logger.info( + # pyre-ignore + f"Loded sharding options from Storage with plan id: {self.plan_loader.get_plan_id()} skipping new plan generation" + ) + best_plan = copy.deepcopy(new_search_space) + else: + proposal_cache: Dict[ + Tuple[int, ...], + Tuple[bool, Optional[List[ShardingOption]], Optional[float]], + ] = {} + + for proposer in self._proposers: + proposer.load(search_space=search_space, enumerator=self._enumerator) - # clear shard.rank for each sharding_option - reset_shard_rank(proposal) + start = time.time() + for proposer in self._proposers: proposal = proposer.propose() + while proposal: + end = time.time() + elapsed = end - start + if self._timeout_seconds: + if elapsed > self._timeout_seconds: + logger.info( + f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s" + ) + break + proposal_key = tuple(sorted(map(hash, proposal))) + if proposal_key in proposal_cache: + partitionable, plan, perf_rating = proposal_cache[proposal_key] + proposer.feedback( + partitionable=partitionable, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + continue + + self._num_proposals += 1 + try: + # plan is just proposal where shard.rank is populated + plan = self._partitioner.partition( + proposal=proposal, + storage_constraint=storage_constraint, + ) + self._num_plans += 1 + perf_rating = self._perf_model.rate(plan=plan) + if perf_rating < best_perf_rating: + best_perf_rating = perf_rating + best_plan = copy.deepcopy(plan) + proposal_cache[proposal_key] = (True, plan, perf_rating) + proposer.feedback( + partitionable=True, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, + ) + except PlannerError as planner_error: + last_planner_error = planner_error + # shallow copy of the proposal + last_proposal: List[ShardingOption] = copy.copy(proposal) + current_storage = cast( + Storage, + reduce( + lambda x, y: x + y, + [ + shard.storage + for option in proposal + for shard in option.shards + ], + ), + ) + if current_storage < lowest_storage: + lowest_storage = current_storage + proposal_cache[proposal_key] = (False, proposal, None) + proposer.feedback( + partitionable=False, + plan=proposal, + storage_constraint=storage_constraint, + ) + + # clear shard.rank for each sharding_option + reset_shard_rank(proposal) + proposal = proposer.propose() + if best_plan: for callback in self._callbacks: best_plan = callback(best_plan) @@ -607,6 +701,32 @@ def plan( + last_planner_error_info, ) + def _loader_plan_validation( + self, current_planner_hash: str, loaded_plan_hash: Optional[str] + ) -> None: + """ + Validates that the current planner context hash matches the loaded plan context hash. + + Args: + current_planner_hash (str): Hash from current planner context + loaded_plan_hash (Optional[str]): Hash from loaded plan context + + Raises: + PlannerError: If hashes don't match + """ + if loaded_plan_hash is not None and current_planner_hash != loaded_plan_hash: + # pyre-fixme[16]: `Optional` has no attribute `get_plan_id`. + plan_id = self.plan_loader.get_plan_id() if self.plan_loader else None + error_msg = ( + f"Planner input context mismatch detected for {plan_id} and current planner set up:" + f"\nCurrent planner hash: {current_planner_hash}, Loaded plan hash: {loaded_plan_hash}" + ) + raise PlannerError( + error_type=PlannerErrorType.PLANNER_INPUT_CONTEXT_MISMATCH, + message="Unable to load, because of planner input mismatch - cannot validate this plan is the best plan for current context.. \n" + + error_msg, + ) + class HeteroEmbeddingShardingPlanner(ShardingPlanner): """