diff --git a/t5_pytorch/t5_pytorch.py b/t5_pytorch/t5_pytorch.py index 26173b7..498a3fb 100644 --- a/t5_pytorch/t5_pytorch.py +++ b/t5_pytorch/t5_pytorch.py @@ -146,16 +146,16 @@ def forward(self, x, mask = None): q = q * self.scale - sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) + sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) # (b, h, n, n) sim = self.relative_position_bias(sim) - # mask + # mask (b, n) mask_value = -torch.finfo(sim.dtype).max if mask is not None: - sim = sim.masked_fill_(~mask, mask_value) + sim = sim.masked_fill_(~mask[:, None, :, None], mask_value) if self.causal: i, j = sim.shape[-2:] @@ -222,19 +222,19 @@ def forward(self, x, context, mask = None, context_mask = None): q = q * self.scale - sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) + sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) # (b, h, n, n) #sim = self.relative_position_bias(sim) - # mask + # mask (b, n) mask_value = -torch.finfo(sim.dtype).max if mask is not None: - sim = sim.masked_fill_(~mask, mask_value) + sim = sim.masked_fill_(~mask[:, None, :, None], mask_value) if context_mask is not None: - sim = sim.masked_fill_(~context_mask[:, None, :], mask_value) + sim = sim.masked_fill_(~context_mask[:, None, None, :], mask_value) # attention @@ -360,7 +360,6 @@ def __init__( ): super().__init__() - self.embedding = nn.Embedding(enc_num_tokens, dim) #self.pos_emb = nn.Embedding(max_seq_len, dim) self.encoder = T5Encoder( @@ -392,10 +391,9 @@ def __init__( self.encoder.token_emb.weight = self.decoder.token_emb.weight def forward(self, src, tgt, mask = None, context_mask = None): - x = self.embedding(src) #x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device)) x = self.encoder(src, mask = mask) - x = self.decoder(tgt, x, mask = mask, context_mask = context_mask) + x = self.decoder(tgt, x, mask = context_mask, context_mask = mask) x = self.to_logits(x) return x