-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPreProcessing.py
More file actions
128 lines (108 loc) · 4.48 KB
/
PreProcessing.py
File metadata and controls
128 lines (108 loc) · 4.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
import os
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from graphics import display_img_grid
SEED = 0
IMG_SIDE_LENGTH = 256
def get_awa_dataframes(base_dir):
"""Given the path to the Animals_with_Attributes2 dataset return a dataframe with columns "path", "label"
where path is the path to the image and label is the label corrisponding to the image"""
labels_dir = f"{base_dir}/JPEGImages"
image_paths = []
image_labels = []
for label in os.listdir(labels_dir):
label_path = os.path.join(labels_dir, label)
#print(label_path)
if os.path.isdir(label_path):
for image_file in os.listdir(label_path):
img_path = os.path.join(label_path, image_file)
image_paths.append(img_path)
image_labels.append(label)
#print(f"\n\n{img_path}\n{label}\n\n")
df = pd.DataFrame({
'path': image_paths,
'label': image_labels
})
train_df, temp_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state = SEED)
valid_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state = SEED)
return train_df, valid_df, test_df
class DataframeDataset(torch.utils.data.Dataset):
"""
Torch dataset wrapper for a dataframe with columns path, label
"""
def __init__(self, dataframe, encoder = None, augment = False):
self.dataframe = dataframe
if encoder is None:
encoder = LabelEncoder()
encoder.fit(self.dataframe["label"])
self.encoder = encoder
self.augment = augment
def __len__(self):
return len(self.dataframe)
def __getitem__(self, index):
"""given a row index, load the image specified in the "path" column and give it the label from the "label" column """
row = self.dataframe.iloc[index]
img = Image.open(row["path"])
resized_img = img.resize((IMG_SIDE_LENGTH, IMG_SIDE_LENGTH))
t = None
#decide how to convert the image into a tensor
if self.augment:
t = transforms.Compose([
transforms.RandomAffine(degrees=90, shear=10), # Rotate by up to 15 degrees and shear by up to 10 degrees
transforms.ColorJitter(brightness=0.2), # Adjust brightness by a factor of 0.5
transforms.ToTensor()
])
else:
t = transforms.Compose([
transforms.ToTensor()
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_tensor = t(resized_img)
#convert the label to an int
label = self.encoder.transform([row['label']])[0]
label_tensor = torch.tensor(label)
return (img_tensor,label_tensor)
def __str__(self):
return f"{self.dataframe['label'].value_counts()}"
def show_samples(self, n=60):
#a test method for the datset
images = []
labels = []
for i in range(min(len(self), n)):
element = self[i]
label = element[1]
img_tensor = element[0]
img_numpy = img_tensor.numpy()
img_numpy = img_numpy.transpose(1, 2, 0)
images.append(img_numpy)
text_label = self.encoder.inverse_transform([label.numpy()])[0]
labels.append(text_label)
display_img_grid(images, captions=labels)
def get_data_loader(self, batch_size = 16):
#get the dataloader with a given batchsize
return DataLoader(self, batch_size=batch_size, shuffle=True)
def play_data_loader(data_loader, n=60):
#a test method for the data loaders
images = []
for entry in data_loader:
for img_tensor in entry[0]:
img_numpy = img_tensor.numpy()
img_numpy = img_numpy.transpose(1, 2, 0)
images.append(img_numpy)
if len(images) >= n:
display_img_grid(images)
return
if __name__ == "__main__":
train_df, valid_df, test_df = get_awa_dataframes("./Animals_with_Attributes2")
print(f"train_df:\n{train_df}")
print(f"valid_df:\n{valid_df}")
print(f"test_df:\n{test_df}")
train_ds = DataframeDataset(train_df, augment = True)
train_ds.show_samples()
play_data_loader(train_ds.get_data_loader())