1
1
from os .path import basename
2
2
3
3
import torch
4
- from commode_utils .callback import UploadCheckpointCallback , PrintEpochResultCallback
4
+ from commode_utils .callback import PrintEpochResultCallback , ModelCheckpointWithUpload
5
5
from omegaconf import DictConfig
6
6
from pytorch_lightning import seed_everything , Trainer , LightningModule , LightningDataModule
7
- from pytorch_lightning .callbacks import ModelCheckpoint , EarlyStopping , LearningRateMonitor
7
+ from pytorch_lightning .callbacks import EarlyStopping , LearningRateMonitor
8
8
from pytorch_lightning .loggers import WandbLogger
9
9
10
10
@@ -18,15 +18,14 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
18
18
wandb_logger = WandbLogger (project = f"{ model_name } -- { dataset_name } " , log_model = False , offline = config .log_offline )
19
19
20
20
# define model checkpoint callback
21
- checkpoint_callback = ModelCheckpoint (
21
+ checkpoint_callback = ModelCheckpointWithUpload (
22
22
dirpath = wandb_logger .experiment .dir ,
23
23
filename = "{epoch:02d}-val_loss={val/loss:.4f}" ,
24
24
monitor = "val/loss" ,
25
25
every_n_epochs = params .save_every_epoch ,
26
26
save_top_k = - 1 ,
27
27
auto_insert_metric_name = False ,
28
28
)
29
- upload_checkpoint_callback = UploadCheckpointCallback (wandb_logger .experiment .dir )
30
29
# define early stopping callback
31
30
early_stopping_callback = EarlyStopping (patience = params .patience , monitor = "val/loss" , verbose = True , mode = "min" )
32
31
# define callback for printing intermediate result
@@ -48,7 +47,6 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
48
47
lr_logger ,
49
48
early_stopping_callback ,
50
49
checkpoint_callback ,
51
- upload_checkpoint_callback ,
52
50
print_epoch_result_callback ,
53
51
],
54
52
resume_from_checkpoint = config .get ("checkpoint" , None ),
0 commit comments