Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9e7d03f

Browse files
authored
Merge pull request #377 from kolloldas/master
Update LSTM Attention Model to use tf.contrib.seq2seq.AttentionWrapper
2 parents 172a1b1 + f67483e commit 9e7d03f

File tree

1 file changed

+63
-148
lines changed

1 file changed

+63
-148
lines changed

tensor2tensor/models/lstm.py

Lines changed: 63 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -31,144 +31,6 @@
3131
import tensorflow as tf
3232
from tensorflow.python.util import nest
3333

34-
# Track Tuple of state and attention values
35-
AttentionTuple = collections.namedtuple("AttentionTuple", ("state",
36-
"attention"))
37-
38-
39-
class ExternalAttentionCellWrapper(tf.contrib.rnn.RNNCell):
40-
"""Wrapper for external attention states for an encoder-decoder setup."""
41-
42-
def __init__(self,
43-
cell,
44-
attn_states,
45-
attn_vec_size=None,
46-
input_size=None,
47-
state_is_tuple=True,
48-
reuse=None):
49-
"""Create a cell with attention.
50-
51-
Args:
52-
cell: an RNNCell, an attention is added to it.
53-
attn_states: External attention states typically the encoder output in the
54-
form [batch_size, time steps, hidden size]
55-
attn_vec_size: integer, the number of convolutional features calculated
56-
on attention state and a size of the hidden layer built from
57-
base cell state. Equal attn_size to by default.
58-
input_size: integer, the size of a hidden linear layer,
59-
built from inputs and attention. Derived from the input tensor
60-
by default.
61-
state_is_tuple: If True, accepted and returned states are n-tuples, where
62-
`n = len(cells)`. Must be set to True else will raise an exception
63-
concatenated along the column axis.
64-
reuse: (optional) Python boolean describing whether to reuse variables
65-
in an existing scope. If not `True`, and the existing scope already has
66-
the given variables, an error is raised.
67-
Raises:
68-
TypeError: if cell is not an RNNCell.
69-
ValueError: if the flag `state_is_tuple` is `False` or if shape of
70-
`attn_states` is not 3 or if innermost dimension (hidden size) is None.
71-
"""
72-
super(ExternalAttentionCellWrapper, self).__init__(_reuse=reuse)
73-
if not state_is_tuple:
74-
raise ValueError("Only tuple state is supported")
75-
76-
self._cell = cell
77-
self._input_size = input_size
78-
79-
# Validate attn_states shape.
80-
attn_shape = attn_states.get_shape()
81-
if not attn_shape or len(attn_shape) != 3:
82-
raise ValueError("attn_shape must be rank 3")
83-
84-
self._attn_states = attn_states
85-
self._attn_size = attn_shape[2].value
86-
if self._attn_size is None:
87-
raise ValueError("Hidden size of attn_states cannot be None")
88-
89-
self._attn_vec_size = attn_vec_size
90-
if self._attn_vec_size is None:
91-
self._attn_vec_size = self._attn_size
92-
93-
self._reuse = reuse
94-
95-
@property
96-
def state_size(self):
97-
return AttentionTuple(self._cell.state_size, self._attn_size)
98-
99-
@property
100-
def output_size(self):
101-
return self._attn_size
102-
103-
def combine_state(self, previous_state):
104-
"""Combines previous state (from encoder) with internal attention values.
105-
106-
You must use this function to derive the initial state passed into
107-
this cell as it expects a named tuple (AttentionTuple).
108-
109-
Args:
110-
previous_state: State from another block that will be fed into this cell;
111-
Must have same structure as the state of the cell wrapped by this.
112-
Returns:
113-
Combined state (AttentionTuple).
114-
"""
115-
batch_size = self._attn_states.get_shape()[0].value
116-
if batch_size is None:
117-
batch_size = tf.shape(self._attn_states)[0]
118-
zeroed_state = self.zero_state(batch_size, self._attn_states.dtype)
119-
return AttentionTuple(previous_state, zeroed_state.attention)
120-
121-
def call(self, inputs, state):
122-
"""Long short-term memory cell with attention (LSTMA)."""
123-
124-
if not isinstance(state, AttentionTuple):
125-
raise TypeError("State must be of type AttentionTuple")
126-
127-
state, attns = state
128-
attn_states = self._attn_states
129-
attn_length = attn_states.get_shape()[1].value
130-
if attn_length is None:
131-
attn_length = tf.shape(attn_states)[1]
132-
133-
input_size = self._input_size
134-
if input_size is None:
135-
input_size = inputs.get_shape().as_list()[1]
136-
if attns is not None:
137-
inputs = tf.layers.dense(tf.concat([inputs, attns], axis=1), input_size)
138-
lstm_output, new_state = self._cell(inputs, state)
139-
140-
new_state_cat = tf.concat(nest.flatten(new_state), 1)
141-
new_attns = self._attention(new_state_cat, attn_states, attn_length)
142-
143-
with tf.variable_scope("attn_output_projection"):
144-
output = tf.layers.dense(
145-
tf.concat([lstm_output, new_attns], axis=1), self._attn_size)
146-
147-
new_state = AttentionTuple(new_state, new_attns)
148-
149-
return output, new_state
150-
151-
def _attention(self, query, attn_states, attn_length):
152-
conv2d = tf.nn.conv2d
153-
reduce_sum = tf.reduce_sum
154-
softmax = tf.nn.softmax
155-
tanh = tf.tanh
156-
157-
with tf.variable_scope("attention"):
158-
k = tf.get_variable("attn_w",
159-
[1, 1, self._attn_size, self._attn_vec_size])
160-
v = tf.get_variable("attn_v", [self._attn_vec_size, 1])
161-
hidden = tf.reshape(attn_states, [-1, attn_length, 1, self._attn_size])
162-
hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
163-
y = tf.layers.dense(query, self._attn_vec_size)
164-
y = tf.reshape(y, [-1, 1, 1, self._attn_vec_size])
165-
s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
166-
a = softmax(s)
167-
d = reduce_sum(tf.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
168-
new_attns = tf.reshape(d, [-1, self._attn_size])
169-
170-
return new_attns
171-
17234

17335
def lstm(inputs, hparams, train, name, initial_state=None):
17436
"""Run LSTM cell on inputs, assuming they are [batch x time x size]."""
@@ -189,7 +51,7 @@ def dropout_lstm_cell():
18951

19052

19153
def lstm_attention_decoder(inputs, hparams, train, name, initial_state,
192-
attn_states):
54+
encoder_outputs):
19355
"""Run LSTM cell with attention on inputs of shape [batch x time x size]."""
19456

19557
def dropout_lstm_cell():
@@ -198,18 +60,36 @@ def dropout_lstm_cell():
19860
input_keep_prob=1.0 - hparams.dropout * tf.to_float(train))
19961

20062
layers = [dropout_lstm_cell() for _ in range(hparams.num_hidden_layers)]
201-
cell = ExternalAttentionCellWrapper(
63+
AttentionMechanism = (tf.contrib.seq2seq.LuongAttention if hparams.attention_mechanism == "luong"
64+
else tf.contrib.seq2seq.BahdanauAttention)
65+
attention_mechanism = AttentionMechanism(hparams.hidden_size, encoder_outputs)
66+
67+
cell = tf.contrib.seq2seq.AttentionWrapper(
20268
tf.nn.rnn_cell.MultiRNNCell(layers),
203-
attn_states,
204-
attn_vec_size=hparams.attn_vec_size)
205-
initial_state = cell.combine_state(initial_state)
69+
[attention_mechanism]*hparams.num_heads,
70+
attention_layer_size=[hparams.attention_layer_size]*hparams.num_heads,
71+
output_attention=(hparams.output_attention==1))
72+
73+
74+
batch_size = inputs.get_shape()[0].value
75+
if batch_size is None:
76+
batch_size = tf.shape(inputs)[0]
77+
78+
initial_state = cell.zero_state(batch_size, tf.float32).clone(cell_state=initial_state)
79+
20680
with tf.variable_scope(name):
207-
return tf.nn.dynamic_rnn(
81+
output, state = tf.nn.dynamic_rnn(
20882
cell,
20983
inputs,
21084
initial_state=initial_state,
21185
dtype=tf.float32,
21286
time_major=False)
87+
88+
# For multi-head attention project output back to hidden size
89+
if hparams.output_attention == 1 and hparams.num_heads > 1:
90+
output = tf.layers.dense(output, hparams.hidden_size)
91+
92+
return output, state
21393

21494

21595
def lstm_seq2seq_internal(inputs, targets, hparams, train):
@@ -273,14 +153,49 @@ def lstm_seq2seq():
273153
hparams.hidden_size = 128
274154
hparams.num_hidden_layers = 2
275155
hparams.initializer = "uniform_unit_scaling"
156+
hparams.initializer_gain = 1.0
157+
hparams.weight_decay = 0.0
158+
159+
return hparams
160+
161+
def lstm_attention_base():
162+
""" Base attention params. """
163+
hparams = lstm_seq2seq()
164+
hparams.add_hparam("attention_layer_size", hparams.hidden_size)
165+
hparams.add_hparam("output_attention", int(True))
166+
hparams.add_hparam("num_heads", 1)
276167
return hparams
277168

278169

170+
@registry.register_hparams
171+
def lstm_bahdanau_attention():
172+
"""hparams for LSTM with bahdanau attention."""
173+
hparams = lstm_attention_base()
174+
hparams.add_hparam("attention_mechanism", "bahdanau")
175+
return hparams
176+
177+
@registry.register_hparams
178+
def lstm_luong_attention():
179+
"""hparams for LSTM with luong attention."""
180+
hparams = lstm_attention_base()
181+
hparams.add_hparam("attention_mechanism", "luong")
182+
return hparams
183+
279184
@registry.register_hparams
280185
def lstm_attention():
281-
"""hparams for LSTM with attention."""
282-
hparams = lstm_seq2seq()
186+
""" For backwards compatibility, Defaults to bahdanau """
187+
return lstm_bahdanau_attention()
283188

284-
# Attention
285-
hparams.add_hparam("attn_vec_size", hparams.hidden_size)
189+
@registry.register_hparams
190+
def lstm_bahdanau_attention_multi():
191+
""" Multi-head Luong attention """
192+
hparams = lstm_bahdanau_attention()
193+
hparams.num_heads = 4
286194
return hparams
195+
196+
@registry.register_hparams
197+
def lstm_luong_attention_multi():
198+
""" Multi-head Luong attention """
199+
hparams = lstm_luong_attention()
200+
hparams.num_heads = 4
201+
return hparams

0 commit comments

Comments
 (0)