1
+ from functools import partial
2
+ from multiprocessing import Pool
3
+ from os import cpu_count
4
+ from pathlib import Path
5
+
6
+ import matplotlib .pyplot as plt
7
+ import numpy as np
8
+ import torch
9
+ import torchvision .transforms .functional as F
1
10
from numpy .core .fromnumeric import product
11
+ from skimage .segmentation import find_boundaries , mark_boundaries , slic
2
12
from skimage .segmentation .boundaries import find_boundaries
3
- import torch
4
- import numpy as np
5
13
from torchvision .io import read_image
6
14
from torchvision .models .segmentation import fcn_resnet50
7
- import matplotlib .pyplot as plt
8
15
from torchvision .transforms .functional import convert_image_dtype
9
- from torchvision .utils import draw_segmentation_masks
10
- from torchvision . utils import make_grid
16
+ from torchvision .utils import draw_segmentation_masks , make_grid
17
+
11
18
from pytorch_superpixels .runtime import superpixelise
12
- from skimage .segmentation import slic , mark_boundaries , find_boundaries
13
- from pathlib import Path
14
- from multiprocessing import Pool
15
- from os import cpu_count
16
- from functools import partial
17
19
18
- import torchvision .transforms .functional as F
19
20
20
21
def show (imgs ):
21
22
if not isinstance (imgs , list ):
@@ -32,9 +33,27 @@ def show(imgs):
32
33
33
34
if __name__ == "__main__" :
34
35
sem_classes = [
35
- '__background__' , 'aeroplane' , 'bicycle' , 'bird' , 'boat' , 'bottle' , 'bus' ,
36
- 'car' , 'cat' , 'chair' , 'cow' , 'diningtable' , 'dog' , 'horse' , 'motorbike' ,
37
- 'person' , 'pottedplant' , 'sheep' , 'sofa' , 'train' , 'tvmonitor'
36
+ "__background__" ,
37
+ "aeroplane" ,
38
+ "bicycle" ,
39
+ "bird" ,
40
+ "boat" ,
41
+ "bottle" ,
42
+ "bus" ,
43
+ "car" ,
44
+ "cat" ,
45
+ "chair" ,
46
+ "cow" ,
47
+ "diningtable" ,
48
+ "dog" ,
49
+ "horse" ,
50
+ "motorbike" ,
51
+ "person" ,
52
+ "pottedplant" ,
53
+ "sheep" ,
54
+ "sofa" ,
55
+ "train" ,
56
+ "tvmonitor" ,
38
57
]
39
58
sem_class_to_idx = {cls : idx for (idx , cls ) in enumerate (sem_classes )}
40
59
image_dims = [420 , 640 ]
@@ -46,24 +65,25 @@ def show(imgs):
46
65
batch = convert_image_dtype (batch_int , dtype = torch .float )
47
66
48
67
# permute because slic expects the last dimension to be channel
49
- with Pool (processes = cpu_count ()- 1 ) as pool :
68
+ with Pool (processes = cpu_count () - 1 ) as pool :
50
69
# re-order axes for skimage
51
- args = [x .permute (1 ,2 , 0 ) for x in batch ]
70
+ args = [x .permute (1 , 2 , 0 ) for x in batch ]
52
71
# 100 segments
53
- kwargs = {"n_segments" :100 , "start_label" :0 , "slic_zero" :True }
72
+ kwargs = {"n_segments" : 100 , "start_label" : 0 , "slic_zero" : True }
54
73
func = partial (slic , ** kwargs )
55
74
masks_100sp = pool .map (func , args )
56
75
# 1000 segments
57
76
kwargs ["n_segments" ] = 1000
58
77
func = partial (slic , ** kwargs )
59
78
masks_1000sp = pool .map (func , args )
60
79
61
-
62
80
model = fcn_resnet50 (pretrained = True , progress = False )
63
81
model = model .eval ()
64
82
65
- normalized_batch = F .normalize (batch , mean = (0.485 , 0.456 , 0.406 ), std = (0.229 , 0.224 , 0.225 ))
66
- outputs = model (batch )['out' ]
83
+ normalized_batch = F .normalize (
84
+ batch , mean = (0.485 , 0.456 , 0.406 ), std = (0.229 , 0.224 , 0.225 )
85
+ )
86
+ outputs = model (batch )["out" ]
67
87
68
88
normalized_masks = torch .nn .functional .softmax (outputs , dim = 1 )
69
89
num_classes = normalized_masks .shape [1 ]
@@ -73,23 +93,33 @@ def generate_all_class_masks(outputs, masks):
73
93
masks = torch .from_numpy (masks )
74
94
outputs_sp = superpixelise (outputs , masks )
75
95
normalized_masks_sp = torch .nn .functional .softmax (outputs_sp , dim = 1 )
76
- return normalized_masks_sp [i ].argmax (0 ) == torch .arange (num_classes )[:, None , None ]
96
+ return (
97
+ normalized_masks_sp [i ].argmax (0 ) == torch .arange (num_classes )[:, None , None ]
98
+ )
77
99
78
100
to_show = []
79
101
for i , image in enumerate (images ):
80
102
# before
81
- all_classes_masks = normalized_masks [i ].argmax (0 ) == torch .arange (num_classes )[:, None , None ]
82
- to_show .append (draw_segmentation_masks (image , masks = all_classes_masks , alpha = .6 ))
103
+ all_classes_masks = (
104
+ normalized_masks [i ].argmax (0 ) == torch .arange (num_classes )[:, None , None ]
105
+ )
106
+ to_show .append (
107
+ draw_segmentation_masks (image , masks = all_classes_masks , alpha = 0.6 )
108
+ )
83
109
# after 100
84
110
all_classes_masks_sp = generate_all_class_masks (outputs , masks_100sp )
85
- to_show .append (draw_segmentation_masks (image , masks = all_classes_masks_sp , alpha = .6 ))
111
+ to_show .append (
112
+ draw_segmentation_masks (image , masks = all_classes_masks_sp , alpha = 0.6 )
113
+ )
86
114
# show superpixel boundaries
87
115
boundaries = find_boundaries (masks_100sp [i ])
88
116
to_show [- 1 ][0 :2 , boundaries ] = 255
89
117
to_show [- 1 ][2 , boundaries ] = 0
90
118
# after 1000
91
119
all_classes_masks_sp = generate_all_class_masks (outputs , masks_1000sp )
92
- to_show .append (draw_segmentation_masks (image , masks = all_classes_masks_sp , alpha = .6 ))
120
+ to_show .append (
121
+ draw_segmentation_masks (image , masks = all_classes_masks_sp , alpha = 0.6 )
122
+ )
93
123
# show superpixel boundaries
94
124
boundaries = find_boundaries (masks_1000sp [i ])
95
125
to_show [- 1 ][0 :2 , boundaries ] = 255
0 commit comments