Skip to content

Commit a1c4457

Browse files
committed
fix: transformer refactor bugs
1 parent c2d8326 commit a1c4457

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

offlinerllib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11

2-
__version__ = "0.1.1"
2+
__version__ = "0.1.2"

offlinerllib/module/net/attention/transformer.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _mha_block(self, input, key_value, attention_mask, key_padding_mask):
153153
key=key_value,
154154
value=key_value,
155155
need_weights=False,
156-
attention_mask=attention_mask,
156+
attn_mask=attention_mask,
157157
key_padding_mask=key_padding_mask
158158
)[0]
159159
return self.dropout2(input)
@@ -194,7 +194,7 @@ def __init__(
194194
) for _ in range(num_layers)
195195
])
196196

197-
self.out_ln = nn.LayerNorm() if out_ln else nn.Identity()
197+
self.out_ln = nn.LayerNorm(embed_dim) if out_ln else nn.Identity()
198198
self.causal = causal
199199

200200
def forward(
@@ -254,7 +254,7 @@ def __init__(
254254
) for _ in range(num_layers)
255255
])
256256

257-
self.out_ln = nn.LayerNorm() if out_ln else nn.Identity()
257+
self.out_ln = nn.LayerNorm(embed_dim) if out_ln else nn.Identity()
258258
self.causal = causal
259259

260260
def forward(
@@ -276,10 +276,7 @@ def forward(
276276
if tgt_attention_mask is not None:
277277
tgt_mask = torch.bitwise_or(tgt_attention_mask.to(torch.bool), tgt_mask)
278278
if do_embedding:
279-
tgt = self.input_embed(tgt)
280-
if timesteps is not None:
281-
timesteps = torch.arange(L).repeat(B, 1).to(tgt.device)
282-
tgt = tgt + self.pos_embed(timesteps)
279+
tgt = self.pos_encoding(self.input_embed(tgt))
283280
output = self.embed_dropout(tgt)
284281
for i, block in enumerate(self.blocks):
285282
output = block(

0 commit comments

Comments
 (0)