Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b9dce9b

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Play with a new model with Transformer with a GAN'y part.
PiperOrigin-RevId: 174403736
1 parent c022afd commit b9dce9b

File tree

5 files changed

+238
-0
lines changed

5 files changed

+238
-0
lines changed

tensor2tensor/layers/common_hparams.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,15 @@ def basic_params1():
116116
# If set to True, drop sequences longer than max_length during eval.
117117
# This affects the validity of the evaluation metrics.
118118
eval_drop_long_sequences=int(False),
119+
# TODO(lukaszkaiser): these parameters should probably be set elsewhere.
119120
# in SymbolModality, share the output embeddings and the softmax
120121
# variables.
121122
# You can also share the input embeddings with the output embeddings
122123
# by using a problem_hparams that uses the same modality object for
123124
# the input_modality and target_modality.
124125
shared_embedding_and_softmax_weights=int(False),
126+
# In SymbolModality, skip the top layer, assume we're providing logits.
127+
symbol_modality_skip_top=int(False),
125128
# For each feature for which you want to override the default input
126129
# modality, add an entry to this semicolon-separated string. Entries are
127130
# formatted "feature_name:modality_type:modality_name", e.g.

tensor2tensor/layers/modalities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def top(self, body_output, _):
115115
else:
116116
scope_name = "softmax"
117117
reuse = False
118+
if self._model_hparams.symbol_modality_skip_top:
119+
return tf.expand_dims(body_output, 3)
118120
with tf.variable_scope(scope_name, reuse=reuse):
119121
var = self._get_weights()
120122
if (self._model_hparams.factored_logits and

tensor2tensor/layers/modalities_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def testSymbolModalityInputs(self):
4040
symbol_modality_num_shards=4,
4141
hidden_size=hidden_size,
4242
multiply_embedding_mode="sqrt_depth",
43+
symbol_modality_skip_top=0,
4344
shared_embedding_and_softmax_weights=0)
4445
x = -1 + np.random.random_integers(
4546
vocab_size, size=(batch_size, length, 1, 1))
@@ -65,6 +66,7 @@ def testSymbolModalityTargets(self):
6566
symbol_modality_num_shards=4,
6667
hidden_size=hidden_size,
6768
label_smoothing=0.2,
69+
symbol_modality_skip_top=0,
6870
shared_embedding_and_softmax_weights=0,
6971
factored_logits=0,
7072
mode=tf.estimator.ModeKeys.TRAIN)
@@ -99,6 +101,7 @@ def testSymbolModalityTargetsFactored(self):
99101
symbol_modality_num_shards=4,
100102
hidden_size=hidden_size,
101103
label_smoothing=0.2,
104+
symbol_modality_skip_top=0,
102105
shared_embedding_and_softmax_weights=0,
103106
factored_logits=1,
104107
mode=tf.estimator.ModeKeys.TRAIN)

tensor2tensor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from tensor2tensor.models import shake_shake
3737
from tensor2tensor.models import slicenet
3838
from tensor2tensor.models import transformer
39+
from tensor2tensor.models import transformer_adv
3940
from tensor2tensor.models import transformer_alternative
4041
from tensor2tensor.models import transformer_moe
4142
from tensor2tensor.models import transformer_revnet
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Adversarial Transformer."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
# Dependency imports
23+
24+
from tensor2tensor.layers import common_layers
25+
from tensor2tensor.models import transformer
26+
from tensor2tensor.models import transformer_vae
27+
from tensor2tensor.utils import registry
28+
from tensor2tensor.utils import t2t_model
29+
30+
import tensorflow as tf
31+
32+
33+
def encode(x, x_space, hparams, name):
34+
"""Transformer preparations and encoder."""
35+
with tf.variable_scope(name):
36+
(encoder_input, encoder_self_attention_bias,
37+
ed) = transformer.transformer_prepare_encoder(x, x_space, hparams)
38+
encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
39+
return transformer.transformer_encoder(
40+
encoder_input, encoder_self_attention_bias, hparams), ed
41+
42+
43+
def decode(encoder_output, encoder_decoder_attention_bias, targets,
44+
hparams, name, reuse=False):
45+
"""Transformer decoder."""
46+
with tf.variable_scope(name, reuse=reuse):
47+
targets = common_layers.flatten4d3d(targets)
48+
49+
decoder_input, decoder_self_bias = transformer.transformer_prepare_decoder(
50+
targets, hparams)
51+
52+
decoder_input = tf.nn.dropout(decoder_input,
53+
1.0 - hparams.layer_prepostprocess_dropout)
54+
55+
decoder_output = transformer.transformer_decoder(
56+
decoder_input,
57+
encoder_output,
58+
decoder_self_bias,
59+
encoder_decoder_attention_bias,
60+
hparams)
61+
62+
# Expand since t2t expects 4d tensors.
63+
return tf.expand_dims(decoder_output, axis=2)
64+
65+
66+
def reverse_gradient(x, delta=1.0):
67+
return tf.stop_gradient((1.0 + delta) * x) - delta * x
68+
69+
70+
def adversary(embedded, inputs, hparams, name, reuse=False):
71+
with tf.variable_scope(name, reuse=reuse):
72+
h0, i0 = common_layers.pad_to_same_length(
73+
embedded, inputs, final_length_divisible_by=16)
74+
h0 = tf.concat([h0, tf.expand_dims(i0, axis=2)], axis=-1)
75+
h0 = tf.layers.dense(h0, hparams.hidden_size, name="io")
76+
h1 = transformer_vae.compress(h0, None, False, hparams, "compress1")
77+
h2 = transformer_vae.compress(h1, None, False, hparams, "compress2")
78+
res_dense = tf.reduce_mean(h2, axis=[1, 2])
79+
res_single = tf.squeeze(tf.layers.dense(res_dense, 1), axis=-1)
80+
return tf.nn.sigmoid(res_single)
81+
82+
83+
def softmax_embed(x, embedding, batch_size, hparams):
84+
"""Softmax x and embed."""
85+
x = tf.reshape(tf.nn.softmax(x), [-1, 34*1024])
86+
x = tf.matmul(x, embedding)
87+
return tf.reshape(x, [batch_size, -1, 1, hparams.hidden_size])
88+
89+
90+
def adv_transformer_internal(inputs, targets, target_space, hparams):
91+
"""Adversarial Transformer, main step used for training."""
92+
with tf.variable_scope("adv_transformer"):
93+
batch_size = tf.shape(targets)[0]
94+
targets = tf.reshape(targets, [batch_size, -1, 1])
95+
embedding = tf.get_variable("embedding", [34*1024, hparams.hidden_size])
96+
targets_emb = tf.gather(embedding, targets)
97+
98+
# Noisy embedded targets.
99+
targets_noisy = tf.one_hot(targets, 34*1024)
100+
noise_val = hparams.noise_val
101+
targets_noisy += tf.random_uniform(tf.shape(targets_noisy),
102+
minval=-noise_val, maxval=noise_val)
103+
targets_emb_noisy = softmax_embed(
104+
targets_noisy, embedding, batch_size, hparams)
105+
106+
# Encoder.
107+
if inputs is not None:
108+
inputs_emb = common_layers.flatten4d3d(inputs)
109+
inputs, ed = encode(inputs_emb, target_space, hparams, "input_enc")
110+
else:
111+
ed = None
112+
113+
# Masking.
114+
masking = common_layers.inverse_lin_decay(60000)
115+
masking *= common_layers.inverse_exp_decay(20000) # Not much at start.
116+
masking -= tf.random_uniform([]) * 0.4
117+
mask = tf.less(masking, tf.random_uniform(tf.shape(targets)))
118+
mask = tf.expand_dims(tf.to_float(mask), 3)
119+
noise = tf.random_uniform(tf.shape(targets_emb))
120+
targets_emb = mask * targets_emb + (1.0 - mask) * noise
121+
122+
# Decoder.
123+
res_dec = decode(inputs, ed, targets_emb, hparams, "decoder")
124+
res = tf.layers.dense(res_dec, 34*1024, name="res_sm")
125+
res_emb = softmax_embed(res, embedding, batch_size, hparams)
126+
127+
# Extra steps.
128+
extra_step_prob = masking * 0.6
129+
if hparams.mode != tf.estimator.ModeKeys.TRAIN:
130+
extra_step_prob = 1.0
131+
for _ in xrange(hparams.extra_steps):
132+
def another_step(emb):
133+
res_dec = decode(inputs, ed, emb, hparams, "decoder", reuse=True)
134+
res = tf.layers.dense(res_dec, 34*1024, name="res_sm", reuse=True)
135+
return softmax_embed(res, embedding, batch_size, hparams), res
136+
res_emb, res = tf.cond(tf.less(tf.random_uniform([]), extra_step_prob),
137+
lambda e=res_emb: another_step(e),
138+
lambda: (res_emb, res))
139+
140+
# Adversary.
141+
delta = masking * hparams.delta_max
142+
true_logit = adversary(tf.stop_gradient(targets_emb_noisy),
143+
tf.stop_gradient(inputs + inputs_emb),
144+
hparams, "adversary")
145+
gen_logit = adversary(reverse_gradient(res_emb, delta),
146+
tf.stop_gradient(inputs + inputs_emb),
147+
hparams, "adversary", reuse=True)
148+
losses = {"adv": gen_logit - true_logit}
149+
res = tf.stop_gradient(masking * res) + (1.0 - masking) * res
150+
return res, losses
151+
152+
153+
@registry.register_model
154+
class TransformerAdv(t2t_model.T2TModel):
155+
"""Adversarial Transformer."""
156+
157+
def model_fn_body(self, features):
158+
inputs = features.get("inputs", None)
159+
return adv_transformer_internal(
160+
inputs, features["targets_raw"],
161+
features["target_space_id"], self._hparams)
162+
163+
def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
164+
last_position_only=False, alpha=0.0):
165+
"""Produce predictions from the model."""
166+
if not features:
167+
features = {}
168+
inputs_old = None
169+
if "inputs" in features and len(features["inputs"].shape) < 4:
170+
inputs_old = features["inputs"]
171+
features["inputs"] = tf.expand_dims(features["inputs"], 2)
172+
173+
# Create an initial targets tensor.
174+
if "partial_targets" in features:
175+
initial_output = tf.convert_to_tensor(features["partial_targets"])
176+
else:
177+
batch_size = tf.shape(features["inputs"])[0]
178+
length = tf.shape(features["inputs"])[1]
179+
initial_output = tf.zeros((batch_size, 2 * length, 1, 1), dtype=tf.int64)
180+
181+
features["targets"] = initial_output
182+
sharded_logits, _ = self.model_fn(
183+
features, False, last_position_only=last_position_only)
184+
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
185+
samples = tf.concat(sharded_samples, 0)
186+
187+
# More steps.
188+
how_many_more_steps = 5
189+
for _ in xrange(how_many_more_steps):
190+
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
191+
features["targets"] = samples
192+
sharded_logits, _ = self.model_fn(
193+
features, False, last_position_only=last_position_only)
194+
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
195+
samples = tf.concat(sharded_samples, 0)
196+
197+
if inputs_old is not None: # Restore to not confuse Estimator.
198+
features["inputs"] = inputs_old
199+
return samples
200+
201+
202+
@registry.register_hparams
203+
def transformer_adv_small():
204+
"""Set of hyperparameters."""
205+
hparams = transformer.transformer_small()
206+
hparams.batch_size = 2048
207+
hparams.learning_rate_warmup_steps = 4000
208+
hparams.num_hidden_layers = 3
209+
hparams.hidden_size = 384
210+
hparams.filter_size = 2048
211+
hparams.label_smoothing = 0.0
212+
hparams.weight_decay = 0.1
213+
hparams.symbol_modality_skip_top = int(True)
214+
hparams.add_hparam("num_compress_steps", 2)
215+
hparams.add_hparam("extra_steps", 0)
216+
hparams.add_hparam("noise_val", 0.3)
217+
hparams.add_hparam("delta_max", 2.0)
218+
return hparams
219+
220+
221+
@registry.register_hparams
222+
def transformer_adv_base():
223+
"""Set of hyperparameters."""
224+
hparams = transformer_adv_small()
225+
hparams.batch_size = 1024
226+
hparams.hidden_size = 512
227+
hparams.filter_size = 4096
228+
hparams.num_hidden_layers = 6
229+
return hparams

0 commit comments

Comments
 (0)