66import logging
77import os
88import struct
9- from typing import Awaitable , List , Optional , TypeVar
9+ from typing import Awaitable , Callable , List , Optional , TypeVar
1010
1111import reactivex .operators as op
1212import semver
@@ -78,7 +78,7 @@ class PybricksHub:
7878 has not been connected yet or the connected hub has Pybricks profile < v1.2.0.
7979 """
8080
81- def __init__ (self , device : BLEDevice ):
81+ def __init__ (self ):
8282 self .connection_state_observable = BehaviorSubject (ConnectionState .DISCONNECTED )
8383 self .status_observable = BehaviorSubject (StatusFlag (0 ))
8484 self ._stdout_subject = Subject ()
@@ -120,11 +120,6 @@ def __init__(self, device: BLEDevice):
120120 # File handle for logging
121121 self .log_file = None
122122
123- def handle_disconnect (_ : BleakClient ):
124- self ._handle_disconnect ()
125-
126- self .client = BleakClient (device , disconnected_callback = handle_disconnect )
127-
128123 @property
129124 def stdout_observable (self ) -> Observable [bytes ]:
130125 """
@@ -237,16 +232,6 @@ def _handle_disconnect(self):
237232 self .connection_state_observable .on_next (ConnectionState .DISCONNECTED )
238233
239234 async def connect (self ):
240- """Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
241-
242- Raises:
243- BleakError: if connecting failed (or old firmware without Device
244- Information Service)
245- RuntimeError: if Pybricks Protocol version is not supported
246- """
247- # TODO: Fix this
248- # logger.info(f"Connecting to {device.name}")
249-
250235 if self .connection_state_observable .value != ConnectionState .DISCONNECTED :
251236 raise RuntimeError (
252237 f"attempting to connect with invalid state: { self .connection_state_observable .value } "
@@ -259,48 +244,12 @@ async def connect(self):
259244 self .connection_state_observable .on_next , ConnectionState .DISCONNECTED
260245 )
261246
262- await self .client . connect ()
247+ await self ._client_connect ()
263248
264249 stack .push_async_callback (self .disconnect )
265250
266- logger .info ("Connected successfully!" )
267-
268- fw_version = await self .client .read_gatt_char (FW_REV_UUID )
269- self .fw_version = Version (fw_version .decode ())
270-
271- protocol_version = await self .client .read_gatt_char (SW_REV_UUID )
272- protocol_version = semver .VersionInfo .parse (protocol_version .decode ())
273-
274- if (
275- protocol_version < PYBRICKS_PROTOCOL_VERSION
276- or protocol_version >= PYBRICKS_PROTOCOL_VERSION .bump_major ()
277- ):
278- raise RuntimeError (
279- f"Unsupported Pybricks protocol version: { protocol_version } "
280- )
281-
282- pnp_id = await self .client .read_gatt_char (PNP_ID_UUID )
283- _ , _ , self .hub_kind , self .hub_variant = unpack_pnp_id (pnp_id )
284-
285- if protocol_version >= "1.2.0" :
286- caps = await self .client .read_gatt_char (PYBRICKS_HUB_CAPABILITIES_UUID )
287- (
288- self ._max_write_size ,
289- self ._capability_flags ,
290- self ._max_user_program_size ,
291- ) = unpack_hub_capabilities (caps )
292- else :
293- # HACK: prior to profile v1.2.0 isn't a proper way to get the
294- # MPY ABI version from hub so we use heuristics on the firmware version
295- self ._mpy_abi_version = (
296- 6 if self .fw_version >= Version ("3.2.0b2" ) else 5
297- )
298-
299- if protocol_version < "1.3.0" :
300- self ._legacy_stdio = True
301-
302- await self .client .start_notify (NUS_TX_UUID , self ._nus_handler )
303- await self .client .start_notify (
251+ await self .start_notify (NUS_TX_UUID , self ._nus_handler )
252+ await self .start_notify (
304253 PYBRICKS_COMMAND_EVENT_UUID , self ._pybricks_service_handler
305254 )
306255
@@ -314,7 +263,7 @@ async def disconnect(self):
314263
315264 if self .connection_state_observable .value == ConnectionState .CONNECTED :
316265 self .connection_state_observable .on_next (ConnectionState .DISCONNECTING )
317- await self .client . disconnect ()
266+ await self ._client_disconnect ()
318267 # ConnectionState.DISCONNECTED should be set by disconnect callback
319268 assert (
320269 self .connection_state_observable .value == ConnectionState .DISCONNECTED
@@ -453,7 +402,7 @@ async def download_user_program(self, program: bytes) -> None:
453402 )
454403
455404 # clear user program meta so hub doesn't try to run invalid program
456- await self .client . write_gatt_char (
405+ await self .write_gatt_char (
457406 PYBRICKS_COMMAND_EVENT_UUID ,
458407 struct .pack ("<BI" , Command .WRITE_USER_PROGRAM_META , 0 ),
459408 response = True ,
@@ -467,7 +416,7 @@ async def download_user_program(self, program: bytes) -> None:
467416 total = len (program ), unit = "B" , unit_scale = True
468417 ) as pbar :
469418 for i , c in enumerate (chunk (program , payload_size )):
470- await self .client . write_gatt_char (
419+ await self .write_gatt_char (
471420 PYBRICKS_COMMAND_EVENT_UUID ,
472421 struct .pack (
473422 f"<BI{ len (c )} s" ,
@@ -480,7 +429,7 @@ async def download_user_program(self, program: bytes) -> None:
480429 pbar .update (len (c ))
481430
482431 # set the metadata to notify that writing was successful
483- await self .client . write_gatt_char (
432+ await self .write_gatt_char (
484433 PYBRICKS_COMMAND_EVENT_UUID ,
485434 struct .pack ("<BI" , Command .WRITE_USER_PROGRAM_META , len (program )),
486435 response = True ,
@@ -492,7 +441,7 @@ async def start_user_program(self) -> None:
492441
493442 Requires hub with Pybricks Profile >= v1.2.0.
494443 """
495- await self .client . write_gatt_char (
444+ await self .write_gatt_char (
496445 PYBRICKS_COMMAND_EVENT_UUID ,
497446 struct .pack ("<B" , Command .START_USER_PROGRAM ),
498447 response = True ,
@@ -502,7 +451,7 @@ async def stop_user_program(self) -> None:
502451 """
503452 Stops the user program on the hub if it is running.
504453 """
505- await self .client . write_gatt_char (
454+ await self .write_gatt_char (
506455 PYBRICKS_COMMAND_EVENT_UUID ,
507456 struct .pack ("<B" , Command .STOP_USER_PROGRAM ),
508457 response = True ,
@@ -682,3 +631,79 @@ async def _wait_for_user_program_stop(self):
682631 # the user program running status flag
683632 # https://github.com/pybricks/support/issues/305
684633 await asyncio .sleep (0.3 )
634+
635+
636+ class PybricksHubBLE (PybricksHub ):
637+ _device : BLEDevice
638+ _client : BleakClient
639+
640+ def __init__ (self , device : BLEDevice ):
641+ super ().__init__ ()
642+
643+ self ._device = device
644+
645+ def handle_disconnect (_ : BleakClient ):
646+ self ._handle_disconnect ()
647+
648+ self ._client = BleakClient (
649+ self ._device , disconnected_callback = handle_disconnect
650+ )
651+
652+ async def _client_connect (self ) -> bool :
653+ """Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
654+
655+ Raises:
656+ BleakError: if connecting failed (or old firmware without Device
657+ Information Service)
658+ RuntimeError: if Pybricks Protocol version is not supported
659+ """
660+
661+ logger .info (f"Connecting to { self ._device .name } " )
662+ await self ._client .connect ()
663+ logger .info ("Connected successfully!" )
664+
665+ fw_version = await self .read_gatt_char (FW_REV_UUID )
666+ self .fw_version = Version (fw_version .decode ())
667+
668+ protocol_version = await self .read_gatt_char (SW_REV_UUID )
669+ protocol_version = semver .VersionInfo .parse (protocol_version .decode ())
670+
671+ if (
672+ protocol_version < PYBRICKS_PROTOCOL_VERSION
673+ or protocol_version >= PYBRICKS_PROTOCOL_VERSION .bump_major ()
674+ ):
675+ raise RuntimeError (
676+ f"Unsupported Pybricks protocol version: { protocol_version } "
677+ )
678+
679+ pnp_id = await self .read_gatt_char (PNP_ID_UUID )
680+ _ , _ , self .hub_kind , self .hub_variant = unpack_pnp_id (pnp_id )
681+
682+ if protocol_version >= "1.2.0" :
683+ caps = await self .read_gatt_char (PYBRICKS_HUB_CAPABILITIES_UUID )
684+ (
685+ self ._max_write_size ,
686+ self ._capability_flags ,
687+ self ._max_user_program_size ,
688+ ) = unpack_hub_capabilities (caps )
689+ else :
690+ # HACK: prior to profile v1.2.0 isn't a proper way to get the
691+ # MPY ABI version from hub so we use heuristics on the firmware version
692+ self ._mpy_abi_version = 6 if self .fw_version >= Version ("3.2.0b2" ) else 5
693+
694+ if protocol_version < "1.3.0" :
695+ self ._legacy_stdio = True
696+
697+ return True
698+
699+ async def _client_disconnect (self ) -> bool :
700+ return await self ._client .disconnect ()
701+
702+ async def read_gatt_char (self , uuid : str ) -> bytearray :
703+ return await self ._client .read_gatt_char (uuid )
704+
705+ async def write_gatt_char (self , uuid : str , data , response : bool ) -> None :
706+ return await self ._client .write_gatt_char (uuid , data , response )
707+
708+ async def start_notify (self , uuid : str , callback : Callable ) -> None :
709+ return await self ._client .start_notify (uuid , callback )
0 commit comments