Skip to content

Commit 1282754

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
add embedding config options in benchmark (#3430)
Summary: Pull Request resolved: #3430 # context * add "additional tables" option in test_tables so that the benchmark config can play with different flavors of tables. * unweighted tables would add two more tables (0_0 and 0_1), while the weighted tables remain the same. And in case there's need for a third embedding module (tables), 2_1 and 2_2 are added there. ``` additional_tables: - - name: additional_tables_0_0 embedding_dim: 128 num_embeddings: 100_000 feature_names: ["additional_0_0"] - name: additional_tables_0_1 embedding_dim: 128 num_embeddings: 100_000 feature_names: ["additional_0_1"] - [] - - name: additional_tables_2_1 embedding_dim: 128 num_embeddings: 100_000 feature_names: ["additional_2_1"] ``` Reviewed By: spmex Differential Revision: D83881145 fbshipit-source-id: c40fdad26bfe9345550908cb67c00204da2a8cbf
1 parent 90eb966 commit 1282754

File tree

5 files changed

+116
-86
lines changed

5 files changed

+116
-86
lines changed

torchrec/distributed/benchmark/README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22
## usage
33
- internal:
44
```
5-
hash=$(hg whereami | cut -c 1-10)
65
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_train_pipeline -- \
76
--yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \
8-
--profile_name=sparse_data_dist_base_${hash:-$USER} # overrides the yaml config
7+
--profile_name=sparse_data_dist_base_$(hg whereami | cut -c 1-10 || echo $USER) # overrides the yaml config
98
```
109
- oss:
1110
```
12-
hash=`git rev-parse --short HEAD`
1311
python -m torchrec.distributed.benchmark.benchmark_train_pipeline \
1412
--yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \
15-
--profile_name=sparse_data_dist_base_${hash:-$USER} # overrides the yaml config
13+
--profile_name=sparse_data_dist_base_$(git rev-parse --short HEAD || echo $USER) # overrides the yaml config
1614
```

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
CPUMemoryStats,
4040
generate_planner,
4141
generate_sharded_model_and_optimizer,
42-
generate_tables,
4342
GPUMemoryStats,
4443
)
4544
from torchrec.distributed.comm import get_local_size
@@ -52,6 +51,7 @@
5251
)
5352
from torchrec.distributed.test_utils.test_input import ModelInput
5453
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
54+
from torchrec.distributed.test_utils.test_tables import EmbeddingTablesConfig
5555
from torchrec.distributed.train_pipeline import TrainPipeline
5656
from torchrec.distributed.types import ShardingType
5757
from torchrec.modules.embedding_configs import EmbeddingBagConfig
@@ -116,28 +116,6 @@ class RunOptions:
116116
export_stacks: bool = False
117117

118118

119-
@dataclass
120-
class EmbeddingTablesConfig:
121-
"""
122-
Configuration for embedding tables.
123-
124-
This class defines the parameters for generating embedding tables with both weighted
125-
and unweighted features.
126-
127-
Args:
128-
num_unweighted_features (int): Number of unweighted features to generate.
129-
Default is 100.
130-
num_weighted_features (int): Number of weighted features to generate.
131-
Default is 100.
132-
embedding_feature_dim (int): Dimension of the embedding vectors.
133-
Default is 128.
134-
"""
135-
136-
num_unweighted_features: int = 100
137-
num_weighted_features: int = 100
138-
embedding_feature_dim: int = 128
139-
140-
141119
@dataclass
142120
class PipelineConfig:
143121
"""
@@ -206,11 +184,7 @@ def main(
206184
pipeline_config: PipelineConfig,
207185
model_config: Optional[BaseModelConfig] = None,
208186
) -> None:
209-
tables, weighted_tables = generate_tables(
210-
num_unweighted_features=table_config.num_unweighted_features,
211-
num_weighted_features=table_config.num_weighted_features,
212-
embedding_feature_dim=table_config.embedding_feature_dim,
213-
)
187+
tables, weighted_tables, *_ = table_config.generate_tables()
214188

215189
if model_config is None:
216190
model_config = create_model_config(
@@ -256,11 +230,7 @@ def run_pipeline(
256230
model_config: BaseModelConfig,
257231
) -> BenchmarkResult:
258232

259-
tables, weighted_tables = generate_tables(
260-
num_unweighted_features=table_config.num_unweighted_features,
261-
num_weighted_features=table_config.num_weighted_features,
262-
embedding_feature_dim=table_config.embedding_feature_dim,
263-
)
233+
tables, weighted_tables, *_ = table_config.generate_tables()
264234

265235
benchmark_res_per_rank = run_multi_process_func(
266236
func=runner,

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -381,55 +381,6 @@ def set_embedding_config(
381381
return embedding_configs, pooling_configs
382382

383383

384-
def generate_tables(
385-
num_unweighted_features: int = 100,
386-
num_weighted_features: int = 100,
387-
embedding_feature_dim: int = 128,
388-
) -> Tuple[
389-
List[EmbeddingBagConfig],
390-
List[EmbeddingBagConfig],
391-
]:
392-
"""
393-
Generate embedding bag configurations for both unweighted and weighted features.
394-
395-
This function creates two lists of EmbeddingBagConfig objects:
396-
1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}"
397-
2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}"
398-
399-
For both types, the number of embeddings scales with the feature index,
400-
calculated as max(i + 1, 100) * 1000.
401-
402-
Args:
403-
num_unweighted_features (int): Number of unweighted features to generate.
404-
num_weighted_features (int): Number of weighted features to generate.
405-
embedding_feature_dim (int): Dimension of the embedding vectors.
406-
407-
Returns:
408-
Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing
409-
two lists - the first for unweighted embedding tables and the second for
410-
weighted embedding tables.
411-
"""
412-
tables = [
413-
EmbeddingBagConfig(
414-
num_embeddings=max(i + 1, 100) * 1000,
415-
embedding_dim=embedding_feature_dim,
416-
name="table_" + str(i),
417-
feature_names=["feature_" + str(i)],
418-
)
419-
for i in range(num_unweighted_features)
420-
]
421-
weighted_tables = [
422-
EmbeddingBagConfig(
423-
num_embeddings=max(i + 1, 100) * 1000,
424-
embedding_dim=embedding_feature_dim,
425-
name="weighted_table_" + str(i),
426-
feature_names=["weighted_feature_" + str(i)],
427-
)
428-
for i in range(num_weighted_features)
429-
]
430-
return tables, weighted_tables
431-
432-
433384
def generate_planner(
434385
planner_type: str,
435386
topology: Topology,

torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,17 @@ EmbeddingTablesConfig:
1313
num_unweighted_features: 100
1414
num_weighted_features: 100
1515
embedding_feature_dim: 128
16+
additional_tables:
17+
- - name: additional_tables_0_0
18+
embedding_dim: 128
19+
num_embeddings: 100_000
20+
feature_names: ["additional_0_0"]
21+
- name: additional_tables_0_1
22+
embedding_dim: 128
23+
num_embeddings: 100_000
24+
feature_names: ["additional_0_1"]
25+
- []
26+
- - name: additional_tables_2_1
27+
embedding_dim: 128
28+
num_embeddings: 100_000
29+
feature_names: ["additional_2_1"]
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from dataclasses import dataclass, field
11+
from typing import Any, Dict, List, Optional
12+
13+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
14+
15+
16+
@dataclass
17+
class EmbeddingTablesConfig:
18+
"""
19+
Configuration for generating embedding tables for test and benchmark
20+
21+
This class defines the parameters for generating embedding tables with both weighted
22+
and unweighted features.
23+
24+
Args:
25+
num_unweighted_features (int): Number of unweighted features to generate.
26+
Default is 100.
27+
num_weighted_features (int): Number of weighted features to generate.
28+
Default is 100.
29+
embedding_feature_dim (int): Dimension of the embedding vectors.
30+
Default is 128.
31+
additional_tables (List[List[Dict[str, Any]]]): Additional tables to include in the configuration.
32+
Default is an empty list.
33+
"""
34+
35+
num_unweighted_features: int = 100
36+
num_weighted_features: int = 100
37+
embedding_feature_dim: int = 128
38+
additional_tables: List[List[Dict[str, Any]]] = field(default_factory=list)
39+
40+
def generate_tables(
41+
self,
42+
) -> List[List[EmbeddingBagConfig]]:
43+
"""
44+
Generate embedding bag configurations for both unweighted and weighted features.
45+
46+
This function creates two lists of EmbeddingBagConfig objects:
47+
1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}"
48+
2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}"
49+
50+
For both types, the number of embeddings scales with the feature index,
51+
calculated as max(i + 1, 100) * 1000.
52+
53+
Args:
54+
num_unweighted_features (int): Number of unweighted features to generate.
55+
num_weighted_features (int): Number of weighted features to generate.
56+
embedding_feature_dim (int): Dimension of the embedding vectors.
57+
58+
Returns:
59+
Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing
60+
two lists - the first for unweighted embedding tables and the second for
61+
weighted embedding tables.
62+
"""
63+
unweighted_tables = [
64+
EmbeddingBagConfig(
65+
num_embeddings=max(i + 1, 100) * 1000,
66+
embedding_dim=self.embedding_feature_dim,
67+
name="table_" + str(i),
68+
feature_names=["feature_" + str(i)],
69+
)
70+
for i in range(self.num_unweighted_features)
71+
]
72+
weighted_tables = [
73+
EmbeddingBagConfig(
74+
num_embeddings=max(i + 1, 100) * 1000,
75+
embedding_dim=self.embedding_feature_dim,
76+
name="weighted_table_" + str(i),
77+
feature_names=["weighted_feature_" + str(i)],
78+
)
79+
for i in range(self.num_weighted_features)
80+
]
81+
tables_list = []
82+
for idx, adts in enumerate(self.additional_tables):
83+
if idx == 0:
84+
tables = unweighted_tables
85+
elif idx == 1:
86+
tables = weighted_tables
87+
else:
88+
tables = []
89+
for adt in adts:
90+
tables.append(EmbeddingBagConfig(**adt))
91+
92+
if len(tables_list) == 0:
93+
tables_list.append(unweighted_tables)
94+
tables_list.append(weighted_tables)
95+
elif len(tables_list) == 1:
96+
tables_list.append(weighted_tables)
97+
return tables_list

0 commit comments

Comments
 (0)