Skip to content
This repository was archived by the owner on Jun 21, 2024. It is now read-only.

Commit c2deb32

Browse files
committed
trying to solve the dimensionality bug
1 parent 301b04f commit c2deb32

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

t5_pytorch/t5_pytorch.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,16 @@ def forward(self, x, mask = None):
146146

147147
q = q * self.scale
148148

149-
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
149+
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) # (b, h, n, n)
150150

151151
sim = self.relative_position_bias(sim)
152152

153-
# mask
153+
# mask (b, n)
154154

155155
mask_value = -torch.finfo(sim.dtype).max
156156

157157
if mask is not None:
158-
sim = sim.masked_fill_(~mask, mask_value)
158+
sim = sim.masked_fill_(~mask[:, None, :, None], mask_value)
159159

160160
if self.causal:
161161
i, j = sim.shape[-2:]
@@ -222,19 +222,19 @@ def forward(self, x, context, mask = None, context_mask = None):
222222

223223
q = q * self.scale
224224

225-
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
225+
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) # (b, h, n, n)
226226

227227
#sim = self.relative_position_bias(sim)
228228

229-
# mask
229+
# mask (b, n)
230230

231231
mask_value = -torch.finfo(sim.dtype).max
232232

233233
if mask is not None:
234-
sim = sim.masked_fill_(~mask, mask_value)
234+
sim = sim.masked_fill_(~mask[:, None, :, None], mask_value)
235235

236236
if context_mask is not None:
237-
sim = sim.masked_fill_(~context_mask[:, None, :], mask_value)
237+
sim = sim.masked_fill_(~context_mask[:, None, :, None], mask_value)
238238

239239
# attention
240240

@@ -360,7 +360,6 @@ def __init__(
360360
):
361361
super().__init__()
362362

363-
self.embedding = nn.Embedding(enc_num_tokens, dim)
364363
#self.pos_emb = nn.Embedding(max_seq_len, dim)
365364

366365
self.encoder = T5Encoder(
@@ -392,7 +391,6 @@ def __init__(
392391
self.encoder.token_emb.weight = self.decoder.token_emb.weight
393392

394393
def forward(self, src, tgt, mask = None, context_mask = None):
395-
x = self.embedding(src)
396394
#x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
397395
x = self.encoder(src, mask = mask)
398396
x = self.decoder(tgt, x, mask = mask, context_mask = context_mask)

0 commit comments

Comments
 (0)