Skip to content

Commit 0b81db3

Browse files
Jubekusophie-xhonneux
authored andcommitted
Jk/log grad norms/log grad norms (#1068)
* Log gradient norms * Prototype for recording grad norms * Address review changes + hide behind feature flag * Final fixes including backward compatibility * Ruff * More ruff stuff * forecast config with small decoder * fixed uv.lock * test gradient logging on mutli gpus * update uv.lock to latest develop version * revert to default confit * add comment on FSDP2 specifics * move plot grad script to private repo * rm seaborn from pyproject * updating terminal and metrics loggin, add get_tensor_item fct * check for DTensor instead of world size * revert forecast fct, fix in separate PR * rename grad_norm log names to exclude from MLFlow * add log_grad_norms to default config --------- Co-authored-by: sophiex <[email protected]>
1 parent ecaf28e commit 0b81db3

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

config/default_config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ grad_clip: 1.0
133133
weight_decay: 0.1
134134
norm_type: "LayerNorm"
135135
nn_module: "te"
136+
log_grad_norms: False
136137

137138
start_date: 197901010000
138139
end_date: 202012310000

src/weathergen/train/trainer.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def init(self, cf: Config, devices):
9090
# world_size gets overwritten by current setting during init_ddp()
9191
self.world_size_original = cf.get("world_size", None)
9292

93+
self.log_grad_norms = cf.get("log_grad_norms", False)
94+
9395
# create output directory
9496
if is_root():
9597
config.get_path_run(cf).mkdir(exist_ok=True, parents=True)
@@ -604,7 +606,16 @@ def train(self, epoch):
604606

605607
# gradient clipping
606608
self.grad_scaler.unscale_(self.optimizer)
607-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=cf.grad_clip)
609+
total_norm = torch.nn.utils.clip_grad_norm_(
610+
self.model.parameters(), max_norm=cf.grad_clip
611+
)
612+
613+
# log gradient norms
614+
if self.log_grad_norms:
615+
if bidx % self.train_log_freq.terminal == 0:
616+
self.last_grad_norm = self._get_tensor_item(total_norm)
617+
if bidx % self.train_log_freq.metrics == 0:
618+
self._log_instant_grad_norms(TRAIN)
608619

609620
# optimizer step
610621
self.grad_scaler.step(self.optimizer)
@@ -990,6 +1001,26 @@ def _log(self, stage: Stage):
9901001

9911002
self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], []
9921003

1004+
def _get_tensor_item(self, tensor):
1005+
"""
1006+
When using FSDP2, tensor is a DTensor and we need full_tensor().item() instead of .item(),
1007+
see here: https://gist.github.com/Kai-46/a9835ef3f36e76d06afee6c11f388144
1008+
"""
1009+
return tensor.full_tensor().item() if isinstance(tensor, DTensor) else tensor.item()
1010+
1011+
def _log_instant_grad_norms(self, stage: Stage):
1012+
"""
1013+
Log instantaneous grad norms, we do not average because of the cost and because we want to
1014+
measure the actual values.
1015+
"""
1016+
grad_norms = {"grad_norm.total": self.last_grad_norm}
1017+
for name, param in self.model.named_parameters():
1018+
if param.grad is not None:
1019+
grad_norms["grad_norm." + name] = self._get_tensor_item(param.grad.norm())
1020+
1021+
if is_root():
1022+
self.train_logger.log_metrics(stage, grad_norms)
1023+
9931024
def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
9941025
print_freq = self.train_log_freq.terminal
9951026
if bidx % print_freq == 0 and bidx > 0 or stage == VAL:
@@ -1011,20 +1042,16 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
10111042
elif stage == TRAIN:
10121043
# samples per sec
10131044
dt = time.time() - self.t_start
1014-
pstr = "{:03d} : {:05d}/{:05d} : {:06d} : loss = {:.4E} "
1015-
pstr += "(lr={:.2E}, s/sec={:.3f})"
10161045
len_dataset = len(self.data_loader) // self.cf.batch_size_per_gpu
1017-
logger.info(
1018-
pstr.format(
1019-
epoch,
1020-
bidx,
1021-
len_dataset,
1022-
self.cf.istep,
1023-
avg_loss.nanmean().item(),
1024-
self.lr_scheduler.get_lr(),
1025-
(print_freq * self.cf.batch_size_per_gpu) / dt,
1026-
),
1046+
pstr = (
1047+
f"{epoch:03d} : {bidx:05d}/{len_dataset:05d} : "
1048+
+ f"{self.cf.istep:06d} : loss = {avg_loss.nanmean().item():.4E} "
1049+
+ f"(lr={self.lr_scheduler.get_lr():.2E}, "
10271050
)
1051+
if self.log_grad_norms:
1052+
pstr += f"gradient norm={self.last_grad_norm:.3f}, "
1053+
pstr += f"s/sec={(print_freq * self.cf.batch_size_per_gpu) / dt:.3f})"
1054+
logger.info(pstr)
10281055
logger.info("\t")
10291056
for _, st in enumerate(self.cf.streams):
10301057
logger.info(

0 commit comments

Comments
 (0)