Skip to content

Commit a48cef7

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
refactor pipeline benchmark (#3432)
Summary: Pull Request resolved: #3432 # context * refactor and clean up the benchmark_uitls and benchmark_pipeline_utils files * run-related functions and wrappers are now in the benchmark_base.py, such as `_run_benchmark_core`, `multi_process_benchmark`, and `BenchmarkResult`. * components generation related functions are now in the benchmark_uilts.py, such as `generate_pipeline`, `generate_planner`, and `generate_sharded_model_and_optimizer`. * also remove some redundent functions such as `_init_module_and_run_benchmark` and `benchmark_module`, which are available in embedding_collection_wrappers.py (D79512602) Reviewed By: spmex Differential Revision: D83883554 fbshipit-source-id: b98759582041477d36200a8a540740bdb6015bc7
1 parent 1282754 commit a48cef7

12 files changed

+1186
-1474
lines changed

torchrec/distributed/benchmark/benchmark_base.py

Lines changed: 778 additions & 0 deletions
Large diffs are not rendered by default.

torchrec/distributed/benchmark/benchmark_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020

21-
from torchrec.distributed.benchmark.benchmark_utils import (
21+
from torchrec.distributed.benchmark.benchmark_base import (
2222
BenchmarkResult,
2323
CompileMode,
2424
DLRM_NUM_EMBEDDINGS_PER_FEATURE,

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 0 additions & 355 deletions
This file was deleted.

torchrec/distributed/benchmark/benchmark_set_sharding_context_post_a2a.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import click
1515
import torch
16-
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func
16+
from torchrec.distributed.benchmark.benchmark_base import benchmark_func
1717
from torchrec.distributed.embedding import EmbeddingCollectionContext
1818
from torchrec.distributed.embedding_sharding import _set_sharding_context_post_a2a
1919
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext

torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ComputeDevice,
2727
SplitTableBatchedEmbeddingBagsCodegen,
2828
)
29-
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func
29+
from torchrec.distributed.benchmark.benchmark_base import benchmark_func
3030
from torchrec.distributed.test_utils.test_model import ModelInput
3131
from torchrec.modules.embedding_configs import EmbeddingBagConfig
3232
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import List, Optional, Tuple
1818

1919
import torch
20-
from torchrec.distributed.benchmark.benchmark_utils import (
20+
from torchrec.distributed.benchmark.benchmark_base import (
2121
BenchmarkResult,
2222
CompileMode,
2323
init_argparse_and_args,

0 commit comments

Comments
 (0)