-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraphics.py
More file actions
109 lines (87 loc) · 3.47 KB
/
graphics.py
File metadata and controls
109 lines (87 loc) · 3.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import matplotlib.pyplot as plt
import torch
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
def display_img_grid(images, captions = None, save_file = None):
grid_size = int(len(images)**0.5)
fig, axs = plt.subplots(grid_size, grid_size, figsize=(15, 15))
for i, ax in enumerate(axs.flat):
if i < len(images):
ax.imshow(images[i])
if captions is not None:
ax.set_title(captions[i])
ax.axis('off')
plt.tight_layout()
if save_file is not None:
plt.savefig(save_file)
plt.show()
def get_eval_grid(model, dataset, n = 60):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.eval()
model.to(device)
images = []
captions = []
for inputs, labels in dataset.get_data_loader():
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for img_tensor, label, pred in zip(inputs, labels, preds):
img_numpy = img_tensor.cpu().numpy()
label_numpy = label.cpu().numpy()
pred_numpy = pred.cpu().numpy()
img_numpy = img_numpy.transpose(1, 2, 0)
images.append(img_numpy)
text_label = dataset.encoder.inverse_transform([label_numpy])[0]
text_pred = dataset.encoder.inverse_transform([pred_numpy])[0]
captions.append(f"true:{text_label}\npredicted:{text_pred}")
if len(images) > n:
return images, captions
def plot_prediction_grid(model, dataset, save_file = None):
images, labels = get_eval_grid(model, dataset)
display_img_grid(images, captions=labels, save_file = save_file)
def plot_history(history, blocking = False, save_file = None):
epochs = history["epochs"]
#print(f"epochs:{epochs}")
#print(f"training_accuracy:{training_accuracy}")
#print(f"validation_accuracy:{validation_accuracy}")
#print(f"training_loss:{training_loss}")
#print(f"validation_loss:{validation_loss}")
plt.figure(figsize=(10, 12))
# Plot Training Accuracy on the top
plt.subplot(2, 1, 1)
plt.plot(epochs, history["training_acc"], label='Training Accuracy', marker='o')
plt.plot(epochs, history["valid_acc"], label='Validation Accuracy', marker='o')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracy Over Epochs')
plt.grid(True)
# Plot Validation Accuracy on the bottom
plt.subplot(2, 1, 2)
plt.plot(epochs, history["training_loss"], label='Training Loss', marker='o')
plt.plot(epochs, history["valid_loss"], label='Validation Loss', marker='o')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Over Epochs')
plt.grid(True)
plt.tight_layout()
if save_file is not None:
plt.savefig(save_file)
plt.show(block=blocking)
def plot_confusion_matrix(y_true, y_pred, labels, save_file = None):
# Compute the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Create a heatmap to visualize the confusion matrix
fig, ax = plt.subplots(figsize=(19,20))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=labels, yticklabels=labels, ax = ax)
# Label the axes
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
# Show the plot
if save_file is not None:
fig.savefig(save_file, bbox_inches='tight', dpi=300)
plt.show()