-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathcheckpointing.py
More file actions
27 lines (22 loc) · 867 Bytes
/
checkpointing.py
File metadata and controls
27 lines (22 loc) · 867 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import pytorch_lightning as pl
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint
class PeriodicCheckpoint(ModelCheckpoint):
def __init__(self, every: int, dirpath: str):
super().__init__()
self.every = every
self.dirpath = dirpath
if not os.path.exists(dirpath):
os.mkdir(dirpath)
def on_before_zero_grad(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs
):
if pl_module.global_step % self.every == 0:
assert self.dirpath is not None
current = Path(self.dirpath) / f"latest-{pl_module.global_step}.ckpt"
prev = (
Path(self.dirpath) / f"latest-{pl_module.global_step - self.every}.ckpt"
)
print(current)
trainer.save_checkpoint(current)