Skip to content

Commit 9bc5d37

Browse files
authored
Improve readability for encoder/decoder blocks (#353)
* Improve readability for encoder/decoder blocks * Address review comments * fixup naming * Rework so as to not break any existing GCP checkpoints * Also rename variables in colab notebooks * Last fixup
1 parent 70ff7b8 commit 9bc5d37

13 files changed

+253
-261
lines changed

keras_nlp/layers/transformer_decoder.py

Lines changed: 69 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -133,74 +133,70 @@ def _build(self, input_shape, has_cross_attention):
133133
self._built = True
134134
self._input_shape = input_shape
135135
self._has_cross_attention = has_cross_attention
136-
feature_size = input_shape[-1]
137-
self._attention_head_size = int(feature_size // self.num_heads)
136+
# Infer the dimension of our hidden feature size from the build shape.
137+
hidden_dim = input_shape[-1]
138+
# Attention head size is `hidden_dim` over the number of heads.
139+
head_dim = int(hidden_dim // self.num_heads)
140+
141+
# Self attention layers.
138142
self._self_attention_layer = keras.layers.MultiHeadAttention(
139143
num_heads=self.num_heads,
140-
key_dim=self._attention_head_size,
141-
value_dim=self._attention_head_size,
144+
key_dim=head_dim,
142145
dropout=self.dropout,
143146
kernel_initializer=self.kernel_initializer,
144147
bias_initializer=self.bias_initializer,
145148
)
146149
self._self_attention_layer._build_from_signature(
147-
input_shape, input_shape
150+
query=input_shape,
151+
value=input_shape,
148152
)
149-
150-
self._decoder_attention_layernorm = keras.layers.LayerNormalization(
153+
self._self_attention_layernorm = keras.layers.LayerNormalization(
151154
epsilon=self.layer_norm_epsilon,
152155
)
156+
self._self_attention_dropout = keras.layers.Dropout(
157+
rate=self.dropout,
158+
)
153159

160+
# Cross attention layers are optional.
154161
self._cross_attention_layer = None
155162
if has_cross_attention:
156-
# Create layers for cross attention.
157163
self._cross_attention_layer = keras.layers.MultiHeadAttention(
158164
num_heads=self.num_heads,
159-
key_dim=self._attention_head_size,
160-
value_dim=feature_size,
165+
key_dim=head_dim,
166+
value_dim=hidden_dim,
161167
dropout=self.dropout,
162168
kernel_initializer=self.kernel_initializer,
163169
bias_initializer=self.bias_initializer,
164170
)
165171
self._cross_attention_layer._build_from_signature(
166-
input_shape, input_shape
172+
query=input_shape,
173+
value=input_shape,
167174
)
168-
169175
self._cross_attention_layernorm = keras.layers.LayerNormalization(
170176
epsilon=self.layer_norm_epsilon,
171177
)
172-
173178
self._cross_attention_dropout = keras.layers.Dropout(
174179
rate=self.dropout,
175180
)
176181

177-
self._feedforward_layernorm = keras.layers.LayerNormalization(
178-
epsilon=self.layer_norm_epsilon,
179-
)
180-
181-
self._self_attention_dropout = keras.layers.Dropout(rate=self.dropout)
182-
183-
# First dense layer in the feedforward network, which maps input
184-
# feauture size to dimension `self.intermediate_dim`.
185-
self._intermediate_dense = keras.layers.Dense(
182+
# Feedforward layers.
183+
self._feedforward_intermediate_dense = keras.layers.Dense(
186184
self.intermediate_dim,
187185
activation=self.activation,
188186
kernel_initializer=self.kernel_initializer,
189187
bias_initializer=self.bias_initializer,
190188
)
191-
# Second dense layer in the feedforward network, which maps input
192-
# feature size back to the input feature size.
193-
self._output_dense = keras.layers.Dense(
194-
feature_size,
189+
self._feedforward_output_dense = keras.layers.Dense(
190+
hidden_dim,
195191
kernel_initializer=self.kernel_initializer,
196192
bias_initializer=self.bias_initializer,
197193
)
198-
self._output_dropout = keras.layers.Dropout(rate=self.dropout)
199-
200-
def _feedforward(self, input):
201-
x = self._intermediate_dense(input)
202-
x = self._output_dense(x)
203-
return self._output_dropout(x)
194+
self._feedforward_layernorm = keras.layers.LayerNormalization(
195+
epsilon=self.layer_norm_epsilon,
196+
)
197+
self._feedforward_dropout = keras.layers.Dropout(
198+
rate=self.dropout,
199+
)
204200

205201
def call(
206202
self,
@@ -232,6 +228,7 @@ def call(
232228
Returns:
233229
A Tensor of the same shape as the `decoder_sequence`.
234230
"""
231+
235232
has_encoder_sequence = encoder_sequence is not None
236233
if not self._built:
237234
self._build(decoder_sequence.shape, has_encoder_sequence)
@@ -257,71 +254,64 @@ def call(
257254
"This layer has been built with cross attention, but "
258255
"you did not provide encoder_sequence."
259256
)
257+
258+
x = decoder_sequence # Intermediate result.
259+
260+
# Compute self attention mask.
261+
self_attention_mask = compute_causal_mask(decoder_sequence)
260262
decoder_mask = merge_padding_and_attention_mask(
261263
decoder_sequence, decoder_padding_mask, decoder_attention_mask
262264
)
263-
causal_mask = tf.cast(
264-
compute_causal_mask(decoder_sequence),
265-
dtype=tf.int32,
266-
)
267-
if decoder_mask is None:
268-
decoder_mask = causal_mask
269-
else:
270-
decoder_mask = tf.minimum(decoder_mask, causal_mask)
265+
if decoder_mask is not None:
266+
self_attention_mask = tf.minimum(decoder_mask, self_attention_mask)
271267

272-
residual_decoder_sequence = decoder_sequence
268+
# Self attention block.
269+
residual = x
273270
if self.normalize_first:
274-
decoder_sequence = self._decoder_attention_layernorm(
275-
decoder_sequence
276-
)
277-
# Decoder input self-attention.
278-
self_attended = self._self_attention_layer(
279-
decoder_sequence,
280-
decoder_sequence,
281-
decoder_sequence,
282-
attention_mask=decoder_mask,
271+
x = self._self_attention_layernorm(x)
272+
x = self._self_attention_layer(
273+
query=x,
274+
value=x,
275+
attention_mask=self_attention_mask,
283276
)
284-
self_attended = self._self_attention_dropout(self_attended)
285-
attention_output = residual_decoder_sequence + self_attended
277+
x = self._self_attention_dropout(x)
278+
x = x + residual
286279
if not self.normalize_first:
287-
attention_output = self._decoder_attention_layernorm(
288-
attention_output
289-
)
280+
x = self._self_attention_layernorm(x)
290281

282+
# Cross attention is optional.
291283
if self._cross_attention_layer is not None:
292-
encoder_mask = merge_padding_and_attention_mask(
284+
# Compute cross attention mask.
285+
cross_attention_mask = merge_padding_and_attention_mask(
293286
encoder_sequence, encoder_padding_mask, encoder_attention_mask
294287
)
295-
residual_attention_output = attention_output
288+
289+
# Cross attention block.
290+
residual = x
296291
if self.normalize_first:
297-
attention_output = self._cross_attention_layernorm(
298-
attention_output
299-
)
300-
# Cross attention.
301-
cross_attended = self._cross_attention_layer(
302-
query=attention_output,
292+
x = self._cross_attention_layernorm(x)
293+
x = self._cross_attention_layer(
294+
query=x,
303295
value=encoder_sequence,
304-
key=encoder_sequence,
305-
attention_mask=encoder_mask,
306-
)
307-
cross_attended = self._cross_attention_dropout(
308-
cross_attended,
296+
attention_mask=cross_attention_mask,
309297
)
310-
attention_output = residual_attention_output + cross_attended
298+
x = self._cross_attention_dropout(x)
299+
x = x + residual
311300
if not self.normalize_first:
312-
attention_output = self._cross_attention_layernorm(
313-
attention_output
314-
)
301+
x = self._cross_attention_layernorm(x)
315302

316-
residual_attention_output = attention_output
303+
# Feedforward block.
304+
residual = x
317305
if self.normalize_first:
318-
attention_output = self._feedforward_layernorm(attention_output)
319-
# Feedforward.
320-
feedforward_output = self._feedforward(attention_output)
321-
feedforward_output = residual_attention_output + feedforward_output
306+
x = self._feedforward_layernorm(x)
307+
x = self._feedforward_intermediate_dense(x)
308+
x = self._feedforward_output_dense(x)
309+
x = self._feedforward_dropout(x)
310+
x = x + residual
322311
if not self.normalize_first:
323-
feedforward_output = self._feedforward_layernorm(feedforward_output)
324-
return feedforward_output
312+
x = self._feedforward_layernorm(x)
313+
314+
return x
325315

326316
def get_config(self):
327317
config = super().get_config()

keras_nlp/layers/transformer_decoder_test.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,9 @@ def test_checkpointing_transformer_decoder(self):
199199
decoder1(decoder_sequence, encoder_sequence)
200200
decoder2(decoder_sequence, encoder_sequence)
201201
# The weights of decoder1 and decoder2 are different.
202-
self.assertFalse(
203-
all(
204-
decoder1._output_dense.trainable_variables[0][0]
205-
== decoder2._output_dense.trainable_variables[0][0]
206-
)
202+
self.assertNotAllClose(
203+
decoder1.trainable_variables[0][0],
204+
decoder2.trainable_variables[0][0],
207205
)
208206
checkpoint = tf.train.Checkpoint(decoder1)
209207
checkpoint2 = tf.train.Checkpoint(decoder2)
@@ -230,11 +228,9 @@ def test_checkpointing_transformer_decoder_without_cross_attention(self):
230228
decoder1(decoder_sequence)
231229
decoder2(decoder_sequence)
232230
# The weights of decoder1 and decoder2 are different.
233-
self.assertFalse(
234-
all(
235-
decoder1._output_dense.trainable_variables[0][0]
236-
== decoder2._output_dense.trainable_variables[0][0]
237-
)
231+
self.assertNotAllClose(
232+
decoder1.trainable_variables[0][0],
233+
decoder2.trainable_variables[0][0],
238234
)
239235
checkpoint = tf.train.Checkpoint(decoder1)
240236
checkpoint2 = tf.train.Checkpoint(decoder2)

keras_nlp/layers/transformer_encoder.py

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -114,53 +114,55 @@ def _build(self, input_shape):
114114
# Create layers based on input shape.
115115
self._built = True
116116
self._input_shape = input_shape
117-
feature_size = input_shape[-1]
118-
self._attention_head_size = int(feature_size // self.num_heads)
119-
self._multi_head_attention_layer = keras.layers.MultiHeadAttention(
117+
# Infer the dimension of our hidden feature size from the build shape.
118+
hidden_dim = input_shape[-1]
119+
# Attention head size is `hidden_dim` over the number of heads.
120+
key_dim = int(hidden_dim // self.num_heads)
121+
122+
# Self attention layers.
123+
self._self_attention_layer = keras.layers.MultiHeadAttention(
120124
num_heads=self.num_heads,
121-
key_dim=self._attention_head_size,
122-
value_dim=self._attention_head_size,
125+
key_dim=key_dim,
123126
dropout=self.dropout,
124127
kernel_initializer=self.kernel_initializer,
125128
bias_initializer=self.bias_initializer,
126129
)
127-
self._multi_head_attention_layer._build_from_signature(
128-
input_shape, input_shape
130+
self._self_attention_layer._build_from_signature(
131+
query=input_shape,
132+
value=input_shape,
129133
)
130-
131-
self._attention_layernorm = keras.layers.LayerNormalization(
134+
self._self_attention_layernorm = keras.layers.LayerNormalization(
132135
epsilon=self.layer_norm_epsilon,
133136
)
137+
self._self_attention_dropout = keras.layers.Dropout(
138+
rate=self.dropout,
139+
)
140+
141+
# Feedforward layers.
134142
self._feedforward_layernorm = keras.layers.LayerNormalization(
135143
epsilon=self.layer_norm_epsilon,
136144
)
137-
138-
self._attention_dropout = keras.layers.Dropout(rate=self.dropout)
139-
140-
self._intermediate_dense = keras.layers.Dense(
145+
self._feedforward_intermediate_dense = keras.layers.Dense(
141146
self.intermediate_dim,
142147
activation=self.activation,
143148
kernel_initializer=self.kernel_initializer,
144149
bias_initializer=self.bias_initializer,
145150
)
146-
self._output_dense = keras.layers.Dense(
147-
feature_size,
151+
self._feedforward_output_dense = keras.layers.Dense(
152+
hidden_dim,
148153
kernel_initializer=self.kernel_initializer,
149154
bias_initializer=self.bias_initializer,
150155
)
151-
self._output_dropout = keras.layers.Dropout(rate=self.dropout)
152-
153-
def _feedforward(self, input):
154-
x = self._intermediate_dense(input)
155-
x = self._output_dense(x)
156-
return self._output_dropout(x)
156+
self._feedforward_dropout = keras.layers.Dropout(
157+
rate=self.dropout,
158+
)
157159

158160
def call(self, inputs, padding_mask=None, attention_mask=None):
159161
"""Forward pass of the TransformerEncoder.
160162
161163
Args:
162164
inputs: a Tensor. The input data to TransformerEncoder, should be
163-
of shape [batch_size, sequence_length, feature_dim].
165+
of shape [batch_size, sequence_length, hidden_dim].
164166
padding_mask: a boolean Tensor. It indicates if the token should be
165167
masked because the token is introduced due to padding.
166168
`padding_mask` should have shape [batch_size, sequence_length].
@@ -176,33 +178,39 @@ def call(self, inputs, padding_mask=None, attention_mask=None):
176178
if not self._built:
177179
self._build(inputs.shape)
178180

179-
mask = merge_padding_and_attention_mask(
180-
inputs,
181-
padding_mask,
182-
attention_mask,
181+
x = inputs # Intermediate result.
182+
183+
# Compute self attention mask.
184+
self_attention_mask = merge_padding_and_attention_mask(
185+
inputs, padding_mask, attention_mask
183186
)
184187

185-
residual_inputs = inputs
188+
# Self attention block.
189+
residual = x
186190
if self.normalize_first:
187-
inputs = self._attention_layernorm(inputs)
188-
# Self attention.
189-
attended = self._multi_head_attention_layer(
190-
inputs, inputs, inputs, attention_mask=mask
191+
x = self._self_attention_layernorm(x)
192+
x = self._self_attention_layer(
193+
query=x,
194+
value=x,
195+
attention_mask=self_attention_mask,
191196
)
192-
attended = self._attention_dropout(attended)
193-
attended = residual_inputs + attended
197+
x = self._self_attention_dropout(x)
198+
x = x + residual
194199
if not self.normalize_first:
195-
attended = self._attention_layernorm(attended)
200+
x = self._self_attention_layernorm(x)
196201

197-
residual_attended = attended
202+
# Feedforward block.
203+
residual = x
198204
if self.normalize_first:
199-
attended = self._feedforward_layernorm(attended)
200-
# Feedforward.
201-
feedforward_output = self._feedforward(attended)
202-
feedforward_output = residual_attended + feedforward_output
205+
x = self._feedforward_layernorm(x)
206+
x = self._feedforward_intermediate_dense(x)
207+
x = self._feedforward_output_dense(x)
208+
x = self._feedforward_dropout(x)
209+
x = x + residual
203210
if not self.normalize_first:
204-
feedforward_output = self._feedforward_layernorm(feedforward_output)
205-
return feedforward_output
211+
x = self._feedforward_layernorm(x)
212+
213+
return x
206214

207215
def get_config(self):
208216
config = super().get_config()

0 commit comments

Comments
 (0)