Skip to content

Commit 3ba83dc

Browse files
lizhouyufacebook-github-bot
authored andcommitted
Add tensorboard to display training and evaluation metrics and revise implementation to support DLRMv2 (#3163)
Summary: Pull Request resolved: #3163 ### Major changes - Add tensorboard to the benchmark testbed, specifically in `benchmark_zch.py`. - Count the number of unique values received by each rank in each epoch by revising `benchmark_zch_utils.py`. - Revise `data/non_zch_remapper.py` to not depend on `batch.to_dict()` method, instead it fetch dataclass `batch`'s attribute with the built-in `vars()` method. - Revise DLRMv2 model EBC config initialization to make the table name identical with the feature name. - Revise DLRMv2 configuration yaml file to set table size for each feature. - Revise the default value for "num_embeddings" parameter in `arguments.py` to None. Differential Revision: D77841795
1 parent 42e9eff commit 3ba83dc

19 files changed

+486
-122
lines changed

torchrec/distributed/benchmark/benchmark_zch/arguments.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
19
import argparse
210
from typing import List
311

@@ -25,7 +33,7 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
2533
parser.add_argument(
2634
"--num_embeddings", # ratio of feature ids to embedding table size # 3 axis: x-bath_idx; y-collisions; zembedding table sizes
2735
type=int,
28-
default=100_000,
36+
default=None,
2937
help="max_ind_size. The number of embeddings in each embedding table. Defaults"
3038
" to 100_000 if num_embeddings_per_feature is not supplied.",
3139
)

torchrec/distributed/benchmark/benchmark_zch/benchmark_zch.py

Lines changed: 263 additions & 54 deletions
Large diffs are not rendered by default.

torchrec/distributed/benchmark/benchmark_zch/benchmark_zch_utils.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,21 @@
1-
import argparse
2-
import copy
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
39
import json
410
import logging
511
import os
6-
from typing import Any, Dict
12+
from typing import Any, Dict, Set
713

814
import numpy as np
915

1016
import torch
1117
import torch.nn as nn
12-
import yaml
1318
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
14-
from torchrec.modules.mc_modules import (
15-
DistanceLFU_EvictionPolicy,
16-
ManagedCollisionCollection,
17-
MCHManagedCollisionModule,
18-
)
19-
20-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
2119

2220

2321
def get_module_from_instance(
@@ -104,6 +102,7 @@ def __init__(
104102
self._mch_stats: Dict[str, Any] = (
105103
{}
106104
) # dictionary of {table_name [str]: {metric_name [str]: metric_value [int]}}
105+
self.feature_name_unique_queried_values_set_dict: Dict[str, Set[int]] = {}
107106

108107
# record mcec state to file
109108
def record_mcec_state(self, stage: str) -> None:
@@ -260,6 +259,7 @@ def update(self) -> None:
260259
"collision_cnt": 0,
261260
"rank_total_cnt": 0,
262261
"num_empty_slots": 0,
262+
"num_unique_queries": 0,
263263
}
264264
# get the input faeture values
265265
input_feature_values = np.array(rank_feature_value_before_fwd[feature_name])
@@ -313,4 +313,16 @@ def update(self) -> None:
313313
this_rank_total_count - this_rank_hits_count - this_rank_insert_count
314314
)
315315
batch_stats[feature_name]["collision_cnt"] += int(this_rank_collision_count)
316+
# get the unique values in the input feature values
317+
if feature_name not in self.feature_name_unique_queried_values_set_dict:
318+
self.feature_name_unique_queried_values_set_dict[feature_name] = set(
319+
input_feature_values.tolist()
320+
)
321+
else:
322+
self.feature_name_unique_queried_values_set_dict[feature_name].update(
323+
set(input_feature_values.tolist())
324+
)
325+
batch_stats[feature_name]["num_unique_queries"] = len(
326+
self.feature_name_unique_queried_values_set_dict[feature_name]
327+
)
316328
self._mch_stats = batch_stats

torchrec/distributed/benchmark/benchmark_zch/count_dataset_distributions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
19
import argparse
210
import json
311
import multiprocessing
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed"
1+
dataset_path: "/home/lizhouyu/datasets/criteo_kaggle_processed"
22
batch_size: 4096
33
seed: 0
4+
multitask_configs:
5+
- task_name: is_click
6+
task_weight: 1
7+
task_type: classification

torchrec/distributed/benchmark/benchmark_zch/data/configs/kuairand_1k.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
dataset_path: "/home/lizhouyu/oss_github/generative-recommenders/generative_recommenders/dlrm_v3/data/KuaiRand-1K/data"
1+
dataset_path: "/home/lizhouyu/datasets/kuairand-1k/data"
22
batch_size: 16
33
train_split_percentage: 0.75
44
num_workers: 4

torchrec/distributed/benchmark/benchmark_zch/data/get_dataloader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
19
import argparse
210
import os
311

torchrec/distributed/benchmark/benchmark_zch/data/get_metric_modules.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
19
import argparse
210
import os
311
import sys

torchrec/distributed/benchmark/benchmark_zch/data/nonzch_remapper.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
19
from dataclasses import dataclass
210
from typing import Dict, List, Optional, Tuple, Union
311

@@ -99,6 +107,24 @@ def __init__(
99107
)
100108
self._input_hash_size = input_hash_size
101109

110+
def get_batch_kjt_dict(self, batch: Batch) -> Dict[str, KeyedJaggedTensor]:
111+
"""
112+
Get the KJT in each batch
113+
Parameters:
114+
batch: the batch whose KJT is ought to be fetched
115+
Returns:
116+
batch_kjt_dict: a dictionary of [batch_attribute_name: KeyedJaggedTensor]
117+
where only attributes whose values are KeyedJaggedTensor are fetched.
118+
"""
119+
batch_kjt_dict = {} # create a dictionary for return
120+
batch_attr_dict = vars(batch) # get batch's attributes and values
121+
for batch_attr_name, batch_attr_value in batch_attr_dict.items():
122+
if isinstance(
123+
batch_attr_value, KeyedJaggedTensor
124+
): # only fetch attributes whose values are KeyedJaggedTensor
125+
batch_kjt_dict[batch_attr_name] = batch_attr_value
126+
return batch_kjt_dict
127+
102128
def remap(self, batch: Batch) -> Batch:
103129
# for all the attributes under batch, like batch.uih_features, batch.candidates_features,
104130
# get the kjt as a dict, and remap the kjt
@@ -118,7 +144,7 @@ def remap(self, batch: Batch) -> Batch:
118144
# candidates_features: KeyedJaggedTensor
119145

120146
# for every attribute in batch, remap the kjt
121-
for attr_name, feature_kjt_dict in batch.get_dict().items():
147+
for attr_name, feature_kjt_dict in self.get_batch_kjt_dict(batch).items():
122148
# separate feature kjt with {feature_name_1: feature_kjt_1, feature_name_2: feature_kjt_2, ...}
123149
# to multiple dict with {feature_name_1: jt_1}, {feature_name_2: jt_2}, ...
124150
attr_feature_jt_dict = {}

torchrec/distributed/benchmark/benchmark_zch/data/preprocess/kuairand_1k.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
19
import argparse
210

311
import json

0 commit comments

Comments
 (0)