@@ -153,7 +153,7 @@ def _mha_block(self, input, key_value, attention_mask, key_padding_mask):
153
153
key = key_value ,
154
154
value = key_value ,
155
155
need_weights = False ,
156
- attention_mask = attention_mask ,
156
+ attn_mask = attention_mask ,
157
157
key_padding_mask = key_padding_mask
158
158
)[0 ]
159
159
return self .dropout2 (input )
@@ -194,7 +194,7 @@ def __init__(
194
194
) for _ in range (num_layers )
195
195
])
196
196
197
- self .out_ln = nn .LayerNorm () if out_ln else nn .Identity ()
197
+ self .out_ln = nn .LayerNorm (embed_dim ) if out_ln else nn .Identity ()
198
198
self .causal = causal
199
199
200
200
def forward (
@@ -254,7 +254,7 @@ def __init__(
254
254
) for _ in range (num_layers )
255
255
])
256
256
257
- self .out_ln = nn .LayerNorm () if out_ln else nn .Identity ()
257
+ self .out_ln = nn .LayerNorm (embed_dim ) if out_ln else nn .Identity ()
258
258
self .causal = causal
259
259
260
260
def forward (
@@ -276,10 +276,7 @@ def forward(
276
276
if tgt_attention_mask is not None :
277
277
tgt_mask = torch .bitwise_or (tgt_attention_mask .to (torch .bool ), tgt_mask )
278
278
if do_embedding :
279
- tgt = self .input_embed (tgt )
280
- if timesteps is not None :
281
- timesteps = torch .arange (L ).repeat (B , 1 ).to (tgt .device )
282
- tgt = tgt + self .pos_embed (timesteps )
279
+ tgt = self .pos_encoding (self .input_embed (tgt ))
283
280
output = self .embed_dropout (tgt )
284
281
for i , block in enumerate (self .blocks ):
285
282
output = block (
0 commit comments