Skip to content

Commit be05daf

Browse files
committed
Update to new version of commode-utils
1 parent 84a8d01 commit be05daf

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

code2seq/utils/train.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from os.path import basename
22

33
import torch
4-
from commode_utils.callback import UploadCheckpointCallback, PrintEpochResultCallback
4+
from commode_utils.callback import PrintEpochResultCallback, ModelCheckpointWithUpload
55
from omegaconf import DictConfig
66
from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule
7-
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
7+
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
88
from pytorch_lightning.loggers import WandbLogger
99

1010

@@ -18,15 +18,14 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
1818
wandb_logger = WandbLogger(project=f"{model_name} -- {dataset_name}", log_model=False, offline=config.log_offline)
1919

2020
# define model checkpoint callback
21-
checkpoint_callback = ModelCheckpoint(
21+
checkpoint_callback = ModelCheckpointWithUpload(
2222
dirpath=wandb_logger.experiment.dir,
2323
filename="{epoch:02d}-val_loss={val/loss:.4f}",
2424
monitor="val/loss",
2525
every_n_epochs=params.save_every_epoch,
2626
save_top_k=-1,
2727
auto_insert_metric_name=False,
2828
)
29-
upload_checkpoint_callback = UploadCheckpointCallback(wandb_logger.experiment.dir)
3029
# define early stopping callback
3130
early_stopping_callback = EarlyStopping(patience=params.patience, monitor="val/loss", verbose=True, mode="min")
3231
# define callback for printing intermediate result
@@ -48,7 +47,6 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
4847
lr_logger,
4948
early_stopping_callback,
5049
checkpoint_callback,
51-
upload_checkpoint_callback,
5250
print_epoch_result_callback,
5351
],
5452
resume_from_checkpoint=config.get("checkpoint", None),

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ black==21.7b0
22
mypy==0.910
33

44
torch==1.9.0
5-
pytorch-lightning==1.4.2
6-
torchmetrics==0.5.0
5+
pytorch-lightning==1.4.7
6+
torchmetrics==0.5.1
77

8-
tqdm==4.62.1
9-
wandb==0.12.0
8+
tqdm==4.62.2
9+
wandb==0.12.2
1010
omegaconf==2.1.1
11-
commode-utils==0.3.8
11+
commode-utils==0.3.9

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import setup, find_packages
22

3-
VERSION = "1.0.1"
3+
VERSION = "1.0.2"
44

55
with open("README.md") as readme_file:
66
readme = readme_file.read()

0 commit comments

Comments
 (0)