|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-strict |
| 9 | + |
| 10 | +import logging |
| 11 | +from typing import Any, Collection, Dict, Mapping, Optional, Union |
| 12 | + |
| 13 | +import torch |
| 14 | +import torch.distributed as dist |
| 15 | +from torch.distributed._shard.sharded_tensor import ShardedTensor |
| 16 | +from torch.optim.optimizer import ParamsT |
| 17 | + |
| 18 | +from torchrec.optim.keyed import KeyedOptimizer |
| 19 | + |
| 20 | +logger: logging.Logger = logging.getLogger(__name__) |
| 21 | + |
| 22 | +SEMI_SYNC_GLOBAL_OPTIM_KEY: str = "semi_sync_global_optim" |
| 23 | +SEMI_SYNC_LOCAL_OPTIM_KEY: str = "semi_sync_local_optim" |
| 24 | + |
| 25 | +SEMI_SYNC_GLOBAL_STATE_KEY: str = "semi_sync_global_" |
| 26 | +SEMI_SYNC_LOCAL_STATE_KEY: str = "semi_sync_local_" |
| 27 | + |
| 28 | +# Step counter keys for special parameters |
| 29 | +SEMISYNC_GLOBAL_STEP_COUNTER_KEY: str = "__semisync_global_step_counter__" |
| 30 | +SEMISYNC_LOCAL_STEP_COUNTER_KEY: str = "__semisync_local_step_counter__" |
| 31 | + |
| 32 | + |
| 33 | +class SemisyncOptimizer(KeyedOptimizer): |
| 34 | + """ |
| 35 | + Semi-Synchronous Optimizer. Implements semi synchronous training for cross-regional |
| 36 | + distributed training for large scale recommendation models. |
| 37 | +
|
| 38 | + This optimizer: |
| 39 | + 1. Takes a local optimizer (e.g., Shampoo or Adam wrapped in KeyedOptimizer) |
| 40 | + 2. Takes a global optimizer (e.g., DiLoCo) |
| 41 | + 3. Performs local steps on the local optimizer |
| 42 | + 4. Periodically performs global aggregation steps using the logic of global optimizer |
| 43 | + 5. Exposes combined state from both optimizers through KeyedOptimizer interface |
| 44 | +
|
| 45 | + Args: |
| 46 | + global_params (ParamsT): Global parameters for SemisyncOptimizer |
| 47 | + optimizer (KeyedOptimizer): Local optimizer (typically Shampoo or Adam, wrapped in KeyedOptimizer) |
| 48 | + global_optimizer (KeyedOptimizer): Global optimizer (such as DiLoCo, also wrapper in KeyedOptimizer) |
| 49 | + num_local_steps (int): Number of local steps before global sync |
| 50 | + semi_sync_worker_shard_group (Optional[dist.ProcessGroup]): Process group for metrics logging |
| 51 | + offload_global_model (bool): Offload global model parameters onto CPU (Default: False) |
| 52 | + non_blocking (bool): GPU <-> CPU communication is non blocking (Default: False) |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + global_params: ParamsT, |
| 58 | + optimizer: KeyedOptimizer, |
| 59 | + global_optimizer: KeyedOptimizer, |
| 60 | + num_local_steps: int = 16, |
| 61 | + semi_sync_worker_shard_group: Optional[dist.ProcessGroup] = None, |
| 62 | + offload_global_model: bool = False, |
| 63 | + non_blocking: bool = False, |
| 64 | + ) -> None: |
| 65 | + # Store the optimizers |
| 66 | + self._optimizer: KeyedOptimizer = optimizer |
| 67 | + self._global_optimizer: KeyedOptimizer = global_optimizer |
| 68 | + |
| 69 | + assert global_params is not None, "global params must be provided" |
| 70 | + # Determine parameters for global optimizer |
| 71 | + global_params_list = list(global_params) |
| 72 | + self._worker_model_params: list[torch.Tensor] = ( |
| 73 | + # pyre-ignore |
| 74 | + [param for pgroup in global_params_list for param in pgroup["params"]] |
| 75 | + if isinstance(global_params_list[0], dict) |
| 76 | + else global_params_list |
| 77 | + ) |
| 78 | + |
| 79 | + # Semi-sync configuration |
| 80 | + |
| 81 | + self._num_local_steps: int = num_local_steps |
| 82 | + self._local_step_counter: torch.Tensor = torch.tensor( |
| 83 | + 0, dtype=torch.int64, device="cpu" |
| 84 | + ) |
| 85 | + self._global_step_counter: torch.Tensor = torch.tensor( |
| 86 | + 0, dtype=torch.int64, device="cpu" |
| 87 | + ) |
| 88 | + |
| 89 | + # Store process groups info for metrics logging |
| 90 | + self._worker_shard_group: Optional[dist.ProcessGroup] = ( |
| 91 | + semi_sync_worker_shard_group |
| 92 | + ) |
| 93 | + self._offload_global_model: bool = offload_global_model |
| 94 | + self._non_blocking: bool = non_blocking |
| 95 | + self.defaults: Dict[str, Any] = {"_save_param_groups": False} |
| 96 | + |
| 97 | + logger.info( |
| 98 | + "Instantiated SemisyncOptimizer with: " |
| 99 | + f"num_local_steps={self._num_local_steps}, " |
| 100 | + f"worker_model_params_count={len(self._worker_model_params)}, " |
| 101 | + f"offload_global_model={offload_global_model}, " |
| 102 | + f"non_blocking={non_blocking}, " |
| 103 | + f"self.defaults={self.defaults}, " |
| 104 | + f"local_optimizer={type(self._optimizer)}, " |
| 105 | + f"global_optimizer={type(self._global_optimizer)}" |
| 106 | + ) |
| 107 | + |
| 108 | + @property |
| 109 | + def param_groups(self) -> Collection[Mapping[str, Any]]: |
| 110 | + """ |
| 111 | + Combine param_groups from both local and global optimizers. |
| 112 | + """ |
| 113 | + return [ |
| 114 | + param_group |
| 115 | + for opt in [self._optimizer, self._global_optimizer] |
| 116 | + for param_group in opt.param_groups |
| 117 | + ] |
| 118 | + |
| 119 | + @property |
| 120 | + def params(self) -> Mapping[str, Union[torch.Tensor, ShardedTensor]]: |
| 121 | + """ |
| 122 | + Combine params from both local and global optimizers. |
| 123 | + If param_key already exists, verify that local and global point to the same parameter tensor. |
| 124 | + """ |
| 125 | + ret = dict(self._optimizer.params) |
| 126 | + |
| 127 | + # Add global params with collision check |
| 128 | + for param_key, param in self._global_optimizer.params.items(): |
| 129 | + if param_key in ret: |
| 130 | + assert ( |
| 131 | + ret[param_key] is param |
| 132 | + ), f"Parameter key '{param_key}' exists in both optimizers but points to different tensors" |
| 133 | + else: |
| 134 | + ret[param_key] = param |
| 135 | + |
| 136 | + # Add step counters as special parameters |
| 137 | + ret[SEMISYNC_GLOBAL_STEP_COUNTER_KEY] = self._global_step_counter |
| 138 | + ret[SEMISYNC_LOCAL_STEP_COUNTER_KEY] = self._local_step_counter |
| 139 | + |
| 140 | + return ret |
| 141 | + |
| 142 | + @property |
| 143 | + # pyre-ignore [3] |
| 144 | + def state(self) -> Mapping[torch.Tensor, Any]: |
| 145 | + """ |
| 146 | + Combine state from both local and global optimizers. |
| 147 | + - For tensors that exist in both optimizers, prefix the state keys to avoid conflicts. |
| 148 | + - Step counters are embedded into every global parameter's state. |
| 149 | + """ |
| 150 | + # Start with prefixed local states |
| 151 | + ret = { |
| 152 | + param: { |
| 153 | + f"{SEMI_SYNC_LOCAL_STATE_KEY}{key}": value |
| 154 | + for key, value in local_state.items() |
| 155 | + } |
| 156 | + for param, local_state in self._optimizer.state.items() |
| 157 | + } |
| 158 | + |
| 159 | + # Add prefixed global states and step counters |
| 160 | + for param, global_state in self._global_optimizer.state.items(): |
| 161 | + prefixed_global_state = { |
| 162 | + f"{SEMI_SYNC_GLOBAL_STATE_KEY}{key}": value |
| 163 | + for key, value in global_state.items() |
| 164 | + } |
| 165 | + |
| 166 | + if param in ret: |
| 167 | + ret[param].update(prefixed_global_state) |
| 168 | + else: |
| 169 | + ret[param] = prefixed_global_state |
| 170 | + |
| 171 | + # Add step counters to all global params |
| 172 | + ret[param].update( |
| 173 | + { |
| 174 | + SEMISYNC_GLOBAL_STEP_COUNTER_KEY: self._global_step_counter, |
| 175 | + SEMISYNC_LOCAL_STEP_COUNTER_KEY: self._local_step_counter, |
| 176 | + } |
| 177 | + ) |
| 178 | + |
| 179 | + return ret |
| 180 | + |
| 181 | + @torch.no_grad() |
| 182 | + def step(self, closure: Any = None) -> None: # pyre-ignore [2] |
| 183 | + """ |
| 184 | + Perform semi-sync optimization step: |
| 185 | + 1. Always perform local optimizer step |
| 186 | + 2. At every num_local_steps, perform global optimizer step |
| 187 | +
|
| 188 | + TODO: |
| 189 | + See more details in D80683853, v8 |
| 190 | + - metrics: add _metrics_logger for global grad and pseudo-gradient statistics |
| 191 | + - cpu cache: add GlobalModelParamsCache with non_blocking_transfer for memory efficiency |
| 192 | + - gradient clipping: add gradient clipping for global optimizer if needed |
| 193 | + """ |
| 194 | + self._local_step_counter.add_(1) |
| 195 | + trigger_global_step = self._local_step_counter.item() > 0 and ( |
| 196 | + self._local_step_counter.item() % self._num_local_steps == 0 |
| 197 | + ) |
| 198 | + |
| 199 | + # Perform local optimizer step (delegate to the local optimizer) |
| 200 | + self._optimizer.step(closure) |
| 201 | + |
| 202 | + # Perform global model sync every num_local_steps |
| 203 | + if trigger_global_step: |
| 204 | + self._global_step_counter.add_(1) |
| 205 | + |
| 206 | + # Step 0: Release the gradient buffer to reduce memory consumption |
| 207 | + self._global_optimizer.zero_grad() |
| 208 | + |
| 209 | + # Step 1: perform global optimizer step |
| 210 | + # pyre-ignore |
| 211 | + self._global_optimizer._optimizer.global_step( |
| 212 | + self._local_step_counter, closure |
| 213 | + ) |
| 214 | + |
| 215 | + logger.info( |
| 216 | + f"Finished global optimizer step {self._global_step_counter.item()} " |
| 217 | + f"(after {self._local_step_counter.item()} local steps)" |
| 218 | + ) |
| 219 | + |
| 220 | + def zero_grad(self, set_to_none: bool = True) -> None: |
| 221 | + self._optimizer.zero_grad(set_to_none=set_to_none) |
| 222 | + self._global_optimizer.zero_grad(set_to_none=set_to_none) |
| 223 | + |
| 224 | + def post_load_state_dict(self) -> None: |
| 225 | + """ |
| 226 | + Called after KeyedOptimizer.load_state_dict() completes. |
| 227 | + This is where we separate the prefixed combined state back to individual optimizers. |
| 228 | + """ |
| 229 | + logger.info( |
| 230 | + "SemisyncOptimizer: post_load_state_dict called - separating prefixed combined state" |
| 231 | + ) |
| 232 | + |
| 233 | + # Extract step counters from any param states that contain them |
| 234 | + combined_state = dict(self.state) |
| 235 | + self._post_load_state_dict_step_counter(combined_state) |
| 236 | + |
| 237 | + # Separate states using dictionary comprehensions |
| 238 | + local_tensors = set(self._optimizer.state.keys()) |
| 239 | + global_tensors = set(self._global_optimizer.state.keys()) |
| 240 | + |
| 241 | + local_state = { |
| 242 | + param: self._extract_prefixed_state(param_state, SEMI_SYNC_LOCAL_STATE_KEY) |
| 243 | + for param, param_state in combined_state.items() |
| 244 | + if param in local_tensors |
| 245 | + and any(k.startswith(SEMI_SYNC_LOCAL_STATE_KEY) for k in param_state) |
| 246 | + } |
| 247 | + |
| 248 | + global_state = { |
| 249 | + param: self._extract_prefixed_state(param_state, SEMI_SYNC_GLOBAL_STATE_KEY) |
| 250 | + for param, param_state in combined_state.items() |
| 251 | + if param in global_tensors |
| 252 | + and any(k.startswith(SEMI_SYNC_GLOBAL_STATE_KEY) for k in param_state) |
| 253 | + } |
| 254 | + |
| 255 | + # Update optimizer states |
| 256 | + for opt, state, name in [ |
| 257 | + (self._optimizer, local_state, "local"), |
| 258 | + (self._global_optimizer, global_state, "global"), |
| 259 | + ]: |
| 260 | + if state: |
| 261 | + opt.state.clear() # pyre-ignore |
| 262 | + opt.state.update(state) # pyre-ignore |
| 263 | + logger.info( |
| 264 | + f"SemisyncOptimizer: Set state on {name} optimizer for {len(state)} parameters" |
| 265 | + ) |
| 266 | + |
| 267 | + # Call post_load_state_dict on individual optimizers if they support it |
| 268 | + for opt in [self._optimizer, self._global_optimizer]: |
| 269 | + if hasattr(opt, "post_load_state_dict"): |
| 270 | + opt.post_load_state_dict() |
| 271 | + |
| 272 | + def save_param_groups(self, save: bool) -> None: |
| 273 | + self.defaults["_save_param_groups"] = save |
| 274 | + self._optimizer.save_param_groups(save) |
| 275 | + self._global_optimizer.save_param_groups(save) |
| 276 | + |
| 277 | + def __repr__(self) -> str: |
| 278 | + ret = [] |
| 279 | + ret.append(f"{SEMI_SYNC_LOCAL_OPTIM_KEY}: {self._optimizer.__repr__()}") |
| 280 | + ret.append(f"{SEMI_SYNC_GLOBAL_OPTIM_KEY}: {self._global_optimizer.__repr__()}") |
| 281 | + return ", ".join(ret) |
| 282 | + |
| 283 | + def set_optimizer_step(self, step: int) -> None: |
| 284 | + for opt in [self._optimizer, self._global_optimizer]: |
| 285 | + if hasattr(opt, "set_optimizer_step"): |
| 286 | + # pyre-ignore [16]: Undefined attribute [16]: `KeyedOptimizer` has no attribute `set_optimizer_step`. |
| 287 | + opt.set_optimizer_step(step) |
| 288 | + |
| 289 | + def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None: |
| 290 | + |
| 291 | + for opt in [self._optimizer, self._global_optimizer]: |
| 292 | + if hasattr(opt, "update_hyper_parameters"): |
| 293 | + # pyre-ignore [16]: Undefined attribute [16]: `KeyedOptimizer` has no attribute `update_hyper_parameters`. |
| 294 | + opt.update_hyper_parameters(params_dict) |
| 295 | + |
| 296 | + @staticmethod |
| 297 | + def _extract_prefixed_state( |
| 298 | + param_state: Dict[str, Any], prefix: str |
| 299 | + ) -> Dict[str, Any]: |
| 300 | + """ |
| 301 | + Extract state keys with a specific prefix and remove the prefix. |
| 302 | +
|
| 303 | + Args: |
| 304 | + param_state: Parameter state dictionary |
| 305 | + prefix: Prefix to extract and remove |
| 306 | +
|
| 307 | + Returns: |
| 308 | + Dictionary with prefix removed from matching keys |
| 309 | + """ |
| 310 | + return { |
| 311 | + key[len(prefix) :]: value |
| 312 | + for key, value in param_state.items() |
| 313 | + if key.startswith(prefix) |
| 314 | + } |
| 315 | + |
| 316 | + def _post_load_state_dict_step_counter( |
| 317 | + self, |
| 318 | + combined_state: Dict[torch.Tensor, Any], # pyre-ignore |
| 319 | + ) -> None: |
| 320 | + """Extract step counters from any param states that contain them.""" |
| 321 | + found = {"global": False, "local": False} |
| 322 | + |
| 323 | + for param_state in combined_state.values(): |
| 324 | + if not found["global"] and SEMISYNC_GLOBAL_STEP_COUNTER_KEY in param_state: |
| 325 | + self._global_step_counter = param_state[ |
| 326 | + SEMISYNC_GLOBAL_STEP_COUNTER_KEY |
| 327 | + ] |
| 328 | + found["global"] = True |
| 329 | + if not found["local"] and SEMISYNC_LOCAL_STEP_COUNTER_KEY in param_state: |
| 330 | + self._local_step_counter = param_state[SEMISYNC_LOCAL_STEP_COUNTER_KEY] |
| 331 | + found["local"] = True |
| 332 | + if all(found.values()): |
| 333 | + break |
| 334 | + |
| 335 | + missing = [k for k, v in found.items() if not v] |
| 336 | + if missing: |
| 337 | + raise RuntimeError( |
| 338 | + f"Missing {' and '.join(missing)} step counter(s) in checkpoint." |
| 339 | + ) |
0 commit comments