21
21
# Dependency imports
22
22
23
23
from tensor2tensor .models import common_layers
24
+ from tensor2tensor .models import common_hparams
24
25
from tensor2tensor .utils import registry
25
26
from tensor2tensor .utils import t2t_model
26
27
27
28
import tensorflow as tf
29
+ from tensorflow .python .ops import rnn_cell_impl
30
+ from tensorflow .python .util import nest
28
31
32
+ import collections
33
+
34
+ # Track Tuple of state and attention values
35
+ AttentionTuple = collections .namedtuple ("AttentionTuple" , ("state" , "attention" ))
36
+
37
+
38
+ class ExternalAttentionCellWrapper (rnn_cell_impl .RNNCell ):
39
+ """
40
+ Wrapper for external attention states. To be used in an encoder-decoder setup
41
+ """
42
+ def __init__ (self , cell , attn_states , attn_vec_size = None ,
43
+ input_size = None , state_is_tuple = True , reuse = None ):
44
+ """Create a cell with attention.
45
+ Args:
46
+ cell: an RNNCell, an attention is added to it.
47
+ attn_states: External attention states typically the encoder output in the
48
+ form [batch_size, time steps, hidden size]
49
+ attn_vec_size: integer, the number of convolutional features calculated
50
+ on attention state and a size of the hidden layer built from
51
+ base cell state. Equal attn_size to by default.
52
+ input_size: integer, the size of a hidden linear layer,
53
+ built from inputs and attention. Derived from the input tensor
54
+ by default.
55
+ state_is_tuple: If True, accepted and returned states are n-tuples, where
56
+ `n = len(cells)`. Must be set to True else will raise an exception
57
+ concatenated along the column axis.
58
+ reuse: (optional) Python boolean describing whether to reuse variables
59
+ in an existing scope. If not `True`, and the existing scope already has
60
+ the given variables, an error is raised.
61
+ Raises:
62
+ TypeError: if cell is not an RNNCell.
63
+ ValueError: if the flag `state_is_tuple` is `False` or
64
+ if shape of attn_states is not 3 or if innermost dimension (hidden size) is None.
65
+ """
66
+ super (ExternalAttentionCellWrapper , self ).__init__ (_reuse = reuse )
67
+ if not rnn_cell_impl ._like_rnncell (cell ): # pylint: disable=protected-access
68
+ raise TypeError ("The parameter cell is not RNNCell." )
69
+
70
+ if not state_is_tuple :
71
+ raise ValueError ("Only tuple state is supported" )
72
+
73
+ self ._cell = cell
74
+ self ._input_size = input_size
75
+
76
+ #Validate attn_states shape
77
+ attn_shape = attn_states .get_shape ()
78
+ if not attn_shape or len (attn_shape ) != 3 :
79
+ raise ValueError ("attn_shape must be rank 3" )
80
+
81
+ self ._attn_states = attn_states
82
+ self ._attn_size = attn_shape [2 ].value
83
+ if self ._attn_size is None :
84
+ raise ValueError ("Hidden size of attn_states cannot be None" )
85
+
86
+ self ._attn_vec_size = attn_vec_size
87
+ if self ._attn_vec_size is None :
88
+ self ._attn_vec_size = self ._attn_size
89
+
90
+ self ._reuse = reuse
91
+
92
+ @property
93
+ def state_size (self ):
94
+ return AttentionTuple (self ._cell .state_size , self ._attn_size )
95
+
96
+
97
+ @property
98
+ def output_size (self ):
99
+ return self ._attn_size
100
+
101
+ def combine_state (self , previous_state ):
102
+ """
103
+ Combines previous state (usually from an encoder) with the internal attention values
104
+ You must use this function to derive the initial state passed into this cell as it expects
105
+ a named tuple (AttentionTuple)
106
+ Args:
107
+ previous_state: State from another block that will be fed into this cell. Must have same
108
+ structure as the state of the cell wrapped by this
109
+ Returns:
110
+ Combined state (AttentionTuple)
111
+ """
112
+ batch_size = self ._attn_states .get_shape ()[0 ].value
113
+ if batch_size is None :
114
+ batch_size = tf .shape (self ._attn_states )[0 ]
115
+ zeroed_state = self .zero_state (batch_size , self ._attn_states .dtype )
116
+ return AttentionTuple (previous_state , zeroed_state .attention )
117
+
118
+ def call (self , inputs , state ):
119
+ """Long short-term memory cell with attention (LSTMA)."""
120
+
121
+ if (not isinstance (state , AttentionTuple )):
122
+ raise TypeError ("State must be of type AttentionTuple" )
123
+
124
+ state , attns = state
125
+ attn_states = self ._attn_states
126
+ attn_length = attn_states .get_shape ()[1 ].value
127
+ if attn_length is None :
128
+ attn_length = tf .shape (attn_states )[1 ]
129
+
130
+
131
+ input_size = self ._input_size
132
+ if input_size is None :
133
+ input_size = inputs .get_shape ().as_list ()[1 ]
134
+ if (attns is not None ):
135
+ inputs = rnn_cell_impl ._linear ([inputs , attns ], input_size , True )
136
+ lstm_output , new_state = self ._cell (inputs , state )
137
+
138
+ new_state_cat = tf .concat (nest .flatten (new_state ), 1 )
139
+ new_attns = self ._attention (new_state_cat , attn_states , attn_length )
140
+
141
+ with tf .variable_scope ("attn_output_projection" ):
142
+ output = rnn_cell_impl ._linear ([lstm_output , new_attns ], self ._attn_size , True )
143
+
144
+ new_state = AttentionTuple (new_state , new_attns )
145
+
146
+ return output , new_state
147
+
148
+ def _attention (self , query , attn_states , attn_length ):
149
+ conv2d = tf .nn .conv2d
150
+ reduce_sum = tf .reduce_sum
151
+ softmax = tf .nn .softmax
152
+ tanh = tf .tanh
153
+
154
+ with tf .variable_scope ("attention" ):
155
+ k = tf .get_variable (
156
+ "attn_w" , [1 , 1 , self ._attn_size , self ._attn_vec_size ])
157
+ v = tf .get_variable ("attn_v" , [self ._attn_vec_size , 1 ])
158
+ hidden = tf .reshape (attn_states ,
159
+ [- 1 , attn_length , 1 , self ._attn_size ])
160
+ hidden_features = conv2d (hidden , k , [1 , 1 , 1 , 1 ], "SAME" )
161
+ y = rnn_cell_impl ._linear (query , self ._attn_vec_size , True )
162
+ y = tf .reshape (y , [- 1 , 1 , 1 , self ._attn_vec_size ])
163
+ s = reduce_sum (v * tanh (hidden_features + y ), [2 , 3 ])
164
+ a = softmax (s )
165
+ d = reduce_sum (
166
+ tf .reshape (a , [- 1 , attn_length , 1 , 1 ]) * hidden , [1 , 2 ])
167
+ new_attns = tf .reshape (d , [- 1 , self ._attn_size ])
168
+
169
+ return new_attns
29
170
30
171
def lstm (inputs , hparams , train , name , initial_state = None ):
31
172
"""Run LSTM cell on inputs, assuming they are [batch x time x size]."""
@@ -44,6 +185,25 @@ def dropout_lstm_cell():
44
185
dtype = tf .float32 ,
45
186
time_major = False )
46
187
188
+ def lstm_attention_decoder (inputs , hparams , train , name , initial_state , attn_states ):
189
+ """Run LSTM cell with attention on inputs, assuming they are [batch x time x size]."""
190
+
191
+ def dropout_lstm_cell ():
192
+ return tf .contrib .rnn .DropoutWrapper (
193
+ tf .nn .rnn_cell .BasicLSTMCell (hparams .hidden_size ),
194
+ input_keep_prob = 1.0 - hparams .dropout * tf .to_float (train ))
195
+
196
+ layers = [dropout_lstm_cell () for _ in range (hparams .num_hidden_layers )]
197
+ cell = ExternalAttentionCellWrapper (tf .nn .rnn_cell .MultiRNNCell (layers ), attn_states ,
198
+ attn_vec_size = hparams .attn_vec_size )
199
+ initial_state = cell .combine_state (initial_state )
200
+ with tf .variable_scope (name ):
201
+ return tf .nn .dynamic_rnn (
202
+ cell ,
203
+ inputs ,
204
+ initial_state = initial_state ,
205
+ dtype = tf .float32 ,
206
+ time_major = False )
47
207
48
208
def lstm_seq2seq_internal (inputs , targets , hparams , train ):
49
209
"""The basic LSTM seq2seq model, main step used for training."""
@@ -63,6 +223,23 @@ def lstm_seq2seq_internal(inputs, targets, hparams, train):
63
223
initial_state = final_encoder_state )
64
224
return tf .expand_dims (decoder_outputs , axis = 2 )
65
225
226
+ def lstm_seq2seq_internal_attention (inputs , targets , hparams , train ):
227
+ """LSTM seq2seq model with attention, main step used for training."""
228
+ with tf .variable_scope ("lstm_seq2seq_attention" ):
229
+ # Flatten inputs.
230
+ inputs = common_layers .flatten4d3d (inputs )
231
+ # LSTM encoder.
232
+ encoder_outputs , final_encoder_state = lstm (
233
+ tf .reverse (inputs , axis = [1 ]), hparams , train , "encoder" )
234
+ # LSTM decoder with attention
235
+ shifted_targets = common_layers .shift_left (targets )
236
+ decoder_outputs , _ = lstm_attention_decoder (
237
+ common_layers .flatten4d3d (shifted_targets ),
238
+ hparams ,
239
+ train ,
240
+ "decoder" ,
241
+ final_encoder_state , encoder_outputs )
242
+ return tf .expand_dims (decoder_outputs , axis = 2 )
66
243
67
244
@registry .register_model ("baseline_lstm_seq2seq" )
68
245
class LSTMSeq2Seq (t2t_model .T2TModel ):
@@ -71,3 +248,23 @@ def model_fn_body(self, features):
71
248
train = self ._hparams .mode == tf .contrib .learn .ModeKeys .TRAIN
72
249
return lstm_seq2seq_internal (features ["inputs" ], features ["targets" ],
73
250
self ._hparams , train )
251
+
252
+ @registry .register_model ("baseline_lstm_seq2seq_attention" )
253
+ class LSTMSeq2SeqAttention (t2t_model .T2TModel ):
254
+
255
+ def model_fn_body (self , features ):
256
+ train = self ._hparams .mode == tf .contrib .learn .ModeKeys .TRAIN
257
+ return lstm_seq2seq_internal_attention (features ["inputs" ], features ["targets" ],
258
+ self ._hparams , train )
259
+
260
+ @registry .register_hparams
261
+ def lstm_attention ():
262
+ """hparams for LSTM with attention."""
263
+ hparams = common_hparams .basic_params1 ()
264
+ hparams .batch_size = 128
265
+ hparams .hidden_size = 128
266
+ hparams .num_hidden_layers = 2
267
+
268
+ # Attention
269
+ hparams .add_hparam ("attn_vec_size" , hparams .hidden_size )
270
+ return hparams
0 commit comments