1- import asyncio
21import logging
32import threading
4- from typing import Any , Dict , Optional , Type , TypeVar , Union
3+ from typing import Dict , Optional , Type , TypeVar , Union
54
65from zero import config
76from zero .encoder import Encoder , get_encoder
8- from zero .error import TimeoutException
9- from zero .utils import util
107from zero .zeromq_patterns import (
118 AsyncZeroMQClient ,
129 ZeroMQClient ,
@@ -28,7 +25,7 @@ def __init__(
2825 self ._default_timeout = default_timeout
2926 self ._encoder = encoder or get_encoder (config .ENCODER )
3027
31- self .client_pool = ZeroMQClientPool (
28+ self .client_pool = ZMQClientPool (
3229 self ._address ,
3330 self ._default_timeout ,
3431 self ._encoder ,
@@ -43,47 +40,17 @@ def call(
4340 ) -> T :
4441 zmqc = self .client_pool .get ()
4542
46- _timeout = self ._default_timeout if timeout is None else timeout
47-
48- def _poll_data ():
49- # TODO poll is slow, need to find a better way
50- if not zmqc .poll (_timeout ):
51- raise TimeoutException (
52- f"Timeout while sending message at { self ._address } "
53- )
54-
55- rcv_data = zmqc .recv ()
56-
57- # first 32 bytes as response id
58- resp_id = rcv_data [:32 ].decode ()
59-
60- # the rest is response data
61- resp_data_encoded = rcv_data [32 :]
62- resp_data = (
63- self ._encoder .decode (resp_data_encoded )
64- if return_type is None
65- else self ._encoder .decode_type (resp_data_encoded , return_type )
66- )
67-
68- return resp_id , resp_data
69-
70- req_id = util .unique_id ()
71-
72- # function name exactly 120 bytes
73- func_name_bytes = rpc_func_name .ljust (120 ).encode ()
74-
43+ # make function name exactly 80 bytes
44+ func_name_bytes = rpc_func_name .ljust (80 ).encode ()
7545 msg_bytes = b"" if msg is None else self ._encoder .encode (msg )
76- zmqc .send (req_id .encode () + func_name_bytes + msg_bytes )
7746
78- resp_id , resp_data = None , None
79- # as the client is synchronous, we know that the response will be available any next poll
80- # we try to get the response until timeout because a previous call might be timed out
81- # and the response is still in the socket,
82- # so we poll until we get the response for this call
83- while resp_id != req_id :
84- resp_id , resp_data = _poll_data ()
47+ resp_data_bytes = zmqc .request (func_name_bytes + msg_bytes , timeout )
8548
86- return resp_data # type: ignore
49+ return (
50+ self ._encoder .decode (resp_data_bytes )
51+ if return_type is None
52+ else self ._encoder .decode_type (resp_data_bytes , return_type )
53+ )
8754
8855 def close (self ):
8956 self .client_pool .close ()
@@ -99,9 +66,8 @@ def __init__(
9966 self ._address = address
10067 self ._default_timeout = default_timeout
10168 self ._encoder = encoder or get_encoder (config .ENCODER )
102- self ._resp_map : Dict [str , Any ] = {}
10369
104- self .client_pool = AsyncZeroMQClientPool (
70+ self .client_pool = AsyncZMQClientPool (
10571 self ._address ,
10672 self ._default_timeout ,
10773 self ._encoder ,
@@ -116,63 +82,23 @@ async def call(
11682 ) -> T :
11783 zmqc = await self .client_pool .get ()
11884
119- _timeout = self ._default_timeout if timeout is None else timeout
120- expire_at = util .current_time_us () + (_timeout * 1000 )
121-
122- async def _poll_data ():
123- # TODO async has issue with poller, after 3-4 calls, it returns empty
124- # if not await zmqc.poll(_timeout):
125- # raise TimeoutException(f"Timeout while sending message at {self._address}")
126-
127- # first 32 bytes as response id
128- resp = await zmqc .recv ()
129- resp_id = resp [:32 ].decode ()
130-
131- # the rest is response data
132- resp_data_encoded = resp [32 :]
133- resp_data = (
134- self ._encoder .decode (resp_data_encoded )
135- if return_type is None
136- else self ._encoder .decode_type (resp_data_encoded , return_type )
137- )
138- self ._resp_map [resp_id ] = resp_data
139-
140- # TODO try to use pipe instead of sleep
141- # await self.peer1.send(b"")
142-
143- req_id = util .unique_id ()
144-
145- # function name exactly 120 bytes
146- func_name_bytes = rpc_func_name .ljust (120 ).encode ()
147-
85+ # make function name exactly 80 bytes
86+ func_name_bytes = rpc_func_name .ljust (80 ).encode ()
14887 msg_bytes = b"" if msg is None else self ._encoder .encode (msg )
149- await zmqc .send (req_id .encode () + func_name_bytes + msg_bytes )
150-
151- # every request poll the data, so whenever a response comes, it will be stored in __resps
152- # dont need to poll again in the while loop
153- await _poll_data ()
15488
155- while req_id not in self ._resp_map and util .current_time_us () <= expire_at :
156- # TODO the problem with the zpipe is that we can miss some response
157- # when we come to this line
158- # await self.peer2.recv()
159- await asyncio .sleep (1e-6 )
89+ resp_data_bytes = await zmqc .request (func_name_bytes + msg_bytes , timeout )
16090
161- if util .current_time_us () > expire_at :
162- raise TimeoutException (
163- f"Timeout while waiting for response at { self ._address } "
164- )
165-
166- resp_data = self ._resp_map .pop (req_id )
167-
168- return resp_data
91+ return (
92+ self ._encoder .decode (resp_data_bytes )
93+ if return_type is None
94+ else self ._encoder .decode_type (resp_data_bytes , return_type )
95+ )
16996
17097 def close (self ):
17198 self .client_pool .close ()
172- self ._resp_map = {}
17399
174100
175- class ZeroMQClientPool :
101+ class ZMQClientPool :
176102 """
177103 Connections are based on different threads and processes.
178104 Each time a call is made it tries to get the connection from the pool,
@@ -196,21 +122,15 @@ def get(self) -> ZeroMQClient:
196122 logging .debug ("No connection found in current thread, creating new one" )
197123 self ._pool [thread_id ] = get_client (config .ZEROMQ_PATTERN , self ._timeout )
198124 self ._pool [thread_id ].connect (self ._address )
199- self ._try_connect_ping (self ._pool [thread_id ])
200125 return self ._pool [thread_id ]
201126
202- def _try_connect_ping (self , client : ZeroMQClient ):
203- client .send (util .unique_id ().encode () + b"connect" + b"" )
204- client .recv ()
205- logging .info ("Connected to server at %s" , self ._address )
206-
207127 def close (self ):
208128 for client in self ._pool .values ():
209129 client .close ()
210130 self ._pool = {}
211131
212132
213- class AsyncZeroMQClientPool :
133+ class AsyncZMQClientPool :
214134 """
215135 Connections are based on different threads and processes.
216136 Each time a call is made it tries to get the connection from the pool,
@@ -235,15 +155,9 @@ async def get(self) -> AsyncZeroMQClient:
235155 self ._pool [thread_id ] = get_async_client (
236156 config .ZEROMQ_PATTERN , self ._timeout
237157 )
238- self ._pool [thread_id ].connect (self ._address )
239- await self ._try_connect_ping (self ._pool [thread_id ])
158+ await self ._pool [thread_id ].connect (self ._address )
240159 return self ._pool [thread_id ]
241160
242- async def _try_connect_ping (self , client : AsyncZeroMQClient ):
243- await client .send (util .unique_id ().encode () + b"connect" + b"" )
244- await client .recv ()
245- logging .info ("Connected to server at %s" , self ._address )
246-
247161 def close (self ):
248162 for client in self ._pool .values ():
249163 client .close ()
0 commit comments