Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f67483e

Browse files
committed
Project outputs to hidden size for multi-head attention
1 parent 349c6ee commit f67483e

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tensor2tensor/models/lstm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,18 @@ def dropout_lstm_cell():
7878
initial_state = cell.zero_state(batch_size, tf.float32).clone(cell_state=initial_state)
7979

8080
with tf.variable_scope(name):
81-
return tf.nn.dynamic_rnn(
81+
output, state = tf.nn.dynamic_rnn(
8282
cell,
8383
inputs,
8484
initial_state=initial_state,
8585
dtype=tf.float32,
8686
time_major=False)
87+
88+
# For multi-head attention project output back to hidden size
89+
if hparams.output_attention == 1 and hparams.num_heads > 1:
90+
output = tf.layers.dense(output, hparams.hidden_size)
91+
92+
return output, state
8793

8894

8995
def lstm_seq2seq_internal(inputs, targets, hparams, train):

0 commit comments

Comments
 (0)