Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d901097

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Autoencoder improvements and related corrections; ordered and discrete residual MNIST samples look reasonable.
PiperOrigin-RevId: 194460369
1 parent 25b101d commit d901097

File tree

5 files changed

+142
-61
lines changed

5 files changed

+142
-61
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2798,10 +2798,18 @@ def dense(x, units, **kwargs):
27982798

27992799

28002800
def mix(x1, x2, steps, is_training,
2801-
min_prob=0.0, max_prob=1.0, mode="lin", simple=False):
2801+
min_prob=0.0, max_prob=1.0,
2802+
mode="lin", simple=False, broadcast_last=False):
28022803
"""Mix starting with x2, mixing mixing, going towards x1."""
28032804
if not is_training:
2804-
return x1
2805+
if max_prob >= 1.0:
2806+
return x1
2807+
alpha_shape = shape_list(x1)
2808+
if broadcast_last:
2809+
alpha_shape = alpha_shape[:-1] + [1]
2810+
alpha = tf.random_uniform(alpha_shape)
2811+
alpha = tf.to_float(tf.less(alpha, max_prob))
2812+
return alpha * x1 + (1.0 - alpha) * x2
28052813

28062814
def get_res():
28072815
"""Create the result. Separate function to speed it up later (see below)."""
@@ -2812,7 +2820,10 @@ def get_res():
28122820
alpha_p = alpha_p * (max_prob - min_prob) + min_prob
28132821
if simple:
28142822
return alpha_p * x1 + (1.0 - alpha_p) * x2
2815-
alpha = tf.random_uniform(shape_list(x1))
2823+
alpha_shape = shape_list(x1)
2824+
if broadcast_last:
2825+
alpha_shape = alpha_shape[:-1] + [1]
2826+
alpha = tf.random_uniform(alpha_shape)
28162827
alpha = tf.to_float(tf.less(alpha, alpha_p))
28172828
return alpha * x1 + (1.0 - alpha) * x2
28182829

tensor2tensor/layers/discretization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,9 @@ def isemhash_bottleneck(x, bottleneck_size, bottleneck_noise,
682682
noise = tf.random_uniform(common_layers.shape_list(x))
683683
noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
684684
d *= noise
685-
d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps,
686-
mode == tf.estimator.ModeKeys.TRAIN,
687-
max_prob=isemhash_mix_prob)
685+
d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps,
686+
mode == tf.estimator.ModeKeys.TRAIN,
687+
max_prob=isemhash_mix_prob)
688688
return d
689689

690690

tensor2tensor/models/basic.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class BasicFcRelu(t2t_model.T2TModel):
3333

3434
def body(self, features):
35-
hparams = self._hparams
35+
hparams = self.hparams
3636
x = features["inputs"]
3737
shape = common_layers.shape_list(x)
3838
x = tf.reshape(x, [-1, shape[1] * shape[2] * shape[3]])
@@ -53,7 +53,7 @@ def __init__(self, *args, **kwargs):
5353

5454
def bottleneck(self, x):
5555
with tf.variable_scope("bottleneck"):
56-
hparams = self._hparams
56+
hparams = self.hparams
5757
x = tf.layers.dense(x, hparams.bottleneck_size, name="bottleneck")
5858
if hparams.mode == tf.estimator.ModeKeys.TRAIN:
5959
noise = 2.0 * tf.random_uniform(common_layers.shape_list(x)) - 1.0
@@ -68,12 +68,27 @@ def unbottleneck(self, x, res_size):
6868
def bottleneck_loss(self, b):
6969
return 0.0
7070

71+
def make_even_size(self, x):
72+
shape = [dim if dim is not None else -1 for dim in x.get_shape().as_list()]
73+
if shape[1] % 2 == 0 and shape[2] % 2 == 0:
74+
return x
75+
if shape[1] % 2 == 0 and self.is1d:
76+
return x
77+
x, _ = common_layers.pad_to_same_length(
78+
x, x, final_length_divisible_by=2, axis=1)
79+
if self.is1d:
80+
return x
81+
x, _ = common_layers.pad_to_same_length(
82+
x, x, final_length_divisible_by=2, axis=2)
83+
return x
84+
7185
def encoder(self, x):
7286
with tf.variable_scope("encoder"):
73-
hparams = self._hparams
87+
hparams = self.hparams
7488
kernel, strides = self._get_kernel_and_strides()
7589
# Down-convolutions.
7690
for i in range(hparams.num_hidden_layers):
91+
x = self.make_even_size(x)
7792
x = tf.layers.conv2d(
7893
x, hparams.hidden_size * 2**(i + 1), kernel, strides=strides,
7994
padding="SAME", activation=common_layers.belu, name="conv_%d" % i)
@@ -82,7 +97,7 @@ def encoder(self, x):
8297

8398
def decoder(self, x):
8499
with tf.variable_scope("decoder"):
85-
hparams = self._hparams
100+
hparams = self.hparams
86101
kernel, strides = self._get_kernel_and_strides()
87102
# Up-convolutions.
88103
for i in range(hparams.num_hidden_layers):
@@ -94,19 +109,13 @@ def decoder(self, x):
94109
return x
95110

96111
def body(self, features):
97-
hparams = self._hparams
112+
hparams = self.hparams
98113
is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
99114
if hparams.mode != tf.estimator.ModeKeys.PREDICT:
100115
x = features["targets"]
101116
shape = common_layers.shape_list(x)
102117
is1d = shape[2] == 1
103118
self.is1d = is1d
104-
x, _ = common_layers.pad_to_same_length(
105-
x, x, final_length_divisible_by=2**hparams.num_hidden_layers, axis=1)
106-
if not is1d:
107-
x, _ = common_layers.pad_to_same_length(
108-
x, x, final_length_divisible_by=2**hparams.num_hidden_layers,
109-
axis=2)
110119
# Run encoder.
111120
x = self.encoder(x)
112121
# Bottleneck (mix during early training, not too important but stable).
@@ -122,21 +131,21 @@ def body(self, features):
122131
x = b
123132
else:
124133
b = self.sample()
125-
res_size = self._hparams.hidden_size * 2**self._hparams.num_hidden_layers
134+
res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
126135
res_size = min(res_size, hparams.max_hidden_size)
127136
x = self.unbottleneck(b, res_size)
128137
# Run decoder.
129138
x = self.decoder(x)
130139
if hparams.mode == tf.estimator.ModeKeys.PREDICT:
131-
return x
140+
return x, {"bottleneck_loss": 0.0}
132141
# Cut to the right size and mix before returning.
133142
res = x[:, :shape[1], :shape[2], :]
134143
res = common_layers.mix(res, features["targets"],
135144
hparams.bottleneck_warmup_steps // 2, is_training)
136145
return res, {"bottleneck_loss": b_loss}
137146

138147
def sample(self):
139-
hp = self._hparams
148+
hp = self.hparams
140149
div_x = 2**hp.num_hidden_layers
141150
div_y = 1 if self.is1d else 2**hp.num_hidden_layers
142151
size = [hp.batch_size, hp.sample_height // div_x, hp.sample_width // div_y,
@@ -158,11 +167,11 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
158167
# Sample and decode.
159168
# TODO(lukaszkaiser): is this a universal enough way to get channels?
160169
try:
161-
num_channels = self._hparams.problem.num_channels
170+
num_channels = self.hparams.problem.num_channels
162171
except AttributeError:
163172
num_channels = 1
164173
features["targets"] = tf.zeros(
165-
[self._hparams.batch_size, 1, 1, num_channels],
174+
[self.hparams.batch_size, 1, 1, num_channels],
166175
dtype=tf.int32)
167176
logits, _ = self(features) # pylint: disable=not-callable
168177
samples = tf.argmax(logits, axis=-1)
@@ -175,7 +184,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
175184
return samples
176185

177186
def _get_kernel_and_strides(self):
178-
hparams = self._hparams
187+
hparams = self.hparams
179188
kernel = (hparams.kernel_height, hparams.kernel_width)
180189
kernel = (hparams.kernel_height, 1) if self.is1d else kernel
181190
strides = (2, 1) if self.is1d else (2, 2)

0 commit comments

Comments
 (0)