1
1
# -*- coding: UTF-8 -*-
2
2
import torch
3
3
import torch .nn as nn
4
+ import torch .nn .functional as F
4
5
from torch .nn .modules .transformer import (TransformerEncoder , TransformerDecoder ,
5
6
TransformerEncoderLayer , TransformerDecoderLayer )
6
7
from deep_keyphrase .dataloader import TOKENS , TOKENS_LENS , TOKENS_OOV , UNK_WORD , PAD_WORD , OOV_COUNT
7
8
8
9
9
- def get_position_encoding (input_tensor , position , dim_size ):
10
+ def get_position_encoding (input_tensor ):
11
+ batch_size , position , dim_size = input_tensor .size ()
10
12
assert dim_size % 2 == 0
11
- batch_size = len (input_tensor )
12
13
num_timescales = dim_size // 2
13
14
time_scales = torch .arange (0 , position + 1 , dtype = torch .float ).unsqueeze (1 )
14
15
dim_scales = torch .arange (0 , num_timescales , dtype = torch .float ).unsqueeze (0 )
@@ -68,6 +69,7 @@ def __init__(self, embedding, input_dim, head_size,
68
69
feed_forward_dim , dropout , num_layers ):
69
70
super ().__init__ ()
70
71
self .embedding = embedding
72
+ self .dropout = dropout
71
73
layer = TransformerEncoderLayer (d_model = input_dim ,
72
74
nhead = head_size ,
73
75
dim_feedforward = feed_forward_dim ,
@@ -76,22 +78,26 @@ def __init__(self, embedding, input_dim, head_size,
76
78
77
79
def forward (self , src_dict ):
78
80
batch_size , max_len = src_dict [TOKENS ].size ()
79
- # mask_range = torch.arange(max_len).expand( batch_size, max_len )
80
- mask_range = torch . zeros ( batch_size , max_len , dtype = torch . bool )
81
+ mask_range = torch .arange (max_len ).unsqueeze ( 0 ). repeat ( batch_size , 1 )
82
+
81
83
if torch .cuda .is_available ():
82
84
mask_range = mask_range .cuda ()
83
- # mask = mask_range > src_dict[TOKENS_LENS].unsqueeze(1)
85
+ mask = mask_range >= src_dict [TOKENS_LENS ]
86
+ # mask = (mask_range > src_dict[TOKENS_LENS].unsqueeze(1)).expand(batch_size, max_len, max_len)
84
87
src_embed = self .embedding (src_dict [TOKENS ]).transpose (1 , 0 )
85
- # print(src_embed.size(), mask.size())
86
- output = self .encoder (src_embed ).transpose (1 , 0 )
87
- return output , mask_range
88
+ pos_embed = get_position_encoding (src_embed )
89
+ src_embed = src_embed + pos_embed
90
+ src_embed = F .dropout (src_embed , p = self .dropout , training = self .training )
91
+ output = self .encoder (src_embed , src_key_padding_mask = mask ).transpose (1 , 0 )
92
+ return output , mask
88
93
89
94
90
95
class CopyTransformerDecoder (nn .Module ):
91
96
def __init__ (self , embedding , input_dim , vocab2id , head_size , feed_forward_dim ,
92
97
dropout , num_layers , target_max_len , max_oov_count ):
93
98
super ().__init__ ()
94
99
self .embedding = embedding
100
+ self .dropout = dropout
95
101
self .vocab_size = embedding .num_embeddings
96
102
self .vocab2id = vocab2id
97
103
layer = TransformerDecoderLayer (d_model = input_dim ,
@@ -124,18 +130,19 @@ def forward(self, prev_output_tokens, prev_decoder_state, position,
124
130
# map copied oov tokens to OOV idx to avoid embedding lookup error
125
131
prev_output_tokens [prev_output_tokens >= self .vocab_size ] = self .vocab2id [UNK_WORD ]
126
132
token_embed = self .embedding (prev_output_tokens )
127
- pos_embed = get_position_encoding (token_embed , position , self .input_dim )
133
+
134
+ pos_embed = get_position_encoding (token_embed )
128
135
# B x seq_len x H
129
136
src_embed = token_embed + pos_embed
130
- # print(token_embed.size(),pos_embed.size())
131
- # print(src_embed.size(),copy_state.size())
132
137
decoder_input = self .embed_proj (torch .cat ([src_embed , copy_state ], dim = 2 )).transpose (1 , 0 )
133
- # print (decoder_input.size() )
138
+ decoder_input = F . dropout (decoder_input , p = self . dropout , training = self . training )
134
139
decoder_input_mask = torch .triu (torch .ones (self .input_dim , self .input_dim ), 1 )
135
140
# B x seq_len x H
136
- decoder_output = self .decoder (tgt = decoder_input , memory = encoder_output .transpose (1 , 0 ), )
141
+ decoder_output = self .decoder (tgt = decoder_input ,
142
+ memory = encoder_output .transpose (1 , 0 ),
143
+ memory_key_padding_mask = decoder_input_mask )
137
144
decoder_output = decoder_output .transpose (1 , 0 )
138
- # tgt_mask=decoder_input_mask, memory_mask=encoder_mask)
145
+
139
146
# B x 1 x H
140
147
decoder_output = decoder_output [:, - 1 :, :]
141
148
generation_logits = self .generate_proj (decoder_output ).squeeze (1 )
0 commit comments