@@ -35,18 +35,26 @@ The model was trained on the standard [MNIST](http://yann.lecun.com/exdb/mnist/)
35
35
** Step 1.**
36
36
Clone this repository with `` git `` and install project dependencies.
37
37
38
- ```
38
+ ``` bash
39
39
$ git clone https://github.com/cedrickchee/capsule-net-pytorch.git
40
40
$ cd capsule-net-pytorch
41
41
$ pip install -r requirements.txt
42
42
```
43
43
44
44
** Step 2.**
45
45
Start the training and evaluation:
46
- ```
46
+
47
+ - running on CPU
48
+ ``` bash
47
49
$ python main.py
48
50
```
49
51
52
+ - running on GPU
53
+ - For example, running on 8 GPUs.
54
+ ``` bash
55
+ $ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py --epochs 30 --threads 16 --batch-size 128 --test-batch-size 128
56
+ ```
57
+
50
58
** The default hyper parameters:**
51
59
52
60
| Parameter | Value | CLI arguments |
@@ -68,8 +76,31 @@ $ python main.py
68
76
| Regularization coefficient for reconstruction loss | 0.0005 | --regularization-scale 0.0005 |
69
77
70
78
## Results
71
- - training loss
79
+
80
+ ### Test error
81
+
82
+ CapsNet classification test error on MNIST. The MNIST average and standard deviation results are reported from 3 trials.
83
+
84
+ [ WIP] The results can be reproduced by running the following commands.
85
+
86
+ ``` bash
87
+ python main.py --num-routing 1 --regularization-scale 0.0 # CapsNet-v1
88
+ python main.py --num-routing 1 --regularization-scale 0.0005 # CapsNet-v2
89
+ python main.py --num-routing 3 --regularization-scale 0.0 # CapsNet-v3
90
+ python main.py --num-routing 3 --regularization-scale 0.0005 # CapsNet-v4
72
91
```
92
+
93
+ Method | Routing | Reconstruction | MNIST (%) | * Paper*
94
+ :---------|:------:|:---:|:----:|:----:
95
+ Baseline | -- | -- | -- | * 0.39*
96
+ CapsNet-v1 | 1 | no | -- | * 0.34 (0.032)*
97
+ CapsNet-v2 | 1 | yes | -- | * 0.29 (0.011)*
98
+ CapsNet-v3 | 3 | no | -- | * 0.35 (0.036)*
99
+ CapsNet-v4 | 3 | yes | -- | * 0.25 (0.005)*
100
+
101
+ ### Training loss
102
+
103
+ ``` text
73
104
# Log from the end of the last epoch.
74
105
75
106
... ... ... ... ... ... ... ... ... ... ...
@@ -116,15 +147,30 @@ Epoch: 10 [59776/60000 (100%)] Loss: 0.029488
116
147
Epoch: 10 [44928/60000 (100%)] Loss: 0.045561
117
148
```
118
149
119
- - evaluation accuracy
120
- ```
150
+ ### Evaluation accuracy
151
+ ``` text
121
152
Test set: Average loss: 0.0004, Accuracy: 9885/10000 (99%)
122
153
Checkpoint saved to model_epoch_10.pth
123
154
```
124
155
156
+ ### Reconstruction
157
+
158
+ The results of CapsNet-v4.
159
+
160
+ Digits at left are reconstructed images.
161
+ <table >
162
+ <tr >
163
+ <td>
164
+ <img src="results/reconstructed_images.png"/>
165
+ </td>
166
+ <td>
167
+ </td>
168
+ </tr >
169
+ </table >
170
+
125
171
## TODO
126
172
- [ DONE] Publish results.
127
- - [ WIP ] More testing.
173
+ - [ DONE ] More testing.
128
174
- Separate training and evaluation into independent command.
129
175
- Jupyter Notebook version.
130
176
- Create a sample to show how we can apply CapsNet to real-world application.
0 commit comments