Skip to content

Commit 0829967

Browse files
author
xiayubin
committed
fix copy transformer
1 parent bcc8445 commit 0829967

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# -*- coding: UTF-8 -*-

deep_keyphrase/copy_transformer/model.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# -*- coding: UTF-8 -*-
22
import torch
33
import torch.nn as nn
4+
import torch.nn.functional as F
45
from torch.nn.modules.transformer import (TransformerEncoder, TransformerDecoder,
56
TransformerEncoderLayer, TransformerDecoderLayer)
67
from deep_keyphrase.dataloader import TOKENS, TOKENS_LENS, TOKENS_OOV, UNK_WORD, PAD_WORD, OOV_COUNT
78

89

9-
def get_position_encoding(input_tensor, position, dim_size):
10+
def get_position_encoding(input_tensor):
11+
batch_size, position, dim_size = input_tensor.size()
1012
assert dim_size % 2 == 0
11-
batch_size = len(input_tensor)
1213
num_timescales = dim_size // 2
1314
time_scales = torch.arange(0, position + 1, dtype=torch.float).unsqueeze(1)
1415
dim_scales = torch.arange(0, num_timescales, dtype=torch.float).unsqueeze(0)
@@ -68,6 +69,7 @@ def __init__(self, embedding, input_dim, head_size,
6869
feed_forward_dim, dropout, num_layers):
6970
super().__init__()
7071
self.embedding = embedding
72+
self.dropout = dropout
7173
layer = TransformerEncoderLayer(d_model=input_dim,
7274
nhead=head_size,
7375
dim_feedforward=feed_forward_dim,
@@ -76,22 +78,26 @@ def __init__(self, embedding, input_dim, head_size,
7678

7779
def forward(self, src_dict):
7880
batch_size, max_len = src_dict[TOKENS].size()
79-
# mask_range = torch.arange(max_len).expand(batch_size, max_len)
80-
mask_range = torch.zeros(batch_size, max_len, dtype=torch.bool)
81+
mask_range = torch.arange(max_len).unsqueeze(0).repeat(batch_size, 1)
82+
8183
if torch.cuda.is_available():
8284
mask_range = mask_range.cuda()
83-
# mask = mask_range > src_dict[TOKENS_LENS].unsqueeze(1)
85+
mask = mask_range >= src_dict[TOKENS_LENS]
86+
# mask = (mask_range > src_dict[TOKENS_LENS].unsqueeze(1)).expand(batch_size, max_len, max_len)
8487
src_embed = self.embedding(src_dict[TOKENS]).transpose(1, 0)
85-
# print(src_embed.size(), mask.size())
86-
output = self.encoder(src_embed).transpose(1, 0)
87-
return output, mask_range
88+
pos_embed = get_position_encoding(src_embed)
89+
src_embed = src_embed + pos_embed
90+
src_embed = F.dropout(src_embed, p=self.dropout, training=self.training)
91+
output = self.encoder(src_embed, src_key_padding_mask=mask).transpose(1, 0)
92+
return output, mask
8893

8994

9095
class CopyTransformerDecoder(nn.Module):
9196
def __init__(self, embedding, input_dim, vocab2id, head_size, feed_forward_dim,
9297
dropout, num_layers, target_max_len, max_oov_count):
9398
super().__init__()
9499
self.embedding = embedding
100+
self.dropout = dropout
95101
self.vocab_size = embedding.num_embeddings
96102
self.vocab2id = vocab2id
97103
layer = TransformerDecoderLayer(d_model=input_dim,
@@ -124,18 +130,19 @@ def forward(self, prev_output_tokens, prev_decoder_state, position,
124130
# map copied oov tokens to OOV idx to avoid embedding lookup error
125131
prev_output_tokens[prev_output_tokens >= self.vocab_size] = self.vocab2id[UNK_WORD]
126132
token_embed = self.embedding(prev_output_tokens)
127-
pos_embed = get_position_encoding(token_embed, position, self.input_dim)
133+
134+
pos_embed = get_position_encoding(token_embed)
128135
# B x seq_len x H
129136
src_embed = token_embed + pos_embed
130-
# print(token_embed.size(),pos_embed.size())
131-
# print(src_embed.size(),copy_state.size())
132137
decoder_input = self.embed_proj(torch.cat([src_embed, copy_state], dim=2)).transpose(1, 0)
133-
# print(decoder_input.size())
138+
decoder_input = F.dropout(decoder_input, p=self.dropout, training=self.training)
134139
decoder_input_mask = torch.triu(torch.ones(self.input_dim, self.input_dim), 1)
135140
# B x seq_len x H
136-
decoder_output = self.decoder(tgt=decoder_input, memory=encoder_output.transpose(1, 0), )
141+
decoder_output = self.decoder(tgt=decoder_input,
142+
memory=encoder_output.transpose(1, 0),
143+
memory_key_padding_mask=decoder_input_mask)
137144
decoder_output = decoder_output.transpose(1, 0)
138-
# tgt_mask=decoder_input_mask, memory_mask=encoder_mask)
145+
139146
# B x 1 x H
140147
decoder_output = decoder_output[:, -1:, :]
141148
generation_logits = self.generate_proj(decoder_output).squeeze(1)

0 commit comments

Comments
 (0)