Skip to content
This repository was archived by the owner on Jan 3, 2024. It is now read-only.

Commit a126fe8

Browse files
authored
Merge pull request #7 from HMellor/add-pre-commit
Add pre-commit and do some formatting
2 parents 0a9e80d + d5d3968 commit a126fe8

File tree

11 files changed

+192
-70
lines changed

11 files changed

+192
-70
lines changed

.gitignore

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,8 @@ ENV/
9090
env.bak/
9191
venv.bak/
9292

93-
# Spyder project settings
94-
.spyderproject
95-
.spyproject
96-
97-
# Rope project settings
98-
.ropeproject
99-
100-
# mkdocs documentation
101-
/site
102-
103-
# mypy
104-
.mypy_cache/
93+
# VS Code workspace settings
94+
.vscode
10595

10696
# Dataset directories
10797
VOCdevkit

.pre-commit-config.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v4.4.0
4+
hooks:
5+
- id: check-yaml
6+
- id: end-of-file-fixer
7+
- id: trailing-whitespace
8+
- repo: https://github.com/psf/black
9+
rev: 23.9.1
10+
hooks:
11+
- id: black
12+
- repo: https://github.com/PyCQA/isort
13+
rev: 5.12.0
14+
hooks:
15+
- id: isort

README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
# PyTorch Superpixels
2+
23
- [Why use superpixels?](#why-use-superpixels)
34
- [Example usage](#example-usage)
5+
46
## Why use superpixels?
7+
58
Dimensionality reduction allows for the use of simpler networks or more complex objectives. A common way of doing this is to simply downsample the images so that there are fewer pixels to contend with. However, this is a lossy operation so detail (and therefore the upper bound on experimental results) is reduced.
69

710
Superpixels slightly alleviate this problem because they are able to encode information about edges within themselves. Generating superpixels is an unsupervised clustering operation. Whilst there are already clustering packages written for Python (some of which this project depends on), they all operate with NumPy arrays. This means that they cannot take advantage of GPU acceleration in the way that PyTorch tensors can.
811

912
The aim of this project is to bridge the gap between these existing packages and PyTorch so that superpixels can be readily used as an alternative to pixels in various machine learning experiments.
13+
1014
## Example usage
15+
1116
Here is some example code that uses superpixels for semantic segmentation.
17+
1218
```
1319
# Generate list of filenames from your dataset
14-
imageList = pytorch_superpixels.list_loader.image_list(
20+
image_list = pytorch_superpixels.list_loader.ImageList(
1521
'pascal-seg', './VOCdevkit/VOC2012', 'trainval')
1622
# Use this list to create and save 100 superpixel dataset
17-
pytorch_superpixels.preprocess.create_masks(imageList, 100)
23+
pytorch_superpixels.preprocess.create_masks(image_list, 100)
1824
1925
# -----------------------------------------------
2026
# code that sets up model, optimizer, dataloader, metrics, etc.
@@ -48,7 +54,8 @@ for (images, labels, masks) in trainloader:
4854
loss.backward()
4955
optimizer.step()
5056
```
51-
_______________________________________
57+
58+
---
5259

5360
This project stems from a module I created for use in my master's thesis.
5461

example.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from numpy.core.fromnumeric import product
2+
from skimage.segmentation.boundaries import find_boundaries
3+
import torch
4+
import numpy as np
5+
from torchvision.io import read_image
6+
from torchvision.models.segmentation import fcn_resnet50
7+
import matplotlib.pyplot as plt
8+
from torchvision.transforms.functional import convert_image_dtype
9+
from torchvision.utils import draw_segmentation_masks
10+
from torchvision.utils import make_grid
11+
from pytorch_superpixels.runtime import superpixelise
12+
from skimage.segmentation import slic, mark_boundaries, find_boundaries
13+
from pathlib import Path
14+
from multiprocessing import Pool
15+
from os import cpu_count
16+
from functools import partial
17+
18+
import torchvision.transforms.functional as F
19+
20+
def show(imgs):
21+
if not isinstance(imgs, list):
22+
imgs = [imgs]
23+
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
24+
for i, img in enumerate(imgs):
25+
img = img.detach()
26+
img = F.to_pil_image(img)
27+
axs[0, i].imshow(np.asarray(img))
28+
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
29+
plt.tight_layout()
30+
plt.show()
31+
32+
33+
if __name__ == "__main__":
34+
sem_classes = [
35+
'__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
36+
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
37+
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
38+
]
39+
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
40+
image_dims = [420, 640]
41+
images = [read_image(str(img)) for img in Path("data").glob("*.jpg")]
42+
images = [F.center_crop(image, image_dims) for image in images]
43+
image_size = product(image_dims)
44+
45+
batch_int = torch.stack(images)
46+
batch = convert_image_dtype(batch_int, dtype=torch.float)
47+
48+
# permute because slic expects the last dimension to be channel
49+
with Pool(processes = cpu_count()-1) as pool:
50+
# re-order axes for skimage
51+
args = [x.permute(1,2,0) for x in batch]
52+
# 100 segments
53+
kwargs = {"n_segments":100, "start_label":0, "slic_zero":True}
54+
func = partial(slic, **kwargs)
55+
masks_100sp = pool.map(func, args)
56+
# 1000 segments
57+
kwargs["n_segments"] = 1000
58+
func = partial(slic, **kwargs)
59+
masks_1000sp = pool.map(func, args)
60+
61+
62+
model = fcn_resnet50(pretrained=True, progress=False)
63+
model = model.eval()
64+
65+
normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
66+
outputs = model(batch)['out']
67+
68+
normalized_masks = torch.nn.functional.softmax(outputs, dim=1)
69+
num_classes = normalized_masks.shape[1]
70+
71+
def generate_all_class_masks(outputs, masks):
72+
masks = np.stack(masks)
73+
masks = torch.from_numpy(masks)
74+
outputs_sp = superpixelise(outputs, masks)
75+
normalized_masks_sp = torch.nn.functional.softmax(outputs_sp, dim=1)
76+
return normalized_masks_sp[i].argmax(0) == torch.arange(num_classes)[:, None, None]
77+
78+
to_show = []
79+
for i, image in enumerate(images):
80+
# before
81+
all_classes_masks = normalized_masks[i].argmax(0) == torch.arange(num_classes)[:, None, None]
82+
to_show.append(draw_segmentation_masks(image, masks=all_classes_masks, alpha=.6))
83+
# after 100
84+
all_classes_masks_sp = generate_all_class_masks(outputs, masks_100sp)
85+
to_show.append(draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=.6))
86+
# show superpixel boundaries
87+
boundaries = find_boundaries(masks_100sp[i])
88+
to_show[-1][0:2, boundaries] = 255
89+
to_show[-1][2, boundaries] = 0
90+
# after 1000
91+
all_classes_masks_sp = generate_all_class_masks(outputs, masks_1000sp)
92+
to_show.append(draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=.6))
93+
# show superpixel boundaries
94+
boundaries = find_boundaries(masks_1000sp[i])
95+
to_show[-1][0:2, boundaries] = 255
96+
to_show[-1][2, boundaries] = 0
97+
show(make_grid(to_show, nrow=6))

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[tool.isort]
2+
profile = "black"

pytorch_superpixels/list_loader.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1-
from os.path import join
2-
from os.path import exists
1+
from os.path import exists, join
32

43

5-
class image_list:
6-
def __init__(self, dataset, path, split='trainval'):
4+
class ImageList:
5+
def __init__(self, dataset, path, split="trainval"):
76
# Configured datasets
8-
datasets = {'pascal-seg': {'listPath': 'ImageSets/Segmentation/',
9-
'imagePath': 'JPEGImages',
10-
'targetPath': 'SegmentationClass'}
11-
}
7+
datasets = {
8+
"pascal-seg": {
9+
"listPath": "ImageSets/Segmentation/",
10+
"imagePath": "JPEGImages",
11+
"targetPath": "SegmentationClass",
12+
}
13+
}
1214
# Object variables
1315
self.split = split
1416
self.dataset = dataset
1517
self.path = path
16-
self.listPath = join(path, datasets[dataset]['listPath'])
17-
self.imagePath = join(path, datasets[dataset]['imagePath'])
18-
self.targetPath = join(path, datasets[dataset]['targetPath'])
18+
self.listPath = join(path, datasets[dataset]["listPath"])
19+
self.imagePath = join(path, datasets[dataset]["imagePath"])
20+
self.targetPath = join(path, datasets[dataset]["targetPath"])
1921
self.list = []
2022
# Does the split exist?
2123
list_path = join(self.listPath, self.split + ".txt")

pytorch_superpixels/metrics.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from os.path import join
2+
13
import torch
24
from tqdm import tqdm
3-
from os.path import join
45

56

67
def mask_accuracy(target, mask):
@@ -18,7 +19,7 @@ def mask_accuracy(target, mask):
1819
def dataset_accuracy(superpixels):
1920
# Generate image list
2021
if superpixels is not None:
21-
image_list = get_image_list('trainval_super')
22+
image_list = get_image_list("trainval_super")
2223
else:
2324
image_list = get_image_list()
2425

@@ -56,10 +57,7 @@ def find_usable_images(split, superpixels):
5657
# Generate image list
5758
image_list = get_image_list(split)
5859
usable = []
59-
target_dir = join(
60-
root,
61-
"SegmentationClass/pre_encoded_{}_sp".format(superpixels)
62-
)
60+
target_dir = join(root, "SegmentationClass/pre_encoded_{}_sp".format(superpixels))
6361
for image_number in image_list:
6462
target_name = image_number + ".pt"
6563
target_path = join(target_dir, target_name)
@@ -83,7 +81,7 @@ def fix_broken_images(superpixels):
8381
def find_size_variance(superpixels):
8482
# Generate image list
8583
if superpixels is not None:
86-
image_list = get_image_list('trainval_super')
84+
image_list = get_image_list("trainval_super")
8785
else:
8886
image_list = get_image_list()
8987
mask_dir = "SegmentationClass/{}_sp".format(superpixels)

pytorch_superpixels/preprocess.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,64 @@
1+
from multiprocessing import cpu_count
2+
from os import mkdir
3+
from os.path import exists, join
4+
5+
import torch
6+
from joblib import Parallel, delayed
17
from skimage.io import imread
28
from skimage.segmentation import slic
39
from skimage.util import img_as_float
4-
from multiprocessing import cpu_count
5-
from joblib import Parallel, delayed
6-
from os.path import exists
7-
from os.path import join
810
from tqdm import tqdm
9-
from os import mkdir
10-
import torch
1111

1212

13-
def create_masks(imageList, numSegments=100, limOverseg=None):
13+
def create_masks(image_list, num_segments=100, oversegmentation_limit=None):
1414
# Save mask and target for image number
1515
def save_mask(image_number):
1616
# Load image/target pair
17-
image_path = join(imageList.imagePath, image_number + ".jpg")
18-
target_path = join(imageList.targetPath, image_number + ".png")
17+
image_path = join(image_list.imagePath, image_number + ".jpg")
18+
target_path = join(image_list.targetPath, image_number + ".png")
1919
image = img_as_float(imread(image_path))
2020
target = imread(target_path)
2121
target = torch.from_numpy(target)
2222
# Save paths
23-
saveDir = join(imageList.path, 'SuperPixels')
24-
maskDir = join(saveDir, '{}_sp_mask'.format(numSegments))
25-
targetDir = join(saveDir, '{}_sp_target'.format(numSegments))
23+
save_dir = join(image_list.path, "SuperPixels")
24+
mask_dir = join(save_dir, "{}_sp_mask".format(num_segments))
25+
targetDir = join(save_dir, "{}_sp_target".format(num_segments))
2626
# Check that directories exist
27-
if not exists(saveDir):
28-
mkdir(saveDir)
29-
if not exists(maskDir):
30-
mkdir(maskDir)
27+
if not exists(save_dir):
28+
mkdir(save_dir)
29+
if not exists(mask_dir):
30+
mkdir(mask_dir)
3131
if not exists(targetDir):
3232
mkdir(targetDir)
3333
# Define save paths
34-
mask_save_path = join(maskDir, image_number + ".pt")
34+
mask_save_path = join(mask_dir, image_number + ".pt")
3535
target_save_path = join(targetDir, image_number + ".pt")
3636
# If they haven't already been made, make them
3737
if not exists(mask_save_path) and not exists(target_save_path):
3838
# Create mask for image/target pair
3939
mask, target_s = create_mask(
4040
image=image,
4141
target=target,
42-
numSegments=numSegments,
43-
limOverseg=limOverseg
42+
num_segments=num_segments,
43+
oversegmentation_limit=oversegmentation_limit,
4444
)
4545
torch.save(mask, mask_save_path)
4646
torch.save(target_s, target_save_path)
4747

4848
num_cores = cpu_count()
49-
inputs = tqdm(imageList.list)
49+
inputs = tqdm(image_list.list)
5050
# Iterate through all images utilising all CPU cores
51-
Parallel(n_jobs=num_cores)(delayed(save_mask)(image_number)
52-
for image_number in inputs)
51+
Parallel(n_jobs=num_cores)(
52+
delayed(save_mask)(image_number) for image_number in inputs
53+
)
5354

5455

55-
def create_mask(image, target, numSegments, limOverseg):
56+
def create_mask(image, target, num_segments, oversegmentation_limit):
5657
# Perform SLIC segmentation
57-
mask = slic(image, n_segments=numSegments, slic_zero=True)
58+
mask = slic(image, n_segments=num_segments, slic_zero=True)
5859
mask = torch.from_numpy(mask)
5960

60-
if limOverseg is not None:
61+
if oversegmentation_limit is not None:
6162
# Oversegmentation step
6263
superpixels = mask.unique().numel()
6364
overseg = superpixels
@@ -78,15 +79,16 @@ def create_mask(image, target, numSegments, limOverseg):
7879
# Find minority class in superpixel
7980
min_class = min(class_hist)
8081
# Is the minority class large enough for oversegmentation
81-
above_threshold = min_class > class_hist.sum() * limOverseg
82+
above_threshold = min_class > class_hist.sum() * oversegmentation_limit
8283
if above_threshold:
8384
# Leaving one class in supperpixel be
8485
for c in classes[1:]:
8586
# Adding to the oversegmentation offset
8687
overseg += 1
8788
# Add offset to class c in the mask
88-
mask[segment_mask] += (target[segment_mask]
89-
== c).long() * overseg
89+
mask[segment_mask] += (
90+
target[segment_mask] == c
91+
).long() * overseg
9092

9193
# (Re)define how many superpixels there are and create target_s
9294
superpixels = mask.unique().numel()

0 commit comments

Comments
 (0)