Skip to content
Open
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install -e ".[test]"
run: |
pip install -e ".[test]"
pip install git+https://github.com/agronholm/anyio.git#egg=anyio --ignore-installed
- name: Check with mypy and ruff
if: ${{ (matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest') }}
run: |
Expand Down
64 changes: 34 additions & 30 deletions src/zmq_anyio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
get_cancelled_exc_class,
sleep,
wait_readable,
ClosedResourceError,
notify_closing,
)
from anyio.abc import TaskGroup, TaskStatus
from anyioutils import FIRST_COMPLETED, Future, create_task, wait
from anyioutils import Future, create_task

import zmq
from zmq import EVENTS, POLLIN, POLLOUT
Expand Down Expand Up @@ -890,36 +892,36 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
task_status.started()
self.started.set()
self._thread = get_ident()

async def wait_or_cancel() -> None:
assert self.stopped is not None
await self.stopped.wait()
tg.cancel_scope.cancel()

def fileno() -> int:
if self.closed:
return -1
try:
return self._shadow_sock.fileno()
except zmq.ZMQError:
return -1

try:
while True:
wait_stopped_task = create_task(
self.stopped.wait(),
self._task_group,
exception_handler=ignore_exceptions,
)
tasks = [
create_task(
wait_readable(self._shadow_sock), # type: ignore[arg-type]
self._task_group,
exception_handler=ignore_exceptions,
),
wait_stopped_task,
]
done, pending = await wait(
tasks, self._task_group, return_when=FIRST_COMPLETED
)
for task in pending:
task.cancel()
if wait_stopped_task in done:
while (fd := fileno()) > 0:
async with create_task_group() as tg:
tg.start_soon(wait_or_cancel)
try:
await wait_readable(fd)
except ClosedResourceError:
break
finally:
tg.cancel_scope.cancel()
if self.stopped.is_set():
break
await self._handle_events()
except BaseException:
pass
finally:
self._exited.set()

assert self.stopped is not None
self.stopped.set()
self.stopped.set()

async def stop(self):
assert self._exited is not None
Expand All @@ -933,11 +935,13 @@ async def stop(self):
self.close()

def close(self, linger: int | None = None) -> None:
try:
if not self.closed and self._fd is not None:
fd = self._fd
if not self.closed and fd is not None:
notify_closing(fd)
try:
super().close(linger=linger)
except BaseException:
pass
except BaseException:
pass

assert self.stopped is not None
self.stopped.set()
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def context(contexts):


@pytest.fixture
def sockets(contexts):
async def sockets(contexts):
sockets = []
yield sockets
# ensure any tracked sockets get their contexts cleaned up
Expand Down
Loading