Skip to content

Commit c05b03a

Browse files
committed
Improve client using asyncio event
1 parent 8906ecf commit c05b03a

File tree

17 files changed

+241
-301
lines changed

17 files changed

+241
-301
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ sanic | 18793.08 | 5.88 | 12739.37 | 8.
257257
zero(sync) | 28471.47 | 4.12 | 18114.84 | 6.69
258258
zero(async) | 29012.03 | 3.43 | 20956.48 | 5.80
259259

260-
Seems like blacksheep is the aster on hello world, but in more complex operations like saving to redis, zero is the winner! 🏆
260+
Seems like blacksheep is faster on hello world, but in more complex operations like saving to redis, zero is the winner! 🏆
261261

262262
# Roadmap 🗺
263263

tests/concurrency/rps_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
async def task(semaphore, items):
1212
async with semaphore:
1313
try:
14-
await async_client.call("sum_async", items)
14+
await async_client.call("sum_sync", items)
1515
# res = await async_client.call("sum_async", items)
1616
# print(res)
1717
except Exception as e:

tests/unit/test_server.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from zero import ZeroServer
1010
from zero.encoder.protocols import Encoder
11-
from zero.zeromq_patterns.protocols import ZeroMQBroker
11+
from zero.zeromq_patterns.interfaces import ZeroMQBroker
1212

1313
DEFAULT_PORT = 5559
1414
DEFAULT_HOST = "0.0.0.0"
@@ -216,6 +216,17 @@ class Message:
216216
def add(msg: Message) -> Message:
217217
return Message()
218218

219+
def test_register_rpc_with_long_name(self):
220+
server = ZeroServer()
221+
222+
with self.assertRaises(ValueError):
223+
224+
@server.register_rpc
225+
def add_this_is_a_very_long_name_for_a_function_more_than_120_characters_ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff(
226+
msg: Tuple[int, int]
227+
) -> int:
228+
return msg[0] + msg[1]
229+
219230
def test_server_run(self):
220231
server = ZeroServer()
221232

zero/protocols/zeromq/client.py

Lines changed: 22 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import asyncio
21
import logging
32
import threading
4-
from typing import Any, Dict, Optional, Type, TypeVar, Union
3+
from typing import Dict, Optional, Type, TypeVar, Union
54

65
from zero import config
76
from zero.encoder import Encoder, get_encoder
8-
from zero.error import TimeoutException
9-
from zero.utils import util
107
from zero.zeromq_patterns import (
118
AsyncZeroMQClient,
129
ZeroMQClient,
@@ -28,7 +25,7 @@ def __init__(
2825
self._default_timeout = default_timeout
2926
self._encoder = encoder or get_encoder(config.ENCODER)
3027

31-
self.client_pool = ZeroMQClientPool(
28+
self.client_pool = ZMQClientPool(
3229
self._address,
3330
self._default_timeout,
3431
self._encoder,
@@ -43,47 +40,17 @@ def call(
4340
) -> T:
4441
zmqc = self.client_pool.get()
4542

46-
_timeout = self._default_timeout if timeout is None else timeout
47-
48-
def _poll_data():
49-
# TODO poll is slow, need to find a better way
50-
if not zmqc.poll(_timeout):
51-
raise TimeoutException(
52-
f"Timeout while sending message at {self._address}"
53-
)
54-
55-
rcv_data = zmqc.recv()
56-
57-
# first 32 bytes as response id
58-
resp_id = rcv_data[:32].decode()
59-
60-
# the rest is response data
61-
resp_data_encoded = rcv_data[32:]
62-
resp_data = (
63-
self._encoder.decode(resp_data_encoded)
64-
if return_type is None
65-
else self._encoder.decode_type(resp_data_encoded, return_type)
66-
)
67-
68-
return resp_id, resp_data
69-
70-
req_id = util.unique_id()
71-
72-
# function name exactly 120 bytes
73-
func_name_bytes = rpc_func_name.ljust(120).encode()
74-
43+
# make function name exactly 80 bytes
44+
func_name_bytes = rpc_func_name.ljust(80).encode()
7545
msg_bytes = b"" if msg is None else self._encoder.encode(msg)
76-
zmqc.send(req_id.encode() + func_name_bytes + msg_bytes)
7746

78-
resp_id, resp_data = None, None
79-
# as the client is synchronous, we know that the response will be available any next poll
80-
# we try to get the response until timeout because a previous call might be timed out
81-
# and the response is still in the socket,
82-
# so we poll until we get the response for this call
83-
while resp_id != req_id:
84-
resp_id, resp_data = _poll_data()
47+
resp_data_bytes = zmqc.request(func_name_bytes + msg_bytes, timeout)
8548

86-
return resp_data # type: ignore
49+
return (
50+
self._encoder.decode(resp_data_bytes)
51+
if return_type is None
52+
else self._encoder.decode_type(resp_data_bytes, return_type)
53+
)
8754

8855
def close(self):
8956
self.client_pool.close()
@@ -99,9 +66,8 @@ def __init__(
9966
self._address = address
10067
self._default_timeout = default_timeout
10168
self._encoder = encoder or get_encoder(config.ENCODER)
102-
self._resp_map: Dict[str, Any] = {}
10369

104-
self.client_pool = AsyncZeroMQClientPool(
70+
self.client_pool = AsyncZMQClientPool(
10571
self._address,
10672
self._default_timeout,
10773
self._encoder,
@@ -116,63 +82,23 @@ async def call(
11682
) -> T:
11783
zmqc = await self.client_pool.get()
11884

119-
_timeout = self._default_timeout if timeout is None else timeout
120-
expire_at = util.current_time_us() + (_timeout * 1000)
121-
122-
async def _poll_data():
123-
# TODO async has issue with poller, after 3-4 calls, it returns empty
124-
# if not await zmqc.poll(_timeout):
125-
# raise TimeoutException(f"Timeout while sending message at {self._address}")
126-
127-
# first 32 bytes as response id
128-
resp = await zmqc.recv()
129-
resp_id = resp[:32].decode()
130-
131-
# the rest is response data
132-
resp_data_encoded = resp[32:]
133-
resp_data = (
134-
self._encoder.decode(resp_data_encoded)
135-
if return_type is None
136-
else self._encoder.decode_type(resp_data_encoded, return_type)
137-
)
138-
self._resp_map[resp_id] = resp_data
139-
140-
# TODO try to use pipe instead of sleep
141-
# await self.peer1.send(b"")
142-
143-
req_id = util.unique_id()
144-
145-
# function name exactly 120 bytes
146-
func_name_bytes = rpc_func_name.ljust(120).encode()
147-
85+
# make function name exactly 80 bytes
86+
func_name_bytes = rpc_func_name.ljust(80).encode()
14887
msg_bytes = b"" if msg is None else self._encoder.encode(msg)
149-
await zmqc.send(req_id.encode() + func_name_bytes + msg_bytes)
150-
151-
# every request poll the data, so whenever a response comes, it will be stored in __resps
152-
# dont need to poll again in the while loop
153-
await _poll_data()
15488

155-
while req_id not in self._resp_map and util.current_time_us() <= expire_at:
156-
# TODO the problem with the zpipe is that we can miss some response
157-
# when we come to this line
158-
# await self.peer2.recv()
159-
await asyncio.sleep(1e-6)
89+
resp_data_bytes = await zmqc.request(func_name_bytes + msg_bytes, timeout)
16090

161-
if util.current_time_us() > expire_at:
162-
raise TimeoutException(
163-
f"Timeout while waiting for response at {self._address}"
164-
)
165-
166-
resp_data = self._resp_map.pop(req_id)
167-
168-
return resp_data
91+
return (
92+
self._encoder.decode(resp_data_bytes)
93+
if return_type is None
94+
else self._encoder.decode_type(resp_data_bytes, return_type)
95+
)
16996

17097
def close(self):
17198
self.client_pool.close()
172-
self._resp_map = {}
17399

174100

175-
class ZeroMQClientPool:
101+
class ZMQClientPool:
176102
"""
177103
Connections are based on different threads and processes.
178104
Each time a call is made it tries to get the connection from the pool,
@@ -196,21 +122,15 @@ def get(self) -> ZeroMQClient:
196122
logging.debug("No connection found in current thread, creating new one")
197123
self._pool[thread_id] = get_client(config.ZEROMQ_PATTERN, self._timeout)
198124
self._pool[thread_id].connect(self._address)
199-
self._try_connect_ping(self._pool[thread_id])
200125
return self._pool[thread_id]
201126

202-
def _try_connect_ping(self, client: ZeroMQClient):
203-
client.send(util.unique_id().encode() + b"connect" + b"")
204-
client.recv()
205-
logging.info("Connected to server at %s", self._address)
206-
207127
def close(self):
208128
for client in self._pool.values():
209129
client.close()
210130
self._pool = {}
211131

212132

213-
class AsyncZeroMQClientPool:
133+
class AsyncZMQClientPool:
214134
"""
215135
Connections are based on different threads and processes.
216136
Each time a call is made it tries to get the connection from the pool,
@@ -235,15 +155,9 @@ async def get(self) -> AsyncZeroMQClient:
235155
self._pool[thread_id] = get_async_client(
236156
config.ZEROMQ_PATTERN, self._timeout
237157
)
238-
self._pool[thread_id].connect(self._address)
239-
await self._try_connect_ping(self._pool[thread_id])
158+
await self._pool[thread_id].connect(self._address)
240159
return self._pool[thread_id]
241160

242-
async def _try_connect_ping(self, client: AsyncZeroMQClient):
243-
await client.send(util.unique_id().encode() + b"connect" + b"")
244-
await client.recv()
245-
logging.info("Connected to server at %s", self._address)
246-
247161
def close(self):
248162
for client in self._pool.values():
249163
client.close()

zero/protocols/zeromq/server.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,14 @@ def start(self, workers: int = os.cpu_count() or 1):
6666

6767
self._start_server(workers, spawn_worker)
6868

69-
def _start_server(self, workers: int, spawn_worker: Callable):
69+
def _start_server(self, workers: int, spawn_worker: Callable[[int], None]):
7070
self._pool = Pool(workers)
7171

7272
# process termination signals
7373
util.register_signal_term(self._sig_handler)
7474

75-
# TODO: by default we start the workers with processes,
76-
# but we need support to run only router, without workers
77-
self._pool.map_async(spawn_worker, list(range(1, workers + 1)))
75+
worker_ids = list(range(1, workers + 1))
76+
self._pool.map_async(spawn_worker, worker_ids)
7877

7978
# blocking
8079
with zmq.utils.win32.allow_interrupt(self.stop):

zero/protocols/zeromq/worker.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def process_message(func_name_encoded: bytes, data: bytes) -> Optional[bytes]:
5454
except ValidationError as exc:
5555
logging.exception(exc)
5656
return self._encoder.encode({"__zerror__validation_error": str(exc)})
57-
except (
58-
Exception
59-
) as inner_exc: # pragma: no cover pylint: disable=broad-except
57+
except Exception as inner_exc: # pylint: disable=broad-except
6058
logging.exception(inner_exc)
6159
return self._encoder.encode(
6260
{"__zerror__server_exception": SERVER_PROCESSING_ERROR}

zero/rpc/server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ def register_rpc(self, func: Callable):
8585
Function should have a single argument.
8686
Argument and return should have a type hint.
8787
88-
If the function got exception, client will get None as return value.
89-
9088
Parameters
9189
----------
9290
func: Callable
@@ -135,6 +133,10 @@ def run(self, workers: int = os.cpu_count() or 1):
135133
def _verify_function_name(self, func):
136134
if not isinstance(func, Callable):
137135
raise ValueError(f"register function; not {type(func)}")
136+
if len(func.__name__) > 80:
137+
raise ValueError(
138+
"function name can be at max 80" f" characters; {func.__name__}"
139+
)
138140
if func.__name__ in self._rpc_router:
139141
raise ValueError(
140142
f"cannot have two RPC function same name: `{func.__name__}`"

zero/utils/async_to_sync.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,29 @@
22
import threading
33
from functools import wraps
44

5-
_loop = None
6-
_thrd = None
5+
_LOOP = None
6+
_THRD = None
77

88

99
def start_async_loop():
10-
global _loop, _thrd
11-
if _loop is None or _thrd is None or not _thrd.is_alive():
12-
_loop = asyncio.new_event_loop()
13-
_thrd = threading.Thread(
14-
target=_loop.run_forever, name="Async Runner", daemon=True
10+
global _LOOP, _THRD # pylint: disable=global-statement
11+
if _LOOP is None or _THRD is None or not _THRD.is_alive():
12+
_LOOP = asyncio.new_event_loop()
13+
_THRD = threading.Thread(
14+
target=_LOOP.run_forever, name="Async Runner", daemon=True
1515
)
16-
_thrd.start()
16+
_THRD.start()
1717

1818

1919
def async_to_sync(func):
2020
@wraps(func)
2121
def run(*args, **kwargs):
2222
start_async_loop() # Ensure the loop and thread are started
2323
try:
24-
future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), _loop)
24+
future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), _LOOP)
2525
return future.result()
26-
except Exception as e:
27-
print(f"Exception occurred: {e}")
26+
except Exception as exc:
27+
print(f"Exception occurred: {exc}")
2828
raise
2929

3030
return run

zero/utils/type_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def verify_function_return_type(func: Callable):
132132
if origin_type is not None and origin_type in allowed_types:
133133
return
134134

135-
for t in msgspec_types:
136-
if issubclass(return_type, t):
135+
for typ in msgspec_types:
136+
if issubclass(return_type, typ):
137137
return
138138

139139
raise TypeError(

zero/zeromq_patterns/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .factory import get_async_client, get_broker, get_client, get_worker
2-
from .protocols import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker
2+
from .interfaces import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker

0 commit comments

Comments
 (0)