@@ -68,6 +68,9 @@ def __init__(
68
68
# Otherwise, all parameters are treated as replicated and will be clipped locally.
69
69
sharded_param_cnt = 0
70
70
self ._replicate_params : List [torch .Tensor ] = []
71
+
72
+ # self._sharded_params: List[ProcessGroup], value: List[torch.Tensor]
73
+ # maps each process group to a list of sharded parameters.
71
74
self ._sharded_params : Dict [Tuple [dist .ProcessGroup ], List [torch .Tensor ]] = (
72
75
defaultdict (list )
73
76
)
@@ -143,90 +146,101 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
143
146
all_grads = []
144
147
total_grad_norm = None
145
148
149
+ sharded_params = self ._sharded_params
150
+ ddp_params = self ._replicate_params
146
151
# Process distributed parameters and gradients
147
- for pgs , dist_params in self . _sharded_params .items ():
152
+ for _ , dist_params in sharded_params .items ():
148
153
sharded_grads = [
149
154
p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
150
155
for p in dist_params
151
156
if p .grad is not None and p .grad .numel () > 0
152
157
]
153
- if len (sharded_grads ) == 0 :
154
- continue
155
158
all_grads .extend (sharded_grads )
156
159
157
- sharded_grad_norm = _batch_cal_norm (
158
- sharded_grads ,
159
- max_norm ,
160
- norm_type ,
161
- pgs ,
162
- )
163
- total_grad_norm = (
164
- sharded_grad_norm
165
- if total_grad_norm is None
166
- else (
167
- torch .maximum (total_grad_norm , sharded_grad_norm )
168
- if norm_type == torch .inf
169
- else total_grad_norm + sharded_grad_norm
170
- )
171
- )
172
-
173
- square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
174
-
175
160
# Process replicated parameters and gradients
176
- if self . _replicate_params :
177
- replicated_grads = [
161
+ if ddp_params :
162
+ ddp_grads = [
178
163
p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
179
164
for p in self ._replicate_params
180
165
if p .grad is not None and p .grad .numel () > 0
181
166
]
182
- all_grads .extend (replicated_grads )
183
-
184
- replicated_grad_norm = _batch_cal_norm (
185
- replicated_grads ,
186
- max_norm ,
187
- norm_type ,
188
- None ,
189
- )
190
- total_grad_norm = (
191
- replicated_grad_norm
192
- if total_grad_norm is None
193
- else (
194
- torch .maximum (total_grad_norm , replicated_grad_norm )
195
- if norm_type == torch .inf
196
- else total_grad_norm + replicated_grad_norm
197
- )
198
- )
199
- square_replicated_grad_norm = replicated_grad_norm
200
- else :
201
- square_replicated_grad_norm = 0
202
-
203
- global log_grad_norm
204
- if log_grad_norm :
205
- if total_grad_norm is not None and norm_type != torch .inf :
206
- # pyre-ignore[58]
207
- grad_norm = total_grad_norm ** (1.0 / norm_type )
208
- else :
209
- grad_norm = total_grad_norm
210
-
211
- rank = dist .get_rank ()
212
- logger .info (
213
- f"Clipping [rank={ rank } , step={ self ._step_num } ]: square_sharded_grad_norm = { square_sharded_grad_norm } , square_replicated_grad_norm = { square_replicated_grad_norm } , total_grad_norm = { grad_norm } "
214
- )
167
+ all_grads .extend (ddp_grads )
215
168
216
- # Aggregation
217
- if total_grad_norm is None :
218
- return
169
+ total_grad_norm = _compute_total_norm (
170
+ ddp_params , sharded_params , norm_type , max_norm
171
+ )
219
172
220
- if norm_type != torch .inf :
221
- # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222
- total_grad_norm = total_grad_norm ** (1.0 / norm_type )
223
173
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224
174
clip_coef = cast (torch .Tensor , max_norm / (total_grad_norm + 1e-6 ))
225
175
clip_coef_clamped = torch .clamp (clip_coef , max = 1.0 )
226
176
torch ._foreach_mul_ (all_grads , clip_coef_clamped )
227
177
return total_grad_norm
228
178
229
179
180
+ def _compute_total_norm (
181
+ ddp_params : List [torch .Tensor ] | None = None ,
182
+ fsdp_params : Dict [Tuple [dist .ProcessGroup ], List [torch .Tensor ]] | None = None ,
183
+ norm_type : float = 2.0 , # can be a normal float, or torch.inf
184
+ max_grad_norm : float = 1.0 ,
185
+ ) -> torch .Tensor :
186
+ """
187
+ Given both ddp params and sharded params, compute the total norm of the gradients of the full ddp params and the
188
+ full fsdp param.
189
+
190
+ Args:
191
+ ddp_params (List[torch.Tensor]): list of ddp params
192
+ fsdp_params (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of tensors
193
+ norm_type (Union[float, str]): type of the used p-norm. Can be ``'inf'`` for infinity norm.
194
+ enable_global_grad_clip (bool): whether to compute total norm using all fsdp shards in the process group
195
+ param_to_pgs (Dict[torch.nn.Parameter, List[dist.ProcessGroup]]): mapping of parameters to process groups.
196
+ """
197
+
198
+ ## compute |W|^p corresponding to all DDP params W
199
+
200
+ if ddp_params is None :
201
+ ddp_params = []
202
+ if fsdp_params is None :
203
+ fsdp_params = defaultdict (list )
204
+
205
+ def get_grad_norm (
206
+ param_list : List [torch .Tensor ],
207
+ norm_type : float ,
208
+ max_grad_norm : float ,
209
+ pgs : Tuple [dist .ProcessGroup ] | None = None ,
210
+ ) -> torch .Tensor :
211
+ grad_list = [
212
+ p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
213
+ for p in param_list
214
+ if p .grad is not None and p .grad .numel () > 0
215
+ ]
216
+ return _batch_cal_norm (grad_list , max_grad_norm , norm_type , pgs )
217
+
218
+ ddp_grad_norm : torch .Tensor = (
219
+ get_grad_norm (ddp_params , norm_type , max_grad_norm )
220
+ if ddp_params
221
+ else torch .tensor (0.0 )
222
+ )
223
+
224
+ ## compute the norm |W|^p corresponding to all sharded params W
225
+ fsdp_grad_norm : torch .Tensor = torch .tensor (0.0 )
226
+ if fsdp_params :
227
+ combine_fsdp_norm_operator = (
228
+ torch .maximum if norm_type == torch .inf else torch .add
229
+ )
230
+ for pgs , dist_params in fsdp_params .items ():
231
+ shard_norm = get_grad_norm (dist_params , norm_type , max_grad_norm , pgs )
232
+ fsdp_grad_norm = combine_fsdp_norm_operator (fsdp_grad_norm , shard_norm )
233
+
234
+ combine_norm_operator = (
235
+ torch .maximum
236
+ if norm_type == torch .inf
237
+ else lambda a , b : torch .add (a , b ).pow (1.0 / norm_type )
238
+ )
239
+
240
+ total_grad_norm = combine_norm_operator (ddp_grad_norm , fsdp_grad_norm )
241
+ return total_grad_norm
242
+
243
+
230
244
def _batch_cal_norm (
231
245
grad_list : List [torch .Tensor ],
232
246
max_norm : float ,
0 commit comments