From bf93aa1857cfa45d8132b4dbc520b8b33e5fb834 Mon Sep 17 00:00:00 2001 From: pengyuchen Date: Tue, 21 Nov 2017 12:21:06 +0800 Subject: [PATCH] multihead_attention --- modules.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules.py b/modules.py index 4222d0a..ff8760b 100644 --- a/modules.py +++ b/modules.py @@ -247,7 +247,10 @@ def multihead_attention(queries, # Restore shape outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, C) - + + # Linear projections + outputs = tf.layers.dense(outputs, num_units, activation=tf.nn.relu) # (N, T_q, C) + # Residual connection outputs += queries