diff --git a/code/systems/s_neural_texture.py b/code/systems/s_neural_texture.py index 098611c..316b53a 100644 --- a/code/systems/s_neural_texture.py +++ b/code/systems/s_neural_texture.py @@ -65,8 +65,9 @@ def forward(self, weights, position, seed): transform_coeff, z_encoding = torch.split(weights, [self.p.texture.t, self.p.texture.e], dim=1) - z_encoding = z_encoding.view(bs, self.p.texture.e, 1, 1) - z_encoding = z_encoding.expand(bs, self.p.texture.e, self.image_res, self.image_res) + if z_encoding.shape[2] == 1: + z_encoding = z_encoding.view(bs, self.p.texture.e, 1, 1) + z_encoding = z_encoding.expand(bs, self.p.texture.e, self.image_res, self.image_res) position = position.unsqueeze(1).expand(bs, self.p.noise.octaves, self.p.dim, h, w) position = position.permute(0, 1, 3, 4, 2) @@ -251,7 +252,7 @@ def test_step(self, batch, batch_nb): z_texture_interpolated = torch.stack(weight_list, dim=2).unsqueeze(-2) - z_texture_interpolated = z_texture_interpolated[:, :-2] + # z_texture_interpolated = z_texture_interpolated[:, :-2] latent_space = z_texture_interpolated.shape[1] z_texture_interpolated = z_texture_interpolated.expand(1, latent_space, self.p.image.image_res, self.p.image.image_res) diff --git a/code/utils/neural_texture_helper.py b/code/utils/neural_texture_helper.py index 9096b02..ef04de6 100644 --- a/code/utils/neural_texture_helper.py +++ b/code/utils/neural_texture_helper.py @@ -102,9 +102,14 @@ def transform_coord(coord, t_coeff, dim): bs, octaves, h, w, dim = coord.size() - t_coeff = t_coeff.reshape(bs, octaves, dim, dim).unsqueeze(2).unsqueeze(2) - - t_coeff = t_coeff.expand(bs, octaves, h, w, dim, dim) + inter = (t_coeff.shape[2] != 1) + if inter: + t_coeff = t_coeff.reshape(bs, octaves, dim, dim, h, w) + t_coeff = t_coeff.permute(0, 1, 4, 5, 2, 3) + else: + t_coeff = t_coeff.reshape(bs, octaves, dim, dim).unsqueeze(2).unsqueeze(2) + + t_coeff = t_coeff.expand(bs, octaves, h, w, dim, dim) t_coeff = t_coeff.reshape(bs * octaves, h, w, dim, dim) transform_matrix = identity_matrix.expand(bs * octaves, dim, dim)