Skip to content

Commit 6ece4ea

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
refactor train pipeline config
Summary: p# context * continue refactoring the pipeline benchmark * moving the pipeline generation config and function to a new file "test_utils/test_pipeline.py", so they can be used by other test cases * the design logic is to use config to generate each component (tables, pipeline, input, etc.) under the test_utils. * rename the benchmark_base.py to base.py per suggestion from previous diff Differential Revision: D83890174
1 parent a48cef7 commit 6ece4ea

11 files changed

+136
-130
lines changed
File renamed without changes.

torchrec/distributed/benchmark/benchmark_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020

21-
from torchrec.distributed.benchmark.benchmark_base import (
21+
from torchrec.distributed.benchmark.base import (
2222
BenchmarkResult,
2323
CompileMode,
2424
DLRM_NUM_EMBEDDINGS_PER_FEATURE,

torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import click
1515
import torch
16-
from torchrec.distributed.benchmark.benchmark_base import benchmark_func
16+
from torchrec.distributed.benchmark.base import benchmark_func
1717
from torchrec.distributed.embedding import EmbeddingCollectionContext
1818
from torchrec.distributed.embedding_sharding import _set_sharding_context_post_a2a
1919
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext

torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ComputeDevice,
2727
SplitTableBatchedEmbeddingBagsCodegen,
2828
)
29-
from torchrec.distributed.benchmark.benchmark_base import benchmark_func
29+
from torchrec.distributed.benchmark.base import benchmark_func
3030
from torchrec.distributed.test_utils.test_model import ModelInput
3131
from torchrec.modules.embedding_configs import EmbeddingBagConfig
3232
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import List, Optional, Tuple
1818

1919
import torch
20-
from torchrec.distributed.benchmark.benchmark_base import (
20+
from torchrec.distributed.benchmark.base import (
2121
BenchmarkResult,
2222
CompileMode,
2323
init_argparse_and_args,

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
OSS (external):
1717
python -m torchrec.distributed.benchmark.benchmark_train_pipeline --world_size=4 --pipeline=sparse --batch_size=10
1818
19-
Adding New Model Support:
19+
To support a new model in pipeline benchmark:
2020
See benchmark_pipeline_utils.py for step-by-step instructions.
2121
"""
2222

@@ -26,7 +26,7 @@
2626
import torch
2727
from fbgemm_gpu.split_embedding_configs import EmbOptimType
2828
from torch import nn
29-
from torchrec.distributed.benchmark.benchmark_base import (
29+
from torchrec.distributed.benchmark.base import (
3030
benchmark_func,
3131
BenchmarkResult,
3232
cmd_conf,
@@ -37,7 +37,6 @@
3737
BaseModelConfig,
3838
create_model_config,
3939
generate_data,
40-
generate_pipeline,
4140
generate_planner,
4241
generate_sharded_model_and_optimizer,
4342
)
@@ -51,6 +50,7 @@
5150
)
5251
from torchrec.distributed.test_utils.test_input import ModelInput
5352
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
53+
from torchrec.distributed.test_utils.test_pipeline import PipelineConfig
5454
from torchrec.distributed.test_utils.test_tables import EmbeddingTablesConfig
5555
from torchrec.distributed.train_pipeline import TrainPipeline
5656
from torchrec.distributed.types import ShardingType
@@ -116,33 +116,6 @@ class RunOptions:
116116
export_stacks: bool = False
117117

118118

119-
@dataclass
120-
class PipelineConfig:
121-
"""
122-
Configuration for training pipelines.
123-
124-
This class defines the parameters for configuring the training pipeline.
125-
126-
Args:
127-
pipeline (str): The type of training pipeline to use. Options include:
128-
- "base": Basic training pipeline
129-
- "sparse": Pipeline optimized for sparse operations
130-
- "fused": Pipeline with fused sparse distribution
131-
- "semi": Semi-synchronous training pipeline
132-
- "prefetch": Pipeline with prefetching for sparse distribution
133-
Default is "base".
134-
emb_lookup_stream (str): The stream to use for embedding lookups.
135-
Only used by certain pipeline types (e.g., "fused").
136-
Default is "data_dist".
137-
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
138-
Default is False.
139-
"""
140-
141-
pipeline: str = "base"
142-
emb_lookup_stream: str = "data_dist"
143-
apply_jit: bool = False
144-
145-
146119
@dataclass
147120
class ModelSelectionConfig:
148121
model_name: str = "test_sparse_nn"
@@ -279,13 +252,10 @@ def _func_to_benchmark(
279252
except StopIteration:
280253
break
281254

282-
pipeline = generate_pipeline(
283-
pipeline_type=pipeline_config.pipeline,
284-
emb_lookup_stream=pipeline_config.emb_lookup_stream,
255+
pipeline = pipeline_config.generate_pipeline(
285256
model=sharded_model,
286257
opt=optimizer,
287258
device=ctx.device,
288-
apply_jit=pipeline_config.apply_jit,
289259
)
290260
pipeline.progress(iter(bench_inputs))
291261

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 1 addition & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111
Utilities for benchmarking training pipelines with different model configurations.
1212
13-
Adding New Model Support:
13+
To support a new model in pipeline benchmark:
1414
1. Create config class inheriting from BaseModelConfig with generate_model() method
1515
2. Add the model to model_configs dict in create_model_config()
1616
3. Add model-specific params to ModelSelectionConfig and create_model_config's arguments in benchmark_train_pipeline.py
@@ -39,15 +39,6 @@
3939
TestTowerCollectionSparseNN,
4040
TestTowerSparseNN,
4141
)
42-
from torchrec.distributed.train_pipeline import (
43-
TrainPipelineBase,
44-
TrainPipelineFusedSparseDist,
45-
TrainPipelineSparseDist,
46-
)
47-
from torchrec.distributed.train_pipeline.train_pipelines import (
48-
PrefetchTrainPipelineSparseDist,
49-
TrainPipelineSemiSync,
50-
)
5142
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
5243
from torchrec.models.deepfm import SimpleDeepFMNNWrapper
5344
from torchrec.models.dlrm import DLRMWrapper
@@ -249,80 +240,6 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:
249240
return model_class(**filtered_kwargs)
250241

251242

252-
def generate_pipeline(
253-
pipeline_type: str,
254-
emb_lookup_stream: str,
255-
model: nn.Module,
256-
opt: torch.optim.Optimizer,
257-
device: torch.device,
258-
apply_jit: bool = False,
259-
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
260-
"""
261-
Generate a training pipeline instance based on the configuration.
262-
263-
This function creates and returns the appropriate training pipeline object
264-
based on the pipeline type specified. Different pipeline types are optimized
265-
for different training scenarios.
266-
267-
Args:
268-
pipeline_type (str): The type of training pipeline to use. Options include:
269-
- "base": Basic training pipeline
270-
- "sparse": Pipeline optimized for sparse operations
271-
- "fused": Pipeline with fused sparse distribution
272-
- "semi": Semi-synchronous training pipeline
273-
- "prefetch": Pipeline with prefetching for sparse distribution
274-
emb_lookup_stream (str): The stream to use for embedding lookups.
275-
Only used by certain pipeline types (e.g., "fused").
276-
model (nn.Module): The model to be trained.
277-
opt (torch.optim.Optimizer): The optimizer to use for training.
278-
device (torch.device): The device to run the training on.
279-
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
280-
Default is False.
281-
282-
Returns:
283-
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
284-
appropriate training pipeline class based on the configuration.
285-
286-
Raises:
287-
RuntimeError: If an unknown pipeline type is specified.
288-
"""
289-
290-
_pipeline_cls: Dict[
291-
str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]
292-
] = {
293-
"base": TrainPipelineBase,
294-
"sparse": TrainPipelineSparseDist,
295-
"fused": TrainPipelineFusedSparseDist,
296-
"semi": TrainPipelineSemiSync,
297-
"prefetch": PrefetchTrainPipelineSparseDist,
298-
}
299-
300-
if pipeline_type == "semi":
301-
return TrainPipelineSemiSync(
302-
model=model,
303-
optimizer=opt,
304-
device=device,
305-
start_batch=0,
306-
apply_jit=apply_jit,
307-
)
308-
elif pipeline_type == "fused":
309-
return TrainPipelineFusedSparseDist(
310-
model=model,
311-
optimizer=opt,
312-
device=device,
313-
emb_lookup_stream=emb_lookup_stream,
314-
apply_jit=apply_jit,
315-
)
316-
elif pipeline_type == "base":
317-
assert apply_jit is False, "JIT is not supported for base pipeline"
318-
319-
return TrainPipelineBase(model=model, optimizer=opt, device=device)
320-
else:
321-
Pipeline = _pipeline_cls[pipeline_type]
322-
# pyre-ignore[28]
323-
return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit)
324-
325-
326243
def generate_data(
327244
tables: List[EmbeddingBagConfig],
328245
weighted_tables: List[EmbeddingBagConfig],

torchrec/distributed/benchmark/embedding_collection_wrappers.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,7 @@
5757
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
5858

5959
# Import the shared types and utilities from benchmark_utils
60-
from .benchmark_base import (
61-
benchmark,
62-
BenchmarkResult,
63-
CompileMode,
64-
multi_process_benchmark,
65-
)
60+
from .base import benchmark, BenchmarkResult, CompileMode, multi_process_benchmark
6661

6762
logger: logging.Logger = logging.getLogger()
6863

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from dataclasses import dataclass
11+
from typing import Dict, Type, Union
12+
13+
import torch
14+
from torch import nn
15+
16+
from torchrec.distributed.train_pipeline import (
17+
TrainPipelineBase,
18+
TrainPipelineFusedSparseDist,
19+
TrainPipelineSparseDist,
20+
)
21+
from torchrec.distributed.train_pipeline.train_pipelines import (
22+
PrefetchTrainPipelineSparseDist,
23+
TrainPipelineSemiSync,
24+
)
25+
26+
27+
@dataclass
28+
class PipelineConfig:
29+
"""
30+
Configuration for training pipelines.
31+
32+
This class defines the parameters for configuring the training pipeline.
33+
34+
Args:
35+
pipeline (str): The type of training pipeline to use. Options include:
36+
- "base": Basic training pipeline
37+
- "sparse": Pipeline optimized for sparse operations
38+
- "fused": Pipeline with fused sparse distribution
39+
- "semi": Semi-synchronous training pipeline
40+
- "prefetch": Pipeline with prefetching for sparse distribution
41+
Default is "base".
42+
emb_lookup_stream (str): The stream to use for embedding lookups.
43+
Only used by certain pipeline types (e.g., "fused").
44+
Default is "data_dist".
45+
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
46+
Default is False.
47+
"""
48+
49+
pipeline: str = "base"
50+
emb_lookup_stream: str = "data_dist"
51+
apply_jit: bool = False
52+
53+
def generate_pipeline(
54+
self,
55+
model: nn.Module,
56+
opt: torch.optim.Optimizer,
57+
device: torch.device,
58+
) -> Union[TrainPipelineBase, TrainPipelineSparseDist]:
59+
"""
60+
Generate a training pipeline instance based on the configuration.
61+
62+
This function creates and returns the appropriate training pipeline object
63+
based on the pipeline type specified. Different pipeline types are optimized
64+
for different training scenarios.
65+
66+
Args:
67+
pipeline_type (str): The type of training pipeline to use. Options include:
68+
- "base": Basic training pipeline
69+
- "sparse": Pipeline optimized for sparse operations
70+
- "fused": Pipeline with fused sparse distribution
71+
- "semi": Semi-synchronous training pipeline
72+
- "prefetch": Pipeline with prefetching for sparse distribution
73+
emb_lookup_stream (str): The stream to use for embedding lookups.
74+
Only used by certain pipeline types (e.g., "fused").
75+
model (nn.Module): The model to be trained.
76+
opt (torch.optim.Optimizer): The optimizer to use for training.
77+
device (torch.device): The device to run the training on.
78+
apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model.
79+
Default is False.
80+
81+
Returns:
82+
Union[TrainPipelineBase, TrainPipelineSparseDist]: An instance of the
83+
appropriate training pipeline class based on the configuration.
84+
85+
Raises:
86+
RuntimeError: If an unknown pipeline type is specified.
87+
"""
88+
89+
_pipeline_cls: Dict[
90+
str, Type[Union[TrainPipelineBase, TrainPipelineSparseDist]]
91+
] = {
92+
"base": TrainPipelineBase,
93+
"sparse": TrainPipelineSparseDist,
94+
"fused": TrainPipelineFusedSparseDist,
95+
"semi": TrainPipelineSemiSync,
96+
"prefetch": PrefetchTrainPipelineSparseDist,
97+
}
98+
99+
if self.pipeline == "semi":
100+
return TrainPipelineSemiSync(
101+
model=model,
102+
optimizer=opt,
103+
device=device,
104+
start_batch=0,
105+
apply_jit=self.apply_jit,
106+
)
107+
elif self.pipeline == "fused":
108+
return TrainPipelineFusedSparseDist(
109+
model=model,
110+
optimizer=opt,
111+
device=device,
112+
emb_lookup_stream=self.emb_lookup_stream,
113+
apply_jit=self.apply_jit,
114+
)
115+
elif self.pipeline == "base":
116+
assert self.apply_jit is False, "JIT is not supported for base pipeline"
117+
118+
return TrainPipelineBase(model=model, optimizer=opt, device=device)
119+
else:
120+
Pipeline = _pipeline_cls[self.pipeline]
121+
# pyre-ignore[28]
122+
return Pipeline(
123+
model=model, optimizer=opt, device=device, apply_jit=self.apply_jit
124+
)

torchrec/sparse/tests/jagged_tensor_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import click
1616

1717
import torch
18-
from torchrec.distributed.benchmark.benchmark_base import (
18+
from torchrec.distributed.benchmark.base import (
1919
benchmark,
2020
BenchmarkResult,
2121
CPUMemoryStats,

0 commit comments

Comments
 (0)