Skip to content

Reshard API Performance Benchmarking #3218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions torchrec/distributed/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
write_report,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta
from torchrec.distributed.test_utils.test_model import TestEBCSharder
from torchrec.distributed.types import DataType
from torchrec.distributed.types import DataType, EmbeddingModuleShardingPlan
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

Expand All @@ -53,8 +54,27 @@ def training_func_to_benchmark(
model: torch.nn.Module,
bench_inputs: List[KeyedJaggedTensor],
optimizer: Optional[torch.optim.Optimizer],
resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None,
) -> None:
for bench_input in bench_inputs:

reshard_idx = 0

for i, bench_input in enumerate(bench_inputs):
if resharding_plan_diffs is not None:
if (
i > 0
and len(resharding_plan_diffs) > 0
and i % (len(bench_inputs) / len(resharding_plan_diffs)) == 0
):

plan_difference = output_sharding_plan_delta(
# Pyre-ignore
model.plan.plan["_module"],
resharding_plan_diffs[reshard_idx],
)
# Pyre-ignore
model.reshard("_module", plan_difference)
reshard_idx += 1
pooled_embeddings = model(bench_input)
vals = []
for _name, param in pooled_embeddings.to_dict().items():
Expand Down Expand Up @@ -120,6 +140,7 @@ def benchmark_ebc(


def main() -> None:
# torch.cuda.cudart().cudaProfilerStart()
args: argparse.Namespace = init_argparse_and_args()

num_requests = args.bench_iters * args.batch_size * args.num_benchmarks
Expand Down Expand Up @@ -203,6 +224,8 @@ def main() -> None:
for i, write_report_func in enumerate(write_report_funcs_per_module):
write_report_func(benchmark_results_per_module[i])

# torch.cuda.cudart().cudaProfilerStop()


def invoke_main() -> None:
logging.basicConfig()
Expand All @@ -212,4 +235,5 @@ def invoke_main() -> None:


if __name__ == "__main__":

invoke_main() # pragma: no cover
113 changes: 111 additions & 2 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import logging
import os
import random
import resource
import time
import timeit
Expand Down Expand Up @@ -55,10 +56,21 @@
EmbeddingStorageEstimator,
)
from torchrec.distributed.shard import _shard_modules

from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta
from torchrec.distributed.sharding_plan import (
construct_module_sharding_plan,
get_sharding_constructor_from_type,
)
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
from torchrec.distributed.test_utils.test_model import ModelInput

from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv
from torchrec.distributed.types import (
DataType,
EmbeddingModuleShardingPlan,
ModuleSharder,
ShardingEnv,
)
from torchrec.fx import symbolic_trace
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
from torchrec.quant.embedding_modules import (
Expand Down Expand Up @@ -359,8 +371,32 @@ def forward(self, input: KeyedJaggedTensor) -> KeyedTensor:
T = TypeVar("T", bound=torch.nn.Module)


def _generate_rank_placements(
world_size: int,
num_tables: int,
ranks_per_tables: List[int],
random_seed: Optional[int] = None,
) -> List[List[int]]:
# Cannot include old/new rank generation with hypothesis library due to depedency on world_size
if random_seed is None:
# Generate a random seed to ensure that the output rank placements can be different each time
random_seed = random.randint(0, 10000)
placements = []
max_rank = world_size - 1
random.seed(random_seed)
if ranks_per_tables == [0]:
ranks_per_tables = [random.randint(1, max_rank) for _ in range(num_tables)]
for i in range(num_tables):
ranks_per_table = ranks_per_tables[i]
placement = sorted(random.sample(range(world_size), ranks_per_table))
placements.append(placement)
return placements


def default_func_to_benchmark(
model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor]
model: torch.nn.Module,
bench_inputs: List[KeyedJaggedTensor],
resharding_plan_diffs: Optional[List[EmbeddingModuleShardingPlan]] = None,
) -> None:
with torch.inference_mode():
for bench_input in bench_inputs:
Expand Down Expand Up @@ -679,6 +715,8 @@ def init_argparse_and_args() -> argparse.Namespace:
parser.add_argument("--num_benchmarks", type=int, default=5)
parser.add_argument("--embedding_config_json", type=str, default="")
parser.add_argument("--device_type", type=str, default="cuda")
parser.add_argument("--enable_resharding", type=bool, default=False)
parser.add_argument("--resharding_interval", type=int, default=1000)

args = parser.parse_args()

Expand Down Expand Up @@ -728,6 +766,7 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
# Don't want to modify the module outright
# Since module is on cpu, won't cause cuda oom.
copied_module = copy.deepcopy(module)

# pyre-ignore [6]
plan = planner.plan(copied_module, [sharder])

Expand Down Expand Up @@ -780,6 +819,7 @@ def _run_benchmark_core(
pre_gpu_load: int = 0,
export_stacks: bool = False,
reset_accumulated_memory_stats: bool = False,
new_ranks_per_plan: Optional[List[int]] = None,
) -> BenchmarkResult:
"""Internal helper that contains the core benchmarking logic shared by
``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory
Expand Down Expand Up @@ -1071,6 +1111,7 @@ def init_module_and_run_benchmark(
queue: Optional[mp.Queue] = None,
pooling_configs: Optional[List[int]] = None,
benchmark_unsharded_module: bool = False,
new_ranks_per_plan: Optional[List[List[int]]] = None,
) -> BenchmarkResult:
"""
There are a couple of caveats here as to why the module has to be initialized
Expand Down Expand Up @@ -1117,6 +1158,37 @@ def init_module_and_run_benchmark(
if rank != -1
else contextlib.nullcontext()
) as ctx:

resharding_plans = []

if new_ranks_per_plan is not None and len(new_ranks_per_plan) > 0:
sharding_type_constructor = get_sharding_constructor_from_type(
sharding_type
)
for new_ranks_per_table in new_ranks_per_plan:
new_per_param_sharding = {}
for table_id, table in enumerate(tables):
if sharding_type == ShardingType.TABLE_WISE:
new_per_param_sharding[table.name] = sharding_type_constructor(
rank=new_ranks_per_table[table_id][0],
compute_kernel=sharder._kernel_type,
)
elif sharding_type == ShardingType.COLUMN_WISE:
new_per_param_sharding[table.name] = sharding_type_constructor(
ranks=new_ranks_per_table[table_id]
)

new_module_sharding_plan = construct_module_sharding_plan(
module=module._module, # Pyre-ignore
# Pyre-ignore
sharder=sharder,
per_param_sharding=new_per_param_sharding,
local_size=world_size,
world_size=world_size,
device_type=device.type,
)
resharding_plans.append(new_module_sharding_plan)

module = transform_module(
module=module,
device=device,
Expand All @@ -1136,6 +1208,10 @@ def init_module_and_run_benchmark(
else:
name = benchmark_type_name(compile_mode, sharding_type)

if benchmark_func_kwargs is None:
benchmark_func_kwargs = {}
benchmark_func_kwargs["resharding_plan_diffs"] = resharding_plans

res = benchmark(
name,
module,
Expand Down Expand Up @@ -1244,6 +1320,8 @@ def benchmark_module(
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
device_type: str = "cuda",
enable_resharding: bool = False,
resharding_interval: int = 1000,
) -> List[BenchmarkResult]:
"""
Args:
Expand Down Expand Up @@ -1325,6 +1403,36 @@ def benchmark_module(
)

if train:

new_ranks_per_plan = []

if enable_resharding:
total_plans_per_benchmark = bench_iters // resharding_interval
total_plans_per_benchmark = max(1, total_plans_per_benchmark)

num_tables = len(tables)
ranks_per_tables = []

if sharding_type == ShardingType.TABLE_WISE:
ranks_per_tables = [1 for _ in range(num_tables)]

elif sharding_type == ShardingType.COLUMN_WISE:
valid_candidates = [
i
for i in range(1, world_size + 1)
if EMBEDDING_DIM % i == 0
]
ranks_per_tables = [
random.choice(valid_candidates) for _ in range(num_tables)
]

new_ranks_per_plan = [
_generate_rank_placements(
world_size, num_tables, ranks_per_tables
)
for _ in range(total_plans_per_benchmark)
]

res = multi_process_benchmark(
# pyre-ignore[6]
callable=init_module_and_run_benchmark,
Expand All @@ -1344,6 +1452,7 @@ def benchmark_module(
func_to_benchmark=func_to_benchmark,
benchmark_func_kwargs=benchmark_func_kwargs,
pooling_configs=pooling_configs,
new_ranks_per_plan=new_ranks_per_plan,
)
else:
res = init_module_and_run_benchmark(
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,10 @@ def reshard(

# Need to use .module to maintain FQN consistency
self._optim: CombinedOptimizer = self._init_optim(
self._dmp_wrapped_module.module # pyre-ignore
# pyre-ignore
self._dmp_wrapped_module.module
if hasattr(self._dmp_wrapped_module, "module")
else self._dmp_wrapped_module._module
)
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
return sharded_module
Expand Down
Loading