Skip to content

Commit 3faf5e5

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
rename function and labels (#3434)
Summary: Pull Request resolved: #3434 # context * rename the over-generic function name `benchmark` to `benchmark_model_with_warmup` * make the argument name consistency in the base.py and benchmark_train_pipeline.py, i.e., `profile_name` ==> `name`, `profile` ==> `profile_dir` * modify the record_function labels in comm_ops.py to reduce duplication and confusion Reviewed By: spmex Differential Revision: D83923279 fbshipit-source-id: 74a5b9de66e02338859c2716cb49b30fc6875ee7
1 parent d92a0c3 commit 3faf5e5

File tree

8 files changed

+67
-44
lines changed

8 files changed

+67
-44
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# TorchRec Benchmark
2-
## usage
2+
## benchmark_train_pipeline usage
33
- internal:
44
```
55
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_train_pipeline -- \
66
--yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \
7-
--profile_name=sparse_data_dist_base_$(hg whereami | cut -c 1-10 || echo $USER) # overrides the yaml config
7+
--name=sparse_data_dist_base_$(hg whereami | cut -c 1-10 || echo $USER) # overrides the yaml config
88
```
99
- oss:
1010
```
1111
python -m torchrec.distributed.benchmark.benchmark_train_pipeline \
1212
--yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \
13-
--profile_name=sparse_data_dist_base_$(git rev-parse --short HEAD || echo $USER) # overrides the yaml config
13+
--name=sparse_data_dist_base_$(git rev-parse --short HEAD || echo $USER) # overrides the yaml config
1414
```

torchrec/distributed/benchmark/base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
681681
)
682682

683683

684-
def benchmark(
684+
def benchmark_model_with_warmup(
685685
name: str,
686686
model: torch.nn.Module,
687687
warmup_inputs: Union[List[KeyedJaggedTensor], List[Dict[str, Any]]],
@@ -750,6 +750,26 @@ def benchmark_func(
750750
pre_gpu_load: int = 0,
751751
export_stacks: bool = False,
752752
) -> BenchmarkResult:
753+
"""
754+
Args:
755+
name: Human-readable benchmark name.
756+
757+
bench_inputs: List[Dict[str, Any]] will be fed to the function at once
758+
prof_inputs: List[Dict[str, Any]] will be fed to the function at once
759+
benchmark_func_kwargs: kwargs to be passed to func_to_benchmark
760+
func_to_benchmark: Callable that executes one measured iteration.
761+
func_to_benchmark(batch_inputs, **kwargs)
762+
763+
world_size, rank: Distributed context to correctly reset / collect GPU
764+
stats. ``rank == -1`` means single-process mode.
765+
num_benchmarks: Number of measured iterations.
766+
device_type: "cuda" or "cpu".
767+
profile_dir: Where to write chrome traces / stack files.
768+
769+
pre_gpu_load: Number of dummy matmul operations to run before the first
770+
measured iteration (helps simulating a loaded allocator).
771+
export_stacks: Whether to export flamegraph-compatible stack files.
772+
"""
753773
if benchmark_func_kwargs is None:
754774
benchmark_func_kwargs = {}
755775

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class RunOptions:
7878
Default is "kjt" (KeyedJaggedTensor).
7979
profile (str): Directory to save profiling results. If empty, profiling is disabled.
8080
Default is "" (disabled).
81-
profile_name (str): Name of the profiling file. Default is pipeline classname.
81+
name (str): Name of the profiling file. Default is pipeline classname.
8282
planner_type (str): Type of sharding planner to use. Options are:
8383
- "embedding": EmbeddingShardingPlanner (default)
8484
- "hetero": HeteroEmbeddingShardingPlanner
@@ -100,8 +100,8 @@ class RunOptions:
100100
sharding_type: ShardingType = ShardingType.TABLE_WISE
101101
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED
102102
input_type: str = "kjt"
103-
profile: str = ""
104-
profile_name: str = ""
103+
name: str = ""
104+
profile_dir: str = ""
105105
planner_type: str = "embedding"
106106
pooling_factors: Optional[List[float]] = None
107107
num_poolings: Optional[List[float]] = None
@@ -261,15 +261,13 @@ def _func_to_benchmark(
261261

262262
result = benchmark_func(
263263
name=(
264-
type(pipeline).__name__
265-
if run_option.profile_name == ""
266-
else run_option.profile_name
264+
type(pipeline).__name__ if run_option.name == "" else run_option.name
267265
),
268266
bench_inputs=bench_inputs, # pyre-ignore
269267
prof_inputs=bench_inputs, # pyre-ignore
270268
num_benchmarks=5,
271269
num_profiles=2,
272-
profile_dir=run_option.profile,
270+
profile_dir=run_option.profile_dir,
273271
world_size=run_option.world_size,
274272
func_to_benchmark=_func_to_benchmark,
275273
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},

torchrec/distributed/benchmark/embedding_collection_wrappers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@
5757
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
5858

5959
# Import the shared types and utilities from benchmark_utils
60-
from .base import benchmark, BenchmarkResult, CompileMode, multi_process_benchmark
60+
from .base import (
61+
benchmark_model_with_warmup,
62+
BenchmarkResult,
63+
CompileMode,
64+
multi_process_benchmark,
65+
)
6166

6267
logger: logging.Logger = logging.getLogger()
6368

@@ -456,7 +461,7 @@ def _init_module_and_run_benchmark(
456461
else:
457462
name = _benchmark_type_name(compile_mode, sharding_type)
458463

459-
res = benchmark(
464+
res = benchmark_model_with_warmup(
460465
name,
461466
module,
462467
warmup_inputs_cuda,

torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ RunOptions:
44
world_size: 2
55
num_batches: 10
66
sharding_type: table_wise
7-
profile: "."
8-
profile_name: "sparse_data_dist_base"
7+
profile_dir: "."
8+
name: "sparse_data_dist_base"
99
# export_stacks: True # enable this to export stack traces
1010
PipelineConfig:
1111
pipeline: "sparse"

torchrec/distributed/comm_ops.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def all2all_pooled_sync(
440440
input_split_sizes = [D_local_sum * B_rank for B_rank in batch_size_per_rank]
441441
qcomm_ctx = None
442442

443-
with record_function("## alltoall_fwd_single ##"):
443+
with record_function("## all2all_pooled ##"):
444444
sharded_output_embeddings = AllToAllSingle.apply(
445445
sharded_input_embeddings,
446446
output_split_sizes,
@@ -558,7 +558,7 @@ def variable_batch_all2all_pooled_sync(
558558
for split in input_split_sizes
559559
]
560560

561-
with record_function("## alltoall_fwd_single ##"):
561+
with record_function("## variable_batch_all2all_pooled ##"):
562562
if pg._get_backend_name() == "custom":
563563
sharded_output_embeddings = torch.empty(
564564
sum(output_split_sizes),
@@ -674,7 +674,7 @@ def all2all_sequence_sync(
674674

675675
local_T = lengths_after_sparse_data_all2all.shape[0]
676676
if local_T > 0:
677-
with record_function("## alltoall_seq_embedding_fwd_permute ##"):
677+
with record_function("## all2all_sequence_permute ##"):
678678
if not variable_batch_size:
679679
(
680680
permuted_lengths_after_sparse_data_all2all,
@@ -719,7 +719,7 @@ def all2all_sequence_sync(
719719
else:
720720
qcomm_ctx = None
721721

722-
with record_function("## alltoall_seq_embedding_fwd_single ##"):
722+
with record_function("## all2all_sequence ##"):
723723
sharded_output_embeddings = AllToAllSingle.apply(
724724
sharded_input_embeddings,
725725
output_splits,
@@ -989,7 +989,7 @@ def reduce_scatter_v_sync(
989989
input = rsi.codecs.forward.encode(input)
990990

991991
if rsi.equal_splits:
992-
with record_function("## reduce_scatter_base ##"):
992+
with record_function("## reduce_scatter_v ##"):
993993
output = torch.ops.torchrec.reduce_scatter_tensor(
994994
input,
995995
reduceOp="sum",
@@ -998,7 +998,7 @@ def reduce_scatter_v_sync(
998998
gradient_division=get_gradient_division(),
999999
)
10001000
else:
1001-
with record_function("## reduce_scatter_v_via_all_to_all_single ##"):
1001+
with record_function("## reduce_scatter_v (AllToAllSingle) ##"):
10021002
input_splits = rsi.input_splits
10031003
output_splits = [rsi.input_splits[rank]] * world_size
10041004
# TODO(ivankobzarev): Replace with _functional_collectives.reduce_scatter_v when it is added
@@ -1197,7 +1197,7 @@ def forward(
11971197
device=sharded_input_embeddings.device,
11981198
)
11991199

1200-
with record_function("## alltoall_fwd_single ##"):
1200+
with record_function("## All2All_Pooled_fwd ##"):
12011201
req = dist.all_to_all_single(
12021202
output=sharded_output_embeddings,
12031203
input=sharded_input_embeddings,
@@ -1218,7 +1218,6 @@ def forward(
12181218

12191219
@staticmethod
12201220
# pyre-fixme[2]: Parameter must be annotated.
1221-
# pyre-fixme[2]: Parameter must be annotated.
12221221
def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
12231222
pg = ctx.pg
12241223
my_rank = dist.get_rank(pg)
@@ -1360,7 +1359,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
13601359
device=sharded_grad_output.device,
13611360
dtype=sharded_grad_output.dtype,
13621361
)
1363-
with record_function("## alltoall_bwd_single ##"):
1362+
with record_function("## All2All_Pooled_bwd ##"):
13641363
req = dist.all_to_all_single(
13651364
output=sharded_grad_input,
13661365
input=sharded_grad_output,
@@ -1445,7 +1444,7 @@ def forward(
14451444
device=sharded_input_embeddings.device,
14461445
)
14471446

1448-
with record_function("## alltoall_fwd_single ##"):
1447+
with record_function("## Variable_Batch_All2All_Pooled_fwd ##"):
14491448
req = dist.all_to_all_single(
14501449
output=sharded_output_embeddings,
14511450
input=sharded_input_embeddings,
@@ -1564,7 +1563,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
15641563
device=sharded_grad_output.device,
15651564
dtype=sharded_grad_output.dtype,
15661565
)
1567-
with record_function("## alltoall_bwd_single ##"):
1566+
with record_function("## Variable_Batch_All2All_Pooled_bwd ##"):
15681567
req = dist.all_to_all_single(
15691568
output=sharded_grad_input,
15701569
input=sharded_grad_output,
@@ -1605,7 +1604,7 @@ def forward(
16051604

16061605
local_T = lengths_after_sparse_data_all2all.shape[0]
16071606
if local_T > 0:
1608-
with record_function("## alltoall_seq_embedding_fwd_permute ##"):
1607+
with record_function("## All2All_Seq_fwd_permute ##"):
16091608
if not variable_batch_size:
16101609
(
16111610
permuted_lengths_after_sparse_data_all2all,
@@ -1659,7 +1658,7 @@ def forward(
16591658
device=sharded_input_embeddings.device,
16601659
)
16611660

1662-
with record_function("## alltoall_seq_embedding_fwd_single ##"):
1661+
with record_function("## All2All_Seq_fwd ##"):
16631662
req = dist.all_to_all_single(
16641663
output=sharded_output_embeddings,
16651664
input=sharded_input_embeddings,
@@ -1707,7 +1706,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
17071706
myreq.dummy_tensor = None
17081707

17091708
if permuted_lengths_after_sparse_data_all2all is not None:
1710-
with record_function("## alltoall_seq_embedding_bwd_permute ##"):
1709+
with record_function("## All2All_Seq_bwd_permute ##"):
17111710
if not variable_batch_size:
17121711
_, sharded_grad_input, _ = torch.ops.fbgemm.permute_2D_sparse_data(
17131712
backward_recat_tensor,
@@ -1788,7 +1787,7 @@ def backward(ctx, sharded_grad_output: Tensor) -> Tuple[None, None, Tensor]:
17881787
device=sharded_grad_output.device,
17891788
dtype=sharded_grad_output.dtype,
17901789
)
1791-
with record_function("## alltoall_seq_embedding_bwd_single ##"):
1790+
with record_function("## All2All_Seq_bwd ##"):
17921791
req = dist.all_to_all_single(
17931792
output=sharded_grad_input,
17941793
input=sharded_grad_output.view(-1),
@@ -1822,7 +1821,7 @@ def forward(
18221821
input = a2ai.codecs.forward.encode(input)
18231822

18241823
output = input.new_empty(sum(output_split_sizes))
1825-
with record_function("## alltoallv_bwd_single ##"):
1824+
with record_function("## All2Allv_fwd ##"):
18261825
req = dist.all_to_all_single(
18271826
output,
18281827
input,
@@ -1908,7 +1907,7 @@ def backward(ctx, *grad_outputs) -> Tuple[None, None, Tensor]:
19081907
grad_outputs = [gout.contiguous().view([-1]) for gout in grad_outputs]
19091908
grad_output = torch.cat(grad_outputs)
19101909
grad_input = grad_output.new_empty([a2ai.B_global * sum(a2ai.D_local_list)])
1911-
with record_function("## alltoall_bwd_single ##"):
1910+
with record_function("## All2Allv_bwd ##"):
19121911
req = dist.all_to_all_single(
19131912
grad_input,
19141913
grad_output,
@@ -1944,7 +1943,7 @@ def forward(
19441943
dtype=inputs[my_rank].dtype,
19451944
device=inputs[my_rank].device,
19461945
)
1947-
with record_function("## reduce_scatter ##"):
1946+
with record_function("## ReduceScatter_fwd ##"):
19481947
req = dist.reduce_scatter(
19491948
output,
19501949
list(inputs),
@@ -2023,7 +2022,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
20232022
for in_size in rsi.input_sizes
20242023
]
20252024

2026-
with record_function("## reduce_scatter_bw (all_gather) ##"):
2025+
with record_function("## ReduceScatter_bwd (all_gather) ##"):
20272026
req = dist.all_gather(
20282027
grad_inputs,
20292028
grad_output.contiguous(),
@@ -2051,7 +2050,7 @@ def forward(
20512050
if rsi.codecs is not None:
20522051
inputs = rsi.codecs.forward.encode(inputs)
20532052
output = inputs.new_empty((inputs.size(0) // my_size, inputs.size(1)))
2054-
with record_function("## reduce_scatter_tensor ##"):
2053+
with record_function("## ReduceScatterBase_fwd (tensor) ##"):
20552054
req = dist.reduce_scatter_tensor(
20562055
output,
20572056
inputs,
@@ -2119,7 +2118,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
21192118
if rsi.codecs is not None:
21202119
grad_output = rsi.codecs.backward.encode(grad_output)
21212120
grad_inputs = grad_output.new_empty(rsi.input_sizes)
2122-
with record_function("## reduce_scatter_base_bw (all_gather) ##"):
2121+
with record_function("## ReduceScatterBase_bwd (all_gather) ##"):
21232122
req = dist.all_gather_into_tensor(
21242123
grad_inputs,
21252124
grad_output.contiguous(),
@@ -2148,7 +2147,7 @@ def forward(
21482147
input = agi.codecs.forward.encode(input)
21492148

21502149
outputs = input.new_empty((input.size(0) * my_size, input.size(1)))
2151-
with record_function("## all_gather_into_tensor ##"):
2150+
with record_function("## AllGatherBase_fwd (into_tensor) ##"):
21522151
req = dist.all_gather_into_tensor(
21532152
outputs,
21542153
input,
@@ -2216,7 +2215,7 @@ def backward(ctx, grad_outputs: Tensor) -> Tuple[None, None, Tensor]:
22162215
if agi.codecs is not None:
22172216
grad_outputs = agi.codecs.backward.encode(grad_outputs)
22182217
grad_input = grad_outputs.new_empty(agi.input_size)
2219-
with record_function("## all_gather_base_bw (reduce_scatter) ##"):
2218+
with record_function("## AllGatherBase_bw (reduce_scatter_tensor) ##"):
22202219
req = dist.reduce_scatter_tensor(
22212220
grad_input,
22222221
grad_outputs.contiguous(),
@@ -2250,15 +2249,15 @@ def forward(
22502249
# Use dist.reduce_scatter_tensor when a vector reduce-scatter is not needed
22512250
# else use dist.reduce_scatter which internally supports vector reduce-scatter
22522251
if rsi.equal_splits:
2253-
with record_function("## reduce_scatter_tensor ##"):
2252+
with record_function("## ReduceScatterV_fwd (tensor) ##"):
22542253
req = dist.reduce_scatter_tensor(
22552254
output,
22562255
input,
22572256
group=pg,
22582257
async_op=True,
22592258
)
22602259
else:
2261-
with record_function("## reduce_scatter_v ##"):
2260+
with record_function("## ReduceScatterV_fwd ##"):
22622261
req = dist.reduce_scatter(
22632262
output,
22642263
list(torch.split(input, rsi.input_splits)),
@@ -2331,15 +2330,15 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
23312330
grad_input = grad_output.new_empty(rsi.total_input_size)
23322331

23332332
if rsi.equal_splits:
2334-
with record_function("## reduce_scatter_base_bw (all_gather) ##"):
2333+
with record_function("## ReduceScatterV_bwd (all_gather) ##"):
23352334
req = dist.all_gather_into_tensor(
23362335
grad_input,
23372336
grad_output.contiguous(),
23382337
group=ctx.pg,
23392338
async_op=True,
23402339
)
23412340
else:
2342-
with record_function("## reduce_scatter_v_bw (all_gather_v) ##"):
2341+
with record_function("## ReduceScatterV_bwd (all_gather_v) ##"):
23432342
req = dist.all_gather(
23442343
list(torch.split(grad_input, rsi.input_splits)),
23452344
grad_output.contiguous(),

torchrec/distributed/test_utils/multi_process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def run_multi_process_func(
222222
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
223223

224224
if world_size == 1:
225+
# skip multiprocess env for single-rank job
225226
kwargs["world_size"] = 1
226227
kwargs["rank"] = 0
227228
result = func(**kwargs)

torchrec/sparse/tests/jagged_tensor_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818
from torchrec.distributed.benchmark.base import (
19-
benchmark,
19+
benchmark_model_with_warmup,
2020
BenchmarkResult,
2121
CPUMemoryStats,
2222
GPUMemoryStats,
@@ -77,7 +77,7 @@ def wrapped_func(
7777
setattr(model, "forward", lambda kwargs: fn(**kwargs))
7878
prof_num = 10
7979
if device_type == "cuda":
80-
result = benchmark(
80+
result = benchmark_model_with_warmup(
8181
name=name,
8282
model=model,
8383
warmup_inputs=[],

0 commit comments

Comments
 (0)