Skip to content

Commit 77c80f0

Browse files
authored
Add MIT Places 365 GoogleNet model (#935)
1 parent cd672a0 commit 77c80f0

File tree

7 files changed

+1159
-28
lines changed

7 files changed

+1159
-28
lines changed

captum/optim/models/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
)
99
from ._image.inception5h_classes import INCEPTION5H_CLASSES # noqa: F401
1010
from ._image.inception_v1 import InceptionV1, googlenet # noqa: F401
11+
from ._image.inception_v1_places365 import ( # noqa: F401
12+
InceptionV1Places365,
13+
googlenet_places365,
14+
)
15+
from ._image.inception_v1_places365_classes import ( # noqa: F401
16+
INCEPTIONV1_PLACES365_CLASSES,
17+
)
1118

1219
__all__ = [
1320
"RedirectedReluLayer",
@@ -19,4 +26,7 @@
1926
"InceptionV1",
2027
"googlenet",
2128
"INCEPTION5H_CLASSES",
29+
"InceptionV1Places365",
30+
"googlenet_places365",
31+
"INCEPTIONV1_PLACES365_CLASSES",
2232
]

captum/optim/models/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class RedirectedReluLayer(nn.Module):
6363
Class for applying RedirectedReLU
6464
"""
6565

66+
@torch.jit.ignore
6667
def forward(self, input: torch.Tensor) -> torch.Tensor:
6768
return RedirectedReLU.apply(input)
6869

captum/optim/models/_image/inception_v1.py

Lines changed: 126 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Optional, Tuple, Type, Union, cast
1+
from typing import Optional, Tuple, Type, Union
2+
from warnings import warn
23

34
import torch
45
import torch.nn as nn
@@ -19,24 +20,37 @@ def googlenet(
1920
) -> "InceptionV1":
2021
r"""GoogLeNet (also known as Inception v1 & Inception 5h) model architecture from
2122
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
23+
2224
Args:
25+
2326
pretrained (bool, optional): If True, returns a model pre-trained on ImageNet.
27+
Default: False
2428
progress (bool, optional): If True, displays a progress bar of the download to
2529
stderr
30+
Default: True
2631
model_path (str, optional): Optional path for InceptionV1 model file.
32+
Default: None
2733
replace_relus_with_redirectedrelu (bool, optional): If True, return pretrained
2834
model with Redirected ReLU in place of ReLU layers.
35+
Default: *True* when pretrained is True otherwise *False*
2936
use_linear_modules_only (bool, optional): If True, return pretrained
3037
model with all nonlinear layers replaced with linear equivalents.
38+
Default: False
3139
aux_logits (bool, optional): If True, adds two auxiliary branches that can
32-
improve training. Default: *False* when pretrained is True otherwise *True*
40+
improve training.
41+
Default: False
3342
out_features (int, optional): Number of output features in the model used for
34-
training. Default: 1008 when pretrained is True.
43+
training.
44+
Default: 1008
3545
transform_input (bool, optional): If True, preprocesses the input according to
36-
the method with which it was trained on ImageNet. Default: *False*
46+
the method with which it was trained on ImageNet.
47+
Default: False
3748
bgr_transform (bool, optional): If True and transform_input is True, perform an
3849
RGB to BGR transform in the internal preprocessing.
39-
Default: *False*
50+
Default: False
51+
52+
Returns:
53+
**InceptionV1** (InceptionV1): An Inception5h model.
4054
"""
4155

4256
if pretrained:
@@ -69,6 +83,8 @@ def googlenet(
6983

7084
# Better version of Inception V1 / GoogleNet for Inception5h
7185
class InceptionV1(nn.Module):
86+
__constants__ = ["aux_logits", "transform_input", "bgr_transform"]
87+
7288
def __init__(
7389
self,
7490
out_features: int = 1008,
@@ -78,7 +94,29 @@ def __init__(
7894
replace_relus_with_redirectedrelu: bool = False,
7995
use_linear_modules_only: bool = False,
8096
) -> None:
81-
super(InceptionV1, self).__init__()
97+
"""
98+
Args:
99+
100+
replace_relus_with_redirectedrelu (bool, optional): If True, return
101+
pretrained model with Redirected ReLU in place of ReLU layers.
102+
Default: False
103+
use_linear_modules_only (bool, optional): If True, return pretrained
104+
model with all nonlinear layers replaced with linear equivalents.
105+
Default: False
106+
aux_logits (bool, optional): If True, adds two auxiliary branches that can
107+
improve training.
108+
Default: False
109+
out_features (int, optional): Number of output features in the model used
110+
for training.
111+
Default: 1008
112+
transform_input (bool, optional): If True, preprocesses the input according
113+
to the method with which it was trained on ImageNet.
114+
Default: False
115+
bgr_transform (bool, optional): If True and transform_input is True,
116+
perform an RGB to BGR transform in the internal preprocessing.
117+
Default: False
118+
"""
119+
super().__init__()
82120
self.aux_logits = aux_logits
83121
self.transform_input = transform_input
84122
self.bgr_transform = bgr_transform
@@ -99,7 +137,6 @@ def __init__(
99137
out_channels=64,
100138
kernel_size=(7, 7),
101139
stride=(2, 2),
102-
padding=3,
103140
groups=1,
104141
bias=True,
105142
)
@@ -121,7 +158,6 @@ def __init__(
121158
out_channels=192,
122159
kernel_size=(3, 3),
123160
stride=(1, 1),
124-
padding=1,
125161
groups=1,
126162
bias=True,
127163
)
@@ -163,9 +199,18 @@ def __init__(
163199
self.fc = nn.Linear(1024, out_features)
164200

165201
def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
202+
"""
203+
Args:
204+
205+
x (torch.Tensor): An input tensor to normalize and scale the values of.
206+
207+
Returns:
208+
x (torch.Tensor): A transformed tensor.
209+
"""
166210
if self.transform_input:
167211
assert x.dim() == 3 or x.dim() == 4
168-
assert x.min() >= 0.0 and x.max() <= 1.0
212+
if x.min() < 0.0 or x.max() > 1.0:
213+
warn("Model input has values outside of the range [0, 1].")
169214
x = x.unsqueeze(0) if x.dim() == 3 else x
170215
x = x * 255 - 117
171216
x = x[:, [2, 1, 0]] if self.bgr_transform else x
@@ -174,6 +219,15 @@ def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
174219
def forward(
175220
self, x: torch.Tensor
176221
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
222+
"""
223+
Args:
224+
225+
x (torch.Tensor): An input tensor to normalize and scale the values of.
226+
227+
Returns:
228+
x (torch.Tensor or tuple of torch.Tensor): A single or multiple output
229+
tensors from the model.
230+
"""
177231
x = self._transform_input(x)
178232
x = self.conv1(x)
179233
x = self.conv1_relu(x)
@@ -212,7 +266,7 @@ def forward(
212266
x = self.drop(x)
213267
x = self.fc(x)
214268
if not self.aux_logits:
215-
return cast(torch.Tensor, x)
269+
return x
216270
else:
217271
return x, aux1_output, aux2_output
218272

@@ -230,7 +284,25 @@ def __init__(
230284
activ: Type[nn.Module] = nn.ReLU,
231285
p_layer: Type[nn.Module] = nn.MaxPool2d,
232286
) -> None:
233-
super(InceptionModule, self).__init__()
287+
"""
288+
Args:
289+
290+
in_channels (int, optional): The number of input channels to use for the
291+
inception module.
292+
c1x1 (int, optional):
293+
c3x3reduce (int, optional):
294+
c3x3 (int, optional):
295+
c5x5reduce (int, optional):
296+
c5x5 (int, optional):
297+
pool_proj (int, optional):
298+
activ (type of nn.Module, optional): The nn.Module class type to use for
299+
activation layers.
300+
Default: nn.ReLU
301+
p_layer (type of nn.Module, optional): The nn.Module class type to use for
302+
pooling layers.
303+
Default: nn.MaxPool2d
304+
"""
305+
super().__init__()
234306
self.conv_1x1 = nn.Conv2d(
235307
in_channels=in_channels,
236308
out_channels=c1x1,
@@ -254,7 +326,6 @@ def __init__(
254326
out_channels=c3x3,
255327
kernel_size=(3, 3),
256328
stride=(1, 1),
257-
padding=1,
258329
groups=1,
259330
bias=True,
260331
)
@@ -273,7 +344,6 @@ def __init__(
273344
out_channels=c5x5,
274345
kernel_size=(5, 5),
275346
stride=(1, 1),
276-
padding=1,
277347
groups=1,
278348
bias=True,
279349
)
@@ -289,6 +359,14 @@ def __init__(
289359
)
290360

291361
def forward(self, x: torch.Tensor) -> torch.Tensor:
362+
"""
363+
Args:
364+
365+
x (torch.Tensor): An input tensor to pass through the Inception Module.
366+
367+
Returns:
368+
x (torch.Tensor): The output tensor of the Inception Module.
369+
"""
292370
c1x1 = self.conv_1x1(x)
293371

294372
c3x3 = self.conv_3x3_reduce(x)
@@ -311,31 +389,51 @@ def __init__(
311389
out_features: int = 1008,
312390
activ: Type[nn.Module] = nn.ReLU,
313391
) -> None:
314-
super(AuxBranch, self).__init__()
392+
"""
393+
Args:
394+
395+
in_channels (int, optional): The number of input channels to use for the
396+
auxiliary branch.
397+
Default: 508
398+
out_features (int, optional): The number of output features to use for the
399+
auxiliary branch.
400+
Default: 1008
401+
activ (type of nn.Module, optional): The nn.Module class type to use for
402+
activation layers.
403+
Default: nn.ReLU
404+
"""
405+
super().__init__()
315406
self.avg_pool = nn.AdaptiveAvgPool2d((4, 4))
316-
self.loss_conv = nn.Conv2d(
407+
self.conv = nn.Conv2d(
317408
in_channels=in_channels,
318409
out_channels=128,
319410
kernel_size=(1, 1),
320411
stride=(1, 1),
321412
groups=1,
322413
bias=True,
323414
)
324-
self.loss_conv_relu = activ()
325-
self.loss_fc = nn.Linear(in_features=2048, out_features=1024, bias=True)
326-
self.loss_fc_relu = activ()
327-
self.loss_dropout = nn.Dropout(0.699999988079071)
328-
self.loss_classifier = nn.Linear(
329-
in_features=1024, out_features=out_features, bias=True
330-
)
415+
self.conv_relu = activ()
416+
self.fc1 = nn.Linear(in_features=2048, out_features=1024, bias=True)
417+
self.fc1_relu = activ()
418+
self.dropout = nn.Dropout(0.699999988079071)
419+
self.fc2 = nn.Linear(in_features=1024, out_features=out_features, bias=True)
331420

332421
def forward(self, x: torch.Tensor) -> torch.Tensor:
422+
"""
423+
Args:
424+
425+
x (torch.Tensor): An input tensor to pass through the auxiliary branch
426+
module.
427+
428+
Returns:
429+
x (torch.Tensor): The output tensor of the auxiliary branch module.
430+
"""
333431
x = self.avg_pool(x)
334-
x = self.loss_conv(x)
335-
x = self.loss_conv_relu(x)
432+
x = self.conv(x)
433+
x = self.conv_relu(x)
336434
x = torch.flatten(x, 1)
337-
x = self.loss_fc(x)
338-
x = self.loss_fc_relu(x)
339-
x = self.loss_dropout(x)
340-
x = self.loss_classifier(x)
435+
x = self.fc1(x)
436+
x = self.fc1_relu(x)
437+
x = self.dropout(x)
438+
x = self.fc2(x)
341439
return x

0 commit comments

Comments
 (0)