-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathserver.py
More file actions
186 lines (158 loc) · 6.65 KB
/
server.py
File metadata and controls
186 lines (158 loc) · 6.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# Modified from
# https://docs.python.org/3.8/library/socketserver.html#asynchronous-mixins
# https://stackoverflow.com/questions/46138771/python-multipleclient-server-with-queues
import socket
import threading
import socketserver
import queue
import json
import argparse
import time
from typing import Tuple
from qa import QaTorchInferenceSession, QaOnnxInferenceSession
class InferenceExecutionThread(threading.Thread):
"""
Inference execution thread.
"""
def __init__(self,
model_filepath: str,
tokenizer_filepath: str,
inference_engine_type: str = "onnx") -> None:
super(InferenceExecutionThread, self).__init__()
# Python queue library is thread-safe.
# https://docs.python.org/3.8/library/queue.html#module-Queue
# We can put tasks into queue from multiple threads safely.
self.model_filepath = model_filepath
self.tokenizer_filepath = tokenizer_filepath
self.inference_engine_type = inference_engine_type
if self.inference_engine_type == "onnx":
self.inference_session = QaOnnxInferenceSession(
model_filepath=self.model_filepath,
tokenizer_filepath=self.tokenizer_filepath)
elif self.inference_engine_type == "pytorch":
self.inference_session = QaTorchInferenceSession(
model_filepath=self.model_filepath,
tokenizer_filepath=self.tokenizer_filepath)
else:
raise RuntimeError("Unsupported inference engine type.")
def run(self) -> None:
"""
Run inference for the tasks in the queue.
"""
while True:
if not request_content_queue.empty():
print(
"Current Thread: {}, Number of Active Threads: {}".format(
threading.current_thread().name,
threading.active_count()))
handler, data_dict = request_content_queue.get()
question = data_dict["question"]
text = data_dict["text"]
start_time = time.time()
answer = self.inference_session.run(question=question,
text=text)
end_time = time.time()
latency = (end_time - start_time) * 1000
print("Server Inference Latency: {} ms".format(latency))
if answer in ["", "[CLS]"]:
answer = "Unknown"
response = bytes(answer, "utf-8")
print("Sending answer \"{}\" ...".format(answer))
handler.request.sendall(response)
request_content_queue.task_done()
print("Inference Done.")
class ThreadedTCPRequestHandler(socketserver.BaseRequestHandler):
"""
TCP request handler.
"""
def handle(self) -> None:
"""
Handle method to override.
"""
while True:
data = str(self.request.recv(1024), "utf-8")
if not data:
print("User disconnected.")
break
data_dict = json.loads(data)
print("{} wrote:".format(self.client_address[0]))
print(data)
request_content_queue.put((self, data_dict))
print("Task sent to queue.")
class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
"""
Mutlithread TCP server.
"""
pass
def main() -> None:
host_default = "0.0.0.0"
port_default = 9999
num_inference_sessions_default = 2
inference_engine_type_default = "onnx"
parser = argparse.ArgumentParser(
description="Question and answer server.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--host",
type=str,
help="Default host IP.",
default=host_default)
parser.add_argument("--port",
type=int,
help="Default port ID.",
default=port_default)
parser.add_argument("--num_inference_sessions",
type=int,
help="Number of inference sessions.",
default=num_inference_sessions_default)
parser.add_argument("--inference_engine_type",
type=str,
choices=["onnx", "pytorch"],
help="Inference engine type.",
default=inference_engine_type_default)
argv = parser.parse_args()
host = argv.host
port = argv.port
num_inference_sessions = argv.num_inference_sessions
inference_engine_type = argv.inference_engine_type
onnx_model_filepath = "./saved_models/bert-base-cased-squad2_model.onnx"
torch_model_filepath = "./saved_models/bert-base-cased-squad2_model.pt"
tokenizer_filepath = "./saved_models/bert-base-cased-squad2_tokenizer.pt"
global request_content_queue
# Do not use multiple queues.
# It will slow down Python application significantly.
# I have tested for each worker thread we have a queue.
# The requests were put evenly into each of the queues.
# But this slows down the latency significantly.
request_content_queue = queue.Queue()
# Number of inference sessions.
# Each inference session gets executed in an independent execution thread.
global execution_threads
if inference_engine_type == "onnx":
model_filepath = onnx_model_filepath
elif inference_engine_type == "pytorch":
model_filepath = torch_model_filepath
else:
raise RuntimeError("Unsupported inference engine type.")
print("Starting QA {} engine x {} ...".format(inference_engine_type,
num_inference_sessions))
execution_threads = [
InferenceExecutionThread(model_filepath=model_filepath,
tokenizer_filepath=tokenizer_filepath,
inference_engine_type=inference_engine_type)
for _ in range(num_inference_sessions)
]
for execution_thread in execution_threads:
execution_thread.start()
print("Starting QA Server ...")
# Create the server, binding to localhost on port
with ThreadedTCPServer((host, port), ThreadedTCPRequestHandler) as server:
# Activate the server; this will keep running until you
# interrupt the program with Ctrl-C
print("=" * 50)
print("QA Server")
print("=" * 50)
server.serve_forever()
for execution_thread in execution_threads:
execution_thread.join()
if __name__ == "__main__":
main()