@@ -59,7 +59,7 @@ def __init__(
5959 super ().__init__ (optimizer )
6060 self ._clipping = clipping
6161 self ._max_gradient = max_gradient
62- self ._norm_type = norm_type
62+ self ._norm_type = float ( norm_type )
6363 self ._check_meta : bool = True
6464 self ._enable_global_grad_clip = enable_global_grad_clip
6565 self ._step_num = 0
@@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None:
124124 torch .nn .utils .clip_grad_norm_ (
125125 replicate_params ,
126126 self ._max_gradient ,
127- norm_type = float ( self ._norm_type ) ,
127+ norm_type = self ._norm_type ,
128128 )
129129 else :
130130 self .clip_grad_norm_ ()
@@ -135,98 +135,101 @@ def step(self, closure: Any = None) -> None:
135135 super ().step (closure )
136136 self ._step_num += 1
137137
138- @torch .no_grad ()
139138 def clip_grad_norm_ (self ) -> Optional [Union [float , torch .Tensor ]]:
140139 """Clip the gradient norm of all parameters."""
141- max_norm = self . _max_gradient
142- norm_type = float ( self ._norm_type )
140+
141+ # converts self._norm_type to a float if it's a string. Used in the case where self._norm_type is 'inf'.
143142 all_grads = []
144143 total_grad_norm = None
145144
146- # Process distributed parameters and gradients
147- for pgs , dist_params in self ._sharded_params .items ():
148- sharded_grads = [
149- p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
150- for p in dist_params
151- if p .grad is not None and p .grad .numel () > 0
152- ]
153- if len (sharded_grads ) == 0 :
154- continue
155- all_grads .extend (sharded_grads )
156-
157- sharded_grad_norm = _batch_cal_norm (
158- sharded_grads ,
159- max_norm ,
160- norm_type ,
161- pgs ,
162- )
163- total_grad_norm = (
164- sharded_grad_norm
165- if total_grad_norm is None
166- else (
167- torch .maximum (total_grad_norm , sharded_grad_norm )
168- if norm_type == torch .inf
169- else total_grad_norm + sharded_grad_norm
170- )
171- )
145+ sharded_params = self ._sharded_params
146+ replicate_params = self ._replicate_params
172147
173- square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
148+ # Process distributed parameters and gradients
149+ sharded_grads = {
150+ pgs : _get_grads (dist_params ) for pgs , dist_params in sharded_params .items ()
151+ }
152+ all_grads .extend (* sharded_grads .values ())
174153
175154 # Process replicated parameters and gradients
176- if self ._replicate_params :
177- replicated_grads = [
178- p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
179- for p in self ._replicate_params
180- if p .grad is not None and p .grad .numel () > 0
181- ]
182- all_grads .extend (replicated_grads )
183-
184- replicated_grad_norm = _batch_cal_norm (
185- replicated_grads ,
186- max_norm ,
187- norm_type ,
188- None ,
189- )
190- total_grad_norm = (
191- replicated_grad_norm
192- if total_grad_norm is None
193- else (
194- torch .maximum (total_grad_norm , replicated_grad_norm )
195- if norm_type == torch .inf
196- else total_grad_norm + replicated_grad_norm
197- )
198- )
199- square_replicated_grad_norm = replicated_grad_norm
200- else :
201- square_replicated_grad_norm = 0
202-
203- global log_grad_norm
204- if log_grad_norm :
205- if total_grad_norm is not None and norm_type != torch .inf :
206- # pyre-ignore[58]
207- grad_norm = total_grad_norm ** (1.0 / norm_type )
208- else :
209- grad_norm = total_grad_norm
210-
211- rank = dist .get_rank ()
212- logger .info (
213- f"Clipping [rank={ rank } , step={ self ._step_num } ]: square_sharded_grad_norm = { square_sharded_grad_norm } , square_replicated_grad_norm = { square_replicated_grad_norm } , total_grad_norm = { grad_norm } "
214- )
215-
216- # Aggregation
217- if total_grad_norm is None :
218- return
155+ replicate_grads = _get_grads (replicate_params )
156+ all_grads .extend (replicate_grads )
157+
158+ total_grad_norm = _compute_total_norm (
159+ replicate_grads = replicate_grads ,
160+ sharded_grads = sharded_grads ,
161+ norm_type = self ._norm_type ,
162+ max_grad_norm = self ._max_gradient ,
163+ )
219164
220- if norm_type != torch .inf :
221- # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222- total_grad_norm = total_grad_norm ** (1.0 / norm_type )
223165 # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224- clip_coef = cast (torch .Tensor , max_norm / (total_grad_norm + 1e-6 ))
166+ clip_coef = cast (torch .Tensor , self . _max_gradient / (total_grad_norm + 1e-6 ))
225167 clip_coef_clamped = torch .clamp (clip_coef , max = 1.0 )
226168 torch ._foreach_mul_ (all_grads , clip_coef_clamped )
227169 return total_grad_norm
228170
229171
172+ def _get_grads (
173+ param_list : List [torch .Tensor ],
174+ ) -> List [torch .Tensor ]:
175+ """Get the gradients of a list of parameters. Converts DTensors to local tensors if needed."""
176+ grads = [
177+ p .grad ._local_tensor if isinstance (p .grad , DTensor ) else p .grad
178+ for p in param_list
179+ if p .grad is not None and p .grad .numel () > 0
180+ ]
181+ return grads
182+
183+
184+ def _compute_total_norm (
185+ replicate_grads : List [torch .Tensor ],
186+ sharded_grads : Dict [Tuple [dist .ProcessGroup ], List [torch .Tensor ]],
187+ norm_type : float = 2.0 , # can be a normal float, or torch.inf
188+ max_grad_norm : float = 1.0 ,
189+ ) -> torch .Tensor :
190+ """
191+ Given both replicate grads and sharded grads, compute the total norm of the gradients of the full replicate params and the
192+ full sharded param (parameters with a process group).
193+
194+ Args:
195+ replicate_grads (List[torch.Tensor]): list of gradients for replicate params
196+ sharded_grads (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of gradients for sharded params
197+ norm_type (float): type of the used p-norm. Can be torch.inf for infinity norm.
198+ max_grad_norm (float): max gradient norm.
199+ """
200+
201+ ## compute the norm |W|^p corresponding to all sharded params W
202+ sharded_grad_norm : torch .Tensor = torch .tensor (0.0 )
203+ combine_norm_operator = torch .maximum if norm_type == torch .inf else torch .add
204+
205+ # We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max)
206+ # this is specifically for the case where sharded_grad_norm is 0, and replicate_grad_norm is not,
207+ # because by default torch.tensor(0.0) is on cpu, and replicate_grad_norm is on GPU. For MTIA
208+ # specifically, adding a tensor on cpu and a tensor on GPU will result in an error.
209+ for pgs , dist_params in sharded_grads .items ():
210+ current_shard_norm = _batch_cal_norm (dist_params , max_grad_norm , norm_type , pgs )
211+ sharded_grad_norm = combine_norm_operator (
212+ sharded_grad_norm .to (current_shard_norm .device ), current_shard_norm
213+ )
214+ # compute |W|^p corresponding to all replicate params W
215+ # Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition.
216+ replicate_grad_norm : torch .Tensor = (
217+ _batch_cal_norm (replicate_grads , max_grad_norm , norm_type )
218+ if replicate_grads
219+ else torch .tensor (0.0 )
220+ ).to (sharded_grad_norm .device )
221+
222+ # In the p-norm case, we are given norms |W_sharded|^p and |W_replicate|^p. To compute the total norm, we need to
223+ # sum them and take the p-th root. In the inf-norm case, we are given max(|W_sharded|) and max(|W_replicate|).
224+ # To compute the total norm, we need to take max(max(|W_sharded|), max(|W_replicate|).
225+ combined_norm = combine_norm_operator (replicate_grad_norm , sharded_grad_norm )
226+ total_grad_norm = (
227+ combined_norm .pow (1.0 / norm_type ) if norm_type != torch .inf else combined_norm
228+ )
229+
230+ return total_grad_norm
231+
232+
230233def _batch_cal_norm (
231234 grad_list : List [torch .Tensor ],
232235 max_norm : float ,
@@ -236,7 +239,6 @@ def _batch_cal_norm(
236239 """Helper function that calculates the norm of a list of gradients in batches. If process_groups
237240 are passed in, the norm will be aggregated across all ranks in the process group.
238241 """
239-
240242 global use_64bit_grad_norm
241243 if use_64bit_grad_norm :
242244 grad_norms = torch .linalg .vector_norm (
0 commit comments