Skip to content

Commit 3633c89

Browse files
implement copy transformer
1 parent abf387c commit 3633c89

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# -*- coding: UTF-8 -*-
2+
import os
3+
import torch
4+
from pysenal import append_jsonlines
5+
from deep_keyphrase.base_predictor import BasePredictor
6+
from deep_keyphrase.dataloader import KeyphraseDataLoader, TOKENS, RAW_BATCH
7+
from deep_keyphrase.utils.constants import BOS_WORD
8+
from deep_keyphrase.utils.vocab_loader import load_vocab
9+
from deep_keyphrase.utils.tokenizer import token_char_tokenize
10+
from .model import CopyTransformer
11+
from .beam_search import TransformerBeamSearch
12+
13+
14+
class CopyTransformerPredictor(BasePredictor):
15+
def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_length):
16+
super().__init__(model_info)
17+
if isinstance(vocab_info, str):
18+
self.vocab2id = load_vocab(vocab_info)
19+
elif isinstance(vocab_info, dict):
20+
self.vocab2id = vocab_info
21+
else:
22+
raise ValueError('vocab info type error')
23+
self.id2vocab = dict(zip(self.vocab2id.values(), self.vocab2id.keys()))
24+
self.config = self.load_config(model_info)
25+
self.model = self.load_model(model_info, CopyTransformer(self.config, self.vocab2id))
26+
self.model.eval()
27+
self.beam_size = beam_size
28+
self.max_target_len = max_target_len
29+
self.max_src_len = max_src_length
30+
self.beam_searcher = TransformerBeamSearch(model=self.model,
31+
beam_size=self.beam_size,
32+
max_target_len=self.max_target_len,
33+
id2vocab=self.id2vocab,
34+
bos_idx=self.vocab2id[BOS_WORD],
35+
args=self.config)
36+
37+
def predict(self, text_list, batch_size, delimiter=None):
38+
self.model.eval()
39+
if len(text_list) < batch_size:
40+
batch_size = len(text_list)
41+
text_list = [{TOKENS: token_char_tokenize(i)} for i in text_list]
42+
loader = KeyphraseDataLoader(data_source=text_list,
43+
vocab2id=self.vocab2id,
44+
batch_size=batch_size,
45+
max_oov_count=self.config.max_oov_count,
46+
max_src_len=self.max_src_len,
47+
max_target_len=self.max_target_len,
48+
mode='inference')
49+
result = []
50+
for batch in loader:
51+
with torch.no_grad():
52+
result.extend(self.beam_searcher.beam_search(batch, delimiter=delimiter))
53+
return result
54+
55+
def eval_predict(self, src_filename, dest_filename, batch_size,
56+
model=None, remove_existed=False,
57+
token_field='tokens', keyphrase_field='keyphrases'):
58+
loader = KeyphraseDataLoader(data_source=src_filename,
59+
vocab2id=self.vocab2id,
60+
batch_size=batch_size,
61+
max_oov_count=self.config.max_oov_count,
62+
max_src_len=self.max_src_len,
63+
max_target_len=self.max_target_len,
64+
mode='inference',
65+
pre_fetch=True,
66+
token_field=token_field,
67+
keyphrase_field=keyphrase_field)
68+
if os.path.exists(dest_filename):
69+
print('destination filename {} existed'.format(dest_filename))
70+
if remove_existed:
71+
os.remove(dest_filename)
72+
if model is not None:
73+
model.eval()
74+
self.beam_searcher = TransformerBeamSearch(model=model,
75+
beam_size=self.beam_size,
76+
max_target_len=self.max_target_len,
77+
id2vocab=self.id2vocab,
78+
bos_idx=self.vocab2id[BOS_WORD],
79+
args=self.config)
80+
81+
for batch in loader:
82+
with torch.no_grad():
83+
batch_result = self.beam_searcher.beam_search(batch, delimiter=None)
84+
final_result = []
85+
assert len(batch_result) == len(batch[RAW_BATCH])
86+
for item_input, item_output in zip(batch[RAW_BATCH], batch_result):
87+
item_input['pred_keyphrases'] = item_output
88+
final_result.append(item_input)
89+
append_jsonlines(dest_filename, final_result)

0 commit comments

Comments
 (0)