Skip to content

Commit f0832e4

Browse files
authored
Merge pull request #76 from juglab/issue-73
Fix for bug # 73
2 parents 04f9760 + bec93d8 commit f0832e4

File tree

6 files changed

+65
-19
lines changed

6 files changed

+65
-19
lines changed

n2v/internals/N2V_DataGenerator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2
183183
patches = []
184184
if n_dims == 2:
185185
if data.shape[1] > shape[0] and data.shape[2] > shape[1]:
186-
for y in range(0, data.shape[1] - shape[0], shape[0]):
187-
for x in range(0, data.shape[2] - shape[1], shape[1]):
186+
for y in range(0, data.shape[1] - shape[0] + 1, shape[0]):
187+
for x in range(0, data.shape[2] - shape[1] + 1, shape[1]):
188188
patches.append(data[:, y:y + shape[0], x:x + shape[1]])
189189

190190
return np.concatenate(patches)
@@ -194,14 +194,13 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2
194194
print("'shape' is too big.")
195195
elif n_dims == 3:
196196
if data.shape[1] > shape[0] and data.shape[2] > shape[1] and data.shape[3] > shape[2]:
197-
for z in range(0, data.shape[1] - shape[0], shape[0]):
198-
for y in range(0, data.shape[2] - shape[1], shape[1]):
199-
for x in range(0, data.shape[3] - shape[2], shape[2]):
197+
for z in range(0, data.shape[1] - shape[0] + 1, shape[0]):
198+
for y in range(0, data.shape[2] - shape[1] + 1, shape[1]):
199+
for x in range(0, data.shape[3] - shape[2] + 1, shape[2]):
200200
patches.append(data[:, z:z + shape[0], y:y + shape[1], x:x + shape[2]])
201201

202202
return np.concatenate(patches)
203-
elif data.shape[1] == shape[0] and data.shape[2] == shape[1] and data.shape[3] == shape[
204-
2]:
203+
elif data.shape[1] == shape[0] and data.shape[2] == shape[1] and data.shape[3] == shape[2]:
205204
return data
206205
else:
207206
print("'shape' is too big.")

n2v/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.10'
1+
__version__ = '0.1.11'

tests/functional/test_training2D_RGB.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77
from n2v.utils.n2v_utils import manipulate_val_data
88
from n2v.internals.N2V_DataGenerator import N2V_DataGenerator
99
from matplotlib import pyplot as plt
10-
import urllib
10+
import urllib.request
1111
import os
1212
import zipfile
1313

1414
# create a folder for our data
1515
if not os.path.isdir('./data'):
1616
os.mkdir('data')
17-
# check if data has been downloaded already
18-
zipPath="data/RGB.zip"
19-
if not os.path.exists(zipPath):
20-
# download and unzip data
21-
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/Frru2hsjjAljpfW/download', zipPath)
22-
with zipfile.ZipFile(zipPath, 'r') as zip_ref:
23-
zip_ref.extractall("data")
17+
# check if data has been downloaded already
18+
zipPath = "data/RGB.zip"
19+
if not os.path.exists(zipPath):
20+
# download and unzip data
21+
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/Frru2hsjjAljpfW/download', zipPath)
22+
with zipfile.ZipFile(zipPath, 'r') as zip_ref:
23+
zip_ref.extractall("data")
2424

2525
# For training, we will load __one__ low-SNR RGB image and use the <code>N2V_DataGenerator</code> to extract non-overlapping patches
2626
datagen = N2V_DataGenerator()
@@ -29,7 +29,7 @@
2929
# The function will return a list of images (numpy arrays).
3030
# In the 'dims' parameter we specify the order of dimensions in the image files we are reading:
3131
# 'C' stands for channels (color)
32-
imgs = datagen.load_imgs_from_directory(directory="data/", filter='*.png', dims='YXC')
32+
imgs = datagen.load_imgs_from_directory(directory="./data", filter='*.png', dims='YXC')
3333

3434
print('shape of loaded images: ',imgs[0].shape)
3535
# Remove alpha channel

tests/functional/test_training2D_SEM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# create a folder for our data
1414
if not os.path.isdir('./data'):
1515
os.mkdir('./data')
16-
zipPath="data/SEM.zip"
16+
zipPath = "data/SEM.zip"
1717
if not os.path.exists(zipPath):
1818
# download and unzip data
1919
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/pXgfbobntrw06lC/download', zipPath)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from n2v.internals.N2V_DataGenerator import N2V_DataGenerator
2+
import urllib.request
3+
import os
4+
import zipfile
5+
6+
7+
def test_generate_patches_2D():
8+
9+
if not os.path.isdir('data'):
10+
os.mkdir('data')
11+
zip_path = "data/RGB.zip"
12+
if not os.path.exists(zip_path):
13+
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/Frru2hsjjAljpfW/download', zip_path)
14+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
15+
zip_ref.extractall('data')
16+
17+
datagen = N2V_DataGenerator()
18+
19+
imgs = datagen.load_imgs_from_directory(directory="data", filter='*.png', dims='YXC')
20+
imgs[0] = imgs[0][..., :3]
21+
patches = datagen.generate_patches_from_list(imgs, shape=(1100, 2800))
22+
assert len(patches) == 1
23+
patches = datagen.generate_patches_from_list(imgs, shape=(550, 1400))
24+
assert len(patches) == 4
25+
patches = datagen.generate_patches_from_list(imgs, shape=(110, 280))
26+
assert len(patches) == 100
27+
28+
def test_generate_patches_3D():
29+
30+
if not os.path.isdir('data'):
31+
os.mkdir('data')
32+
zip_path = 'data/flywing-data.zip'
33+
if not os.path.exists(zip_path):
34+
# download and unzip data
35+
data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/RKStdwKo4FlFrxE/download', zip_path)
36+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
37+
zip_ref.extractall('data')
38+
39+
datagen = N2V_DataGenerator()
40+
41+
imgs = datagen.load_imgs_from_directory(directory="data", filter='*.tif', dims='ZYX')
42+
print(imgs[0].shape)
43+
patches = datagen.generate_patches_from_list(imgs[:1], shape=(35, 520, 692))
44+
assert len(patches) == 1
45+
patches = datagen.generate_patches_from_list(imgs[:1], shape=(5, 52, 174))
46+
assert len(patches) == 210
47+

tests/test_Noise2VoidDataWrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from n2v.internals.N2V_DataWrapper import N2V_DataWrapper
1+
from n2v.internals.N2V_DataWrapper import N2V_DataWrapper
22

33
import numpy as np
44

0 commit comments

Comments
 (0)