Skip to content

Commit 29f3764

Browse files
Gavin Zhangfacebook-github-bot
authored andcommitted
refactor the total norm computation in grad clipping in APS (#3243)
Summary: Pull Request resolved: #3243 Refactored the previous code for applying gradient clipping across ddp and fsdp parameter. Added a new funciton _compute_total_norm() that takes in the replicated and sharded params provided in the gradientclippingOpitmizer class and computes the total gradient norm of the given parameter. Differential Revision: D79128843
1 parent 2e874f7 commit 29f3764

File tree

1 file changed

+94
-72
lines changed

1 file changed

+94
-72
lines changed

torchrec/optim/clipping.py

Lines changed: 94 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -135,98 +135,120 @@ def step(self, closure: Any = None) -> None:
135135
super().step(closure)
136136
self._step_num += 1
137137

138-
@torch.no_grad()
139138
def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
140139
"""Clip the gradient norm of all parameters."""
141140
max_norm = self._max_gradient
142141
norm_type = float(self._norm_type)
143142
all_grads = []
144143
total_grad_norm = None
145144

145+
sharded_params = self._sharded_params
146+
replicate_params = self._replicate_params
147+
146148
# Process distributed parameters and gradients
147-
for pgs, dist_params in self._sharded_params.items():
148-
sharded_grads = [
149-
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
150-
for p in dist_params
151-
if p.grad is not None and p.grad.numel() > 0
152-
]
153-
if len(sharded_grads) == 0:
154-
continue
149+
for dist_params in sharded_params.values():
150+
sharded_grads = _get_grads(dist_params)
155151
all_grads.extend(sharded_grads)
156152

157-
sharded_grad_norm = _batch_cal_norm(
158-
sharded_grads,
159-
max_norm,
160-
norm_type,
161-
pgs,
162-
)
163-
total_grad_norm = (
164-
sharded_grad_norm
165-
if total_grad_norm is None
166-
else (
167-
torch.maximum(total_grad_norm, sharded_grad_norm)
168-
if norm_type == torch.inf
169-
else total_grad_norm + sharded_grad_norm
170-
)
171-
)
172-
173-
square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
174-
175153
# Process replicated parameters and gradients
176-
if self._replicate_params:
177-
replicated_grads = [
178-
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
179-
for p in self._replicate_params
180-
if p.grad is not None and p.grad.numel() > 0
181-
]
182-
all_grads.extend(replicated_grads)
183-
184-
replicated_grad_norm = _batch_cal_norm(
185-
replicated_grads,
186-
max_norm,
187-
norm_type,
188-
None,
189-
)
190-
total_grad_norm = (
191-
replicated_grad_norm
192-
if total_grad_norm is None
193-
else (
194-
torch.maximum(total_grad_norm, replicated_grad_norm)
195-
if norm_type == torch.inf
196-
else total_grad_norm + replicated_grad_norm
197-
)
198-
)
199-
square_replicated_grad_norm = replicated_grad_norm
200-
else:
201-
square_replicated_grad_norm = 0
202-
203-
global log_grad_norm
204-
if log_grad_norm:
205-
if total_grad_norm is not None and norm_type != torch.inf:
206-
# pyre-ignore[58]
207-
grad_norm = total_grad_norm ** (1.0 / norm_type)
208-
else:
209-
grad_norm = total_grad_norm
154+
if replicate_params:
155+
replicate_grads = _get_grads(replicate_params)
156+
all_grads.extend(replicate_grads)
210157

211-
rank = dist.get_rank()
212-
logger.info(
213-
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
214-
)
215-
216-
# Aggregation
217-
if total_grad_norm is None:
218-
return
158+
total_grad_norm = _compute_total_norm(
159+
replicate_params, sharded_params, norm_type, max_norm
160+
)
219161

220-
if norm_type != torch.inf:
221-
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222-
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
223162
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224163
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
225164
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
226165
torch._foreach_mul_(all_grads, clip_coef_clamped)
227166
return total_grad_norm
228167

229168

169+
def _get_grads(
170+
param_list: List[torch.Tensor],
171+
) -> List[torch.Tensor]:
172+
"""Get the gradients of a list of parameters. Converts DTensors to local tensors if needed."""
173+
grads = [
174+
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
175+
for p in param_list
176+
if p.grad is not None and p.grad.numel() > 0
177+
]
178+
return grads
179+
180+
181+
def _compute_total_norm(
182+
replicate_params: Optional[List[torch.Tensor]] = None,
183+
sharded_params: Optional[Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]] = None,
184+
norm_type: float = 2.0, # can be a normal float, or torch.inf
185+
max_grad_norm: float = 1.0,
186+
) -> torch.Tensor:
187+
"""
188+
Given both recpliate params and sharded params, compute the total norm of the gradients of the full replicate params and the
189+
full sharded param (parameters with a process group).
190+
191+
Args:
192+
replicate_params (List[torch.Tensor]): list of replicate params
193+
sharded_params (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of tensors
194+
norm_type (Union[float, str]): type of the used p-norm. Can be ``'inf'`` for infinity norm.
195+
max_grad_norm (float): max gradient norm.
196+
"""
197+
198+
## compute |W|^p corresponding to all DDP params W
199+
200+
if replicate_params is None:
201+
replicate_params = []
202+
if sharded_params is None:
203+
sharded_params = defaultdict(list)
204+
205+
def get_grad_norm_power(
206+
param_list: List[torch.Tensor],
207+
norm_type: float,
208+
max_grad_norm: float,
209+
pgs: Optional[Tuple[dist.ProcessGroup]] = None,
210+
) -> torch.Tensor:
211+
"""
212+
Given a list of parameters, convert them to local tensors if they are DTensors,
213+
and compute the squared (or p-th power) norm of the gradients of the parameters.
214+
"""
215+
grad_list = _get_grads(param_list)
216+
return _batch_cal_norm(grad_list, max_grad_norm, norm_type, pgs)
217+
218+
## compute the norm |W|^p corresponding to all sharded params W
219+
sharded_grad_norm: torch.Tensor = torch.tensor(0.0)
220+
if sharded_params:
221+
combine_sharded_norm_operator = (
222+
torch.maximum if norm_type == torch.inf else torch.add
223+
)
224+
225+
# We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max)
226+
# this is specifically for the case where sharded_grad_norm is 0, and replicate_grad_norm is not,
227+
# because by default torch.tensor(0.0) is on cpu, and replicate_grad_norm is on GPU. For MTIA
228+
# specifically, adding a tensor on cpu and a tensor on GPU will result in an error.
229+
for pgs, dist_params in sharded_params.items():
230+
shard_norm = get_grad_norm_power(dist_params, norm_type, max_grad_norm, pgs)
231+
sharded_grad_norm = combine_sharded_norm_operator(
232+
sharded_grad_norm.to(shard_norm.device), shard_norm
233+
)
234+
235+
# Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition.
236+
replicate_grad_norm: torch.Tensor = (
237+
get_grad_norm_power(replicate_params, norm_type, max_grad_norm)
238+
if replicate_params
239+
else torch.tensor(0.0)
240+
).to(sharded_grad_norm.device)
241+
242+
combine_norm_operator = (
243+
torch.maximum
244+
if norm_type == torch.inf
245+
else lambda a, b: torch.add(a, b).pow(1.0 / norm_type)
246+
)
247+
248+
total_grad_norm = combine_norm_operator(replicate_grad_norm, sharded_grad_norm)
249+
return total_grad_norm
250+
251+
230252
def _batch_cal_norm(
231253
grad_list: List[torch.Tensor],
232254
max_norm: float,

0 commit comments

Comments
 (0)