diff --git a/core/models/cgnet.py b/core/models/cgnet.py index 9cae5c837..3c7aa6764 100644 --- a/core/models/cgnet.py +++ b/core/models/cgnet.py @@ -72,27 +72,22 @@ def forward(self, x): # stage 2 out0_cat = self.bn_prelu1(torch.cat([out0, inp1], dim=1)) out1_0 = self.stage2_0(out0_cat) - for i, layer in enumerate(self.stage2): - if i == 0: - out1 = layer(out1_0) - else: - out1 = layer(out1) + out1 = out1_0 + for layer in self.stage2: + out1 = layer(out1) out1_cat = self.bn_prelu2(torch.cat([out1, out1_0, inp2], dim=1)) # stage 3 out2_0 = self.stage3_0(out1_cat) - for i, layer in enumerate(self.stage3): - if i == 0: - out2 = layer(out2_0) - else: - out2 = layer(out2) + out2 = out2_0 + for layer in self.stage3: + out2 = layer(out2) out2_cat = self.bn_prelu3(torch.cat([out2_0, out2], dim=1)) - outputs = [] out = self.head(out2_cat) out = F.interpolate(out, size, mode='bilinear', align_corners=True) - outputs.append(out) - return tuple(outputs) + + return (out,) class _ChannelWiseConv(nn.Module): @@ -158,6 +153,7 @@ def __init__(self, in_channels, out_channels, dilation=2, reduction=16, down=Fal self.reduce = nn.Conv2d(inter_channels * 2, out_channels, 1, bias=False) else: self.conv = _ConvBNPReLU(in_channels, inter_channels, 1, 1, 0, norm_layer=norm_layer, **kwargs) + self.reduce = nn.Identity() self.f_loc = _ChannelWiseConv(inter_channels, inter_channels, **kwargs) self.f_sur = _ChannelWiseConv(inter_channels, inter_channels, dilation, **kwargs) self.bn = norm_layer(inter_channels * 2)