32
32
class BasicFcRelu (t2t_model .T2TModel ):
33
33
34
34
def body (self , features ):
35
- hparams = self ._hparams
35
+ hparams = self .hparams
36
36
x = features ["inputs" ]
37
37
shape = common_layers .shape_list (x )
38
38
x = tf .reshape (x , [- 1 , shape [1 ] * shape [2 ] * shape [3 ]])
@@ -53,7 +53,7 @@ def __init__(self, *args, **kwargs):
53
53
54
54
def bottleneck (self , x ):
55
55
with tf .variable_scope ("bottleneck" ):
56
- hparams = self ._hparams
56
+ hparams = self .hparams
57
57
x = tf .layers .dense (x , hparams .bottleneck_size , name = "bottleneck" )
58
58
if hparams .mode == tf .estimator .ModeKeys .TRAIN :
59
59
noise = 2.0 * tf .random_uniform (common_layers .shape_list (x )) - 1.0
@@ -68,12 +68,27 @@ def unbottleneck(self, x, res_size):
68
68
def bottleneck_loss (self , b ):
69
69
return 0.0
70
70
71
+ def make_even_size (self , x ):
72
+ shape = [dim if dim is not None else - 1 for dim in x .get_shape ().as_list ()]
73
+ if shape [1 ] % 2 == 0 and shape [2 ] % 2 == 0 :
74
+ return x
75
+ if shape [1 ] % 2 == 0 and self .is1d :
76
+ return x
77
+ x , _ = common_layers .pad_to_same_length (
78
+ x , x , final_length_divisible_by = 2 , axis = 1 )
79
+ if self .is1d :
80
+ return x
81
+ x , _ = common_layers .pad_to_same_length (
82
+ x , x , final_length_divisible_by = 2 , axis = 2 )
83
+ return x
84
+
71
85
def encoder (self , x ):
72
86
with tf .variable_scope ("encoder" ):
73
- hparams = self ._hparams
87
+ hparams = self .hparams
74
88
kernel , strides = self ._get_kernel_and_strides ()
75
89
# Down-convolutions.
76
90
for i in range (hparams .num_hidden_layers ):
91
+ x = self .make_even_size (x )
77
92
x = tf .layers .conv2d (
78
93
x , hparams .hidden_size * 2 ** (i + 1 ), kernel , strides = strides ,
79
94
padding = "SAME" , activation = common_layers .belu , name = "conv_%d" % i )
@@ -82,7 +97,7 @@ def encoder(self, x):
82
97
83
98
def decoder (self , x ):
84
99
with tf .variable_scope ("decoder" ):
85
- hparams = self ._hparams
100
+ hparams = self .hparams
86
101
kernel , strides = self ._get_kernel_and_strides ()
87
102
# Up-convolutions.
88
103
for i in range (hparams .num_hidden_layers ):
@@ -94,19 +109,13 @@ def decoder(self, x):
94
109
return x
95
110
96
111
def body (self , features ):
97
- hparams = self ._hparams
112
+ hparams = self .hparams
98
113
is_training = hparams .mode == tf .estimator .ModeKeys .TRAIN
99
114
if hparams .mode != tf .estimator .ModeKeys .PREDICT :
100
115
x = features ["targets" ]
101
116
shape = common_layers .shape_list (x )
102
117
is1d = shape [2 ] == 1
103
118
self .is1d = is1d
104
- x , _ = common_layers .pad_to_same_length (
105
- x , x , final_length_divisible_by = 2 ** hparams .num_hidden_layers , axis = 1 )
106
- if not is1d :
107
- x , _ = common_layers .pad_to_same_length (
108
- x , x , final_length_divisible_by = 2 ** hparams .num_hidden_layers ,
109
- axis = 2 )
110
119
# Run encoder.
111
120
x = self .encoder (x )
112
121
# Bottleneck (mix during early training, not too important but stable).
@@ -122,21 +131,21 @@ def body(self, features):
122
131
x = b
123
132
else :
124
133
b = self .sample ()
125
- res_size = self ._hparams .hidden_size * 2 ** self ._hparams .num_hidden_layers
134
+ res_size = self .hparams .hidden_size * 2 ** self .hparams .num_hidden_layers
126
135
res_size = min (res_size , hparams .max_hidden_size )
127
136
x = self .unbottleneck (b , res_size )
128
137
# Run decoder.
129
138
x = self .decoder (x )
130
139
if hparams .mode == tf .estimator .ModeKeys .PREDICT :
131
- return x
140
+ return x , { "bottleneck_loss" : 0.0 }
132
141
# Cut to the right size and mix before returning.
133
142
res = x [:, :shape [1 ], :shape [2 ], :]
134
143
res = common_layers .mix (res , features ["targets" ],
135
144
hparams .bottleneck_warmup_steps // 2 , is_training )
136
145
return res , {"bottleneck_loss" : b_loss }
137
146
138
147
def sample (self ):
139
- hp = self ._hparams
148
+ hp = self .hparams
140
149
div_x = 2 ** hp .num_hidden_layers
141
150
div_y = 1 if self .is1d else 2 ** hp .num_hidden_layers
142
151
size = [hp .batch_size , hp .sample_height // div_x , hp .sample_width // div_y ,
@@ -158,11 +167,11 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
158
167
# Sample and decode.
159
168
# TODO(lukaszkaiser): is this a universal enough way to get channels?
160
169
try :
161
- num_channels = self ._hparams .problem .num_channels
170
+ num_channels = self .hparams .problem .num_channels
162
171
except AttributeError :
163
172
num_channels = 1
164
173
features ["targets" ] = tf .zeros (
165
- [self ._hparams .batch_size , 1 , 1 , num_channels ],
174
+ [self .hparams .batch_size , 1 , 1 , num_channels ],
166
175
dtype = tf .int32 )
167
176
logits , _ = self (features ) # pylint: disable=not-callable
168
177
samples = tf .argmax (logits , axis = - 1 )
@@ -175,7 +184,7 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
175
184
return samples
176
185
177
186
def _get_kernel_and_strides (self ):
178
- hparams = self ._hparams
187
+ hparams = self .hparams
179
188
kernel = (hparams .kernel_height , hparams .kernel_width )
180
189
kernel = (hparams .kernel_height , 1 ) if self .is1d else kernel
181
190
strides = (2 , 1 ) if self .is1d else (2 , 2 )
0 commit comments