Skip to content

Commit 3387268

Browse files
add batch reorder for lstm onnx compatibility
1 parent e3a253e commit 3387268

File tree

1 file changed

+49
-31
lines changed

1 file changed

+49
-31
lines changed

deep_keyphrase/dataloader.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: UTF-8 -*-
22
import random
3+
import time
34
import traceback
45
import sys
56
import threading
@@ -18,6 +19,10 @@
1819
TARGET = 'target'
1920
RAW_BATCH = 'raw'
2021

22+
TRAIN_MODE = 'train'
23+
EVAL_MODE = 'eval'
24+
INFERENCE_MODE = 'inference'
25+
2126

2227
class ExceptionWrapper(object):
2328
"""
@@ -97,32 +102,35 @@ def __init__(self, loader):
97102
self.data_source = loader.data_source
98103
self.batch_size = loader.batch_size
99104
self.num_workers = 8
100-
if self.loader.mode == 'train':
105+
if self.loader.mode == TRAIN_MODE:
101106
self.chunk_size = self.num_workers * 50
102107
else:
103108
self.chunk_size = self.batch_size
104109
self._data = self.load_data(self.chunk_size)
105110
self._batch_count_in_output_queue = 0
106111
self._redundant_batch = []
107-
self.input_queue = multiprocessing.Queue(1000 * self.num_workers)
108-
self.output_queue = multiprocessing.Queue(1000 * self.num_workers)
109-
self.done_event = threading.Event()
110-
self.worker_shutdown = False
111112
self.workers = []
112-
for _ in range(self.num_workers):
113-
worker = multiprocessing.Process(
114-
target=self._train_data_worker_loop
115-
)
116-
self.workers.append(worker)
117-
for worker in self.workers:
118-
worker.daemon = True
119-
worker.start()
120-
121-
if self.loader.mode == 'train' or self.loader.pre_fetch:
113+
self.worker_shutdown = False
114+
115+
if self.loader.mode in {TRAIN_MODE, EVAL_MODE}:
116+
self.input_queue = multiprocessing.Queue(1000 * self.num_workers)
117+
self.output_queue = multiprocessing.Queue(1000 * self.num_workers)
118+
self.done_event = threading.Event()
119+
120+
for _ in range(self.num_workers):
121+
worker = multiprocessing.Process(
122+
target=self._train_data_worker_loop
123+
)
124+
self.workers.append(worker)
125+
for worker in self.workers:
126+
worker.daemon = True
127+
worker.start()
128+
129+
if self.loader.mode in {TRAIN_MODE, EVAL_MODE} or self.loader.pre_fetch:
122130
self.__prefetch()
123131

124132
def __iter__(self):
125-
if self.loader.mode == 'train':
133+
if self.loader.mode == TRAIN_MODE:
126134
yield from self.iter_train()
127135
else:
128136
yield from self.iter_inference()
@@ -138,7 +146,10 @@ def load_data(self, chunk_size):
138146

139147
def _train_data_worker_loop(self):
140148
while True:
149+
if self.done_event.is_set():
150+
return
141151
raw_batch = self.input_queue.get()
152+
# exit signal
142153
if raw_batch is None:
143154
break
144155
try:
@@ -194,31 +205,33 @@ def get_batches(self, item_chunk, batch):
194205
batch = flatten_items
195206
else:
196207
batch.extend(flatten_items)
197-
208+
batches = self.reorder_batch_list(batches)
198209
return batches, batch
199210

211+
def reorder_batch(self, batch):
212+
seq_idx_and_len = [(idx, len(item[TOKENS])) for idx, item in enumerate(batch)]
213+
seq_idx_and_len = sorted(seq_idx_and_len, key=lambda i: i[1], reverse=True)
214+
batch = [batch[idx] for idx, _ in seq_idx_and_len]
215+
return batch
216+
217+
def reorder_batch_list(self, batches):
218+
new_batches = []
219+
for batch in batches:
220+
new_batches.append(self.reorder_batch(batch))
221+
return new_batches
222+
200223
def flatten_raw_item(self, item):
201224
flatten_items = []
202225
for phrase in item[self.loader.keyphrases_field]:
203226
flatten_items.append({'tokens': item['tokens'], 'phrase': phrase})
204227
return flatten_items
205228

206-
def flatten_item(self, item):
207-
tokens = item[TOKENS]
208-
token_with_oov = item[TOKENS_OOV]
209-
oov_count = item[OOV_COUNT]
210-
flatten_items = []
211-
for phrase in item[TARGET_LIST]:
212-
one2one_item = {TOKENS: tokens,
213-
TOKENS_OOV: token_with_oov,
214-
OOV_COUNT: oov_count,
215-
TARGET: phrase}
216-
flatten_items.append(one2one_item)
217-
return flatten_items
218-
219229
def iter_inference(self):
220230
for item_chunk in self._data:
231+
# item_chunk is same as a batch
221232
item_chunk = [self.loader.collate_fn(item, is_inference=True) for item in item_chunk]
233+
if len(item_chunk) > 1:
234+
item_chunk = self.reorder_batch(item_chunk)
222235
yield self.padding_batch_inference(item_chunk)
223236

224237
def padding_batch_train(self, batch):
@@ -276,11 +289,16 @@ def __padding(self, x_raw, max_len):
276289
return x, x_len_list
277290

278291
def _shutdown_workers(self):
279-
if not self.worker_shutdown:
292+
# print('shutdown workers')
293+
if not self.worker_shutdown and self.loader.mode in {'train', 'eval'}:
280294
self.worker_shutdown = True
281295
self.done_event.set()
282296
for _ in self.workers:
283297
self.input_queue.put(None)
298+
time.sleep(1)
299+
300+
for worker in self.workers:
301+
worker.terminate()
284302

285303
def __del__(self):
286304
if self.num_workers > 0:

0 commit comments

Comments
 (0)