@@ -133,74 +133,70 @@ def _build(self, input_shape, has_cross_attention):
133133 self ._built = True
134134 self ._input_shape = input_shape
135135 self ._has_cross_attention = has_cross_attention
136- feature_size = input_shape [- 1 ]
137- self ._attention_head_size = int (feature_size // self .num_heads )
136+ # Infer the dimension of our hidden feature size from the build shape.
137+ hidden_dim = input_shape [- 1 ]
138+ # Attention head size is `hidden_dim` over the number of heads.
139+ head_dim = int (hidden_dim // self .num_heads )
140+
141+ # Self attention layers.
138142 self ._self_attention_layer = keras .layers .MultiHeadAttention (
139143 num_heads = self .num_heads ,
140- key_dim = self ._attention_head_size ,
141- value_dim = self ._attention_head_size ,
144+ key_dim = head_dim ,
142145 dropout = self .dropout ,
143146 kernel_initializer = self .kernel_initializer ,
144147 bias_initializer = self .bias_initializer ,
145148 )
146149 self ._self_attention_layer ._build_from_signature (
147- input_shape , input_shape
150+ query = input_shape ,
151+ value = input_shape ,
148152 )
149-
150- self ._decoder_attention_layernorm = keras .layers .LayerNormalization (
153+ self ._self_attention_layernorm = keras .layers .LayerNormalization (
151154 epsilon = self .layer_norm_epsilon ,
152155 )
156+ self ._self_attention_dropout = keras .layers .Dropout (
157+ rate = self .dropout ,
158+ )
153159
160+ # Cross attention layers are optional.
154161 self ._cross_attention_layer = None
155162 if has_cross_attention :
156- # Create layers for cross attention.
157163 self ._cross_attention_layer = keras .layers .MultiHeadAttention (
158164 num_heads = self .num_heads ,
159- key_dim = self . _attention_head_size ,
160- value_dim = feature_size ,
165+ key_dim = head_dim ,
166+ value_dim = hidden_dim ,
161167 dropout = self .dropout ,
162168 kernel_initializer = self .kernel_initializer ,
163169 bias_initializer = self .bias_initializer ,
164170 )
165171 self ._cross_attention_layer ._build_from_signature (
166- input_shape , input_shape
172+ query = input_shape ,
173+ value = input_shape ,
167174 )
168-
169175 self ._cross_attention_layernorm = keras .layers .LayerNormalization (
170176 epsilon = self .layer_norm_epsilon ,
171177 )
172-
173178 self ._cross_attention_dropout = keras .layers .Dropout (
174179 rate = self .dropout ,
175180 )
176181
177- self ._feedforward_layernorm = keras .layers .LayerNormalization (
178- epsilon = self .layer_norm_epsilon ,
179- )
180-
181- self ._self_attention_dropout = keras .layers .Dropout (rate = self .dropout )
182-
183- # First dense layer in the feedforward network, which maps input
184- # feauture size to dimension `self.intermediate_dim`.
185- self ._intermediate_dense = keras .layers .Dense (
182+ # Feedforward layers.
183+ self ._feedforward_intermediate_dense = keras .layers .Dense (
186184 self .intermediate_dim ,
187185 activation = self .activation ,
188186 kernel_initializer = self .kernel_initializer ,
189187 bias_initializer = self .bias_initializer ,
190188 )
191- # Second dense layer in the feedforward network, which maps input
192- # feature size back to the input feature size.
193- self ._output_dense = keras .layers .Dense (
194- feature_size ,
189+ self ._feedforward_output_dense = keras .layers .Dense (
190+ hidden_dim ,
195191 kernel_initializer = self .kernel_initializer ,
196192 bias_initializer = self .bias_initializer ,
197193 )
198- self ._output_dropout = keras .layers .Dropout ( rate = self . dropout )
199-
200- def _feedforward ( self , input ):
201- x = self . _intermediate_dense ( input )
202- x = self ._output_dense ( x )
203- return self . _output_dropout ( x )
194+ self ._feedforward_layernorm = keras .layers .LayerNormalization (
195+ epsilon = self . layer_norm_epsilon ,
196+ )
197+ self . _feedforward_dropout = keras . layers . Dropout (
198+ rate = self .dropout ,
199+ )
204200
205201 def call (
206202 self ,
@@ -232,6 +228,7 @@ def call(
232228 Returns:
233229 A Tensor of the same shape as the `decoder_sequence`.
234230 """
231+
235232 has_encoder_sequence = encoder_sequence is not None
236233 if not self ._built :
237234 self ._build (decoder_sequence .shape , has_encoder_sequence )
@@ -257,71 +254,64 @@ def call(
257254 "This layer has been built with cross attention, but "
258255 "you did not provide encoder_sequence."
259256 )
257+
258+ x = decoder_sequence # Intermediate result.
259+
260+ # Compute self attention mask.
261+ self_attention_mask = compute_causal_mask (decoder_sequence )
260262 decoder_mask = merge_padding_and_attention_mask (
261263 decoder_sequence , decoder_padding_mask , decoder_attention_mask
262264 )
263- causal_mask = tf .cast (
264- compute_causal_mask (decoder_sequence ),
265- dtype = tf .int32 ,
266- )
267- if decoder_mask is None :
268- decoder_mask = causal_mask
269- else :
270- decoder_mask = tf .minimum (decoder_mask , causal_mask )
265+ if decoder_mask is not None :
266+ self_attention_mask = tf .minimum (decoder_mask , self_attention_mask )
271267
272- residual_decoder_sequence = decoder_sequence
268+ # Self attention block.
269+ residual = x
273270 if self .normalize_first :
274- decoder_sequence = self ._decoder_attention_layernorm (
275- decoder_sequence
276- )
277- # Decoder input self-attention.
278- self_attended = self ._self_attention_layer (
279- decoder_sequence ,
280- decoder_sequence ,
281- decoder_sequence ,
282- attention_mask = decoder_mask ,
271+ x = self ._self_attention_layernorm (x )
272+ x = self ._self_attention_layer (
273+ query = x ,
274+ value = x ,
275+ attention_mask = self_attention_mask ,
283276 )
284- self_attended = self ._self_attention_dropout (self_attended )
285- attention_output = residual_decoder_sequence + self_attended
277+ x = self ._self_attention_dropout (x )
278+ x = x + residual
286279 if not self .normalize_first :
287- attention_output = self ._decoder_attention_layernorm (
288- attention_output
289- )
280+ x = self ._self_attention_layernorm (x )
290281
282+ # Cross attention is optional.
291283 if self ._cross_attention_layer is not None :
292- encoder_mask = merge_padding_and_attention_mask (
284+ # Compute cross attention mask.
285+ cross_attention_mask = merge_padding_and_attention_mask (
293286 encoder_sequence , encoder_padding_mask , encoder_attention_mask
294287 )
295- residual_attention_output = attention_output
288+
289+ # Cross attention block.
290+ residual = x
296291 if self .normalize_first :
297- attention_output = self ._cross_attention_layernorm (
298- attention_output
299- )
300- # Cross attention.
301- cross_attended = self ._cross_attention_layer (
302- query = attention_output ,
292+ x = self ._cross_attention_layernorm (x )
293+ x = self ._cross_attention_layer (
294+ query = x ,
303295 value = encoder_sequence ,
304- key = encoder_sequence ,
305- attention_mask = encoder_mask ,
306- )
307- cross_attended = self ._cross_attention_dropout (
308- cross_attended ,
296+ attention_mask = cross_attention_mask ,
309297 )
310- attention_output = residual_attention_output + cross_attended
298+ x = self ._cross_attention_dropout (x )
299+ x = x + residual
311300 if not self .normalize_first :
312- attention_output = self ._cross_attention_layernorm (
313- attention_output
314- )
301+ x = self ._cross_attention_layernorm (x )
315302
316- residual_attention_output = attention_output
303+ # Feedforward block.
304+ residual = x
317305 if self .normalize_first :
318- attention_output = self ._feedforward_layernorm (attention_output )
319- # Feedforward.
320- feedforward_output = self ._feedforward (attention_output )
321- feedforward_output = residual_attention_output + feedforward_output
306+ x = self ._feedforward_layernorm (x )
307+ x = self ._feedforward_intermediate_dense (x )
308+ x = self ._feedforward_output_dense (x )
309+ x = self ._feedforward_dropout (x )
310+ x = x + residual
322311 if not self .normalize_first :
323- feedforward_output = self ._feedforward_layernorm (feedforward_output )
324- return feedforward_output
312+ x = self ._feedforward_layernorm (x )
313+
314+ return x
325315
326316 def get_config (self ):
327317 config = super ().get_config ()
0 commit comments