Skip to content
Merged
Changes from 19 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a0039ec
Log gradient norms
sophie-xhonneux Aug 6, 2025
e83903b
Prototype for recording grad norms
sophie-xhonneux Aug 6, 2025
d2995b4
Address review changes + hide behind feature flag
sophie-xhonneux Aug 7, 2025
26c6869
Final fixes including backward compatibility
sophie-xhonneux Aug 7, 2025
66da0d7
Merge branch 'develop' into sophiex/dev/log-grad-norms
sophie-xhonneux Aug 7, 2025
9a66f72
Ruff
sophie-xhonneux Aug 7, 2025
22a6fd7
More ruff stuff
sophie-xhonneux Aug 7, 2025
87e7d3b
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 7, 2025
754d31c
forecast config with small decoder
Jubeku Oct 8, 2025
cd7948f
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 9, 2025
7c756a3
fixed uv.lock
Jubeku Oct 9, 2025
41716a6
test gradient logging on mutli gpus
Jubeku Oct 9, 2025
8bdbac4
update uv.lock to latest develop version
Jubeku Oct 13, 2025
da92f8f
revert to default confit
Jubeku Oct 13, 2025
a072c35
add comment on FSDP2 specifics
Jubeku Oct 13, 2025
6d477be
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 13, 2025
b30b69a
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 16, 2025
c8fadf6
move plot grad script to private repo
Jubeku Oct 16, 2025
8bd7383
rm seaborn from pyproject
Jubeku Oct 16, 2025
37a428d
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 20, 2025
9892dfa
updating terminal and metrics loggin, add get_tensor_item fct
Jubeku Oct 21, 2025
2885062
check for DTensor instead of world size
Jubeku Oct 21, 2025
cbb1c85
revert forecast fct, fix in separate PR
Jubeku Oct 21, 2025
2bf714b
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 21, 2025
75749df
rename grad_norm log names to exclude from MLFlow
Jubeku Oct 23, 2025
8b25312
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 23, 2025
f1c24fa
add log_grad_norms to default config
Jubeku Oct 24, 2025
f1ff748
Merge branch 'develop' into jk/log-grad-norms/log-grad-norms
Jubeku Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def init(self, cf: Config, devices):
# world_size gets overwritten by current setting during init_ddp()
self.world_size_original = cf.get("world_size", None)

self.log_grad_norms = cf.get("log_grad_norms", False)

# create output directory
if is_root():
config.get_path_run(cf).mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -539,6 +541,7 @@ def train(self, epoch):

# Unweighted loss, real weighted loss, std for losses that need it
self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], []
self.last_grad_norm = 0.0

# training loop
self.t_start = time.time()
Expand Down Expand Up @@ -570,7 +573,13 @@ def train(self, epoch):

# gradient clipping
self.grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=cf.grad_clip)
total_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=cf.grad_clip
)

# log gradient norms
if bidx % log_interval == 0 and self.log_grad_norms:
self._log_instant_grad_norms(TRAIN, total_norm)

# optimizer step
self.grad_scaler.step(self.optimizer)
Expand Down Expand Up @@ -942,6 +951,32 @@ def _log(self, stage: Stage):

self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], []

def _log_instant_grad_norms(self, stage: Stage, total_norm):
"""
Log instantaneous grad norms, we do not average because of the cost and because we want to
measure the actual values.

Note: When using FSDP2, we need full_tensor().item() instead of .item(), see here:
https://gist.github.com/Kai-46/a9835ef3f36e76d06afee6c11f388144
"""
self.last_grad_norm = (
total_norm.full_tensor().item() if self.cf.world_size > 1 else total_norm.item()
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned here, full_tensor().item() is needed in parallel runs with FSDP2.

I tested this by logging both ways of calculating:

000 : 00010/02048 : 000010 : loss = 1.0287E+00 (lr=1.64E-06, gradient norm=0.983, gradient norm FT=1.403, s/sec=0.236)

ERA5 : 1.0287E+00 


000 : 00020/02048 : 000020 : loss = 1.0101E+00 (lr=3.34E-06, gradient norm=0.587, gradient norm FT=0.817, s/sec=0.435)

ERA5 : 1.0101E+00 

grad_norms = {"total_grad_norm": self.last_grad_norm}
for name, param in self.model.named_parameters():
if param.grad is not None:
# grad_norms["grad_norm_" + name] = param.grad.norm().item()
grad_norms["grad_norm_" + name] = (
param.grad.norm().full_tensor().item()
if self.cf.world_size > 1
else param.grad.norm().item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't you divide by the number of items in the gradient? otherwise, if every component in the gradient is equal, you are biased by batching computations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow. But as far as I know the gradient norm logging is correct and people do not commonly account for the number of dimensions, as for batching this is handle in the forward pass and thus is automatically dealt with during backprop.

)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, we also need .full_tensor().item() here in multi-gpu mode.

Tested it by printing both versions on 2 GPUs:

print(".item():", param.grad.norm().item())
print(".full_tensor().item()", param.grad.norm().full_tensor().item())

.item(): 0.028306283056735992
.item(): 0.022433193400502205
.full_tensor().item() 0.03611777722835541
.full_tensor().item() 0.03611777722835541

# print(".item():", param.grad.norm().item())
# print(".full_tensor().item()", param.grad.norm().full_tensor().item())
if is_root():
self.train_logger.log_metrics(TRAIN, grad_norms)

def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
print_freq = self.train_log_freq.terminal
if bidx % print_freq == 0 and bidx > 0 or stage == VAL:
Expand All @@ -964,7 +999,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
# samples per sec
dt = time.time() - self.t_start
pstr = "{:03d} : {:05d}/{:05d} : {:06d} : loss = {:.4E} "
pstr += "(lr={:.2E}, s/sec={:.3f})"
pstr += "(lr={:.2E}, gradient norm={:.3f}, s/sec={:.3f})"
len_dataset = len(self.data_loader) // self.cf.batch_size_per_gpu
logger.info(
pstr.format(
Expand All @@ -974,6 +1009,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage):
self.cf.istep,
avg_loss.nanmean().item(),
self.lr_scheduler.get_lr(),
self.last_grad_norm,
(print_freq * self.cf.batch_size_per_gpu) / dt,
),
)
Expand Down
Loading