31
31
import tensorflow as tf
32
32
from tensorflow .python .util import nest
33
33
34
- # Track Tuple of state and attention values
35
- AttentionTuple = collections .namedtuple ("AttentionTuple" , ("state" ,
36
- "attention" ))
37
-
38
-
39
- class ExternalAttentionCellWrapper (tf .contrib .rnn .RNNCell ):
40
- """Wrapper for external attention states for an encoder-decoder setup."""
41
-
42
- def __init__ (self ,
43
- cell ,
44
- attn_states ,
45
- attn_vec_size = None ,
46
- input_size = None ,
47
- state_is_tuple = True ,
48
- reuse = None ):
49
- """Create a cell with attention.
50
-
51
- Args:
52
- cell: an RNNCell, an attention is added to it.
53
- attn_states: External attention states typically the encoder output in the
54
- form [batch_size, time steps, hidden size]
55
- attn_vec_size: integer, the number of convolutional features calculated
56
- on attention state and a size of the hidden layer built from
57
- base cell state. Equal attn_size to by default.
58
- input_size: integer, the size of a hidden linear layer,
59
- built from inputs and attention. Derived from the input tensor
60
- by default.
61
- state_is_tuple: If True, accepted and returned states are n-tuples, where
62
- `n = len(cells)`. Must be set to True else will raise an exception
63
- concatenated along the column axis.
64
- reuse: (optional) Python boolean describing whether to reuse variables
65
- in an existing scope. If not `True`, and the existing scope already has
66
- the given variables, an error is raised.
67
- Raises:
68
- TypeError: if cell is not an RNNCell.
69
- ValueError: if the flag `state_is_tuple` is `False` or if shape of
70
- `attn_states` is not 3 or if innermost dimension (hidden size) is None.
71
- """
72
- super (ExternalAttentionCellWrapper , self ).__init__ (_reuse = reuse )
73
- if not state_is_tuple :
74
- raise ValueError ("Only tuple state is supported" )
75
-
76
- self ._cell = cell
77
- self ._input_size = input_size
78
-
79
- # Validate attn_states shape.
80
- attn_shape = attn_states .get_shape ()
81
- if not attn_shape or len (attn_shape ) != 3 :
82
- raise ValueError ("attn_shape must be rank 3" )
83
-
84
- self ._attn_states = attn_states
85
- self ._attn_size = attn_shape [2 ].value
86
- if self ._attn_size is None :
87
- raise ValueError ("Hidden size of attn_states cannot be None" )
88
-
89
- self ._attn_vec_size = attn_vec_size
90
- if self ._attn_vec_size is None :
91
- self ._attn_vec_size = self ._attn_size
92
-
93
- self ._reuse = reuse
94
-
95
- @property
96
- def state_size (self ):
97
- return AttentionTuple (self ._cell .state_size , self ._attn_size )
98
-
99
- @property
100
- def output_size (self ):
101
- return self ._attn_size
102
-
103
- def combine_state (self , previous_state ):
104
- """Combines previous state (from encoder) with internal attention values.
105
-
106
- You must use this function to derive the initial state passed into
107
- this cell as it expects a named tuple (AttentionTuple).
108
-
109
- Args:
110
- previous_state: State from another block that will be fed into this cell;
111
- Must have same structure as the state of the cell wrapped by this.
112
- Returns:
113
- Combined state (AttentionTuple).
114
- """
115
- batch_size = self ._attn_states .get_shape ()[0 ].value
116
- if batch_size is None :
117
- batch_size = tf .shape (self ._attn_states )[0 ]
118
- zeroed_state = self .zero_state (batch_size , self ._attn_states .dtype )
119
- return AttentionTuple (previous_state , zeroed_state .attention )
120
-
121
- def call (self , inputs , state ):
122
- """Long short-term memory cell with attention (LSTMA)."""
123
-
124
- if not isinstance (state , AttentionTuple ):
125
- raise TypeError ("State must be of type AttentionTuple" )
126
-
127
- state , attns = state
128
- attn_states = self ._attn_states
129
- attn_length = attn_states .get_shape ()[1 ].value
130
- if attn_length is None :
131
- attn_length = tf .shape (attn_states )[1 ]
132
-
133
- input_size = self ._input_size
134
- if input_size is None :
135
- input_size = inputs .get_shape ().as_list ()[1 ]
136
- if attns is not None :
137
- inputs = tf .layers .dense (tf .concat ([inputs , attns ], axis = 1 ), input_size )
138
- lstm_output , new_state = self ._cell (inputs , state )
139
-
140
- new_state_cat = tf .concat (nest .flatten (new_state ), 1 )
141
- new_attns = self ._attention (new_state_cat , attn_states , attn_length )
142
-
143
- with tf .variable_scope ("attn_output_projection" ):
144
- output = tf .layers .dense (
145
- tf .concat ([lstm_output , new_attns ], axis = 1 ), self ._attn_size )
146
-
147
- new_state = AttentionTuple (new_state , new_attns )
148
-
149
- return output , new_state
150
-
151
- def _attention (self , query , attn_states , attn_length ):
152
- conv2d = tf .nn .conv2d
153
- reduce_sum = tf .reduce_sum
154
- softmax = tf .nn .softmax
155
- tanh = tf .tanh
156
-
157
- with tf .variable_scope ("attention" ):
158
- k = tf .get_variable ("attn_w" ,
159
- [1 , 1 , self ._attn_size , self ._attn_vec_size ])
160
- v = tf .get_variable ("attn_v" , [self ._attn_vec_size , 1 ])
161
- hidden = tf .reshape (attn_states , [- 1 , attn_length , 1 , self ._attn_size ])
162
- hidden_features = conv2d (hidden , k , [1 , 1 , 1 , 1 ], "SAME" )
163
- y = tf .layers .dense (query , self ._attn_vec_size )
164
- y = tf .reshape (y , [- 1 , 1 , 1 , self ._attn_vec_size ])
165
- s = reduce_sum (v * tanh (hidden_features + y ), [2 , 3 ])
166
- a = softmax (s )
167
- d = reduce_sum (tf .reshape (a , [- 1 , attn_length , 1 , 1 ]) * hidden , [1 , 2 ])
168
- new_attns = tf .reshape (d , [- 1 , self ._attn_size ])
169
-
170
- return new_attns
171
-
172
34
173
35
def lstm (inputs , hparams , train , name , initial_state = None ):
174
36
"""Run LSTM cell on inputs, assuming they are [batch x time x size]."""
@@ -189,7 +51,7 @@ def dropout_lstm_cell():
189
51
190
52
191
53
def lstm_attention_decoder (inputs , hparams , train , name , initial_state ,
192
- attn_states ):
54
+ encoder_outputs ):
193
55
"""Run LSTM cell with attention on inputs of shape [batch x time x size]."""
194
56
195
57
def dropout_lstm_cell ():
@@ -198,18 +60,36 @@ def dropout_lstm_cell():
198
60
input_keep_prob = 1.0 - hparams .dropout * tf .to_float (train ))
199
61
200
62
layers = [dropout_lstm_cell () for _ in range (hparams .num_hidden_layers )]
201
- cell = ExternalAttentionCellWrapper (
63
+ AttentionMechanism = (tf .contrib .seq2seq .LuongAttention if hparams .attention_mechanism == "luong"
64
+ else tf .contrib .seq2seq .BahdanauAttention )
65
+ attention_mechanism = AttentionMechanism (hparams .hidden_size , encoder_outputs )
66
+
67
+ cell = tf .contrib .seq2seq .AttentionWrapper (
202
68
tf .nn .rnn_cell .MultiRNNCell (layers ),
203
- attn_states ,
204
- attn_vec_size = hparams .attn_vec_size )
205
- initial_state = cell .combine_state (initial_state )
69
+ [attention_mechanism ]* hparams .num_heads ,
70
+ attention_layer_size = [hparams .attention_layer_size ]* hparams .num_heads ,
71
+ output_attention = (hparams .output_attention == 1 ))
72
+
73
+
74
+ batch_size = inputs .get_shape ()[0 ].value
75
+ if batch_size is None :
76
+ batch_size = tf .shape (inputs )[0 ]
77
+
78
+ initial_state = cell .zero_state (batch_size , tf .float32 ).clone (cell_state = initial_state )
79
+
206
80
with tf .variable_scope (name ):
207
- return tf .nn .dynamic_rnn (
81
+ output , state = tf .nn .dynamic_rnn (
208
82
cell ,
209
83
inputs ,
210
84
initial_state = initial_state ,
211
85
dtype = tf .float32 ,
212
86
time_major = False )
87
+
88
+ # For multi-head attention project output back to hidden size
89
+ if hparams .output_attention == 1 and hparams .num_heads > 1 :
90
+ output = tf .layers .dense (output , hparams .hidden_size )
91
+
92
+ return output , state
213
93
214
94
215
95
def lstm_seq2seq_internal (inputs , targets , hparams , train ):
@@ -273,14 +153,49 @@ def lstm_seq2seq():
273
153
hparams .hidden_size = 128
274
154
hparams .num_hidden_layers = 2
275
155
hparams .initializer = "uniform_unit_scaling"
156
+ hparams .initializer_gain = 1.0
157
+ hparams .weight_decay = 0.0
158
+
159
+ return hparams
160
+
161
+ def lstm_attention_base ():
162
+ """ Base attention params. """
163
+ hparams = lstm_seq2seq ()
164
+ hparams .add_hparam ("attention_layer_size" , hparams .hidden_size )
165
+ hparams .add_hparam ("output_attention" , int (True ))
166
+ hparams .add_hparam ("num_heads" , 1 )
276
167
return hparams
277
168
278
169
170
+ @registry .register_hparams
171
+ def lstm_bahdanau_attention ():
172
+ """hparams for LSTM with bahdanau attention."""
173
+ hparams = lstm_attention_base ()
174
+ hparams .add_hparam ("attention_mechanism" , "bahdanau" )
175
+ return hparams
176
+
177
+ @registry .register_hparams
178
+ def lstm_luong_attention ():
179
+ """hparams for LSTM with luong attention."""
180
+ hparams = lstm_attention_base ()
181
+ hparams .add_hparam ("attention_mechanism" , "luong" )
182
+ return hparams
183
+
279
184
@registry .register_hparams
280
185
def lstm_attention ():
281
- """hparams for LSTM with attention. """
282
- hparams = lstm_seq2seq ()
186
+ """ For backwards compatibility, Defaults to bahdanau """
187
+ return lstm_bahdanau_attention ()
283
188
284
- # Attention
285
- hparams .add_hparam ("attn_vec_size" , hparams .hidden_size )
189
+ @registry .register_hparams
190
+ def lstm_bahdanau_attention_multi ():
191
+ """ Multi-head Luong attention """
192
+ hparams = lstm_bahdanau_attention ()
193
+ hparams .num_heads = 4
286
194
return hparams
195
+
196
+ @registry .register_hparams
197
+ def lstm_luong_attention_multi ():
198
+ """ Multi-head Luong attention """
199
+ hparams = lstm_luong_attention ()
200
+ hparams .num_heads = 4
201
+ return hparams
0 commit comments