@@ -68,6 +68,9 @@ def __init__(
6868 # Otherwise, all parameters are treated as replicated and will be clipped locally.
6969 sharded_param_cnt = 0
7070 self ._replicate_params : List [torch .Tensor ] = []
71+
72+ # self._sharded_params: List[ProcessGroup], value: List[torch.Tensor]
73+ # maps each process group to a list of sharded parameters.
7174 self ._sharded_params : Dict [Tuple [dist .ProcessGroup ], List [torch .Tensor ]] = (
7275 defaultdict (list )
7376 )
@@ -143,90 +146,105 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
143146 all_grads = []
144147 total_grad_norm = None
145148
149+ sharded_params = self ._sharded_params
150+ ddp_params = self ._replicate_params
146151 # Process distributed parameters and gradients
147- for pgs , dist_params in self . _sharded_params .items ():
152+ for _ , dist_params in sharded_params .items ():
148153 sharded_grads = [
149154 p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
150155 for p in dist_params
151156 if p .grad is not None and p .grad .numel () > 0
152157 ]
153- if len (sharded_grads ) == 0 :
154- continue
155158 all_grads .extend (sharded_grads )
156159
157- sharded_grad_norm = _batch_cal_norm (
158- sharded_grads ,
159- max_norm ,
160- norm_type ,
161- pgs ,
162- )
163- total_grad_norm = (
164- sharded_grad_norm
165- if total_grad_norm is None
166- else (
167- torch .maximum (total_grad_norm , sharded_grad_norm )
168- if norm_type == torch .inf
169- else total_grad_norm + sharded_grad_norm
170- )
171- )
172-
173- square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
174-
175160 # Process replicated parameters and gradients
176- if self . _replicate_params :
177- replicated_grads = [
161+ if ddp_params :
162+ ddp_grads = [
178163 p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
179164 for p in self ._replicate_params
180165 if p .grad is not None and p .grad .numel () > 0
181166 ]
182- all_grads .extend (replicated_grads )
167+ all_grads .extend (ddp_grads )
183168
184- replicated_grad_norm = _batch_cal_norm (
185- replicated_grads ,
186- max_norm ,
187- norm_type ,
188- None ,
189- )
190- total_grad_norm = (
191- replicated_grad_norm
192- if total_grad_norm is None
193- else (
194- torch .maximum (total_grad_norm , replicated_grad_norm )
195- if norm_type == torch .inf
196- else total_grad_norm + replicated_grad_norm
197- )
198- )
199- square_replicated_grad_norm = replicated_grad_norm
200- else :
201- square_replicated_grad_norm = 0
202-
203- global log_grad_norm
204- if log_grad_norm :
205- if total_grad_norm is not None and norm_type != torch .inf :
206- # pyre-ignore[58]
207- grad_norm = total_grad_norm ** (1.0 / norm_type )
208- else :
209- grad_norm = total_grad_norm
210-
211- rank = dist .get_rank ()
212- logger .info (
213- f"Clipping [rank={ rank } , step={ self ._step_num } ]: square_sharded_grad_norm = { square_sharded_grad_norm } , square_replicated_grad_norm = { square_replicated_grad_norm } , total_grad_norm = { grad_norm } "
214- )
215-
216- # Aggregation
217- if total_grad_norm is None :
218- return
169+ total_grad_norm = _compute_total_norm (
170+ ddp_params , sharded_params , norm_type , max_norm
171+ )
219172
220- if norm_type != torch .inf :
221- # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222- total_grad_norm = total_grad_norm ** (1.0 / norm_type )
223173 # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224174 clip_coef = cast (torch .Tensor , max_norm / (total_grad_norm + 1e-6 ))
225175 clip_coef_clamped = torch .clamp (clip_coef , max = 1.0 )
226176 torch ._foreach_mul_ (all_grads , clip_coef_clamped )
227177 return total_grad_norm
228178
229179
180+ def _compute_total_norm (
181+ ddp_params : List [torch .Tensor ] = [],
182+ fsdp_params : Dict [Tuple [dist .ProcessGroup ], List [torch .Tensor ]] = (
183+ defaultdict (list )
184+ ),
185+ norm_type : float = 2.0 , # can be a normal float, or torch.inf
186+ max_grad_norm : float = 1.0 ,
187+ ) -> torch .Tensor :
188+ """
189+ Given both ddp params and sharded params, compute the total norm of the gradients of the full ddp params and the
190+ full fsdp param.
191+
192+ Args:
193+ ddp_params (List[torch.Tensor]): list of ddp params
194+ fsdp_params (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of tensors
195+ norm_type (Union[float, str]): type of the used p-norm. Can be ``'inf'`` for infinity norm.
196+ enable_global_grad_clip (bool): whether to compute total norm using all fsdp shards in the process group
197+ param_to_pgs (Dict[torch.nn.Parameter, List[dist.ProcessGroup]]): mapping of parameters to process groups.
198+ """
199+
200+ ## compute |W|^p corresponding to all DDP params W
201+ ddp_grad_norm : torch .Tensor = torch .tensor (0 )
202+ if ddp_params :
203+ ddp_params_grads = [
204+ p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
205+ for p in ddp_params
206+ if p .grad is not None and p .grad .numel () > 0
207+ ]
208+
209+ # _batch_cal_norm computes ||weight||_p^p
210+ ddp_grad_norm = _batch_cal_norm (
211+ ddp_params_grads ,
212+ max_grad_norm ,
213+ norm_type ,
214+ None ,
215+ )
216+
217+ ## compute the norm |W|^p corresponding to all sharded params W
218+ fsdp_grad_norm : torch .Tensor = torch .tensor (0.0 )
219+ if fsdp_params :
220+ for pgs , dist_params in fsdp_params .items ():
221+ sharded_grads = [
222+ p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
223+ for p in dist_params
224+ if p .grad is not None and p .grad .numel () > 0
225+ ]
226+
227+ # _batch_cal_norm computes ||shard||_p^p for each shard
228+ shard_norm = _batch_cal_norm (
229+ sharded_grads ,
230+ max_grad_norm ,
231+ norm_type ,
232+ pgs ,
233+ )
234+
235+ if norm_type == torch .inf :
236+ fsdp_grad_norm = torch .maximum (fsdp_grad_norm , shard_norm )
237+ else :
238+ fsdp_grad_norm += shard_norm
239+
240+ if norm_type == torch .inf :
241+ total_grad_norm = torch .maximum (ddp_grad_norm , fsdp_grad_norm )
242+ else :
243+ total_grad_norm = (ddp_grad_norm + fsdp_grad_norm ).pow (1.0 / norm_type )
244+
245+ return total_grad_norm
246+
247+
230248def _batch_cal_norm (
231249 grad_list : List [torch .Tensor ],
232250 max_norm : float ,
0 commit comments