Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/distributed/benchmark/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch

from torchrec.distributed.benchmark.benchmark_base import (
from torchrec.distributed.benchmark.base import (
BenchmarkResult,
CompileMode,
DLRM_NUM_EMBEDDINGS_PER_FEATURE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import click
import torch
from torchrec.distributed.benchmark.benchmark_base import benchmark_func
from torchrec.distributed.benchmark.base import benchmark_func
from torchrec.distributed.embedding import EmbeddingCollectionContext
from torchrec.distributed.embedding_sharding import _set_sharding_context_post_a2a
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ComputeDevice,
SplitTableBatchedEmbeddingBagsCodegen,
)
from torchrec.distributed.benchmark.benchmark_base import benchmark_func
from torchrec.distributed.benchmark.base import benchmark_func
from torchrec.distributed.test_utils.test_model import ModelInput
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import List, Optional, Tuple

import torch
from torchrec.distributed.benchmark.benchmark_base import (
from torchrec.distributed.benchmark.base import (
BenchmarkResult,
CompileMode,
init_argparse_and_args,
Expand Down
40 changes: 5 additions & 35 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
OSS (external):
python -m torchrec.distributed.benchmark.benchmark_train_pipeline --world_size=4 --pipeline=sparse --batch_size=10

Adding New Model Support:
To support a new model in pipeline benchmark:
See benchmark_pipeline_utils.py for step-by-step instructions.
"""

Expand All @@ -26,7 +26,7 @@
import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from torch import nn
from torchrec.distributed.benchmark.benchmark_base import (
from torchrec.distributed.benchmark.base import (
benchmark_func,
BenchmarkResult,
cmd_conf,
Expand All @@ -37,7 +37,6 @@
BaseModelConfig,
create_model_config,
generate_data,
generate_pipeline,
generate_planner,
generate_sharded_model_and_optimizer,
)
Expand All @@ -49,9 +48,10 @@
MultiProcessContext,
run_multi_process_func,
)
from torchrec.distributed.test_utils.table_config import EmbeddingTablesConfig
from torchrec.distributed.test_utils.test_input import ModelInput
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
from torchrec.distributed.test_utils.test_tables import EmbeddingTablesConfig
from torchrec.distributed.test_utils.train_pipeline import PipelineConfig
from torchrec.distributed.train_pipeline import TrainPipeline
from torchrec.distributed.types import ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig
Expand Down Expand Up @@ -116,33 +116,6 @@ class RunOptions:
export_stacks: bool = False


@dataclass
class PipelineConfig:
"""
Configuration for training pipelines.

This class defines the parameters for configuring the training pipeline.

Args:
pipeline (str): The type of training pipeline to use. Options include:
- "base": Basic training pipeline
- "sparse": Pipeline optimized for sparse operations
- "fused": Pipeline with fused sparse distribution
- "semi": Semi-synchronous training pipeline
- "prefetch": Pipeline with prefetching for sparse distribution
Default is "base".
emb_lookup_stream (str): The stream to use for embedding lookups.
Only used by certain pipeline types (e.g., "fused").
Default is "data_dist".
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
Default is False.
"""

pipeline: str = "base"
emb_lookup_stream: str = "data_dist"
apply_jit: bool = False


@dataclass
class ModelSelectionConfig:
model_name: str = "test_sparse_nn"
Expand Down Expand Up @@ -279,13 +252,10 @@ def _func_to_benchmark(
except StopIteration:
break

pipeline = generate_pipeline(
pipeline_type=pipeline_config.pipeline,
emb_lookup_stream=pipeline_config.emb_lookup_stream,
pipeline = pipeline_config.generate_pipeline(
model=sharded_model,
opt=optimizer,
device=ctx.device,
apply_jit=pipeline_config.apply_jit,
)
pipeline.progress(iter(bench_inputs))

Expand Down
85 changes: 1 addition & 84 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
Utilities for benchmarking training pipelines with different model configurations.

Adding New Model Support:
To support a new model in pipeline benchmark:
1. Create config class inheriting from BaseModelConfig with generate_model() method
2. Add the model to model_configs dict in create_model_config()
3. Add model-specific params to ModelSelectionConfig and create_model_config's arguments in benchmark_train_pipeline.py
Expand Down Expand Up @@ -39,15 +39,6 @@
TestTowerCollectionSparseNN,
TestTowerSparseNN,
)
from torchrec.distributed.train_pipeline import (
TrainPipelineBase,
TrainPipelineFusedSparseDist,
TrainPipelineSparseDist,
)
from torchrec.distributed.train_pipeline.train_pipelines import (
PrefetchTrainPipelineSparseDist,
TrainPipelineSemiSync,
)
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
from torchrec.models.deepfm import SimpleDeepFMNNWrapper
from torchrec.models.dlrm import DLRMWrapper
Expand Down Expand Up @@ -249,80 +240,6 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
return model_class(**filtered_kwargs)


def generate_pipeline(
pipeline_type: str,
emb_lookup_stream: str,
model: nn.Module,
opt: torch.optim.Optimizer,
device: torch.device,
apply_jit: bool = False,
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
"""
Generate a training pipeline instance based on the configuration.

This function creates and returns the appropriate training pipeline object
based on the pipeline type specified. Different pipeline types are optimized
for different training scenarios.

Args:
pipeline_type (str): The type of training pipeline to use. Options include:
- "base": Basic training pipeline
- "sparse": Pipeline optimized for sparse operations
- "fused": Pipeline with fused sparse distribution
- "semi": Semi-synchronous training pipeline
- "prefetch": Pipeline with prefetching for sparse distribution
emb_lookup_stream (str): The stream to use for embedding lookups.
Only used by certain pipeline types (e.g., "fused").
model (nn.Module): The model to be trained.
opt (torch.optim.Optimizer): The optimizer to use for training.
device (torch.device): The device to run the training on.
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
Default is False.

Returns:
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
appropriate training pipeline class based on the configuration.

Raises:
RuntimeError: If an unknown pipeline type is specified.
"""

_pipeline_cls: Dict[
str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]
] = {
"base": TrainPipelineBase,
"sparse": TrainPipelineSparseDist,
"fused": TrainPipelineFusedSparseDist,
"semi": TrainPipelineSemiSync,
"prefetch": PrefetchTrainPipelineSparseDist,
}

if pipeline_type == "semi":
return TrainPipelineSemiSync(
model=model,
optimizer=opt,
device=device,
start_batch=0,
apply_jit=apply_jit,
)
elif pipeline_type == "fused":
return TrainPipelineFusedSparseDist(
model=model,
optimizer=opt,
device=device,
emb_lookup_stream=emb_lookup_stream,
apply_jit=apply_jit,
)
elif pipeline_type == "base":
assert apply_jit is False, "JIT is not supported for base pipeline"

return TrainPipelineBase(model=model, optimizer=opt, device=device)
else:
Pipeline = _pipeline_cls[pipeline_type]
# pyre-ignore[28]
return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit)


def generate_data(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor

# Import the shared types and utilities from benchmark_utils
from .benchmark_base import (
benchmark,
BenchmarkResult,
CompileMode,
multi_process_benchmark,
)
from .base import benchmark, BenchmarkResult, CompileMode, multi_process_benchmark

logger: logging.Logger = logging.getLogger()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from torchrec.modules.embedding_configs import EmbeddingBagConfig

Expand Down
124 changes: 124 additions & 0 deletions torchrec/distributed/test_utils/train_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from dataclasses import dataclass
from typing import Dict, Type, Union

import torch
from torch import nn

from torchrec.distributed.train_pipeline import (
TrainPipelineBase,
TrainPipelineFusedSparseDist,
TrainPipelineSparseDist,
)
from torchrec.distributed.train_pipeline.train_pipelines import (
PrefetchTrainPipelineSparseDist,
TrainPipelineSemiSync,
)


@dataclass
class PipelineConfig:
"""
Configuration for training pipelines.

This class defines the parameters for configuring the training pipeline.

Args:
pipeline (str): The type of training pipeline to use. Options include:
- "base": Basic training pipeline
- "sparse": Pipeline optimized for sparse operations
- "fused": Pipeline with fused sparse distribution
- "semi": Semi-synchronous training pipeline
- "prefetch": Pipeline with prefetching for sparse distribution
Default is "base".
emb_lookup_stream (str): The stream to use for embedding lookups.
Only used by certain pipeline types (e.g., "fused").
Default is "data_dist".
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
Default is False.
"""

pipeline: str = "base"
emb_lookup_stream: str = "data_dist"
apply_jit: bool = False

def generate_pipeline(
self,
model: nn.Module,
opt: torch.optim.Optimizer,
device: torch.device,
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
"""
Generate a training pipeline instance based on the configuration.

This function creates and returns the appropriate training pipeline object
based on the pipeline type specified. Different pipeline types are optimized
for different training scenarios.

Args:
pipeline_type (str): The type of training pipeline to use. Options include:
- "base": Basic training pipeline
- "sparse": Pipeline optimized for sparse operations
- "fused": Pipeline with fused sparse distribution
- "semi": Semi-synchronous training pipeline
- "prefetch": Pipeline with prefetching for sparse distribution
emb_lookup_stream (str): The stream to use for embedding lookups.
Only used by certain pipeline types (e.g., "fused").
model (nn.Module): The model to be trained.
opt (torch.optim.Optimizer): The optimizer to use for training.
device (torch.device): The device to run the training on.
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
Default is False.

Returns:
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
appropriate training pipeline class based on the configuration.

Raises:
RuntimeError: If an unknown pipeline type is specified.
"""

_pipeline_cls: Dict[
str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]
] = {
"base": TrainPipelineBase,
"sparse": TrainPipelineSparseDist,
"fused": TrainPipelineFusedSparseDist,
"semi": TrainPipelineSemiSync,
"prefetch": PrefetchTrainPipelineSparseDist,
}

if self.pipeline == "semi":
return TrainPipelineSemiSync(
model=model,
optimizer=opt,
device=device,
start_batch=0,
apply_jit=self.apply_jit,
)
elif self.pipeline == "fused":
return TrainPipelineFusedSparseDist(
model=model,
optimizer=opt,
device=device,
emb_lookup_stream=self.emb_lookup_stream,
apply_jit=self.apply_jit,
)
elif self.pipeline == "base":
assert self.apply_jit is False, "JIT is not supported for base pipeline"

return TrainPipelineBase(model=model, optimizer=opt, device=device)
else:
Pipeline = _pipeline_cls[self.pipeline]
# pyre-ignore[28]
return Pipeline(
model=model, optimizer=opt, device=device, apply_jit=self.apply_jit
)
2 changes: 1 addition & 1 deletion torchrec/sparse/tests/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import click

import torch
from torchrec.distributed.benchmark.benchmark_base import (
from torchrec.distributed.benchmark.base import (
benchmark,
BenchmarkResult,
CPUMemoryStats,
Expand Down
2 changes: 1 addition & 1 deletion torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Otherwise will get error
# NotImplementedError: fbgemm::permute_1D_sparse_data: We could not find the abstract impl for this operator.
from fbgemm_gpu import sparse_ops # noqa: F401, E402
from torchrec.distributed.benchmark.benchmark_base import (
from torchrec.distributed.benchmark.base import (
BenchmarkResult,
CPUMemoryStats,
GPUMemoryStats,
Expand Down
Loading