Skip to content

Commit f526de0

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
2/N: CPU Comms Module
Summary: This diff adds the basic building blocks for a zero overhead RecMetrics implementation. Follow up patches will contain integration with users of torchrec. One of the main pain points of using RecMetricModule is that metric updates and computes are done synchronously. In training jobs, there has been cases where metric updates take +20% of a training iteration. Metric computations, although less frequent, can takes over a couple of seconds. CPUOffloadedRecMetricModule aims to perform all metric updates/computes asynchronously, completely removing them from the critical path. This patch adds: - CPUCommsRecMetricModule: Submodule that all gathers, loads, and computes aggregated metric state tensors across ranks. Differential Revision: D83773528
1 parent d5c2494 commit f526de0

File tree

2 files changed

+501
-0
lines changed

2 files changed

+501
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
import logging
10+
from typing import Any, cast, Dict
11+
12+
from torch import nn
13+
14+
from torch.profiler import record_function
15+
16+
from torchrec.metrics.metric_module import RecMetricModule
17+
from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot
18+
from torchrec.metrics.rec_metric import (
19+
RecComputeMode,
20+
RecMetric,
21+
RecMetricComputation,
22+
RecMetricList,
23+
)
24+
25+
logger: logging.Logger = logging.getLogger(__name__)
26+
27+
28+
class CPUCommsRecMetricModule(RecMetricModule):
29+
"""
30+
A submodule of CPUOffloadedRecMetricModule.
31+
32+
The comms module's main purposes are:
33+
1. All gather metric state tensors
34+
2. Load all gathered metric states
35+
3. Compute metrics
36+
37+
This isolation allows CPUOffloadedRecMetricModule from having
38+
to concern about aggregated states and instead focus solely
39+
updating local state tensors and dumping snapshots to the comms module
40+
for metric aggregations.
41+
"""
42+
43+
def __init__(
44+
self,
45+
*args: Any,
46+
**kwargs: Any,
47+
) -> None:
48+
"""
49+
All arguments are the same as RecMetricModule
50+
"""
51+
52+
super().__init__(*args, **kwargs)
53+
54+
rec_metrics_clone = self._clone_rec_metrics()
55+
self.rec_metrics: RecMetricList = rec_metrics_clone
56+
57+
for metric in self.rec_metrics.rec_metrics:
58+
# Disable automatic sync for all metrics - handled manually via
59+
# RecMetricModule.get_pre_compute_states()
60+
metric = cast(RecMetric, metric)
61+
for computation in metric._metrics_computations:
62+
computation = cast(RecMetricComputation, computation)
63+
computation._to_sync = False
64+
65+
def load_local_metric_state_snapshot(
66+
self, state_snapshot: MetricStateSnapshot
67+
) -> None:
68+
"""
69+
Load local metric states before all gather.
70+
MetricStateSnapshot provides already-reduced states.
71+
72+
Args:
73+
state_snapshot (MetricStateSnapshot): a snapshot of metric states to load.
74+
"""
75+
76+
# Load states into comms module to be shared across ranks.
77+
78+
with record_function("## CPUCommsRecMetricModule: load_snapshot ##"):
79+
for metric in self.rec_metrics.rec_metrics:
80+
metric = cast(RecMetric, metric)
81+
compute_mode = metric._compute_mode
82+
if (
83+
compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION
84+
or compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION
85+
):
86+
prefix = compute_mode.name
87+
computation = metric._metrics_computations[0]
88+
self._load_metric_states(
89+
prefix, computation, state_snapshot.metric_states
90+
)
91+
for task, computation in zip(
92+
metric._tasks, metric._metrics_computations
93+
):
94+
self._load_metric_states(
95+
task.name, computation, state_snapshot.metric_states
96+
)
97+
98+
if state_snapshot.throughput_metric is not None:
99+
self.throughput_metric = state_snapshot.throughput_metric
100+
101+
def _load_metric_states(
102+
self, prefix: str, computation: nn.Module, metric_states: Dict[str, Any]
103+
) -> None:
104+
"""
105+
Load metric states after all gather.
106+
Uses aggregated states.
107+
"""
108+
109+
# All update() calls were done prior. Clear previous computed state.
110+
# Otherwise, we get warnings that compute() was called before
111+
# update() which is not the case.
112+
computation = cast(RecMetricComputation, computation)
113+
computation._update_called = True
114+
computation._computed = None
115+
116+
computation_name = f"{prefix}_{computation.__class__.__name__}"
117+
# Restore all cached states from reductions
118+
for attr_name in computation._reductions:
119+
cache_key = f"{computation_name}_{attr_name}"
120+
if cache_key in metric_states:
121+
cached_value = metric_states[cache_key]
122+
setattr(computation, attr_name, cached_value)
123+
124+
def _clone_rec_metrics(self) -> RecMetricList:
125+
"""
126+
Clone rec_metrics. We need to keep references to the original tasks
127+
and computation to load the state tensors. More importantly, we need to
128+
remove the references to the original metrics to prevent concurrent access
129+
from the update and compute threads.
130+
"""
131+
132+
cloned_metrics = []
133+
for metric in self.rec_metrics.rec_metrics:
134+
metric = cast(RecMetric, metric)
135+
cloned_metric = type(metric)(
136+
world_size=metric._world_size,
137+
my_rank=metric._my_rank,
138+
batch_size=metric._batch_size,
139+
tasks=metric._tasks,
140+
compute_mode=metric._compute_mode,
141+
# Standard initialization passes in the global window size. A RecMetric's
142+
# window size is set as the local window size.
143+
window_size=metric._window_size * metric._world_size,
144+
fused_update_limit=metric._fused_update_limit,
145+
compute_on_all_ranks=metric._metrics_computations[
146+
0
147+
]._compute_on_all_ranks,
148+
should_validate_update=metric._should_validate_update,
149+
# Process group should be none to prevent unwanted distributed syncs.
150+
# This is handled manually via RecMetricModule.get_pre_compute_states()
151+
process_group=None,
152+
)
153+
cloned_metrics.append(cloned_metric)
154+
155+
return RecMetricList(cloned_metrics)

0 commit comments

Comments
 (0)