Skip to content

Commit cd8e0aa

Browse files
wz337facebook-github-bot
authored andcommitted
Refactor _batch_cal_norm and remove #pyre-ignore
Summary: As title. Remove repeated `total_grad_norm` calculation and fix #pyre errors. Differential Revision: D78398248
1 parent 5a8a005 commit cd8e0aa

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

torchrec/optim/clipping.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -200,28 +200,23 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
200200
else:
201201
square_replicated_grad_norm = 0
202202

203+
if total_grad_norm is not None:
204+
total_grad_norm = (
205+
torch.pow(total_grad_norm, 1.0 / norm_type)
206+
if norm_type != torch.inf
207+
else total_grad_norm
208+
)
209+
else:
210+
return None
211+
203212
global log_grad_norm
204213
if log_grad_norm:
205-
if total_grad_norm is not None and norm_type != torch.inf:
206-
# pyre-ignore[58]
207-
grad_norm = total_grad_norm ** (1.0 / norm_type)
208-
else:
209-
grad_norm = 0
210-
211214
rank = dist.get_rank()
212215
logger.info(
213-
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
216+
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {total_grad_norm}"
214217
)
215218

216-
# Aggregation
217-
if total_grad_norm is None:
218-
return
219-
220-
if norm_type != torch.inf:
221-
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222-
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
223-
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224-
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
219+
clip_coef = torch.tensor(max_norm) / (total_grad_norm + 1e-6)
225220
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
226221
torch._foreach_mul_(all_grads, clip_coef_clamped)
227222
return total_grad_norm

0 commit comments

Comments
 (0)