-
Notifications
You must be signed in to change notification settings - Fork 46
Jk/log grad norms/log grad norms #1068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
a0039ec
e83903b
d2995b4
26c6869
66da0d7
9a66f72
22a6fd7
87e7d3b
754d31c
cd7948f
7c756a3
41716a6
8bdbac4
da92f8f
a072c35
6d477be
b30b69a
c8fadf6
8bd7383
37a428d
9892dfa
2885062
cbb1c85
2bf714b
75749df
8b25312
f1c24fa
f1ff748
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,6 +88,8 @@ def init(self, cf: Config, devices): | |
| # world_size gets overwritten by current setting during init_ddp() | ||
| self.world_size_original = cf.get("world_size", None) | ||
|
|
||
| self.log_grad_norms = cf.get("log_grad_norms", False) | ||
|
|
||
| # create output directory | ||
| if is_root(): | ||
| config.get_path_run(cf).mkdir(exist_ok=True, parents=True) | ||
|
|
@@ -539,6 +541,7 @@ def train(self, epoch): | |
|
|
||
| # Unweighted loss, real weighted loss, std for losses that need it | ||
| self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] | ||
| self.last_grad_norm = 0.0 | ||
Jubeku marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # training loop | ||
| self.t_start = time.time() | ||
|
|
@@ -570,7 +573,13 @@ def train(self, epoch): | |
|
|
||
| # gradient clipping | ||
| self.grad_scaler.unscale_(self.optimizer) | ||
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=cf.grad_clip) | ||
| total_norm = torch.nn.utils.clip_grad_norm_( | ||
| self.model.parameters(), max_norm=cf.grad_clip | ||
| ) | ||
|
|
||
| # log gradient norms | ||
| if bidx % log_interval == 0 and self.log_grad_norms: | ||
| self._log_instant_grad_norms(TRAIN, total_norm) | ||
|
|
||
| # optimizer step | ||
| self.grad_scaler.step(self.optimizer) | ||
|
|
@@ -942,6 +951,32 @@ def _log(self, stage: Stage): | |
|
|
||
| self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] | ||
|
|
||
| def _log_instant_grad_norms(self, stage: Stage, total_norm): | ||
| """ | ||
| Log instantaneous grad norms, we do not average because of the cost and because we want to | ||
| measure the actual values. | ||
|
|
||
| Note: When using FSDP2, we need full_tensor().item() instead of .item(), see here: | ||
| https://gist.github.com/Kai-46/a9835ef3f36e76d06afee6c11f388144 | ||
| """ | ||
| self.last_grad_norm = ( | ||
| total_norm.full_tensor().item() if self.cf.world_size > 1 else total_norm.item() | ||
| ) | ||
|
||
| grad_norms = {"total_grad_norm": self.last_grad_norm} | ||
| for name, param in self.model.named_parameters(): | ||
| if param.grad is not None: | ||
| # grad_norms["grad_norm_" + name] = param.grad.norm().item() | ||
| grad_norms["grad_norm_" + name] = ( | ||
| param.grad.norm().full_tensor().item() | ||
Jubeku marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.cf.world_size > 1 | ||
| else param.grad.norm().item() | ||
|
||
| ) | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, we also need Tested it by printing both versions on 2 GPUs: |
||
| # print(".item():", param.grad.norm().item()) | ||
| # print(".full_tensor().item()", param.grad.norm().full_tensor().item()) | ||
| if is_root(): | ||
| self.train_logger.log_metrics(TRAIN, grad_norms) | ||
|
|
||
| def _log_terminal(self, bidx: int, epoch: int, stage: Stage): | ||
| print_freq = self.train_log_freq.terminal | ||
| if bidx % print_freq == 0 and bidx > 0 or stage == VAL: | ||
|
|
@@ -964,7 +999,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): | |
| # samples per sec | ||
| dt = time.time() - self.t_start | ||
| pstr = "{:03d} : {:05d}/{:05d} : {:06d} : loss = {:.4E} " | ||
| pstr += "(lr={:.2E}, s/sec={:.3f})" | ||
| pstr += "(lr={:.2E}, gradient norm={:.3f}, s/sec={:.3f})" | ||
Jubeku marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| len_dataset = len(self.data_loader) // self.cf.batch_size_per_gpu | ||
| logger.info( | ||
| pstr.format( | ||
|
|
@@ -974,6 +1009,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): | |
| self.cf.istep, | ||
| avg_loss.nanmean().item(), | ||
| self.lr_scheduler.get_lr(), | ||
| self.last_grad_norm, | ||
| (print_freq * self.cf.batch_size_per_gpu) / dt, | ||
| ), | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.