Skip to content

Commit fd2d131

Browse files
committed
fix: minor updates (#90)
* fix: add bow/backyard to equipment lookups, fix async cancellation error * chore: cleanup mypy pre-commit config * chore: revert change to vscode settings
1 parent 093b5a9 commit fd2d131

File tree

5 files changed

+60
-4
lines changed

5 files changed

+60
-4
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,7 @@ repos:
5757
hooks:
5858
- id: mypy
5959
exclude: cli.py
60-
additional_dependencies: [ "pydantic>=2.0.0", "pytest>=8.0.0" ]
60+
additional_dependencies:
61+
- "pydantic>=2.0.0"
62+
- "pytest>=8.0.0"
6163
args: [ "--config-file=./pyproject.toml"]

pyomnilogic_local/api/protocol.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,12 @@ async def _wait_for_ack(self, ack_id: int) -> None:
229229
# Wait for either a message or an error
230230
data_task = asyncio.create_task(self.data_queue.get())
231231
error_task = asyncio.create_task(self.error_queue.get())
232-
done, _ = await asyncio.wait([data_task, error_task], return_when=asyncio.FIRST_COMPLETED)
232+
done, pending = await asyncio.wait([data_task, error_task], return_when=asyncio.FIRST_COMPLETED)
233+
234+
# Cancel any pending tasks to avoid "Task was destroyed but it is pending" warnings
235+
for task in pending:
236+
task.cancel()
237+
233238
if error_task in done:
234239
exc = error_task.result()
235240
if isinstance(exc, Exception):

pyomnilogic_local/omnilogic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
if TYPE_CHECKING:
1717
from pyomnilogic_local._base import OmniEquipment
18+
from pyomnilogic_local.bow import Bow
1819
from pyomnilogic_local.chlorinator import Chlorinator
1920
from pyomnilogic_local.chlorinator_equip import ChlorinatorEquipment
2021
from pyomnilogic_local.colorlogiclight import ColorLogicLight
@@ -260,6 +261,12 @@ def all_csads(self) -> EquipmentDict[CSAD]:
260261
csads.extend(bow.csads.values())
261262
return EquipmentDict(csads)
262263

264+
@property
265+
def all_bows(self) -> EquipmentDict[Bow]:
266+
"""Returns all Bow instances across all bows in the backyard."""
267+
# Bows are stored directly in backyard as EquipmentDict already
268+
return self.backyard.bow
269+
263270
# Equipment search methods
264271
def get_equipment_by_name(self, name: str) -> OmniEquipment[Any, Any] | None:
265272
"""Find equipment by name across all equipment types.
@@ -272,6 +279,7 @@ def get_equipment_by_name(self, name: str) -> OmniEquipment[Any, Any] | None:
272279
"""
273280
# Search all equipment types
274281
all_equipment: list[OmniEquipment[Any, Any]] = []
282+
all_equipment.extend([self.backyard])
275283
all_equipment.extend(self.all_lights.values())
276284
all_equipment.extend(self.all_relays.values())
277285
all_equipment.extend(self.all_pumps.values())
@@ -283,6 +291,7 @@ def get_equipment_by_name(self, name: str) -> OmniEquipment[Any, Any] | None:
283291
all_equipment.extend(self.all_chlorinator_equipment.values())
284292
all_equipment.extend(self.all_csads.values())
285293
all_equipment.extend(self.all_csad_equipment.values())
294+
all_equipment.extend(self.all_bows.values())
286295
all_equipment.extend(self.groups.values())
287296

288297
for equipment in all_equipment:
@@ -302,6 +311,7 @@ def get_equipment_by_id(self, system_id: int) -> OmniEquipment[Any, Any] | None:
302311
"""
303312
# Search all equipment types
304313
all_equipment: list[OmniEquipment[Any, Any]] = []
314+
all_equipment.extend([self.backyard])
305315
all_equipment.extend(self.all_lights.values())
306316
all_equipment.extend(self.all_relays.values())
307317
all_equipment.extend(self.all_pumps.values())
@@ -313,6 +323,7 @@ def get_equipment_by_id(self, system_id: int) -> OmniEquipment[Any, Any] | None:
313323
all_equipment.extend(self.all_chlorinator_equipment.values())
314324
all_equipment.extend(self.all_csads.values())
315325
all_equipment.extend(self.all_csad_equipment.values())
326+
all_equipment.extend(self.all_bows.values())
316327
all_equipment.extend(self.groups.values())
317328
all_equipment.extend(self.schedules.values())
318329

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ python_version = "3.13"
4747
plugins = [
4848
"pydantic.mypy"
4949
]
50-
follow_imports = "silent"
50+
# follow_imports = "silent"
5151
strict = true
5252
ignore_missing_imports = true
5353
disallow_subclassing_any = false

tests/test_protocol.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import struct
1515
import time
1616
import zlib
17-
from typing import TYPE_CHECKING
17+
from typing import TYPE_CHECKING, Any
1818
from unittest.mock import AsyncMock, MagicMock, patch
1919
from xml.etree import ElementTree as ET
2020

@@ -787,3 +787,41 @@ async def test_receive_file_fragmented_ignores_non_block_messages(caplog: pytest
787787

788788
assert any("other than a blockmessage" in r.message for r in caplog.records)
789789
assert result == "data"
790+
791+
792+
@pytest.mark.asyncio
793+
async def test_wait_for_ack_cancels_pending_tasks() -> None:
794+
"""Test that pending tasks are properly cancelled in _wait_for_ack to avoid warnings."""
795+
protocol = OmniLogicProtocol()
796+
protocol.transport = MagicMock()
797+
798+
# Track tasks created during _wait_for_ack
799+
created_tasks: list[asyncio.Task[Any]] = []
800+
original_create_task = asyncio.create_task
801+
802+
def track_create_task(coro: Any) -> asyncio.Task[Any]:
803+
task: asyncio.Task[Any] = original_create_task(coro)
804+
created_tasks.append(task)
805+
return task
806+
807+
# Queue up an ACK message
808+
ack_msg = OmniLogicMessage(42, MessageType.ACK)
809+
await protocol.data_queue.put(ack_msg)
810+
811+
# Patch create_task to track tasks
812+
with patch("asyncio.create_task", side_effect=track_create_task):
813+
await protocol._wait_for_ack(42)
814+
815+
# Give the event loop a chance to process cancellation
816+
await asyncio.sleep(0)
817+
818+
# Should have created 2 tasks (data_task and error_task)
819+
assert len(created_tasks) == 2
820+
821+
# One should be done (the data_task that got the ACK)
822+
# One should be cancelled (the error_task that was waiting)
823+
done_tasks = [t for t in created_tasks if t.done() and not t.cancelled()]
824+
cancelled_tasks = [t for t in created_tasks if t.cancelled()]
825+
826+
assert len(done_tasks) == 1, "Expected exactly one task to complete normally"
827+
assert len(cancelled_tasks) == 1, "Expected exactly one task to be cancelled"

0 commit comments

Comments
 (0)