-
Couldn't load subscription status.
- Fork 40
Jk/log grad norms/log grad norms #1068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/weathergen/train/trainer.py
Outdated
| """ | ||
| self.last_grad_norm = ( | ||
| total_norm.full_tensor().item() if self.cf.world_size > 1 else total_norm.item() | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned here, full_tensor().item() is needed in parallel runs with FSDP2.
I tested this by logging both ways of calculating:
000 : 00010/02048 : 000010 : loss = 1.0287E+00 (lr=1.64E-06, gradient norm=0.983, gradient norm FT=1.403, s/sec=0.236)
ERA5 : 1.0287E+00
000 : 00020/02048 : 000020 : loss = 1.0101E+00 (lr=3.34E-06, gradient norm=0.587, gradient norm FT=0.817, s/sec=0.435)
ERA5 : 1.0101E+00
| if self.cf.world_size > 1 | ||
| else param.grad.norm().item() | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, we also need .full_tensor().item() here in multi-gpu mode.
Tested it by printing both versions on 2 GPUs:
print(".item():", param.grad.norm().item())
print(".full_tensor().item()", param.grad.norm().full_tensor().item())
.item(): 0.028306283056735992
.item(): 0.022433193400502205
.full_tensor().item() 0.03611777722835541
.full_tensor().item() 0.03611777722835541
src/weathergen/train/trainer.py
Outdated
| grad_norms["grad_norm_" + name] = ( | ||
| param.grad.norm().full_tensor().item() | ||
| if self.cf.world_size > 1 | ||
| else param.grad.norm().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't you divide by the number of items in the gradient? otherwise, if every component in the gradient is equal, you are biased by batching computations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow. But as far as I know the gradient norm logging is correct and people do not commonly account for the number of dimensions, as for batching this is handle in the forward pass and thus is automatically dealt with during backprop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor things to fix in the comments, I trust the will happen and thus already approve the PR. If logging is off, there should be no effect on the runs.
src/weathergen/train/trainer.py
Outdated
| grad_norms["grad_norm_" + name] = ( | ||
| param.grad.norm().full_tensor().item() | ||
| if self.cf.world_size > 1 | ||
| else param.grad.norm().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow. But as far as I know the gradient norm logging is correct and people do not commonly account for the number of dimensions, as for batching this is handle in the forward pass and thus is automatically dealt with during backprop.
|
@sophie-xhonneux, @tjhunter, should we keep the |
|
@tjhunter, here is a sample of the metrics logs when gradient logging in on. Should we log them in a separate file? Happy to have a quick chat on that whenever you are available. |
|
All |
* Log gradient norms * Prototype for recording grad norms * Address review changes + hide behind feature flag * Final fixes including backward compatibility * Ruff * More ruff stuff * forecast config with small decoder * fixed uv.lock * test gradient logging on mutli gpus * update uv.lock to latest develop version * revert to default confit * add comment on FSDP2 specifics * move plot grad script to private repo * rm seaborn from pyproject * updating terminal and metrics loggin, add get_tensor_item fct * check for DTensor instead of world size * revert forecast fct, fix in separate PR * rename grad_norm log names to exclude from MLFlow * add log_grad_norms to default config --------- Co-authored-by: sophiex <[email protected]>
* Log gradient norms * Prototype for recording grad norms * Address review changes + hide behind feature flag * Final fixes including backward compatibility * Ruff * More ruff stuff * forecast config with small decoder * fixed uv.lock * test gradient logging on mutli gpus * update uv.lock to latest develop version * revert to default confit * add comment on FSDP2 specifics * move plot grad script to private repo * rm seaborn from pyproject * updating terminal and metrics loggin, add get_tensor_item fct * check for DTensor instead of world size * revert forecast fct, fix in separate PR * rename grad_norm log names to exclude from MLFlow * add log_grad_norms to default config --------- Co-authored-by: sophiex <[email protected]>
* Log gradient norms * Prototype for recording grad norms * Address review changes + hide behind feature flag * Final fixes including backward compatibility * Ruff * More ruff stuff * forecast config with small decoder * fixed uv.lock * test gradient logging on mutli gpus * update uv.lock to latest develop version * revert to default confit * add comment on FSDP2 specifics * move plot grad script to private repo * rm seaborn from pyproject * updating terminal and metrics loggin, add get_tensor_item fct * check for DTensor instead of world size * revert forecast fct, fix in separate PR * rename grad_norm log names to exclude from MLFlow * add log_grad_norms to default config --------- Co-authored-by: sophiex <[email protected]>
Description
This PR is based on @sophie-xhonneux's log_grad_norm branch in #685, modified to allow logging gradients when running in parallel on multiple GPUs with FSDP2.
Issue Number
Closes #688
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60