Skip to content

Commit 6b79c83

Browse files
authored
New features and refactoring (#100)
* Add aux classification head * Add ability to change input channels * Add encoder_depth parameter * Add mobilenet encoder * Add hall of fame
1 parent 378431b commit 6b79c83

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2443
-1541
lines changed

HALLOFFAME.md

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Hall of Fame
2+
3+
`Segmentation Models` package is widely used in the image segmentation competitions.
4+
Here you can find competitions, names of the winners and links to their solutions.
5+
6+
Please, follow these rules, when adding a solution to the "Hall of Fame":
7+
8+
1. Solution should be high rated (e.g. for Kaggle gold or silver medal)
9+
2. There should be a description of the solution (post at the forum / code / blog post / paper / pre-print)
10+
11+
12+
## Kaggle
13+
14+
### [Severstal: Steel Defect Detection](https://www.kaggle.com/c/severstal-steel-defect-detection)
15+
16+
- 1st place.
17+
[Wuxi Jiangsu](https://www.kaggle.com/rguo97),
18+
[Hongbo Zhu](https://www.kaggle.com/zhuhongbo),
19+
[Yizhuo Yu](https://www.kaggle.com/paffpaffyu)
20+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254#latest-675874)]
21+
22+
- 5th place.
23+
[Guanshuo Xu](https://www.kaggle.com/wowfattie)
24+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117208#latest-675385)]
25+
26+
- 9th place.
27+
[Jacek Poplawski](https://www.linkedin.com/in/jacekpoplawski/)
28+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114297#latest-660842)]
29+
30+
- 10th place.
31+
[Alexey Rozhkov](https://www.linkedin.com/in/alexisrozhkov)
32+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114465#latest-659615)]
33+
34+
- 12th place.
35+
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/),
36+
[Ilya Dobrynin](https://www.linkedin.com/in/ilya-dobrynin-79a89b106/),
37+
[Denis Kolpakov](https://www.linkedin.com/in/denis-kolpakov-ab3137197/)
38+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114309#latest-661404)]
39+
40+
- 31st place.
41+
[Insaf Ashrapov](https://www.linkedin.com/in/iashrapov/),
42+
[Igor Krashenyi](https://www.linkedin.com/in/igor-krashenyi-38b89b98),
43+
[Pavel Pleskov](https://www.linkedin.com/in/ppleskov),
44+
[Anton Zakharenkov](https://www.linkedin.com/in/anton-zakharenkov/),
45+
[Nikolai Popov](https://www.linkedin.com/in/nikolai-popov-b2157370/)
46+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114383#latest-658438)]
47+
[[code](https://github.com/Diyago/Severstal-Steel-Defect-Detection)]
48+
49+
- 55th place.
50+
[Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
51+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114410#latest-672682)]
52+
[[code](https://github.com/khornlund/severstal-steel-defect-detection)]
53+
54+
- Efficiency round 1st place.
55+
[Stefan Stefanov](https://www.linkedin.com/in/stefan-stefanov-63a77b1)
56+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117486#latest-674229)]
57+
58+
59+
### [Understanding Clouds from Satellite Images](https://www.kaggle.com/c/understanding_cloud_organization)
60+
61+
- 2nd place.
62+
[Andrey Kiryasov](https://www.kaggle.com/ekydna)
63+
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118255#latest-678189)]
64+
65+
- 4th place.
66+
[Ching-Loong Seow](https://www.linkedin.com/in/clseow/)
67+
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118016#latest-677333)]
68+
69+
- 34th place.
70+
[Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
71+
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118250#latest-678176)]
72+
[[code](https://github.com/khornlund/understanding-cloud-organization)]
73+
74+
- 55th place.
75+
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/)
76+
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118019#latest-678626)]

README.md

Lines changed: 109 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The main features of this library are:
77

88
- High level API (just two lines to create neural network)
99
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
10-
- 31 available encoders for each architecture
10+
- 45 available encoders for each architecture
1111
- All encoders have pre-trained weights for faster and better convergence
1212

1313
### Table of content
@@ -16,10 +16,14 @@ The main features of this library are:
1616
3. [Models](#models)
1717
1. [Architectures](#architectires)
1818
2. [Encoders](#encoders)
19-
3. [Pretrained weights](#weights)
2019
4. [Models API](#api)
20+
1. [Input channels](#input-channels)
21+
2. [Auxiliary classification output](#auxiliary-classification-output)
22+
3. [Depth](#depth)
2123
5. [Installation](#installation)
22-
6. [License](#license)
24+
6. [Competitions won with the library](#competitions-won-with-the-library)
25+
7. [License](#license)
26+
8. [Contributing](#contributing)
2327

2428
### Quick start <a name="start"></a>
2529
Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as:
@@ -60,33 +64,95 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
6064

6165
#### Encoders <a name="encoders"></a>
6266

63-
| Type | Encoder names |
64-
|------------|---------------------------------------------------------------------------------------------|
65-
| VGG | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn |
66-
| DenseNet | densenet121, densenet169, densenet201, densenet161 |
67-
| DPN | dpn68, dpn68b, dpn92, dpn98, dpn107, dpn131 |
68-
| Inception | inceptionresnetv2 |
69-
| ResNet | resnet18, resnet34, resnet50, resnet101, resnet152 |
70-
| ResNeXt | resnext50_32x4d, resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
71-
| SE-ResNet | se_resnet50, se_resnet101, se_resnet152 |
72-
| SE-ResNeXt | se_resnext50_32x4d, se_resnext101_32x4d |
73-
| SENet | senet154 |
74-
| EfficientNet | efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7
75-
76-
#### Weights <a name="weights"></a>
77-
78-
| Weights name | Encoder names |
79-
|---------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
80-
| imagenet+5k | dpn68b, dpn92, dpn107 |
81-
| imagenet | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn, <br> densenet121, densenet169, densenet201, densenet161, dpn68, dpn98, dpn131, <br> inceptionresnetv2, <br> resnet18, resnet34, resnet50, resnet101, resnet152, <br> resnext50_32x4d, resnext101_32x8d, <br> se_resnet50, se_resnet101, se_resnet152, <br> se_resnext50_32x4d, se_resnext101_32x4d, <br> senet154, <br> efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7 |
82-
| [instagram](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) | resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
67+
|Encoder |Weights |Params, M |
68+
|--------------------------------|:------------------------------:|:------------------------------:|
69+
|resnet18 |imagenet |11M |
70+
|resnet34 |imagenet |21M |
71+
|resnet50 |imagenet |23M |
72+
|resnet101 |imagenet |42M |
73+
|resnet152 |imagenet |58M |
74+
|resnext50_32x4d |imagenet |22M |
75+
|resnext101_32x8d |imagenet<br>instagram |86M |
76+
|resnext101_32x16d |instagram |191M |
77+
|resnext101_32x32d |instagram |466M |
78+
|resnext101_32x48d |instagram |826M |
79+
|dpn68 |imagenet |11M |
80+
|dpn68b |imagenet+5k |11M |
81+
|dpn92 |imagenet+5k |34M |
82+
|dpn98 |imagenet |58M |
83+
|dpn107 |imagenet+5k |84M |
84+
|dpn131 |imagenet |76M |
85+
|vgg11 |imagenet |9M |
86+
|vgg11_bn |imagenet |9M |
87+
|vgg13 |imagenet |9M |
88+
|vgg13_bn |imagenet |9M |
89+
|vgg16 |imagenet |14M |
90+
|vgg16_bn |imagenet |14M |
91+
|vgg19 |imagenet |20M |
92+
|vgg19_bn |imagenet |20M |
93+
|senet154 |imagenet |113M |
94+
|se_resnet50 |imagenet |26M |
95+
|se_resnet101 |imagenet |47M |
96+
|se_resnet152 |imagenet |64M |
97+
|se_resnext50_32x4d |imagenet |25M |
98+
|se_resnext101_32x4d |imagenet |46M |
99+
|densenet121 |imagenet |6M |
100+
|densenet169 |imagenet |12M |
101+
|densenet201 |imagenet |18M |
102+
|densenet161 |imagenet |26M |
103+
|inceptionresnetv2 |imagenet<br>imagenet+background |54M |
104+
|inceptionv4 |imagenet<br>imagenet+background |41M |
105+
|efficientnet-b0 |imagenet |4M |
106+
|efficientnet-b1 |imagenet |6M |
107+
|efficientnet-b2 |imagenet |7M |
108+
|efficientnet-b3 |imagenet |10M |
109+
|efficientnet-b4 |imagenet |17M |
110+
|efficientnet-b5 |imagenet |28M |
111+
|efficientnet-b6 |imagenet |40M |
112+
|efficientnet-b7 |imagenet |63M |
113+
|mobilenet_v2 |imagenet |2M |
83114

84115
### Models API <a name="api"></a>
116+
85117
- `model.encoder` - pretrained backbone to extract features of different spatial resolution
86-
- `model.decoder` - segmentation head, depends on models architecture (`Unet`/`Linknet`/`PSPNet`/`FPN`)
87-
- `model.activation` - output activation function, one of `sigmoid`, `softmax`
88-
- `model.forward(x)` - sequentially pass `x` through model\`s encoder and decoder (return logits!)
89-
- `model.predict(x)` - inference method, switch model to `.eval()` mode, call `.forward(x)` and apply activation function with `torch.no_grad()`
118+
- `model.decoder` - depends on models architecture (`Unet`/`Linknet`/`PSPNet`/`FPN`)
119+
- `model.segmentation_head` - last block to produce required number of mask channels (include also optional upsampling and activation)
120+
- `model.classification_head` - optional block which create classification head on top of encoder
121+
- `model.forward(x)` - sequentially pass `x` through model\`s encoder, decoder and segmentation head (and classification head if specified)
122+
123+
##### Input channels
124+
Input channels parameter allow you to create models, which process tensors with arbitrary number of channels.
125+
If you use pretrained weights from imagenet - weights of first convolution will be reused for
126+
1- or 2- channels inputs, for input channels > 4 weights of first convolution will be initialized randomly.
127+
```python
128+
model = smp.FPN('resnet34', in_channels=1)
129+
mask = model(torch.ones([1, 1, 64, 64]))
130+
```
131+
132+
##### Auxiliary classification output
133+
All models support `aux_params` parameters, which is default set to `None`.
134+
If `aux_params = None` than classification auxiliary output is not created, else
135+
model produce not only `mask`, but also `label` output with shape `NC`.
136+
Classification head consist of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
137+
configured by `aux_params` as follows:
138+
```python
139+
aux_params=dict(
140+
pooling='avg', # one of 'avg', 'max'
141+
dropout=0.5, # dropout ratio, default is None
142+
activation='sigmoid', # activation function, default is None
143+
classes=4, # define number of output labels
144+
)
145+
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
146+
mask, label = model(x)
147+
```
148+
149+
##### Depth
150+
Depth parameter specify a number of downsampling operations in encoder, so you can make
151+
your model lighted if specify smaller `depth`.
152+
```python
153+
model = smp.FPN('resnet34', depth=4)
154+
```
155+
90156

91157
### Installation <a name="installation"></a>
92158
PyPI version:
@@ -97,11 +163,24 @@ Latest version from source:
97163
```bash
98164
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
99165
````
166+
167+
### Competitions won with the library
168+
169+
`Segmentation Models` package is widely used in the image segmentation competitions.
170+
[Here](https://github.com/qubvel/segmentation_models.pytorch/blob/master/HALLOFFAME.md) you can find competitions, names of the winners and links to their solutions.
171+
172+
100173
### License <a name="license"></a>
101174
Project is distributed under [MIT License](https://github.com/qubvel/segmentation_models.pytorch/blob/master/LICENSE)
102175
103-
### Run tests
176+
177+
### Contributing
178+
179+
##### Run test
180+
```bash
181+
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
182+
```
183+
##### Generate table
104184
```bash
105-
$ docker build -f docker/Dockerfile.dev -t smp:dev .
106-
$ docker run --rm smp:dev pytest -p no:cacheprovider
185+
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py
107186
```

docker/Dockerfile.dev

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM anibali/pytorch:cuda-9.0
1+
FROM python:3.6 #anibali/pytorch:cuda-9.0
22

33
WORKDIR /tmp/smp/
44

examples/cars segmentation (camvid).ipynb

Lines changed: 145 additions & 169 deletions
Large diffs are not rendered by default.

misc/generate_table.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import segmentation_models_pytorch as smp
2+
3+
encoders = smp.encoders.encoders
4+
5+
6+
WIDTH = 32
7+
COLUMNS = [
8+
"Encoder",
9+
"Weights",
10+
"Params, M",
11+
]
12+
13+
def wrap_row(r):
14+
return "|{}|".format(r)
15+
16+
header = "|".join([column.ljust(WIDTH, ' ') for column in COLUMNS])
17+
separator = "|".join(["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1))
18+
19+
print(wrap_row(header))
20+
print(wrap_row(separator))
21+
22+
for encoder_name, encoder in encoders.items():
23+
weights = "<br>".join(encoder["pretrained_settings"].keys())
24+
encoder_name = encoder_name.ljust(WIDTH, " ")
25+
weights = weights.ljust(WIDTH, " ")
26+
27+
model = encoder["encoder"](**encoder["params"], depth=5)
28+
params = sum(p.numel() for p in model.parameters())
29+
params = str(params // 1000000) + "M"
30+
params = params.ljust(WIDTH, " ")
31+
32+
row = "|".join([encoder_name, weights, params])
33+
print(wrap_row(row))

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
torchvision>=0.2.2,<=0.4.0
1+
torchvision>=0.3.0
22
pretrainedmodels==0.7.4
3-
efficientnet-pytorch==0.4.0
3+
efficientnet-pytorch>=0.5.1
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
VERSION = (0, 0, 3)
1+
VERSION = (0, 1, 0)
22

33
__version__ = '.'.join(map(str, VERSION))
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1-
from .encoder_decoder import EncoderDecoder
1+
from .model import SegmentationModel
2+
3+
from .modules import (
4+
Conv2dReLU,
5+
Attention,
6+
)
7+
8+
from .heads import (
9+
SegmentationHead,
10+
ClassificationHead,
11+
)

segmentation_models_pytorch/base/encoder_decoder.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)