Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 228feae

Browse files
authored
Merge pull request #100 from kolloldas/lstm_attention
Adding attention to LSTM seq2seq baseline
2 parents b287c0e + 7b20843 commit 228feae

File tree

3 files changed

+220
-1
lines changed

3 files changed

+220
-1
lines changed

tensor2tensor/models/common_layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,6 @@ def conv2d_kernel(kernel_size_arg, name_suffix):
331331

332332
return conv2d_kernel(kernel_size, "single")
333333

334-
335334
def conv(inputs, filters, kernel_size, **kwargs):
336335
return conv_internal(tf.layers.conv2d, inputs, filters, kernel_size, **kwargs)
337336

tensor2tensor/models/lstm.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,152 @@
2121
# Dependency imports
2222

2323
from tensor2tensor.models import common_layers
24+
from tensor2tensor.models import common_hparams
2425
from tensor2tensor.utils import registry
2526
from tensor2tensor.utils import t2t_model
2627

2728
import tensorflow as tf
29+
from tensorflow.python.ops import rnn_cell_impl
30+
from tensorflow.python.util import nest
2831

32+
import collections
33+
34+
# Track Tuple of state and attention values
35+
AttentionTuple = collections.namedtuple("AttentionTuple", ("state", "attention"))
36+
37+
38+
class ExternalAttentionCellWrapper(rnn_cell_impl.RNNCell):
39+
"""
40+
Wrapper for external attention states. To be used in an encoder-decoder setup
41+
"""
42+
def __init__(self, cell, attn_states, attn_vec_size=None,
43+
input_size=None, state_is_tuple=True, reuse=None):
44+
"""Create a cell with attention.
45+
Args:
46+
cell: an RNNCell, an attention is added to it.
47+
attn_states: External attention states typically the encoder output in the
48+
form [batch_size, time steps, hidden size]
49+
attn_vec_size: integer, the number of convolutional features calculated
50+
on attention state and a size of the hidden layer built from
51+
base cell state. Equal attn_size to by default.
52+
input_size: integer, the size of a hidden linear layer,
53+
built from inputs and attention. Derived from the input tensor
54+
by default.
55+
state_is_tuple: If True, accepted and returned states are n-tuples, where
56+
`n = len(cells)`. Must be set to True else will raise an exception
57+
concatenated along the column axis.
58+
reuse: (optional) Python boolean describing whether to reuse variables
59+
in an existing scope. If not `True`, and the existing scope already has
60+
the given variables, an error is raised.
61+
Raises:
62+
TypeError: if cell is not an RNNCell.
63+
ValueError: if the flag `state_is_tuple` is `False` or
64+
if shape of attn_states is not 3 or if innermost dimension (hidden size) is None.
65+
"""
66+
super(ExternalAttentionCellWrapper, self).__init__(_reuse=reuse)
67+
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
68+
raise TypeError("The parameter cell is not RNNCell.")
69+
70+
if not state_is_tuple:
71+
raise ValueError("Only tuple state is supported")
72+
73+
self._cell = cell
74+
self._input_size = input_size
75+
76+
#Validate attn_states shape
77+
attn_shape = attn_states.get_shape()
78+
if not attn_shape or len(attn_shape) != 3:
79+
raise ValueError("attn_shape must be rank 3")
80+
81+
self._attn_states = attn_states
82+
self._attn_size = attn_shape[2].value
83+
if self._attn_size is None:
84+
raise ValueError("Hidden size of attn_states cannot be None")
85+
86+
self._attn_vec_size = attn_vec_size
87+
if self._attn_vec_size is None:
88+
self._attn_vec_size = self._attn_size
89+
90+
self._reuse = reuse
91+
92+
@property
93+
def state_size(self):
94+
return AttentionTuple(self._cell.state_size, self._attn_size)
95+
96+
97+
@property
98+
def output_size(self):
99+
return self._attn_size
100+
101+
def combine_state(self, previous_state):
102+
"""
103+
Combines previous state (usually from an encoder) with the internal attention values
104+
You must use this function to derive the initial state passed into this cell as it expects
105+
a named tuple (AttentionTuple)
106+
Args:
107+
previous_state: State from another block that will be fed into this cell. Must have same
108+
structure as the state of the cell wrapped by this
109+
Returns:
110+
Combined state (AttentionTuple)
111+
"""
112+
batch_size = self._attn_states.get_shape()[0].value
113+
if batch_size is None:
114+
batch_size = tf.shape(self._attn_states)[0]
115+
zeroed_state = self.zero_state(batch_size, self._attn_states.dtype)
116+
return AttentionTuple(previous_state, zeroed_state.attention)
117+
118+
def call(self, inputs, state):
119+
"""Long short-term memory cell with attention (LSTMA)."""
120+
121+
if(not isinstance(state, AttentionTuple)):
122+
raise TypeError("State must be of type AttentionTuple")
123+
124+
state, attns = state
125+
attn_states = self._attn_states
126+
attn_length = attn_states.get_shape()[1].value
127+
if attn_length is None:
128+
attn_length = tf.shape(attn_states)[1]
129+
130+
131+
input_size = self._input_size
132+
if input_size is None:
133+
input_size = inputs.get_shape().as_list()[1]
134+
if(attns is not None):
135+
inputs = rnn_cell_impl._linear([inputs, attns], input_size, True)
136+
lstm_output, new_state = self._cell(inputs, state)
137+
138+
new_state_cat = tf.concat(nest.flatten(new_state), 1)
139+
new_attns = self._attention(new_state_cat, attn_states, attn_length)
140+
141+
with tf.variable_scope("attn_output_projection"):
142+
output = rnn_cell_impl._linear([lstm_output, new_attns], self._attn_size, True)
143+
144+
new_state = AttentionTuple(new_state, new_attns)
145+
146+
return output, new_state
147+
148+
def _attention(self, query, attn_states, attn_length):
149+
conv2d = tf.nn.conv2d
150+
reduce_sum = tf.reduce_sum
151+
softmax = tf.nn.softmax
152+
tanh = tf.tanh
153+
154+
with tf.variable_scope("attention"):
155+
k = tf.get_variable(
156+
"attn_w", [1, 1, self._attn_size, self._attn_vec_size])
157+
v = tf.get_variable("attn_v", [self._attn_vec_size, 1])
158+
hidden = tf.reshape(attn_states,
159+
[-1, attn_length, 1, self._attn_size])
160+
hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
161+
y = rnn_cell_impl._linear(query, self._attn_vec_size, True)
162+
y = tf.reshape(y, [-1, 1, 1, self._attn_vec_size])
163+
s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
164+
a = softmax(s)
165+
d = reduce_sum(
166+
tf.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
167+
new_attns = tf.reshape(d, [-1, self._attn_size])
168+
169+
return new_attns
29170

30171
def lstm(inputs, hparams, train, name, initial_state=None):
31172
"""Run LSTM cell on inputs, assuming they are [batch x time x size]."""
@@ -44,6 +185,25 @@ def dropout_lstm_cell():
44185
dtype=tf.float32,
45186
time_major=False)
46187

188+
def lstm_attention_decoder(inputs, hparams, train, name, initial_state, attn_states):
189+
"""Run LSTM cell with attention on inputs, assuming they are [batch x time x size]."""
190+
191+
def dropout_lstm_cell():
192+
return tf.contrib.rnn.DropoutWrapper(
193+
tf.nn.rnn_cell.BasicLSTMCell(hparams.hidden_size),
194+
input_keep_prob=1.0 - hparams.dropout * tf.to_float(train))
195+
196+
layers = [dropout_lstm_cell() for _ in range(hparams.num_hidden_layers)]
197+
cell = ExternalAttentionCellWrapper(tf.nn.rnn_cell.MultiRNNCell(layers), attn_states,
198+
attn_vec_size=hparams.attn_vec_size)
199+
initial_state = cell.combine_state(initial_state)
200+
with tf.variable_scope(name):
201+
return tf.nn.dynamic_rnn(
202+
cell,
203+
inputs,
204+
initial_state=initial_state,
205+
dtype=tf.float32,
206+
time_major=False)
47207

48208
def lstm_seq2seq_internal(inputs, targets, hparams, train):
49209
"""The basic LSTM seq2seq model, main step used for training."""
@@ -63,6 +223,23 @@ def lstm_seq2seq_internal(inputs, targets, hparams, train):
63223
initial_state=final_encoder_state)
64224
return tf.expand_dims(decoder_outputs, axis=2)
65225

226+
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train):
227+
"""LSTM seq2seq model with attention, main step used for training."""
228+
with tf.variable_scope("lstm_seq2seq_attention"):
229+
# Flatten inputs.
230+
inputs = common_layers.flatten4d3d(inputs)
231+
# LSTM encoder.
232+
encoder_outputs, final_encoder_state = lstm(
233+
tf.reverse(inputs, axis=[1]), hparams, train, "encoder")
234+
# LSTM decoder with attention
235+
shifted_targets = common_layers.shift_left(targets)
236+
decoder_outputs, _ = lstm_attention_decoder(
237+
common_layers.flatten4d3d(shifted_targets),
238+
hparams,
239+
train,
240+
"decoder",
241+
final_encoder_state, encoder_outputs)
242+
return tf.expand_dims(decoder_outputs, axis=2)
66243

67244
@registry.register_model("baseline_lstm_seq2seq")
68245
class LSTMSeq2Seq(t2t_model.T2TModel):
@@ -71,3 +248,23 @@ def model_fn_body(self, features):
71248
train = self._hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
72249
return lstm_seq2seq_internal(features["inputs"], features["targets"],
73250
self._hparams, train)
251+
252+
@registry.register_model("baseline_lstm_seq2seq_attention")
253+
class LSTMSeq2SeqAttention(t2t_model.T2TModel):
254+
255+
def model_fn_body(self, features):
256+
train = self._hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
257+
return lstm_seq2seq_internal_attention(features["inputs"], features["targets"],
258+
self._hparams, train)
259+
260+
@registry.register_hparams
261+
def lstm_attention():
262+
"""hparams for LSTM with attention."""
263+
hparams = common_hparams.basic_params1()
264+
hparams.batch_size = 128
265+
hparams.hidden_size = 128
266+
hparams.num_hidden_layers = 2
267+
268+
# Attention
269+
hparams.add_hparam("attn_vec_size", hparams.hidden_size)
270+
return hparams

tensor2tensor/models/lstm_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,29 @@ def testLSTMSeq2Seq(self):
5151
res = session.run(logits)
5252
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))
5353

54+
def testLSTMSeq2Seq_attention(self):
55+
vocab_size = 9
56+
x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1))
57+
y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 6, 1, 1))
58+
hparams = lstm.lstm_attention()
59+
60+
p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size,
61+
vocab_size)
62+
x = tf.constant(x, dtype=tf.int32)
63+
x._shape = tf.TensorShape([None, None, 1, 1])
64+
65+
with self.test_session() as session:
66+
features = {
67+
"inputs": x,
68+
"targets": tf.constant(y, dtype=tf.int32),
69+
}
70+
model = lstm.LSTMSeq2SeqAttention(
71+
hparams, tf.contrib.learn.ModeKeys.TRAIN, p_hparams)
72+
sharded_logits, _, _ = model.model_fn(features)
73+
logits = tf.concat(sharded_logits, 0)
74+
session.run(tf.global_variables_initializer())
75+
res = session.run(logits)
76+
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))
5477

5578
if __name__ == "__main__":
5679
tf.test.main()

0 commit comments

Comments
 (0)