-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining_functions.py
More file actions
144 lines (122 loc) · 5.66 KB
/
training_functions.py
File metadata and controls
144 lines (122 loc) · 5.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import matplotlib.pyplot as plt
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tempfile import TemporaryDirectory
from graphics import plot_history, plot_prediction_grid, plot_confusion_matrix
def epoch_func(model, data_loader, criterion, optimizer = None):
"""
a model, a data loader, a criterion (loss function) and an optional optimizer
returns the average loss, the accuracy, and two lists y_pred and y_true
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if optimizer is not None:
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
#print(f"training using device {device}")
model.to(device)
epoch_loss = 0
epoch_correct = 0
current_batch_num = 0
epoch_start = time.time()
num_instances_seen = 0
y_pred = []
y_true = []
for inputs, labels in data_loader:
num_instances_seen += len(inputs)
current_batch_num += 1
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
if optimizer is not None:
optimizer.zero_grad()
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if optimizer is not None:
loss.backward()
optimizer.step()
for pred, label in zip(preds, labels):
y_pred.append(int(pred))
y_true.append(int(label))
epoch_loss += loss.item() * inputs.size(0)
epoch_correct += torch.sum(preds == labels.data)
time_elapsed = time.time() - epoch_start
print(f"\r ({time_elapsed:0.2f}s, {current_batch_num/len(data_loader)*100:0.2f}%) batch: {current_batch_num}/{len(data_loader)}, loss: {epoch_loss/num_instances_seen:0.3f}, accuracy: {epoch_correct/num_instances_seen:0.3f}",end='', flush=True)
print()
return (epoch_loss/num_instances_seen), (epoch_correct/num_instances_seen), y_pred, y_true
def train_model(model, train_dl, valid_dl, num_epochs=4):
"""
takes a model, two data loaders for training and validation, and a number of epochs to train for
returns the model with it's weights adjusted and a dictionary called history that contains information on the training
"""
CE_criterion = nn.CrossEntropyLoss()
SGD_optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
lr_decay_scheduler = lr_scheduler.StepLR(SGD_optimizer, step_size=7, gamma=0.1)
best_valid_loss = 999999
training_start = time.time()
training_accuracies = []
training_losses = []
valid_accuracies = []
valid_losses = []
print("initial measurements")
# Create a temporary directory to save training checkpoints
print(" training")
initial_training_loss, initial_training_acc, _, _ = epoch_func(model, train_dl, CE_criterion)
training_accuracies.append(initial_training_acc)
training_losses.append(initial_training_loss)
print(" valid")
initial_valid_loss, initial_valid_acc, _, _ = epoch_func(model, valid_dl, CE_criterion)
valid_accuracies.append(initial_valid_acc)
valid_losses.append(initial_valid_loss)
with TemporaryDirectory() as tempdir:
best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
torch.save(model.state_dict(), best_model_params_path)
for i in range(num_epochs):
print(f"\n\n Epoch {i+1}/{num_epochs}")
print(" training")
training_loss, training_acc, _, _ = epoch_func(model, train_dl, criterion=CE_criterion, optimizer=SGD_optimizer)
training_accuracies.append(training_acc)
training_losses.append(training_loss)
lr_decay_scheduler.step()
print(" valid")
valid_loss, valid_acc, _, _ = epoch_func(model, valid_dl, criterion=CE_criterion)
valid_accuracies.append(valid_acc)
valid_losses.append(valid_loss)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), best_model_params_path)
print(" new best valid loss")
#load best model weights
model.load_state_dict(torch.load(best_model_params_path))
training_accuracies = [float(acc) for acc in training_accuracies]
valid_accuracies = [float(acc) for acc in valid_accuracies]
history = {
"epochs" : list(range(num_epochs+1)),
"training_acc": training_accuracies,
"training_loss": training_losses,
"valid_acc": valid_accuracies,
"valid_loss": valid_losses
}
return model, history
if __name__ == "__main__":
from PreProcessing import get_awa_dataframes, DataframeDataset
from Models import ResNetModel
train_df, valid_df, test_df = get_awa_dataframes("./Animals_with_Attributes2")
train_dataset = DataframeDataset(train_df)
valid_dataset = DataframeDataset(valid_df)
test_dataset = DataframeDataset(test_df)
model = ResNetModel(train_dataset)
model, history = train_model(model, train_dataset.get_data_loader(), valid_dataset.get_data_loader(), num_epochs=1)
#plot_history(history)
print("test")
test_loss, test_acc, y_pred, y_true = epoch_func(model, test_dataset.get_data_loader(), nn.CrossEntropyLoss())
print(f"test_loss:{test_loss}, test_acc:{test_acc}")
print(f"y_pred: {len(y_pred)}, y_true:{len(y_true)}")
print(f"y_pred: {type(y_pred)}, y_true:{type(y_true)}")
#print(f"y_pred: {y_pred[0].shape}, y_true:{y_true[0].shape}")
plot_confusion_matrix(y_pred, y_true, labels = train_dataset.encoder.classes_)
plot_prediction_grid(model, test_dataset)