1
1
import asyncio
2
+ import collections
2
3
import contextlib
3
4
import logging
4
5
import ssl
@@ -17,8 +18,9 @@ def __init__(self, config):
17
18
self ._config = config
18
19
self ._mdns_resolver = enapter .mdns .Resolver ()
19
20
self ._tls_context = self ._new_tls_context (config )
20
- self ._publisher = None
21
- self ._publisher_connected = asyncio .Event ()
21
+ self ._client = None
22
+ self ._client_ready = asyncio .Event ()
23
+ self ._subscribers = collections .defaultdict (int )
22
24
23
25
@staticmethod
24
26
def _new_logger (config ):
@@ -29,58 +31,89 @@ def config(self):
29
31
return self ._config
30
32
31
33
async def publish (self , * args , ** kwargs ):
32
- await self ._publisher_connected . wait ()
33
- await self . _publisher .publish (* args , ** kwargs )
34
+ client = await self ._wait_client ()
35
+ await client .publish (* args , ** kwargs )
34
36
35
37
@enapter .async_ .generator
36
- async def subscribe (self , * topics ):
38
+ async def subscribe (self , topic ):
37
39
while True :
40
+ client = await self ._wait_client ()
41
+
38
42
try :
39
- async with self . _connect () as subscriber :
40
- for topic in topics :
41
- await subscriber . subscribe ( topic )
42
- self . _logger . info ( "subscriber [%s] connected" , "," . join ( topics ))
43
- async for msg in subscriber . messages :
44
- yield msg
43
+ async with client . messages () as messages :
44
+ async with self . _subscribe ( client , topic ) :
45
+ async for msg in messages :
46
+ if msg . topic . matches ( topic ):
47
+ yield msg
48
+
45
49
except aiomqtt .MqttError as e :
46
50
self ._logger .error (e )
47
51
retry_interval = 5
48
52
await asyncio .sleep (retry_interval )
49
- finally :
50
- self ._logger .info ("subscriber disconnected" )
53
+
54
+ @contextlib .asynccontextmanager
55
+ async def _subscribe (self , client , topic ):
56
+ first_subscriber = not self ._subscribers [topic ]
57
+ self ._subscribers [topic ] += 1
58
+ try :
59
+ if first_subscriber :
60
+ await client .subscribe (topic )
61
+ yield
62
+ finally :
63
+ self ._subscribers [topic ] -= 1
64
+ assert not self ._subscribers [topic ] < 0
65
+ last_unsubscriber = not self ._subscribers [topic ]
66
+ if last_unsubscriber :
67
+ del self ._subscribers [topic ]
68
+ await client .unsubscribe (topic )
69
+
70
+ async def _wait_client (self ):
71
+ await self ._client_ready .wait ()
72
+ assert self ._client_ready .is_set ()
73
+ return self ._client
51
74
52
75
async def _run (self ):
53
76
self ._logger .info ("starting" )
77
+
54
78
self ._started .set ()
79
+
55
80
while True :
56
81
try :
57
- async with self ._connect () as publisher :
58
- self ._logger .info ("publisher connected" )
59
- self ._publisher = publisher
60
- self ._publisher_connected .set ()
61
- async for msg in publisher .messages :
62
- pass
82
+ async with self ._connect () as client :
83
+ self ._client = client
84
+ self ._client_ready .set ()
85
+ self ._logger .info ("client ready" )
86
+
87
+ # tracking disconnect
88
+ async with client .messages () as messages :
89
+ async for msg in messages :
90
+ pass
63
91
except aiomqtt .MqttError as e :
64
92
self ._logger .error (e )
65
93
retry_interval = 5
66
94
await asyncio .sleep (retry_interval )
67
95
finally :
68
- self ._publisher_connected .clear ()
69
- self ._publisher = None
70
- self ._logger .info ("publisher disconnected " )
96
+ self ._client_ready .clear ()
97
+ self ._client = None
98
+ self ._logger .info ("client not ready " )
71
99
72
100
@contextlib .asynccontextmanager
73
101
async def _connect (self ):
74
102
host = await self ._maybe_resolve_mdns (self ._config .host )
75
- async with aiomqtt .Client (
76
- hostname = host ,
77
- port = self ._config .port ,
78
- username = self ._config .user ,
79
- password = self ._config .password ,
80
- logger = self ._logger ,
81
- tls_context = self ._tls_context ,
82
- ) as client :
83
- yield client
103
+
104
+ try :
105
+ async with aiomqtt .Client (
106
+ hostname = host ,
107
+ port = self ._config .port ,
108
+ username = self ._config .user ,
109
+ password = self ._config .password ,
110
+ logger = self ._logger ,
111
+ tls_context = self ._tls_context ,
112
+ ) as client :
113
+ yield client
114
+ except asyncio .CancelledError :
115
+ # FIXME: A cancelled `aiomqtt.Client.connect` leaks resources.
116
+ raise
84
117
85
118
@staticmethod
86
119
def _new_tls_context (config ):
0 commit comments