-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
45 lines (40 loc) · 2.12 KB
/
main.py
File metadata and controls
45 lines (40 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import torch.nn as nn
import os
from PreProcessing import get_awa_dataframes, DataframeDataset
from training_functions import train_model, epoch_func
from Models import ResNetModel
from graphics import plot_history, plot_prediction_grid, plot_confusion_matrix
BATCH_SIZE = 64
EPOCHS = 2
if __name__ == "__main__":
"""This main loops trains a naive classifier and then builds a hierarchy using agglomeritve clustering from the class centroids"""
#Make a directory to save model, metrics, and hierarchies
save_dir = "./Models/2epoch_resnet18/"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
#get a dataframes of img, class pairs
train_df, valid_df, test_df = get_awa_dataframes("./Animals_with_Attributes2")
train_dataset = DataframeDataset(train_df)
valid_dataset = DataframeDataset(valid_df)
test_dataset = DataframeDataset(test_df)
#get data loaders for each dataset
train_dl = train_dataset.get_data_loader(batch_size = BATCH_SIZE)
valid_dl = valid_dataset.get_data_loader(batch_size = BATCH_SIZE)
test_dl = test_dataset.get_data_loader(batch_size = BATCH_SIZE)
#Train and save the model with important metrics
model = ResNetModel(train_dataset)
model, history = train_model(model, train_dl, valid_dl, num_epochs = EPOCHS)
plot_history(history, save_file = f"{save_dir}training_history.png", blocking = False)
torch.save(model, f"{save_dir}model.pt")
#save the test loss and accuracy
print("test")
test_loss, test_acc, y_pred, y_true = epoch_func(model, test_dl, nn.CrossEntropyLoss())
plot_confusion_matrix(y_true, y_pred, train_dataset.encoder.classes_, save_file=f"{save_dir}confusion_matrix.png")
print(f"test_loss:{test_loss}, test_acc:{test_acc}")
with open(f"{save_dir}metrics.txt", "w") as file:
file.write(f"test_loss:{test_loss}, test_acc:{test_acc}")
#save the prediction grid
plot_prediction_grid(model, test_dataset, save_file = f"{save_dir}prediction_grid.png")
#createh the hierarchy with the model and test dataset
#create_binary_hierarchy(model, test_dataset, save_file = f"{save_dir}hierarchy")