From 3d689bcd7e843441e4ee0589d451721569620b0b Mon Sep 17 00:00:00 2001 From: Yangze Luo Date: Fri, 8 Mar 2019 18:57:04 +0800 Subject: [PATCH 1/2] fix generator input size not equal outpout size --- src/networks.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/networks.py b/src/networks.py index 1444f1091..0cd99e345 100644 --- a/src/networks.py +++ b/src/networks.py @@ -1,6 +1,19 @@ import torch import torch.nn as nn +def output_align(input, output): + """ + author: @youyuge34 (https://github.com/youyuge34) + In testing, sometimes output is several pixels less than irregular-size input, + here is to fill them + """ + if output.size() != input.size(): + diff_width = input.size(-1) - output.size(-1) + diff_height = input.size(-2) - output.size(-2) + m = nn.ReplicationPad2d((0, diff_width, 0, diff_height)) + output = m(output) + + return output class BaseNetwork(nn.Module): def __init__(self): @@ -78,11 +91,12 @@ def __init__(self, residual_blocks=8, init_weights=True): self.init_weights() def forward(self, x): + inpt = x x = self.encoder(x) x = self.middle(x) x = self.decoder(x) x = (torch.tanh(x) + 1) / 2 - + x = output_align(inpt, x) return x @@ -129,10 +143,12 @@ def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True) self.init_weights() def forward(self, x): + inpt = x x = self.encoder(x) x = self.middle(x) x = self.decoder(x) x = torch.sigmoid(x) + x = output_align(inpt, x) return x From 35b728cd164b2c13b5dc831c17f6e7d5433756ba Mon Sep 17 00:00:00 2001 From: Yangze Luo Date: Fri, 8 Mar 2019 19:23:22 +0800 Subject: [PATCH 2/2] fix input/output size equation in output_align --- src/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/networks.py b/src/networks.py index 0cd99e345..f0c374e2b 100644 --- a/src/networks.py +++ b/src/networks.py @@ -7,7 +7,7 @@ def output_align(input, output): In testing, sometimes output is several pixels less than irregular-size input, here is to fill them """ - if output.size() != input.size(): + if output.size()[-2:] != input.size()[-2:]: diff_width = input.size(-1) - output.size(-1) diff_height = input.size(-2) - output.size(-2) m = nn.ReplicationPad2d((0, diff_width, 0, diff_height))