Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 202 additions & 77 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@
from torchrec.distributed.planner.types import (
Enumerator,
hash_planner_context_inputs,
hash_planner_context_inputs_str,
ParameterConstraints,
Partitioner,
PerfModel,
PlanDebugStats,
PlanLoader,
PlannerError,
PlannerErrorType,
Proposer,
Expand Down Expand Up @@ -117,6 +120,50 @@ def to_sharding_plan(
return ShardingPlan(plan)


def _update_search_space(
search_space: List[ShardingOption],
loaded_sharding_options: Dict[int, ShardingOption],
) -> List[ShardingOption]:
new_search_space: List[ShardingOption] = []
for so in search_space:
loaded_so = loaded_sharding_options.get(so.storage_hash())
if loaded_so is not None:
new_search_space.append(
ShardingOption(
name=so.name,
tensor=so.tensor,
module=so.module,
input_lengths=so.input_lengths,
batch_size=so.batch_size,
compute_kernel=so.compute_kernel,
sharding_type=so.sharding_type,
partition_by=so.partition_by,
# We only need to update the shards from the loaded plan
shards=loaded_so.shards,
cache_params=so.cache_params,
enforce_hbm=so.enforce_hbm,
stochastic_rounding=so.stochastic_rounding,
bounds_check_mode=so.bounds_check_mode,
dependency=so.dependency,
is_pooled=so.is_pooled,
feature_names=so.feature_names,
output_dtype=so.output_dtype,
key_value_params=so.key_value_params,
)
)
else:
logger.info(
f"Loaded sharding options from Storage, but not all search space is covered. "
f"Loaded {len(new_search_space)} out of {len(search_space)} search space."
)
raise PlannerError(
error_type=PlannerErrorType.PLAN_LOADING_FAILED,
message="Unable to create merge loaded plan with enumerated space due to sharded option key mismatch. \n",
)

return new_search_space


def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan:
if len(best_plans) == 1:
return best_plans[0]
Expand Down Expand Up @@ -268,6 +315,22 @@ def hash_planner_context_inputs(self) -> int:
self._constraints,
)

def hash_planner_context_inputs_str(self) -> str:
"""
Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats.
These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context.

Returns:
Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints.
"""
return hash_planner_context_inputs_str(
self._topology,
self._batch_size,
self._enumerator,
self._storage_reservation,
self._constraints,
)


class EmbeddingShardingPlanner(EmbeddingPlannerBase):
"""
Expand Down Expand Up @@ -314,6 +377,7 @@ def __init__(
List[Callable[[List[ShardingOption]], List[ShardingOption]]]
] = None,
timeout_seconds: Optional[int] = None,
plan_loader: Optional[PlanLoader] = None,
) -> None:
super().__init__(
topology=topology,
Expand Down Expand Up @@ -346,6 +410,8 @@ def __init__(
else NoopPerfModel(topology=self._topology)
)

self.plan_loader = plan_loader

self._num_proposals: int = 0
self._num_plans: int = 0
self._best_plan: Optional[List[ShardingOption]] = None
Expand Down Expand Up @@ -426,86 +492,115 @@ def plan(
# No shardable parameters
return ShardingPlan({})

proposal_cache: Dict[
Tuple[int, ...],
Tuple[bool, Optional[List[ShardingOption]], Optional[float]],
] = {}

for proposer in self._proposers:
proposer.load(search_space=search_space, enumerator=self._enumerator)

start = time.time()
for proposer in self._proposers:
proposal = proposer.propose()

while proposal:
end = time.time()
elapsed = end - start
if self._timeout_seconds:
if elapsed > self._timeout_seconds:
logger.info(
f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s"
)
break
proposal_key = tuple(sorted(map(hash, proposal)))
if proposal_key in proposal_cache:
partitionable, plan, perf_rating = proposal_cache[proposal_key]
proposer.feedback(
partitionable=partitionable,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
proposal = proposer.propose()
continue

self._num_proposals += 1
try:
# plan is just proposal where shard.rank is populated
plan = self._partitioner.partition(
proposal=proposal,
storage_constraint=storage_constraint,
)
self._num_plans += 1
perf_rating = self._perf_model.rate(plan=plan)
if perf_rating < best_perf_rating:
best_perf_rating = perf_rating
best_plan = copy.deepcopy(plan)
proposal_cache[proposal_key] = (True, plan, perf_rating)
proposer.feedback(
partitionable=True,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
except PlannerError as planner_error:
last_planner_error = planner_error
# shallow copy of the proposal
last_proposal: List[ShardingOption] = copy.copy(proposal)
current_storage = cast(
Storage,
reduce(
lambda x, y: x + y,
[
shard.storage
for option in proposal
for shard in option.shards
],
),
)
if current_storage < lowest_storage:
lowest_storage = current_storage
proposal_cache[proposal_key] = (False, proposal, None)
proposer.feedback(
partitionable=False,
plan=proposal,
storage_constraint=storage_constraint,
)
loaded_sharding_options = None
new_search_space: List[ShardingOption] = []
# import fbvscode

# fbvscode.set_trace()
if self.plan_loader is not None:
# validate plan before loading
self._loader_plan_validation(
current_planner_hash=self.hash_planner_context_inputs_str(),
# pyre-fixme[16]: `Optional` has no attribute `plan_context_hash`.
loaded_plan_hash=self.plan_loader.plan_context_hash(),
)
# pyre-ignore
loaded_sharding_options = self.plan_loader.load()
if loaded_sharding_options is not None:
# Merging sharding options from loaded plan with enumerated search space
new_search_space = _update_search_space(
search_space=search_space,
loaded_sharding_options=loaded_sharding_options,
)

# Loaded plan is validated successfully and can be used for generate the sharding plan, skipping new plan generation.
if new_search_space:
logger.info(
# pyre-ignore
f"Loded sharding options from Storage with plan id: {self.plan_loader.get_plan_id()} skipping new plan generation"
)
best_plan = copy.deepcopy(new_search_space)
else:
proposal_cache: Dict[
Tuple[int, ...],
Tuple[bool, Optional[List[ShardingOption]], Optional[float]],
] = {}

for proposer in self._proposers:
proposer.load(search_space=search_space, enumerator=self._enumerator)

# clear shard.rank for each sharding_option
reset_shard_rank(proposal)
start = time.time()
for proposer in self._proposers:
proposal = proposer.propose()

while proposal:
end = time.time()
elapsed = end - start
if self._timeout_seconds:
if elapsed > self._timeout_seconds:
logger.info(
f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s"
)
break
proposal_key = tuple(sorted(map(hash, proposal)))
if proposal_key in proposal_cache:
partitionable, plan, perf_rating = proposal_cache[proposal_key]
proposer.feedback(
partitionable=partitionable,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
proposal = proposer.propose()
continue

self._num_proposals += 1
try:
# plan is just proposal where shard.rank is populated
plan = self._partitioner.partition(
proposal=proposal,
storage_constraint=storage_constraint,
)
self._num_plans += 1
perf_rating = self._perf_model.rate(plan=plan)
if perf_rating < best_perf_rating:
best_perf_rating = perf_rating
best_plan = copy.deepcopy(plan)
proposal_cache[proposal_key] = (True, plan, perf_rating)
proposer.feedback(
partitionable=True,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
except PlannerError as planner_error:
last_planner_error = planner_error
# shallow copy of the proposal
last_proposal: List[ShardingOption] = copy.copy(proposal)
current_storage = cast(
Storage,
reduce(
lambda x, y: x + y,
[
shard.storage
for option in proposal
for shard in option.shards
],
),
)
if current_storage < lowest_storage:
lowest_storage = current_storage
proposal_cache[proposal_key] = (False, proposal, None)
proposer.feedback(
partitionable=False,
plan=proposal,
storage_constraint=storage_constraint,
)

# clear shard.rank for each sharding_option
reset_shard_rank(proposal)
proposal = proposer.propose()

if best_plan:
for callback in self._callbacks:
best_plan = callback(best_plan)
Expand All @@ -528,6 +623,10 @@ def plan(
enumerator=self._enumerator,
sharders=sharders,
debug=self._debug,
debug_stats=PlanDebugStats(
planner_type=self.__class__.__name__,
timeout_seconds=self._timeout_seconds,
),
)
return sharding_plan
else:
Expand Down Expand Up @@ -602,6 +701,32 @@ def plan(
+ last_planner_error_info,
)

def _loader_plan_validation(
self, current_planner_hash: str, loaded_plan_hash: Optional[str]
) -> None:
"""
Validates that the current planner context hash matches the loaded plan context hash.

Args:
current_planner_hash (str): Hash from current planner context
loaded_plan_hash (Optional[str]): Hash from loaded plan context

Raises:
PlannerError: If hashes don't match
"""
if loaded_plan_hash is not None and current_planner_hash != loaded_plan_hash:
# pyre-fixme[16]: `Optional` has no attribute `get_plan_id`.
plan_id = self.plan_loader.get_plan_id() if self.plan_loader else None
error_msg = (
f"Planner input context mismatch detected for {plan_id} and current planner set up:"
f"\nCurrent planner hash: {current_planner_hash}, Loaded plan hash: {loaded_plan_hash}"
)
raise PlannerError(
error_type=PlannerErrorType.PLANNER_INPUT_CONTEXT_MISMATCH,
message="Unable to load, because of planner input mismatch - cannot validate this plan is the best plan for current context.. \n"
+ error_msg,
)


class HeteroEmbeddingShardingPlanner(ShardingPlanner):
"""
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
Enumerator,
ParameterConstraints,
Perf,
PlanDebugStats,
ShardingOption,
Stats,
Storage,
Expand Down Expand Up @@ -160,6 +161,7 @@ def log(
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
enumerator: Optional[Enumerator] = None,
debug: bool = True,
debug_stats: Optional[PlanDebugStats] = None,
) -> None:
"""
Logs stats for a given sharding plan.
Expand Down Expand Up @@ -1138,5 +1140,6 @@ def log(
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
enumerator: Optional[Enumerator] = None,
debug: bool = True,
debug_stats: Optional[PlanDebugStats] = None,
) -> None:
pass
Loading
Loading