diff --git a/pyghthouse/connection/wsconnector.py b/pyghthouse/connection/wsconnector.py index d451f0a..e9ca550 100644 --- a/pyghthouse/connection/wsconnector.py +++ b/pyghthouse/connection/wsconnector.py @@ -1,9 +1,15 @@ +from enum import Enum from threading import Thread, Lock from websocket import WebSocketApp, setdefaulttimeout, ABNF from msgpack import packb, unpackb from ssl import CERT_NONE +class ConnectionState(Enum): + NONE=0 + CONNECTING=1 + CONNECTED=2 + FAILED=3 class WSConnector: class REID: @@ -26,7 +32,7 @@ def __init__(self, username: str, token: str, address: str, on_msg=None, ignore_ self.ws = None self.lock = Lock() self.reid = self.REID() - self.running = False + self.__connection_state = ConnectionState.NONE self.ignore_ssl_cert = ignore_ssl_cert setdefaulttimeout(60) @@ -39,13 +45,14 @@ def start(self): self.ws = WebSocketApp(self.address, on_message=None if self.on_msg is None else self._handle_msg, on_open=self._ready, on_error=self._fail) - self.lock.acquire() - kwargs = {"sslopt": {"cert_reqs": CERT_NONE}} if self.ignore_ssl_cert else None - Thread(target=self.ws.run_forever, kwargs=kwargs).start() self.lock.acquire() # wait for connection to be established - self.lock.release() + self.__connection_state = ConnectionState.CONNECTING + kwargs = {"sslopt": {"cert_reqs": CERT_NONE}} if self.ignore_ssl_cert else None + Thread(target=self.ws.run_forever, name="Websocket Thread", kwargs=kwargs).start() + def _fail(self, ws, err): + self.__connection_state = ConnectionState.FAILED self.lock.release() raise err @@ -53,15 +60,19 @@ def stop(self): if self.ws is not None: with self.lock: print("Closing the connection.") - self.running = False + self.__connection_state = ConnectionState.NONE self.ws.close() self.ws = None def _ready(self, ws): print(f"Connected to {self.address}.") - self.running = True + self.__connection_state = ConnectionState.CONNECTED self.lock.release() + @property + def connection_state(self): + return self.__connection_state + def _handle_msg(self, ws, msg): if isinstance(msg, bytes): msg = unpackb(msg) diff --git a/pyghthouse/ph.py b/pyghthouse/ph.py index 62f92ea..06d755c 100644 --- a/pyghthouse/ph.py +++ b/pyghthouse/ph.py @@ -6,7 +6,7 @@ import numpy as np from pyghthouse.data.canvas import PyghthouseCanvas -from pyghthouse.connection.wsconnector import WSConnector +from pyghthouse.connection.wsconnector import WSConnector,ConnectionState class VerbosityLevel(Enum): @@ -166,7 +166,7 @@ def print_warning(msg): class PHThread(Thread): def __init__(self, parent): - super().__init__() + super().__init__(name = "PHThread") self.parent = parent self._stop_event = Event() @@ -208,10 +208,19 @@ def connect(self): self.connector.start() def start(self): - if not self.connector.running: + if not self.connector.connection_state==ConnectionState.CONNECTED: self.connect() self.stop() self.msg_handler.reset() + while True: + state=self.connector.connection_state + if state==ConnectionState.CONNECTED: + # Connected + break + elif state==ConnectionState.FAILED: + # Connection failed + raise ConnectionError("Failed to connect") + sleep(100) self.ph_thread = self.PHThread(self) self.ph_thread.start()