-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
114 lines (87 loc) · 4.13 KB
/
train.py
File metadata and controls
114 lines (87 loc) · 4.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torchmetrics
from mlcpl.loss import *
from torch.utils.data import DataLoader
import torchvision
import torch
from mlcpl.helper import *
import os
from pathlib import Path
import torchmetrics
from models import Model
import config
import time
from train_eval_fn import *
from mlcpl.sample_mix import *
from mlcpl.curriculum_labeling import CurriculumLabeling
def main():
device = config.device
output_dir = 'output/train'
train_dataset = config.train_dataset
valid_dataset = config.valid_dataset
train_dataset.transform = torchvision.transforms.Compose([
*config.data_aug,
torchvision.transforms.Resize(config.image_size),
torchvision.transforms.ToTensor(),
])
valid_dataset.transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(config.image_size),
torchvision.transforms.ToTensor(),
])
train_dataset_cl = CurriculumLabeling(train_dataset, transform_for_update=torchvision.transforms.Compose([
torchvision.transforms.Resize(config.image_size),
torchvision.transforms.ToTensor(),
]))
train_dataset_mix = LogicMix(train_dataset_cl, probability=config.probability, mix_num_samples=config.num_samples)
num_categories = train_dataset.num_categories
model = Model(num_categories)
loss_fn_original = PartialAsymmetricLoss(gamma_neg=config.gamma_n, gamma_pos=config.gamma_p, clip=config.m, reduction=None)
loss_fn_aug = PartialAsymmetricLoss(gamma_neg=config.omega_n, gamma_pos=config.omega_p, clip=config.n, reduction=None)
validation_metrics = {
'mAP@C': torchmetrics.classification.MultilabelAveragePrecision(num_categories, average='macro', validate_args=False),
}
monitor_validation_metric_name = 'mAP@C'
train_dataloader = DataLoader(train_dataset_mix, batch_size=config.batch_size, num_workers=0, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=0, shuffle=False)
model = model.to(device)
parameters = add_weight_decay(model, weight_decay=config.weight_decay)
ema = ModelEma(model, config.ema)
optimizer = torch.optim.Adam(parameters, lr=config.lr, weight_decay=0)
steps_per_epoch = len(train_dataloader)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, config.lr, steps_per_epoch=steps_per_epoch, epochs=config.epochs, pct_start=0.2)
log_dir = os.path.join(output_dir, 'log')
Path(output_dir).mkdir(parents=True, exist_ok=True)
Path(log_dir).mkdir(parents=True, exist_ok=True)
logger = MultiLogger(log_dir)
best_score = 0
best_at_epoch = -1
for epoch in range(config.epochs):
print(f'Epoch: {epoch}/{config.epochs}')
epoch_start_time = time.time()
# Train Loop
losses = train(model, train_dataloader, loss_fn_original, loss_fn_aug, optimizer, scheduler, model_ema=ema, grad_accum_steps=config.accum_step, device=device)
logger.add('train_loss', torch.mean(losses).detach().numpy())
# Valid Loop
preds, targets = evaluate(ema.module, valid_dataloader, device=device)
# Calculate metrics and logging
for name, metric in validation_metrics.items():
result = metric(preds, targets).detach().numpy()
logger.add('valid_'+name, result)
if name == monitor_validation_metric_name:
current_score = result
if current_score > best_score:
best_score = current_score
best_at_epoch = epoch
print(f'New best {monitor_validation_metric_name}: {best_score:.4f}')
torch.save(ema.module.state_dict(), os.path.join(output_dir, 'best.pth'))
if epoch >= (config.E-1): # -1 becasue epoch starts from 0
train_dataset_cl.update(ema.module, thresholds=config.thresholds)
epoch_end_time = time.time()
print(f'Epoch end. Total time: {(epoch_end_time-epoch_start_time):.2f} sec')
print()
if config.early_stopping is not None:
if epoch - best_at_epoch >= config.early_stopping:
print('Early stopping.')
break
logger.flush()
if __name__=='__main__':
main()