Skip to content

Commit afa92f0

Browse files
kesang357facebook-github-bot
authored andcommitted
add semi sync optimizer paradigm into torchrec (#3380)
Summary: Pull Request resolved: #3380 add semi sync class in torchrec repo: - SemisyncOptimizer takes a local and global optimizer - SemisyncOptimizer inherits from KeyedOptimizer, following torchrec design to resolve CP compatibilty for modelstore and CP - SemisyncOptimizer will delegate step function to 2 phase: 1. local optimizer step(); 2. at every num_local_step, a global_optimizer step() is called. Limitation: 1. currently, it is NOT real semi sync, since global params exists in every rank for the global optimizer; and we will bring this key feature later 2. memory and QPS is NOT optimized at this point; extra work is needed for the performance Reviewed By: wz337 Differential Revision: D82503673 fbshipit-source-id: bbf71bef6ec59ba3a68c16cd7cf45f5e97dc6ac5
1 parent 29b55bd commit afa92f0

File tree

3 files changed

+782
-0
lines changed

3 files changed

+782
-0
lines changed

torchrec/optim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SGD,
3939
)
4040
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad # noqa
41+
from torchrec.optim.semi_sync import SemisyncOptimizer # noqa
4142
from torchrec.optim.warmup import WarmupOptimizer, WarmupPolicy, WarmupStage # noqa
4243

4344
from . import ( # noqa # noqa # noqa # noqa

torchrec/optim/semi_sync.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import logging
11+
from typing import Any, Collection, Dict, Mapping, Optional, Union
12+
13+
import torch
14+
import torch.distributed as dist
15+
from torch.distributed._shard.sharded_tensor import ShardedTensor
16+
from torch.optim.optimizer import ParamsT
17+
18+
from torchrec.optim.keyed import KeyedOptimizer
19+
20+
logger: logging.Logger = logging.getLogger(__name__)
21+
22+
SEMI_SYNC_GLOBAL_OPTIM_KEY: str = "semi_sync_global_optim"
23+
SEMI_SYNC_LOCAL_OPTIM_KEY: str = "semi_sync_local_optim"
24+
25+
SEMI_SYNC_GLOBAL_STATE_KEY: str = "semi_sync_global_"
26+
SEMI_SYNC_LOCAL_STATE_KEY: str = "semi_sync_local_"
27+
28+
# Step counter keys for special parameters
29+
SEMISYNC_GLOBAL_STEP_COUNTER_KEY: str = "__semisync_global_step_counter__"
30+
SEMISYNC_LOCAL_STEP_COUNTER_KEY: str = "__semisync_local_step_counter__"
31+
32+
33+
class SemisyncOptimizer(KeyedOptimizer):
34+
"""
35+
Semi-Synchronous Optimizer. Implements semi synchronous training for cross-regional
36+
distributed training for large scale recommendation models.
37+
38+
This optimizer:
39+
1. Takes a local optimizer (e.g., Shampoo or Adam wrapped in KeyedOptimizer)
40+
2. Takes a global optimizer (e.g., DiLoCo)
41+
3. Performs local steps on the local optimizer
42+
4. Periodically performs global aggregation steps using the logic of global optimizer
43+
5. Exposes combined state from both optimizers through KeyedOptimizer interface
44+
45+
Args:
46+
global_params (ParamsT): Global parameters for SemisyncOptimizer
47+
optimizer (KeyedOptimizer): Local optimizer (typically Shampoo or Adam, wrapped in KeyedOptimizer)
48+
global_optimizer (KeyedOptimizer): Global optimizer (such as DiLoCo, also wrapper in KeyedOptimizer)
49+
num_local_steps (int): Number of local steps before global sync
50+
semi_sync_worker_shard_group (Optional[dist.ProcessGroup]): Process group for metrics logging
51+
offload_global_model (bool): Offload global model parameters onto CPU (Default: False)
52+
non_blocking (bool): GPU <-> CPU communication is non blocking (Default: False)
53+
"""
54+
55+
def __init__(
56+
self,
57+
global_params: ParamsT,
58+
optimizer: KeyedOptimizer,
59+
global_optimizer: KeyedOptimizer,
60+
num_local_steps: int = 16,
61+
semi_sync_worker_shard_group: Optional[dist.ProcessGroup] = None,
62+
offload_global_model: bool = False,
63+
non_blocking: bool = False,
64+
) -> None:
65+
# Store the optimizers
66+
self._optimizer: KeyedOptimizer = optimizer
67+
self._global_optimizer: KeyedOptimizer = global_optimizer
68+
69+
assert global_params is not None, "global params must be provided"
70+
# Determine parameters for global optimizer
71+
global_params_list = list(global_params)
72+
self._worker_model_params: list[torch.Tensor] = (
73+
# pyre-ignore
74+
[param for pgroup in global_params_list for param in pgroup["params"]]
75+
if isinstance(global_params_list[0], dict)
76+
else global_params_list
77+
)
78+
79+
# Semi-sync configuration
80+
81+
self._num_local_steps: int = num_local_steps
82+
self._local_step_counter: torch.Tensor = torch.tensor(
83+
0, dtype=torch.int64, device="cpu"
84+
)
85+
self._global_step_counter: torch.Tensor = torch.tensor(
86+
0, dtype=torch.int64, device="cpu"
87+
)
88+
89+
# Store process groups info for metrics logging
90+
self._worker_shard_group: Optional[dist.ProcessGroup] = (
91+
semi_sync_worker_shard_group
92+
)
93+
self._offload_global_model: bool = offload_global_model
94+
self._non_blocking: bool = non_blocking
95+
self.defaults: Dict[str, Any] = {"_save_param_groups": False}
96+
97+
logger.info(
98+
"Instantiated SemisyncOptimizer with: "
99+
f"num_local_steps={self._num_local_steps}, "
100+
f"worker_model_params_count={len(self._worker_model_params)}, "
101+
f"offload_global_model={offload_global_model}, "
102+
f"non_blocking={non_blocking}, "
103+
f"self.defaults={self.defaults}, "
104+
f"local_optimizer={type(self._optimizer)}, "
105+
f"global_optimizer={type(self._global_optimizer)}"
106+
)
107+
108+
@property
109+
def param_groups(self) -> Collection[Mapping[str, Any]]:
110+
"""
111+
Combine param_groups from both local and global optimizers.
112+
"""
113+
return [
114+
param_group
115+
for opt in [self._optimizer, self._global_optimizer]
116+
for param_group in opt.param_groups
117+
]
118+
119+
@property
120+
def params(self) -> Mapping[str, Union[torch.Tensor, ShardedTensor]]:
121+
"""
122+
Combine params from both local and global optimizers.
123+
If param_key already exists, verify that local and global point to the same parameter tensor.
124+
"""
125+
ret = dict(self._optimizer.params)
126+
127+
# Add global params with collision check
128+
for param_key, param in self._global_optimizer.params.items():
129+
if param_key in ret:
130+
assert (
131+
ret[param_key] is param
132+
), f"Parameter key '{param_key}' exists in both optimizers but points to different tensors"
133+
else:
134+
ret[param_key] = param
135+
136+
# Add step counters as special parameters
137+
ret[SEMISYNC_GLOBAL_STEP_COUNTER_KEY] = self._global_step_counter
138+
ret[SEMISYNC_LOCAL_STEP_COUNTER_KEY] = self._local_step_counter
139+
140+
return ret
141+
142+
@property
143+
# pyre-ignore [3]
144+
def state(self) -> Mapping[torch.Tensor, Any]:
145+
"""
146+
Combine state from both local and global optimizers.
147+
- For tensors that exist in both optimizers, prefix the state keys to avoid conflicts.
148+
- Step counters are embedded into every global parameter's state.
149+
"""
150+
# Start with prefixed local states
151+
ret = {
152+
param: {
153+
f"{SEMI_SYNC_LOCAL_STATE_KEY}{key}": value
154+
for key, value in local_state.items()
155+
}
156+
for param, local_state in self._optimizer.state.items()
157+
}
158+
159+
# Add prefixed global states and step counters
160+
for param, global_state in self._global_optimizer.state.items():
161+
prefixed_global_state = {
162+
f"{SEMI_SYNC_GLOBAL_STATE_KEY}{key}": value
163+
for key, value in global_state.items()
164+
}
165+
166+
if param in ret:
167+
ret[param].update(prefixed_global_state)
168+
else:
169+
ret[param] = prefixed_global_state
170+
171+
# Add step counters to all global params
172+
ret[param].update(
173+
{
174+
SEMISYNC_GLOBAL_STEP_COUNTER_KEY: self._global_step_counter,
175+
SEMISYNC_LOCAL_STEP_COUNTER_KEY: self._local_step_counter,
176+
}
177+
)
178+
179+
return ret
180+
181+
@torch.no_grad()
182+
def step(self, closure: Any = None) -> None: # pyre-ignore [2]
183+
"""
184+
Perform semi-sync optimization step:
185+
1. Always perform local optimizer step
186+
2. At every num_local_steps, perform global optimizer step
187+
188+
TODO:
189+
See more details in D80683853, v8
190+
- metrics: add _metrics_logger for global grad and pseudo-gradient statistics
191+
- cpu cache: add GlobalModelParamsCache with non_blocking_transfer for memory efficiency
192+
- gradient clipping: add gradient clipping for global optimizer if needed
193+
"""
194+
self._local_step_counter.add_(1)
195+
trigger_global_step = self._local_step_counter.item() > 0 and (
196+
self._local_step_counter.item() % self._num_local_steps == 0
197+
)
198+
199+
# Perform local optimizer step (delegate to the local optimizer)
200+
self._optimizer.step(closure)
201+
202+
# Perform global model sync every num_local_steps
203+
if trigger_global_step:
204+
self._global_step_counter.add_(1)
205+
206+
# Step 0: Release the gradient buffer to reduce memory consumption
207+
self._global_optimizer.zero_grad()
208+
209+
# Step 1: perform global optimizer step
210+
# pyre-ignore
211+
self._global_optimizer._optimizer.global_step(
212+
self._local_step_counter, closure
213+
)
214+
215+
logger.info(
216+
f"Finished global optimizer step {self._global_step_counter.item()} "
217+
f"(after {self._local_step_counter.item()} local steps)"
218+
)
219+
220+
def zero_grad(self, set_to_none: bool = True) -> None:
221+
self._optimizer.zero_grad(set_to_none=set_to_none)
222+
self._global_optimizer.zero_grad(set_to_none=set_to_none)
223+
224+
def post_load_state_dict(self) -> None:
225+
"""
226+
Called after KeyedOptimizer.load_state_dict() completes.
227+
This is where we separate the prefixed combined state back to individual optimizers.
228+
"""
229+
logger.info(
230+
"SemisyncOptimizer: post_load_state_dict called - separating prefixed combined state"
231+
)
232+
233+
# Extract step counters from any param states that contain them
234+
combined_state = dict(self.state)
235+
self._post_load_state_dict_step_counter(combined_state)
236+
237+
# Separate states using dictionary comprehensions
238+
local_tensors = set(self._optimizer.state.keys())
239+
global_tensors = set(self._global_optimizer.state.keys())
240+
241+
local_state = {
242+
param: self._extract_prefixed_state(param_state, SEMI_SYNC_LOCAL_STATE_KEY)
243+
for param, param_state in combined_state.items()
244+
if param in local_tensors
245+
and any(k.startswith(SEMI_SYNC_LOCAL_STATE_KEY) for k in param_state)
246+
}
247+
248+
global_state = {
249+
param: self._extract_prefixed_state(param_state, SEMI_SYNC_GLOBAL_STATE_KEY)
250+
for param, param_state in combined_state.items()
251+
if param in global_tensors
252+
and any(k.startswith(SEMI_SYNC_GLOBAL_STATE_KEY) for k in param_state)
253+
}
254+
255+
# Update optimizer states
256+
for opt, state, name in [
257+
(self._optimizer, local_state, "local"),
258+
(self._global_optimizer, global_state, "global"),
259+
]:
260+
if state:
261+
opt.state.clear() # pyre-ignore
262+
opt.state.update(state) # pyre-ignore
263+
logger.info(
264+
f"SemisyncOptimizer: Set state on {name} optimizer for {len(state)} parameters"
265+
)
266+
267+
# Call post_load_state_dict on individual optimizers if they support it
268+
for opt in [self._optimizer, self._global_optimizer]:
269+
if hasattr(opt, "post_load_state_dict"):
270+
opt.post_load_state_dict()
271+
272+
def save_param_groups(self, save: bool) -> None:
273+
self.defaults["_save_param_groups"] = save
274+
self._optimizer.save_param_groups(save)
275+
self._global_optimizer.save_param_groups(save)
276+
277+
def __repr__(self) -> str:
278+
ret = []
279+
ret.append(f"{SEMI_SYNC_LOCAL_OPTIM_KEY}: {self._optimizer.__repr__()}")
280+
ret.append(f"{SEMI_SYNC_GLOBAL_OPTIM_KEY}: {self._global_optimizer.__repr__()}")
281+
return ", ".join(ret)
282+
283+
def set_optimizer_step(self, step: int) -> None:
284+
for opt in [self._optimizer, self._global_optimizer]:
285+
if hasattr(opt, "set_optimizer_step"):
286+
# pyre-ignore [16]: Undefined attribute [16]: `KeyedOptimizer` has no attribute `set_optimizer_step`.
287+
opt.set_optimizer_step(step)
288+
289+
def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None:
290+
291+
for opt in [self._optimizer, self._global_optimizer]:
292+
if hasattr(opt, "update_hyper_parameters"):
293+
# pyre-ignore [16]: Undefined attribute [16]: `KeyedOptimizer` has no attribute `update_hyper_parameters`.
294+
opt.update_hyper_parameters(params_dict)
295+
296+
@staticmethod
297+
def _extract_prefixed_state(
298+
param_state: Dict[str, Any], prefix: str
299+
) -> Dict[str, Any]:
300+
"""
301+
Extract state keys with a specific prefix and remove the prefix.
302+
303+
Args:
304+
param_state: Parameter state dictionary
305+
prefix: Prefix to extract and remove
306+
307+
Returns:
308+
Dictionary with prefix removed from matching keys
309+
"""
310+
return {
311+
key[len(prefix) :]: value
312+
for key, value in param_state.items()
313+
if key.startswith(prefix)
314+
}
315+
316+
def _post_load_state_dict_step_counter(
317+
self,
318+
combined_state: Dict[torch.Tensor, Any], # pyre-ignore
319+
) -> None:
320+
"""Extract step counters from any param states that contain them."""
321+
found = {"global": False, "local": False}
322+
323+
for param_state in combined_state.values():
324+
if not found["global"] and SEMISYNC_GLOBAL_STEP_COUNTER_KEY in param_state:
325+
self._global_step_counter = param_state[
326+
SEMISYNC_GLOBAL_STEP_COUNTER_KEY
327+
]
328+
found["global"] = True
329+
if not found["local"] and SEMISYNC_LOCAL_STEP_COUNTER_KEY in param_state:
330+
self._local_step_counter = param_state[SEMISYNC_LOCAL_STEP_COUNTER_KEY]
331+
found["local"] = True
332+
if all(found.values()):
333+
break
334+
335+
missing = [k for k, v in found.items() if not v]
336+
if missing:
337+
raise RuntimeError(
338+
f"Missing {' and '.join(missing)} step counter(s) in checkpoint."
339+
)

0 commit comments

Comments
 (0)