77
88import pkg_resources
99from google .api_core import retry
10+ from google .cloud .pubsub_v1 import SubscriberClient
1011
1112from octue .cloud import EXCEPTIONS_MAPPING
1213from octue .compatibility import warn_if_incompatible
3031class OrderedMessageHandler :
3132 """A handler for Google Pub/Sub messages that ensures messages are handled in the order they were sent.
3233
33- :param google.pubsub_v1.services.subscriber.client.SubscriberClient subscriber: a Google Pub/Sub subscriber
3434 :param octue.cloud.pub_sub.subscription.Subscription subscription: the subscription messages are pulled from
35+ :param octue.cloud.pub_sub.service.Service receiving_service: the service that's receiving the messages
3536 :param callable|None handle_monitor_message: a function to handle monitor messages (e.g. send them to an endpoint for plotting or displaying) - this function should take a single JSON-compatible python primitive
3637 :param str|None record_messages_to: if given a path to a JSON file, received messages are saved to it
3738 :param str service_name: an arbitrary name to refer to the service subscribed to by (used for labelling its remote log messages)
@@ -41,20 +42,21 @@ class OrderedMessageHandler:
4142
4243 def __init__ (
4344 self ,
44- subscriber ,
4545 subscription ,
46+ receiving_service ,
4647 handle_monitor_message = None ,
4748 record_messages_to = None ,
4849 service_name = "REMOTE" ,
4950 message_handlers = None ,
5051 ):
51- self .subscriber = subscriber
5252 self .subscription = subscription
53+ self .receiving_service = receiving_service
5354 self .handle_monitor_message = handle_monitor_message
5455 self .record_messages_to = record_messages_to
5556 self .service_name = service_name
5657
5758 self .received_delivery_acknowledgement = None
59+ self ._subscriber = SubscriberClient ()
5860 self ._child_sdk_version = None
5961 self ._heartbeat_checker = None
6062 self ._last_heartbeat = None
@@ -108,55 +110,56 @@ def handle_messages(self, timeout=60, delivery_acknowledgement_timeout=120, maxi
108110 kwargs = {"maximum_heartbeat_interval" : maximum_heartbeat_interval },
109111 )
110112
111- self ._heartbeat_checker .daemon = True
112- self ._heartbeat_checker .start ()
113+ try :
114+ self ._heartbeat_checker .daemon = True
115+ self ._heartbeat_checker .start ()
113116
114- while self ._alive :
117+ while self ._alive :
115118
116- if timeout is not None :
117- run_time = time .perf_counter () - self ._start_time
119+ if timeout is not None :
120+ run_time = time .perf_counter () - self ._start_time
118121
119- if run_time > timeout :
120- raise TimeoutError (
121- f"No final answer received from topic { self .subscription .topic .path !r} after { timeout } seconds." ,
122- )
122+ if run_time > timeout :
123+ raise TimeoutError (
124+ f"No final answer received from topic { self .subscription .topic .path !r} after { timeout } seconds." ,
125+ )
123126
124- pull_timeout = timeout - run_time
127+ pull_timeout = timeout - run_time
125128
126- message = self ._pull_message (
127- timeout = pull_timeout ,
128- delivery_acknowledgement_timeout = delivery_acknowledgement_timeout ,
129- )
129+ message = self ._pull_message (
130+ timeout = pull_timeout ,
131+ delivery_acknowledgement_timeout = delivery_acknowledgement_timeout ,
132+ )
130133
131- self ._waiting_messages [int (message ["message_number" ])] = message
134+ self ._waiting_messages [int (message ["message_number" ])] = message
132135
133- try :
134- while self ._waiting_messages :
135- message = self ._waiting_messages .pop (self ._previous_message_number + 1 )
136+ try :
137+ while self ._waiting_messages :
138+ message = self ._waiting_messages .pop (self ._previous_message_number + 1 )
136139
137- if self .record_messages_to :
138- recorded_messages .append (message )
140+ if self .record_messages_to :
141+ recorded_messages .append (message )
139142
140- result = self ._handle_message (message )
143+ result = self ._handle_message (message )
141144
142- if result is not None :
143- self ._heartbeat_checker .cancel ()
144- return result
145+ if result is not None :
146+ return result
145147
146- except KeyError :
147- pass
148+ except KeyError :
149+ pass
148150
149- finally :
150- self ._heartbeat_checker .cancel ()
151+ finally :
152+ self ._heartbeat_checker .cancel ()
153+ self ._subscriber .close ()
151154
152- if self .record_messages_to :
153- directory_name = os .path .dirname (self .record_messages_to )
155+ if self .record_messages_to :
156+ directory_name = os .path .dirname (self .record_messages_to )
154157
155- if not os .path .exists (directory_name ):
156- os .makedirs (directory_name )
158+ if not os .path .exists (directory_name ):
159+ os .makedirs (directory_name )
157160
158- with open (self .record_messages_to , "w" ) as f :
159- json .dump (recorded_messages , f )
161+ with open (self .record_messages_to , "w" ) as f :
162+ json .dump (recorded_messages , f )
160163
161164 raise TimeoutError (
162165 f"No heartbeat has been received within the maximum allowed interval of { maximum_heartbeat_interval } s."
@@ -193,7 +196,7 @@ def _pull_message(self, timeout, delivery_acknowledgement_timeout):
193196 while True :
194197 logger .debug ("Pulling messages from Google Pub/Sub: attempt %d." , attempt )
195198
196- pull_response = self .subscriber .pull (
199+ pull_response = self ._subscriber .pull (
197200 request = {"subscription" : self .subscription .path , "max_messages" : 1 },
198201 retry = retry .Retry (),
199202 )
@@ -220,11 +223,11 @@ def _pull_message(self, timeout, delivery_acknowledgement_timeout):
220223 f"after { delivery_acknowledgement_timeout } seconds."
221224 )
222225
223- self .subscriber .acknowledge (request = {"subscription" : self .subscription .path , "ack_ids" : [answer .ack_id ]})
226+ self ._subscriber .acknowledge (request = {"subscription" : self .subscription .path , "ack_ids" : [answer .ack_id ]})
224227
225228 logger .debug (
226229 "%r received a message related to question %r." ,
227- self .subscription . topic . service ,
230+ self .receiving_service ,
228231 self .subscription .topic .path .split ("." )[- 1 ],
229232 )
230233
@@ -260,7 +263,7 @@ def _handle_message(self, message):
260263 if isinstance (error , KeyError ):
261264 logger .warning (
262265 "%r received a message of unknown type %r." ,
263- self .subscription . topic . service ,
266+ self .receiving_service ,
264267 message .get ("type" , "unknown" ),
265268 )
266269 return
@@ -275,7 +278,7 @@ def _handle_delivery_acknowledgement(self, message):
275278 :return None:
276279 """
277280 self .received_delivery_acknowledgement = True
278- logger .info ("%r's question was delivered at %s." , self .subscription . topic . service , message ["delivery_time" ])
281+ logger .info ("%r's question was delivered at %s." , self .receiving_service , message ["delivery_time" ])
279282
280283 def _handle_heartbeat (self , message ):
281284 """Record the time the heartbeat was received.
@@ -292,7 +295,7 @@ def _handle_monitor_message(self, message):
292295 :param dict message:
293296 :return None:
294297 """
295- logger .debug ("%r received a monitor message." , self .subscription . topic . service )
298+ logger .debug ("%r received a monitor message." , self .receiving_service )
296299
297300 if self .handle_monitor_message is not None :
298301 self .handle_monitor_message (json .loads (message ["data" ]))
@@ -359,7 +362,7 @@ def _handle_result(self, message):
359362 """
360363 logger .info (
361364 "%r received an answer to question %r." ,
362- self .subscription . topic . service ,
365+ self .receiving_service ,
363366 self .subscription .topic .path .split ("." )[- 1 ],
364367 )
365368
0 commit comments