|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2017 The Tensor2Tensor Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Adversarial Transformer.""" |
| 17 | + |
| 18 | +from __future__ import absolute_import |
| 19 | +from __future__ import division |
| 20 | +from __future__ import print_function |
| 21 | + |
| 22 | +# Dependency imports |
| 23 | + |
| 24 | +from tensor2tensor.layers import common_layers |
| 25 | +from tensor2tensor.models import transformer |
| 26 | +from tensor2tensor.models import transformer_vae |
| 27 | +from tensor2tensor.utils import registry |
| 28 | +from tensor2tensor.utils import t2t_model |
| 29 | + |
| 30 | +import tensorflow as tf |
| 31 | + |
| 32 | + |
| 33 | +def encode(x, x_space, hparams, name): |
| 34 | + """Transformer preparations and encoder.""" |
| 35 | + with tf.variable_scope(name): |
| 36 | + (encoder_input, encoder_self_attention_bias, |
| 37 | + ed) = transformer.transformer_prepare_encoder(x, x_space, hparams) |
| 38 | + encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) |
| 39 | + return transformer.transformer_encoder( |
| 40 | + encoder_input, encoder_self_attention_bias, hparams), ed |
| 41 | + |
| 42 | + |
| 43 | +def decode(encoder_output, encoder_decoder_attention_bias, targets, |
| 44 | + hparams, name, reuse=False): |
| 45 | + """Transformer decoder.""" |
| 46 | + with tf.variable_scope(name, reuse=reuse): |
| 47 | + targets = common_layers.flatten4d3d(targets) |
| 48 | + |
| 49 | + decoder_input, decoder_self_bias = transformer.transformer_prepare_decoder( |
| 50 | + targets, hparams) |
| 51 | + |
| 52 | + decoder_input = tf.nn.dropout(decoder_input, |
| 53 | + 1.0 - hparams.layer_prepostprocess_dropout) |
| 54 | + |
| 55 | + decoder_output = transformer.transformer_decoder( |
| 56 | + decoder_input, |
| 57 | + encoder_output, |
| 58 | + decoder_self_bias, |
| 59 | + encoder_decoder_attention_bias, |
| 60 | + hparams) |
| 61 | + |
| 62 | + # Expand since t2t expects 4d tensors. |
| 63 | + return tf.expand_dims(decoder_output, axis=2) |
| 64 | + |
| 65 | + |
| 66 | +def reverse_gradient(x, delta=1.0): |
| 67 | + return tf.stop_gradient((1.0 + delta) * x) - delta * x |
| 68 | + |
| 69 | + |
| 70 | +def adversary(embedded, inputs, hparams, name, reuse=False): |
| 71 | + with tf.variable_scope(name, reuse=reuse): |
| 72 | + h0, i0 = common_layers.pad_to_same_length( |
| 73 | + embedded, inputs, final_length_divisible_by=16) |
| 74 | + h0 = tf.concat([h0, tf.expand_dims(i0, axis=2)], axis=-1) |
| 75 | + h0 = tf.layers.dense(h0, hparams.hidden_size, name="io") |
| 76 | + h1 = transformer_vae.compress(h0, None, False, hparams, "compress1") |
| 77 | + h2 = transformer_vae.compress(h1, None, False, hparams, "compress2") |
| 78 | + res_dense = tf.reduce_mean(h2, axis=[1, 2]) |
| 79 | + res_single = tf.squeeze(tf.layers.dense(res_dense, 1), axis=-1) |
| 80 | + return tf.nn.sigmoid(res_single) |
| 81 | + |
| 82 | + |
| 83 | +def softmax_embed(x, embedding, batch_size, hparams): |
| 84 | + """Softmax x and embed.""" |
| 85 | + x = tf.reshape(tf.nn.softmax(x), [-1, 34*1024]) |
| 86 | + x = tf.matmul(x, embedding) |
| 87 | + return tf.reshape(x, [batch_size, -1, 1, hparams.hidden_size]) |
| 88 | + |
| 89 | + |
| 90 | +def adv_transformer_internal(inputs, targets, target_space, hparams): |
| 91 | + """Adversarial Transformer, main step used for training.""" |
| 92 | + with tf.variable_scope("adv_transformer"): |
| 93 | + batch_size = tf.shape(targets)[0] |
| 94 | + targets = tf.reshape(targets, [batch_size, -1, 1]) |
| 95 | + embedding = tf.get_variable("embedding", [34*1024, hparams.hidden_size]) |
| 96 | + targets_emb = tf.gather(embedding, targets) |
| 97 | + |
| 98 | + # Noisy embedded targets. |
| 99 | + targets_noisy = tf.one_hot(targets, 34*1024) |
| 100 | + noise_val = hparams.noise_val |
| 101 | + targets_noisy += tf.random_uniform(tf.shape(targets_noisy), |
| 102 | + minval=-noise_val, maxval=noise_val) |
| 103 | + targets_emb_noisy = softmax_embed( |
| 104 | + targets_noisy, embedding, batch_size, hparams) |
| 105 | + |
| 106 | + # Encoder. |
| 107 | + if inputs is not None: |
| 108 | + inputs_emb = common_layers.flatten4d3d(inputs) |
| 109 | + inputs, ed = encode(inputs_emb, target_space, hparams, "input_enc") |
| 110 | + else: |
| 111 | + ed = None |
| 112 | + |
| 113 | + # Masking. |
| 114 | + masking = common_layers.inverse_lin_decay(60000) |
| 115 | + masking *= common_layers.inverse_exp_decay(20000) # Not much at start. |
| 116 | + masking -= tf.random_uniform([]) * 0.4 |
| 117 | + mask = tf.less(masking, tf.random_uniform(tf.shape(targets))) |
| 118 | + mask = tf.expand_dims(tf.to_float(mask), 3) |
| 119 | + noise = tf.random_uniform(tf.shape(targets_emb)) |
| 120 | + targets_emb = mask * targets_emb + (1.0 - mask) * noise |
| 121 | + |
| 122 | + # Decoder. |
| 123 | + res_dec = decode(inputs, ed, targets_emb, hparams, "decoder") |
| 124 | + res = tf.layers.dense(res_dec, 34*1024, name="res_sm") |
| 125 | + res_emb = softmax_embed(res, embedding, batch_size, hparams) |
| 126 | + |
| 127 | + # Extra steps. |
| 128 | + extra_step_prob = masking * 0.6 |
| 129 | + if hparams.mode != tf.estimator.ModeKeys.TRAIN: |
| 130 | + extra_step_prob = 1.0 |
| 131 | + for _ in xrange(hparams.extra_steps): |
| 132 | + def another_step(emb): |
| 133 | + res_dec = decode(inputs, ed, emb, hparams, "decoder", reuse=True) |
| 134 | + res = tf.layers.dense(res_dec, 34*1024, name="res_sm", reuse=True) |
| 135 | + return softmax_embed(res, embedding, batch_size, hparams), res |
| 136 | + res_emb, res = tf.cond(tf.less(tf.random_uniform([]), extra_step_prob), |
| 137 | + lambda e=res_emb: another_step(e), |
| 138 | + lambda: (res_emb, res)) |
| 139 | + |
| 140 | + # Adversary. |
| 141 | + delta = masking * hparams.delta_max |
| 142 | + true_logit = adversary(tf.stop_gradient(targets_emb_noisy), |
| 143 | + tf.stop_gradient(inputs + inputs_emb), |
| 144 | + hparams, "adversary") |
| 145 | + gen_logit = adversary(reverse_gradient(res_emb, delta), |
| 146 | + tf.stop_gradient(inputs + inputs_emb), |
| 147 | + hparams, "adversary", reuse=True) |
| 148 | + losses = {"adv": gen_logit - true_logit} |
| 149 | + res = tf.stop_gradient(masking * res) + (1.0 - masking) * res |
| 150 | + return res, losses |
| 151 | + |
| 152 | + |
| 153 | +@registry.register_model |
| 154 | +class TransformerAdv(t2t_model.T2TModel): |
| 155 | + """Adversarial Transformer.""" |
| 156 | + |
| 157 | + def model_fn_body(self, features): |
| 158 | + inputs = features.get("inputs", None) |
| 159 | + return adv_transformer_internal( |
| 160 | + inputs, features["targets_raw"], |
| 161 | + features["target_space_id"], self._hparams) |
| 162 | + |
| 163 | + def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, |
| 164 | + last_position_only=False, alpha=0.0): |
| 165 | + """Produce predictions from the model.""" |
| 166 | + if not features: |
| 167 | + features = {} |
| 168 | + inputs_old = None |
| 169 | + if "inputs" in features and len(features["inputs"].shape) < 4: |
| 170 | + inputs_old = features["inputs"] |
| 171 | + features["inputs"] = tf.expand_dims(features["inputs"], 2) |
| 172 | + |
| 173 | + # Create an initial targets tensor. |
| 174 | + if "partial_targets" in features: |
| 175 | + initial_output = tf.convert_to_tensor(features["partial_targets"]) |
| 176 | + else: |
| 177 | + batch_size = tf.shape(features["inputs"])[0] |
| 178 | + length = tf.shape(features["inputs"])[1] |
| 179 | + initial_output = tf.zeros((batch_size, 2 * length, 1, 1), dtype=tf.int64) |
| 180 | + |
| 181 | + features["targets"] = initial_output |
| 182 | + sharded_logits, _ = self.model_fn( |
| 183 | + features, False, last_position_only=last_position_only) |
| 184 | + sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) |
| 185 | + samples = tf.concat(sharded_samples, 0) |
| 186 | + |
| 187 | + # More steps. |
| 188 | + how_many_more_steps = 5 |
| 189 | + for _ in xrange(how_many_more_steps): |
| 190 | + with tf.variable_scope(tf.get_variable_scope(), reuse=True): |
| 191 | + features["targets"] = samples |
| 192 | + sharded_logits, _ = self.model_fn( |
| 193 | + features, False, last_position_only=last_position_only) |
| 194 | + sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) |
| 195 | + samples = tf.concat(sharded_samples, 0) |
| 196 | + |
| 197 | + if inputs_old is not None: # Restore to not confuse Estimator. |
| 198 | + features["inputs"] = inputs_old |
| 199 | + return samples |
| 200 | + |
| 201 | + |
| 202 | +@registry.register_hparams |
| 203 | +def transformer_adv_small(): |
| 204 | + """Set of hyperparameters.""" |
| 205 | + hparams = transformer.transformer_small() |
| 206 | + hparams.batch_size = 2048 |
| 207 | + hparams.learning_rate_warmup_steps = 4000 |
| 208 | + hparams.num_hidden_layers = 3 |
| 209 | + hparams.hidden_size = 384 |
| 210 | + hparams.filter_size = 2048 |
| 211 | + hparams.label_smoothing = 0.0 |
| 212 | + hparams.weight_decay = 0.1 |
| 213 | + hparams.symbol_modality_skip_top = int(True) |
| 214 | + hparams.add_hparam("num_compress_steps", 2) |
| 215 | + hparams.add_hparam("extra_steps", 0) |
| 216 | + hparams.add_hparam("noise_val", 0.3) |
| 217 | + hparams.add_hparam("delta_max", 2.0) |
| 218 | + return hparams |
| 219 | + |
| 220 | + |
| 221 | +@registry.register_hparams |
| 222 | +def transformer_adv_base(): |
| 223 | + """Set of hyperparameters.""" |
| 224 | + hparams = transformer_adv_small() |
| 225 | + hparams.batch_size = 1024 |
| 226 | + hparams.hidden_size = 512 |
| 227 | + hparams.filter_size = 4096 |
| 228 | + hparams.num_hidden_layers = 6 |
| 229 | + return hparams |
0 commit comments