diff --git a/config/default_config.yml b/config/default_config.yml index 679f58dd3..620f5c4ae 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -133,6 +133,7 @@ grad_clip: 1.0 weight_decay: 0.1 norm_type: "LayerNorm" nn_module: "te" +log_grad_norms: False start_date: 197901010000 end_date: 202012310000 diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 55d30c7f4..c0a62a21e 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -90,6 +90,8 @@ def init(self, cf: Config, devices): # world_size gets overwritten by current setting during init_ddp() self.world_size_original = cf.get("world_size", None) + self.log_grad_norms = cf.get("log_grad_norms", False) + # create output directory if is_root(): config.get_path_run(cf).mkdir(exist_ok=True, parents=True) @@ -604,7 +606,16 @@ def train(self, epoch): # gradient clipping self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=cf.grad_clip) + total_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=cf.grad_clip + ) + + # log gradient norms + if self.log_grad_norms: + if bidx % self.train_log_freq.terminal == 0: + self.last_grad_norm = self._get_tensor_item(total_norm) + if bidx % self.train_log_freq.metrics == 0: + self._log_instant_grad_norms(TRAIN) # optimizer step self.grad_scaler.step(self.optimizer) @@ -990,6 +1001,26 @@ def _log(self, stage: Stage): self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] + def _get_tensor_item(self, tensor): + """ + When using FSDP2, tensor is a DTensor and we need full_tensor().item() instead of .item(), + see here: https://gist.github.com/Kai-46/a9835ef3f36e76d06afee6c11f388144 + """ + return tensor.full_tensor().item() if isinstance(tensor, DTensor) else tensor.item() + + def _log_instant_grad_norms(self, stage: Stage): + """ + Log instantaneous grad norms, we do not average because of the cost and because we want to + measure the actual values. + """ + grad_norms = {"grad_norm.total": self.last_grad_norm} + for name, param in self.model.named_parameters(): + if param.grad is not None: + grad_norms["grad_norm." + name] = self._get_tensor_item(param.grad.norm()) + + if is_root(): + self.train_logger.log_metrics(stage, grad_norms) + def _log_terminal(self, bidx: int, epoch: int, stage: Stage): print_freq = self.train_log_freq.terminal if bidx % print_freq == 0 and bidx > 0 or stage == VAL: @@ -1011,20 +1042,16 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): elif stage == TRAIN: # samples per sec dt = time.time() - self.t_start - pstr = "{:03d} : {:05d}/{:05d} : {:06d} : loss = {:.4E} " - pstr += "(lr={:.2E}, s/sec={:.3f})" len_dataset = len(self.data_loader) // self.cf.batch_size_per_gpu - logger.info( - pstr.format( - epoch, - bidx, - len_dataset, - self.cf.istep, - avg_loss.nanmean().item(), - self.lr_scheduler.get_lr(), - (print_freq * self.cf.batch_size_per_gpu) / dt, - ), + pstr = ( + f"{epoch:03d} : {bidx:05d}/{len_dataset:05d} : " + + f"{self.cf.istep:06d} : loss = {avg_loss.nanmean().item():.4E} " + + f"(lr={self.lr_scheduler.get_lr():.2E}, " ) + if self.log_grad_norms: + pstr += f"gradient norm={self.last_grad_norm:.3f}, " + pstr += f"s/sec={(print_freq * self.cf.batch_size_per_gpu) / dt:.3f})" + logger.info(pstr) logger.info("\t") for _, st in enumerate(self.cf.streams): logger.info(