Skip to content

Commit b9b4b1f

Browse files
author
Alexandru Pisarenco
committed
Refactor connection infrastructure with async io
Use selectors to manage sockets in parallel, and enable multiple connections Simplifies code by a lot
1 parent 1df96ae commit b9b4b1f

File tree

9 files changed

+333
-369
lines changed

9 files changed

+333
-369
lines changed

client.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

connection.py

Lines changed: 55 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,62 @@
11
import logging
22

3-
TYPE_SERVER=0
4-
TYPE_CLIENT=1
5-
63
class Connection:
7-
def __init__(self, sock, target, name):
4+
def __init__(self, sock, address, name, events, context):
85
self.sock = sock
9-
self.target = target
6+
self.address = address
107
self.name = name
11-
12-
def send(self, message):
13-
logging.debug("sending message to {}:\n{}".format(self.name, message))
14-
total = len(message)
15-
total_sent = 0
16-
remaining = message
17-
while total_sent < total:
18-
sent = self.sock.send(remaining)
19-
total_sent += sent
20-
remaining = remaining[sent:]
21-
22-
def __receive_raw(self, length):
23-
total_received = 0
24-
chunks = []
25-
while total_received < length:
26-
chunk = self.sock.recv(min([length - total_received]), 4096)
27-
if chunk==b'':
28-
raise RuntimeError("socket connection broken")
29-
chunks.append(chunk)
30-
total_received += len(chunk)
31-
return b''.join(chunks)
32-
33-
34-
def receive_packet(self):
35-
pack_type = self.__receive_raw(1)
36-
if pack_type == b'N':
37-
# Null message? This message has no length. Just a single byte. Weird.
38-
return pack_type, pack_type
39-
if pack_type == b'\x00':
40-
# Initialization packet. No type. This, and the next 3 bytes are the length
41-
pack_length = self.__receive_raw(3)
42-
pack_length = b''.join([pack_type, pack_length])
43-
pack_header = pack_length
44-
else:
45-
pack_length = self.__receive_raw(4)
46-
pack_header = b''.join([pack_type, pack_length])
47-
pack_length = int.from_bytes(pack_length, 'big')
48-
pack_body = self.__receive_raw(pack_length - 4)
49-
pack = b''.join([pack_header, pack_body])
50-
return pack, pack_type
51-
52-
53-
def receive(self):
54-
packets = []
55-
logging.debug("receive message from {}:".format(self.name))
8+
self.events = events
9+
self.context = context
10+
self.is_reading = False
11+
self.is_writing = False
12+
self.interceptor = None
13+
self.redirect_conn = None
14+
self.out_bytes = b''
15+
self.in_bytes = b''
16+
17+
def parse_length(self, length_bytes):
18+
return int.from_bytes(length_bytes, 'big')
19+
20+
def encode_length(self, length):
21+
return length.to_bytes(4, byteorder='big')
22+
23+
def received(self, in_bytes):
24+
self.in_bytes += in_bytes
25+
# Read packet from byte array while there are enough bytes to make up a packet.
26+
# Otherwise wait for more bytes to be received (break and exit)
5627
while True:
57-
logging.debug("receive packet from {}:".format(self.name))
58-
packet, pack_type = self.receive_packet()
59-
packets.append(packet)
60-
logging.debug("received packet from {}:\n{}".format(self.name, packet))
61-
if not pack_type in (b'B', b'D', b'P', b'E'):
28+
ptype = self.in_bytes[0:1]
29+
if ptype == b'\x00':
30+
if len(self.in_bytes) < 4:
31+
break
32+
header_length = 4
33+
body_length = self.parse_length(self.in_bytes[0:4]) - 4
34+
elif ptype == b'N':
35+
header_length = 1
36+
body_length = 0
37+
else:
38+
if len(self.in_bytes) < 5:
39+
break
40+
header_length = 5
41+
body_length = self.parse_length(self.in_bytes[1:5]) - 4
42+
43+
length = header_length + body_length
44+
if len(self.in_bytes) < length:
6245
break
63-
message = b''.join(packets)
64-
logging.debug("received message from {}:\n{}".format(self.name, message))
65-
return message
46+
header = self.in_bytes[0:header_length]
47+
body = self.in_bytes[header_length:length]
48+
self.process_inbound_packet(header, body)
49+
self.in_bytes = self.in_bytes[length:]
50+
51+
def process_inbound_packet(self, header, body):
52+
if header != b'N':
53+
packet_type = header[0:-4]
54+
logging.info("intercepting packet of type '%s' from %s", packet_type, self.name)
55+
body = self.interceptor.intercept(packet_type, body)
56+
header = packet_type + self.encode_length(len(body) + 4)
57+
message = header + body
58+
logging.debug("Received message. Relaying. Speaker: %s, message:\n%s", self.name, message)
59+
self.redirect_conn.out_bytes += message
60+
61+
def sent(self, num_bytes):
62+
self.out_bytes = self.out_bytes[num_bytes:]

interceptors.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import logging
2+
3+
class Interceptor:
4+
def __init__(self, interceptor_config, plugins, context):
5+
self.interceptor_config = interceptor_config
6+
self.plugins = plugins
7+
self.context = context
8+
9+
def intercept(self, packet_type, data):
10+
return data
11+
12+
def get_codec(self):
13+
if self.context is not None and 'connect_params' in self.context:
14+
if self.context['connect_params'] is not None and 'client_encoding' in self.context['connect_params']:
15+
return self.context['connect_params']['client_encoding']
16+
return 'utf-8'
17+
18+
19+
class CommandInterceptor(Interceptor):
20+
def intercept(self, packet_type, data):
21+
if self.interceptor_config.queries is not None:
22+
ic_queries = self.interceptor_config.queries
23+
if packet_type == b'Q':
24+
# Query, ends with b'\x00'
25+
data = self.__intercept_query(data, ic_queries)
26+
elif packet_type == b'P':
27+
# Statement that needs parsing.
28+
# First byte of the body is some Statement flag. Ignore, don't lose
29+
# Next is the query, same as above, ends with an b'\x00'
30+
# Last 2 bytes are the number of parameters. Ignore, don't lose
31+
statement = data[0:1]
32+
query = self.__intercept_query(data[1:-2], ic_queries)
33+
params = data[-2:]
34+
data = statement + query + params
35+
elif packet_type == b'':
36+
self.__intercept_context_data(data)
37+
return data
38+
39+
40+
def __intercept_context_data(self, data):
41+
# first 4 bytes and last zero byte are not interesting
42+
relevant_data = data[4:-1]
43+
# Each entry is terminated by b'\x00'
44+
entries = relevant_data.split(b'\x00')[:-1]
45+
entries = dict(zip(entries[0::2], entries[1::2]))
46+
self.context['connect_params'] = {}
47+
# Try to set codec, then transcode the dict
48+
if b'client_encoding' in entries:
49+
self.context['connect_params']['client_encoding'] = entries[b'client_encoding'].decode('ascii')
50+
codec = self.get_codec()
51+
for k, v in entries.items():
52+
self.context['connect_params'][k.decode(codec)] = v.decode(codec)
53+
54+
55+
def __intercept_query(self, query, interceptors):
56+
logging.getLogger('intercept').debug("intercepting query\n%s", query)
57+
# Remove zero byte at the end
58+
query = query[:-1].decode('utf-8')
59+
for interceptor in interceptors:
60+
if interceptor.plugin in self.plugins:
61+
plugin = self.plugins[interceptor.plugin]
62+
if hasattr(plugin, interceptor.function):
63+
func = getattr(plugin, interceptor.function)
64+
query = func(query, self.context)
65+
logging.getLogger('intercept').debug(
66+
"modifying query using interceptor %s.%s\n%s",
67+
interceptor.plugin,
68+
interceptor.function,
69+
query)
70+
else:
71+
raise Exception("Can't find function {} in plugin {}".format(
72+
interceptor.function,
73+
interceptor.plugin
74+
))
75+
else:
76+
raise Exception("Plugin {} not loaded".format(interceptor.plugin))
77+
# Append the zero byte at the end
78+
return query.encode('utf-8') + b'\x00'
79+
80+
81+
class ResponseInterceptor(Interceptor):
82+
pass

pg_connection.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

plugins/tableau_hll/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def replace(match):
2727

2828
# need to know which columns are hll
2929
if not hll_table.lower() in column_cache:
30-
db_conn_info = context.instance_config.redirect
30+
db_conn_info = context['instance_config'].redirect
3131
conn = None
3232
try:
3333
conn = psycopg2.connect(
@@ -36,8 +36,8 @@ def replace(match):
3636
db_conn_info.host,
3737
db_conn_info.port,
3838
# Get auth information from the proxied request
39-
context.connect_params['database'],
40-
context.connect_params['user']
39+
context['connect_params']['database'],
40+
context['connect_params']['user']
4141
)
4242
)
4343

@@ -83,7 +83,6 @@ def replace(match):
8383
return match.group(0)
8484

8585

86-
query = query.decode('utf-8')
8786
# Matches this string. The 2 groups are `schema.table` and `"alias"`
8887
# FROM schema.table) "alias"
8988
table_result = table_pattern.search(query)
@@ -94,4 +93,4 @@ def replace(match):
9493
# Replaces count(distinct ...) with hll_cardinality(hll_union_agg(...)) :: BIGINT
9594
# where and how it is appropriate
9695
# the inner function `replace` uses the variables `original_table` and `table_alias` from this scope (smelly code)
97-
return field_pattern.sub(replace, query).encode('utf-8')
96+
return field_pattern.sub(replace, query)

plugins/tableau_hll/test.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
import plugins.tableau_hll as hll
44
import yaml
55

6-
class TestContext():
7-
def __init__(self):
8-
with open(os.path.dirname(__file__) + '/config.yml', 'r') as fp:
9-
self.config = yaml.load(fp)
10-
InstanceConfig = collections.namedtuple('InstanceConfig', 'redirect')
11-
Redirect = collections.namedtuple('Redirect', 'name host port')
12-
self.instance_config = InstanceConfig(redirect=Redirect(**self.config['redirect']))
13-
self.connect_params = self.config['connect_params']
6+
def test_context():
7+
with open(os.path.dirname(__file__) + '/config.yml', 'r') as fp:
8+
config = yaml.load(fp)
9+
InstanceConfig = collections.namedtuple('InstanceConfig', 'redirect')
10+
Redirect = collections.namedtuple('Redirect', 'name host port')
11+
return {
12+
'instance_config': InstanceConfig(redirect=Redirect(**config['redirect']))
13+
'connect_params': config['connect_params']
14+
}
1415

1516

1617
def run():
17-
query = b'SELECT COUNT(DISTINCT "crm_data_source"."Set of Customers") AS "ctd:Set of Customers:ok"\nFROM "crm_dim"."crm_data_source" "crm_data_source"\nHAVING (COUNT(1) > 0);'
18-
out_query = b'SELECT hll_cardinality(hll_union_agg("crm_data_source"."Set of Customers")) :: BIGINT AS "ctd:Set of Customers:ok"\nFROM "crm_dim"."crm_data_source" "crm_data_source"\nHAVING (COUNT(1) > 0);'
19-
context = TestContext()
18+
query = 'SELECT COUNT(DISTINCT "crm_data_source"."Set of Customers") AS "ctd:Set of Customers:ok"\nFROM "crm_dim"."crm_data_source" "crm_data_source"\nHAVING (COUNT(1) > 0);'
19+
out_query = 'SELECT hll_cardinality(hll_union_agg("crm_data_source"."Set of Customers")) :: BIGINT AS "ctd:Set of Customers:ok"\nFROM "crm_dim"."crm_data_source" "crm_data_source"\nHAVING (COUNT(1) > 0);'
20+
context = test_context()
2021
assert hll.rewrite_query(query, context) == out_query

0 commit comments

Comments
 (0)