-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimage_util.py
More file actions
111 lines (101 loc) · 3.88 KB
/
image_util.py
File metadata and controls
111 lines (101 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import math
import time
import random
import numpy as np
from PIL import Image
import os
import random
import matplotlib.pyplot as plt
import pickle as pk
from glob import glob
# img = Image.open('../train_images/landscape.jpg')
# npArray = np.array(img)
# plt.figure('oriImage')
# plt.imshow(npArray)
# plt.show()
class imageUtil(object):
def __init__(self, dir=os.path.join('..', 'data', 'Raise_6k')):
self.dataDir = dir
dataSplitFile = os.path.basename(dir) + '-split.pkl'
if not os.path.exists(dataSplitFile):
fns = []
for x in os.walk(dir):
fns += [y for y in glob(os.path.join(x[0], '*.bmp'))]
random.shuffle(fns)
self.trainImgFileNames = fns[:len(fns) * 9 // 10]
self.testImgFileNames = fns[len(fns) * 9 // 10:]
with open(dataSplitFile, 'wb') as f:
pk.dump([self.trainImgFileNames, self.testImgFileNames], f)
else:
with open(dataSplitFile, 'rb') as f:
self.trainImgFileNames, self.testImgFileNames = pk.load(f)
self.trainImgMinId = 0
self.trainImgMaxId = len(self.trainImgFileNames) - 1
self.testImgMinId = 0
self.testImgMaxId = len(self.testImgFileNames) - 1
self.trainImgId = self.trainImgMaxId
self.testImgId = self.testImgMaxId
self.trainFlipped = True
self.testFlipped = True
# plt.figure('cropResizedImg')
# plt.imshow(cropResizedImg)
# plt.show()
def updateTrainImg(self):
if self.trainFlipped:
self.trainImgId = (self.trainImgId + 1) % (self.trainImgMaxId - self.trainImgMinId + 1) + self.trainImgMinId
img = Image.open(self.trainImgFileNames[self.trainImgId])
self.train_image_float_array = np.array(img, dtype='float32') / 255.0
self.trainFlipped = False
else:
self.train_image_float_array = np.flip(self.train_image_float_array, axis=1)
self.trainFlipped = True
def updateTestImg(self):
if self.testFlipped:
self.testImgId = (self.testImgId + 1) % (self.testImgMaxId - self.testImgMinId + 1) + self.testImgMinId
img = Image.open(self.testImgFileNames[self.testImgId])
self.test_image_float_array = np.array(img, dtype='float32') / 255.0
self.testFlipped = False
else:
self.test_image_float_array = np.flip(self.test_image_float_array, axis=1)
self.testFlipped = True
def getTrainImage(self):
self.updateTrainImg()
npArray = self.train_image_float_array
while len(npArray.shape) != 3: #or (npArray[0:10, 0:10, 0] - npArray[0:10, 0:10, 1] < 1e-3).all():
self.updateTrainImg()
npArray = self.train_image_float_array
# Image.fromarray(np.array(npArray * 255 // 1, dtype=np.uint8)).show()
npArray = npArray[np.newaxis, ...]
return npArray
def getTestImage(self):
self.updateTestImg()
npArray = self.test_image_float_array
while len(npArray.shape) != 3: #or (npArray[0:10, 0:10, 0] - npArray[0:10, 0:10, 1] < 1e-3).all():
self.updateTestImg()
npArray = self.test_image_float_array
npArray = npArray[np.newaxis, ...]
return npArray
def generateTrainImageBatch(self, batch_size, is_train=True):
tensor_batch= self.getTrainImage()
for i in range(batch_size-1):
tensor_batch = np.concatenate((tensor_batch, self.getTrainImage()), axis=0)
return tensor_batch
def generateTestImageBatch(self, batch_size, is_train=False):
tensor_batch = self.getTestImage()
for i in range(batch_size-1):
tensor_batch = np.concatenate((tensor_batch, self.getTestImage()), axis=0)
return tensor_batch
if __name__ == '__main__':
dir = os.path.join('..', 'data', 'Raise_6k')
imageGenerater = imageUtil(dir)
for i in range(1000):
# if i%20==0:
# print('index',i)
# imageGenerater.generateImageBatch(2)
print (i)
imageGenerater.generateTrainImageBatch(2)
input('continue')