Skip to content
This repository was archived by the owner on Jan 3, 2024. It is now read-only.

Commit 882825c

Browse files
Update README.md (#9)
* Update README.md * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a126fe8 commit 882825c

File tree

3 files changed

+60
-26
lines changed

3 files changed

+60
-26
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ repos:
1212
- repo: https://github.com/PyCQA/isort
1313
rev: 5.12.0
1414
hooks:
15-
- id: isort
15+
- id: isort

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
⚠️ The functionality provided by this repository can now be found in [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) in the [to_superpixels()](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/to_superpixels.html) function.
2+
3+
---
4+
15
# PyTorch Superpixels
26

37
- [Why use superpixels?](#why-use-superpixels)

example.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1+
from functools import partial
2+
from multiprocessing import Pool
3+
from os import cpu_count
4+
from pathlib import Path
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import torch
9+
import torchvision.transforms.functional as F
110
from numpy.core.fromnumeric import product
11+
from skimage.segmentation import find_boundaries, mark_boundaries, slic
212
from skimage.segmentation.boundaries import find_boundaries
3-
import torch
4-
import numpy as np
513
from torchvision.io import read_image
614
from torchvision.models.segmentation import fcn_resnet50
7-
import matplotlib.pyplot as plt
815
from torchvision.transforms.functional import convert_image_dtype
9-
from torchvision.utils import draw_segmentation_masks
10-
from torchvision.utils import make_grid
16+
from torchvision.utils import draw_segmentation_masks, make_grid
17+
1118
from pytorch_superpixels.runtime import superpixelise
12-
from skimage.segmentation import slic, mark_boundaries, find_boundaries
13-
from pathlib import Path
14-
from multiprocessing import Pool
15-
from os import cpu_count
16-
from functools import partial
1719

18-
import torchvision.transforms.functional as F
1920

2021
def show(imgs):
2122
if not isinstance(imgs, list):
@@ -32,9 +33,27 @@ def show(imgs):
3233

3334
if __name__ == "__main__":
3435
sem_classes = [
35-
'__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
36-
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
37-
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
36+
"__background__",
37+
"aeroplane",
38+
"bicycle",
39+
"bird",
40+
"boat",
41+
"bottle",
42+
"bus",
43+
"car",
44+
"cat",
45+
"chair",
46+
"cow",
47+
"diningtable",
48+
"dog",
49+
"horse",
50+
"motorbike",
51+
"person",
52+
"pottedplant",
53+
"sheep",
54+
"sofa",
55+
"train",
56+
"tvmonitor",
3857
]
3958
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
4059
image_dims = [420, 640]
@@ -46,24 +65,25 @@ def show(imgs):
4665
batch = convert_image_dtype(batch_int, dtype=torch.float)
4766

4867
# permute because slic expects the last dimension to be channel
49-
with Pool(processes = cpu_count()-1) as pool:
68+
with Pool(processes=cpu_count() - 1) as pool:
5069
# re-order axes for skimage
51-
args = [x.permute(1,2,0) for x in batch]
70+
args = [x.permute(1, 2, 0) for x in batch]
5271
# 100 segments
53-
kwargs = {"n_segments":100, "start_label":0, "slic_zero":True}
72+
kwargs = {"n_segments": 100, "start_label": 0, "slic_zero": True}
5473
func = partial(slic, **kwargs)
5574
masks_100sp = pool.map(func, args)
5675
# 1000 segments
5776
kwargs["n_segments"] = 1000
5877
func = partial(slic, **kwargs)
5978
masks_1000sp = pool.map(func, args)
6079

61-
6280
model = fcn_resnet50(pretrained=True, progress=False)
6381
model = model.eval()
6482

65-
normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
66-
outputs = model(batch)['out']
83+
normalized_batch = F.normalize(
84+
batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
85+
)
86+
outputs = model(batch)["out"]
6787

6888
normalized_masks = torch.nn.functional.softmax(outputs, dim=1)
6989
num_classes = normalized_masks.shape[1]
@@ -73,23 +93,33 @@ def generate_all_class_masks(outputs, masks):
7393
masks = torch.from_numpy(masks)
7494
outputs_sp = superpixelise(outputs, masks)
7595
normalized_masks_sp = torch.nn.functional.softmax(outputs_sp, dim=1)
76-
return normalized_masks_sp[i].argmax(0) == torch.arange(num_classes)[:, None, None]
96+
return (
97+
normalized_masks_sp[i].argmax(0) == torch.arange(num_classes)[:, None, None]
98+
)
7799

78100
to_show = []
79101
for i, image in enumerate(images):
80102
# before
81-
all_classes_masks = normalized_masks[i].argmax(0) == torch.arange(num_classes)[:, None, None]
82-
to_show.append(draw_segmentation_masks(image, masks=all_classes_masks, alpha=.6))
103+
all_classes_masks = (
104+
normalized_masks[i].argmax(0) == torch.arange(num_classes)[:, None, None]
105+
)
106+
to_show.append(
107+
draw_segmentation_masks(image, masks=all_classes_masks, alpha=0.6)
108+
)
83109
# after 100
84110
all_classes_masks_sp = generate_all_class_masks(outputs, masks_100sp)
85-
to_show.append(draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=.6))
111+
to_show.append(
112+
draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=0.6)
113+
)
86114
# show superpixel boundaries
87115
boundaries = find_boundaries(masks_100sp[i])
88116
to_show[-1][0:2, boundaries] = 255
89117
to_show[-1][2, boundaries] = 0
90118
# after 1000
91119
all_classes_masks_sp = generate_all_class_masks(outputs, masks_1000sp)
92-
to_show.append(draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=.6))
120+
to_show.append(
121+
draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=0.6)
122+
)
93123
# show superpixel boundaries
94124
boundaries = find_boundaries(masks_1000sp[i])
95125
to_show[-1][0:2, boundaries] = 255

0 commit comments

Comments
 (0)