@@ -146,16 +146,16 @@ def forward(self, x, mask = None):
146
146
147
147
q = q * self .scale
148
148
149
- sim = torch .einsum ('b h i d, b h j d -> b h i j' , q , k )
149
+ sim = torch .einsum ('b h i d, b h j d -> b h i j' , q , k ) # (b, h, n, n)
150
150
151
151
sim = self .relative_position_bias (sim )
152
152
153
- # mask
153
+ # mask (b, n)
154
154
155
155
mask_value = - torch .finfo (sim .dtype ).max
156
156
157
157
if mask is not None :
158
- sim = sim .masked_fill_ (~ mask , mask_value )
158
+ sim = sim .masked_fill_ (~ mask [:, None , :, None ] , mask_value )
159
159
160
160
if self .causal :
161
161
i , j = sim .shape [- 2 :]
@@ -222,19 +222,19 @@ def forward(self, x, context, mask = None, context_mask = None):
222
222
223
223
q = q * self .scale
224
224
225
- sim = torch .einsum ('b h i d, b h j d -> b h i j' , q , k )
225
+ sim = torch .einsum ('b h i d, b h j d -> b h i j' , q , k ) # (b, h, n, n)
226
226
227
227
#sim = self.relative_position_bias(sim)
228
228
229
- # mask
229
+ # mask (b, n)
230
230
231
231
mask_value = - torch .finfo (sim .dtype ).max
232
232
233
233
if mask is not None :
234
- sim = sim .masked_fill_ (~ mask , mask_value )
234
+ sim = sim .masked_fill_ (~ mask [:, None , :, None ] , mask_value )
235
235
236
236
if context_mask is not None :
237
- sim = sim .masked_fill_ (~ context_mask [:, None , :], mask_value )
237
+ sim = sim .masked_fill_ (~ context_mask [:, None , :, None ], mask_value )
238
238
239
239
# attention
240
240
@@ -360,7 +360,6 @@ def __init__(
360
360
):
361
361
super ().__init__ ()
362
362
363
- self .embedding = nn .Embedding (enc_num_tokens , dim )
364
363
#self.pos_emb = nn.Embedding(max_seq_len, dim)
365
364
366
365
self .encoder = T5Encoder (
@@ -392,7 +391,6 @@ def __init__(
392
391
self .encoder .token_emb .weight = self .decoder .token_emb .weight
393
392
394
393
def forward (self , src , tgt , mask = None , context_mask = None ):
395
- x = self .embedding (src )
396
394
#x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
397
395
x = self .encoder (src , mask = mask )
398
396
x = self .decoder (tgt , x , mask = mask , context_mask = context_mask )
0 commit comments