1
1
# -*- coding: UTF-8 -*-
2
2
import random
3
+ import time
3
4
import traceback
4
5
import sys
5
6
import threading
18
19
TARGET = 'target'
19
20
RAW_BATCH = 'raw'
20
21
22
+ TRAIN_MODE = 'train'
23
+ EVAL_MODE = 'eval'
24
+ INFERENCE_MODE = 'inference'
25
+
21
26
22
27
class ExceptionWrapper (object ):
23
28
"""
@@ -97,32 +102,35 @@ def __init__(self, loader):
97
102
self .data_source = loader .data_source
98
103
self .batch_size = loader .batch_size
99
104
self .num_workers = 8
100
- if self .loader .mode == 'train' :
105
+ if self .loader .mode == TRAIN_MODE :
101
106
self .chunk_size = self .num_workers * 50
102
107
else :
103
108
self .chunk_size = self .batch_size
104
109
self ._data = self .load_data (self .chunk_size )
105
110
self ._batch_count_in_output_queue = 0
106
111
self ._redundant_batch = []
107
- self .input_queue = multiprocessing .Queue (1000 * self .num_workers )
108
- self .output_queue = multiprocessing .Queue (1000 * self .num_workers )
109
- self .done_event = threading .Event ()
110
- self .worker_shutdown = False
111
112
self .workers = []
112
- for _ in range (self .num_workers ):
113
- worker = multiprocessing .Process (
114
- target = self ._train_data_worker_loop
115
- )
116
- self .workers .append (worker )
117
- for worker in self .workers :
118
- worker .daemon = True
119
- worker .start ()
120
-
121
- if self .loader .mode == 'train' or self .loader .pre_fetch :
113
+ self .worker_shutdown = False
114
+
115
+ if self .loader .mode in {TRAIN_MODE , EVAL_MODE }:
116
+ self .input_queue = multiprocessing .Queue (1000 * self .num_workers )
117
+ self .output_queue = multiprocessing .Queue (1000 * self .num_workers )
118
+ self .done_event = threading .Event ()
119
+
120
+ for _ in range (self .num_workers ):
121
+ worker = multiprocessing .Process (
122
+ target = self ._train_data_worker_loop
123
+ )
124
+ self .workers .append (worker )
125
+ for worker in self .workers :
126
+ worker .daemon = True
127
+ worker .start ()
128
+
129
+ if self .loader .mode in {TRAIN_MODE , EVAL_MODE } or self .loader .pre_fetch :
122
130
self .__prefetch ()
123
131
124
132
def __iter__ (self ):
125
- if self .loader .mode == 'train' :
133
+ if self .loader .mode == TRAIN_MODE :
126
134
yield from self .iter_train ()
127
135
else :
128
136
yield from self .iter_inference ()
@@ -138,7 +146,10 @@ def load_data(self, chunk_size):
138
146
139
147
def _train_data_worker_loop (self ):
140
148
while True :
149
+ if self .done_event .is_set ():
150
+ return
141
151
raw_batch = self .input_queue .get ()
152
+ # exit signal
142
153
if raw_batch is None :
143
154
break
144
155
try :
@@ -194,31 +205,33 @@ def get_batches(self, item_chunk, batch):
194
205
batch = flatten_items
195
206
else :
196
207
batch .extend (flatten_items )
197
-
208
+ batches = self . reorder_batch_list ( batches )
198
209
return batches , batch
199
210
211
+ def reorder_batch (self , batch ):
212
+ seq_idx_and_len = [(idx , len (item [TOKENS ])) for idx , item in enumerate (batch )]
213
+ seq_idx_and_len = sorted (seq_idx_and_len , key = lambda i : i [1 ], reverse = True )
214
+ batch = [batch [idx ] for idx , _ in seq_idx_and_len ]
215
+ return batch
216
+
217
+ def reorder_batch_list (self , batches ):
218
+ new_batches = []
219
+ for batch in batches :
220
+ new_batches .append (self .reorder_batch (batch ))
221
+ return new_batches
222
+
200
223
def flatten_raw_item (self , item ):
201
224
flatten_items = []
202
225
for phrase in item [self .loader .keyphrases_field ]:
203
226
flatten_items .append ({'tokens' : item ['tokens' ], 'phrase' : phrase })
204
227
return flatten_items
205
228
206
- def flatten_item (self , item ):
207
- tokens = item [TOKENS ]
208
- token_with_oov = item [TOKENS_OOV ]
209
- oov_count = item [OOV_COUNT ]
210
- flatten_items = []
211
- for phrase in item [TARGET_LIST ]:
212
- one2one_item = {TOKENS : tokens ,
213
- TOKENS_OOV : token_with_oov ,
214
- OOV_COUNT : oov_count ,
215
- TARGET : phrase }
216
- flatten_items .append (one2one_item )
217
- return flatten_items
218
-
219
229
def iter_inference (self ):
220
230
for item_chunk in self ._data :
231
+ # item_chunk is same as a batch
221
232
item_chunk = [self .loader .collate_fn (item , is_inference = True ) for item in item_chunk ]
233
+ if len (item_chunk ) > 1 :
234
+ item_chunk = self .reorder_batch (item_chunk )
222
235
yield self .padding_batch_inference (item_chunk )
223
236
224
237
def padding_batch_train (self , batch ):
@@ -276,11 +289,16 @@ def __padding(self, x_raw, max_len):
276
289
return x , x_len_list
277
290
278
291
def _shutdown_workers (self ):
279
- if not self .worker_shutdown :
292
+ # print('shutdown workers')
293
+ if not self .worker_shutdown and self .loader .mode in {'train' , 'eval' }:
280
294
self .worker_shutdown = True
281
295
self .done_event .set ()
282
296
for _ in self .workers :
283
297
self .input_queue .put (None )
298
+ time .sleep (1 )
299
+
300
+ for worker in self .workers :
301
+ worker .terminate ()
284
302
285
303
def __del__ (self ):
286
304
if self .num_workers > 0 :
0 commit comments