Skip to content

Commit 3fd902f

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
2/N: CPU Comms Module (meta-pytorch#3424)
Summary: This diff adds the basic building blocks for a zero overhead RecMetrics implementation. Follow up patches will contain integration with users of torchrec. One of the main pain points of using RecMetricModule is that metric updates and computes are done synchronously. In training jobs, there has been cases where metric updates take +20% of a training iteration. Metric computations, although less frequent, can takes over a couple of seconds. CPUOffloadedRecMetricModule aims to perform all metric updates/computes asynchronously, completely removing them from the critical path. This patch adds: - CPUCommsRecMetricModule: Submodule that all gathers, loads, and computes aggregated metric state tensors across ranks. Differential Revision: D83773528
1 parent 90eb966 commit 3fd902f

File tree

2 files changed

+513
-0
lines changed

2 files changed

+513
-0
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
import logging
10+
from typing import Any, cast, Dict
11+
12+
from torch import nn
13+
14+
from torch.profiler import record_function
15+
16+
from torchrec.metrics.metric_module import RecMetricModule
17+
from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot
18+
from torchrec.metrics.rec_metric import (
19+
RecComputeMode,
20+
RecMetric,
21+
RecMetricComputation,
22+
RecMetricList,
23+
)
24+
25+
logger: logging.Logger = logging.getLogger(__name__)
26+
27+
28+
class CPUCommsRecMetricModule(RecMetricModule):
29+
"""
30+
A submodule of CPUOffloadedRecMetricModule.
31+
32+
The comms module's main purposes are:
33+
1. All gather metric state tensors
34+
2. Load all gathered metric states
35+
3. Compute metrics
36+
37+
This isolation allows CPUOffloadedRecMetricModule from having
38+
to concern about aggregated states and instead focus solely
39+
updating local state tensors and dumping snapshots to the comms module
40+
for metric aggregations.
41+
"""
42+
43+
def __init__(
44+
self,
45+
*args: Any,
46+
**kwargs: Any,
47+
) -> None:
48+
"""
49+
All arguments are the same as RecMetricModule
50+
"""
51+
52+
super().__init__(*args, **kwargs)
53+
54+
rec_metrics_clone = self._clone_rec_metrics()
55+
self.rec_metrics: RecMetricList = rec_metrics_clone
56+
57+
for metric in self.rec_metrics.rec_metrics:
58+
# Disable automatic sync for all metrics - handled manually via
59+
# RecMetricModule.get_pre_compute_states()
60+
metric = cast(RecMetric, metric)
61+
for computation in metric._metrics_computations:
62+
computation = cast(RecMetricComputation, computation)
63+
computation._to_sync = False
64+
65+
def load_local_metric_state_snapshot(
66+
self, state_snapshot: MetricStateSnapshot
67+
) -> None:
68+
"""
69+
Load local metric states before all gather.
70+
MetricStateSnapshot provides already-reduced states.
71+
72+
Args:
73+
state_snapshot (MetricStateSnapshot): a snapshot of metric states to load.
74+
"""
75+
76+
# Load states into comms module to be shared across ranks.
77+
78+
with record_function("## CPUCommsRecMetricModule: load_snapshot ##"):
79+
for metric in self.rec_metrics.rec_metrics:
80+
metric = cast(RecMetric, metric)
81+
compute_mode = metric._compute_mode
82+
if (
83+
compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION
84+
or compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION
85+
):
86+
prefix = compute_mode.name
87+
computation = metric._metrics_computations[0]
88+
self._load_metric_states(
89+
prefix, computation, state_snapshot.metric_states
90+
)
91+
for task, computation in zip(
92+
metric._tasks, metric._metrics_computations
93+
):
94+
self._load_metric_states(
95+
task.name, computation, state_snapshot.metric_states
96+
)
97+
98+
if state_snapshot.throughput_metric is not None:
99+
self.throughput_metric = state_snapshot.throughput_metric
100+
101+
def _load_metric_states(
102+
self, prefix: str, computation: nn.Module, metric_states: Dict[str, Any]
103+
) -> None:
104+
"""
105+
Load metric states after all gather.
106+
Uses aggregated states.
107+
"""
108+
109+
# All update() calls were done prior. Clear previous computed state.
110+
# Otherwise, we get warnings that compute() was called before
111+
# update() which is not the case.
112+
computation = cast(RecMetricComputation, computation)
113+
set_update_called(computation)
114+
computation._computed = None
115+
116+
computation_name = f"{prefix}_{computation.__class__.__name__}"
117+
# Restore all cached states from reductions
118+
for attr_name in computation._reductions:
119+
cache_key = f"{computation_name}_{attr_name}"
120+
if cache_key in metric_states:
121+
cached_value = metric_states[cache_key]
122+
setattr(computation, attr_name, cached_value)
123+
124+
def _clone_rec_metrics(self) -> RecMetricList:
125+
"""
126+
Clone rec_metrics. We need to keep references to the original tasks
127+
and computation to load the state tensors. More importantly, we need to
128+
remove the references to the original metrics to prevent concurrent access
129+
from the update and compute threads.
130+
"""
131+
132+
cloned_metrics = []
133+
for metric in self.rec_metrics.rec_metrics:
134+
metric = cast(RecMetric, metric)
135+
cloned_metric = type(metric)(
136+
world_size=metric._world_size,
137+
my_rank=metric._my_rank,
138+
batch_size=metric._batch_size,
139+
tasks=metric._tasks,
140+
compute_mode=metric._compute_mode,
141+
# Standard initialization passes in the global window size. A RecMetric's
142+
# window size is set as the local window size.
143+
window_size=metric._window_size * metric._world_size,
144+
fused_update_limit=metric._fused_update_limit,
145+
compute_on_all_ranks=metric._metrics_computations[
146+
0
147+
]._compute_on_all_ranks,
148+
should_validate_update=metric._should_validate_update,
149+
# Process group should be none to prevent unwanted distributed syncs.
150+
# This is handled manually via RecMetricModule.get_pre_compute_states()
151+
process_group=None,
152+
)
153+
cloned_metrics.append(cloned_metric)
154+
155+
return RecMetricList(cloned_metrics)
156+
157+
158+
def set_update_called(computation: RecMetricComputation) -> None:
159+
"""
160+
Set _update_called to True for RecMetricComputation.
161+
This is a workaround for torchmetrics 1.0.3+.
162+
"""
163+
try:
164+
computation._update_called = True
165+
except AttributeError:
166+
# pyre-ignore
167+
computation._update_count = 1

0 commit comments

Comments
 (0)