Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions a_sync/primitives/_debug.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ The mixin provides a framework for managing a debug daemon task, which can be us
"""

from asyncio import AbstractEventLoop, Future
from typing import Any

from a_sync.primitives._loggable import _LoggerMixin

class _LoopBoundMixin(_LoggerMixin):
def __init__(self, *, loop=None): ...
def __init__(self, *, loop: AbstractEventLoop | None = ...) -> None: ...
@property
def _loop(self) -> AbstractEventLoop: ...
def _get_loop(self) -> AbstractEventLoop: ...
Expand All @@ -24,7 +25,13 @@ class _DebugDaemonMixin(_LoopBoundMixin):
:class:`_LoggerMixin` for logging capabilities.
"""

async def _debug_daemon(self, fut: Future, fn, *args, **kwargs) -> None:
async def _debug_daemon(
self,
fut: Future[Any],
fn: Any,
*args: Any,
**kwargs: Any,
) -> None:
"""
Abstract method to define the debug daemon's behavior.

Expand Down
44 changes: 38 additions & 6 deletions a_sync/primitives/locks/semaphore.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import asyncio
import functools
from logging import Logger
from threading import Thread
from typing import DefaultDict, Literal
from types import TracebackType
from typing import Any, DefaultDict, Literal, override

from typing_extensions import Never

Expand Down Expand Up @@ -48,7 +49,13 @@ class Semaphore(asyncio.Semaphore, _DebugDaemonMixin):
name: str
"""An optional name for the counter, used in debug logs. Defaults to an empty string."""

def __init__(self, value: int, name: str = "", **kwargs) -> None:
def __init__(
self,
value: int = 1,
name: str = "",
loop: asyncio.AbstractEventLoop | None = ...,
**kwargs: Any,
) -> None:
"""
Initialize the semaphore with a given value and optional name for debugging.

Expand Down Expand Up @@ -84,6 +91,7 @@ class Semaphore(asyncio.Semaphore, _DebugDaemonMixin):
return 1
"""

@override
async def acquire(self) -> Literal[True]:
"""
Acquire the semaphore, ensuring that debug logging is enabled if there are waiters.
Expand All @@ -94,7 +102,14 @@ class Semaphore(asyncio.Semaphore, _DebugDaemonMixin):
True when the semaphore is successfully acquired.
"""

async def _debug_daemon(self) -> None:
@override
async def _debug_daemon(
self,
fut: asyncio.Future[Any],
fn: Any,
*args: Any,
**kwargs: Any,
) -> None:
"""
Daemon coroutine (runs in a background task) which will emit a debug log every minute while the semaphore has waiters.

Expand Down Expand Up @@ -133,16 +148,25 @@ class DummySemaphore(asyncio.Semaphore):
name (optional): An optional name for the dummy semaphore.
"""

@override
async def acquire(self) -> Literal[True]:
"""Acquire the dummy semaphore, which is a no-op."""

@override
def release(self) -> None:
"""No-op release method."""

async def __aenter__(self):
@override
async def __aenter__(self) -> None:
"""No-op context manager entry."""

async def __aexit__(self, *args) -> None:
@override
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
"""No-op context manager exit."""

class ThreadsafeSemaphore(Semaphore):
Expand Down Expand Up @@ -175,6 +199,7 @@ class ThreadsafeSemaphore(Semaphore):
name (optional): An optional name for the semaphore.
"""

@override
def __len__(self) -> int: ...
@functools.cached_property
def use_dummy(self) -> bool:
Expand All @@ -200,5 +225,12 @@ class ThreadsafeSemaphore(Semaphore):
return 1
"""

@override
async def __aenter__(self) -> None: ...
async def __aexit__(self, *args) -> None: ...
@override
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None: ...
25 changes: 25 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,28 @@ line-length = 100

[tool.isort]
line_length = 100

[tool.mypy]
python_version = "3.10"
files = ["tests"]
strict = true
follow_imports = "skip"
enable_error_code = [
"deprecated",
"exhaustive-match",
"explicit-override",
"ignore-without-code",
"mutable-override",
"possibly-undefined",
"redundant-expr",
"redundant-self",
"truthy-bool",
"truthy-iterable",
"unimported-reveal",
"unused-awaitable",
"unused-ignore",
]

[[tool.mypy.overrides]]
module = ["a_sync.*"]
ignore_errors = true
101 changes: 53 additions & 48 deletions tests/primitives/test_counter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import asyncio
from collections.abc import Callable
from typing import Any, TypeVar, cast

import pytest

from a_sync.primitives import CounterLock

_F = TypeVar("_F", bound=Callable[..., Any])
asyncio_cooperative = cast(Callable[[_F], _F], pytest.mark.asyncio_cooperative)

@pytest.mark.asyncio_cooperative
async def test_counter_lock():

@asyncio_cooperative
async def test_counter_lock() -> None:
counter = CounterLock()
assert counter._name == ""
assert repr(counter) == "<CounterLock value=0 waiters={}>"
Expand All @@ -20,8 +25,8 @@ async def test_counter_lock():
assert repr(counter) == "<CounterLock value=1 waiters={}>"


@pytest.mark.asyncio_cooperative
async def test_counter_lock_with_name():
@asyncio_cooperative
async def test_counter_lock_with_name() -> None:
counter = CounterLock(name="test")
assert counter._name == "test"
assert repr(counter) == "<CounterLock name=test value=0 waiters={}>"
Expand All @@ -35,24 +40,24 @@ async def test_counter_lock_with_name():
assert repr(counter) == "<CounterLock name=test value=1 waiters={}>"


@pytest.mark.asyncio_cooperative
async def test_counterlock_initialization():
@asyncio_cooperative
async def test_counterlock_initialization() -> None:
counter = CounterLock(start_value=5)
assert counter.value == 5


@pytest.mark.asyncio_cooperative
async def test_counterlock_set():
@asyncio_cooperative
async def test_counterlock_set() -> None:
counter = CounterLock(start_value=0)
counter.set(10)
assert counter.value == 10


@pytest.mark.asyncio_cooperative
async def test_counterlock_wait_for():
@asyncio_cooperative
async def test_counterlock_wait_for() -> None:
counter = CounterLock(start_value=0)

async def waiter():
async def waiter() -> str:
await counter.wait_for(5)
return "done"

Expand All @@ -63,12 +68,12 @@ async def waiter():
assert result == "done"


@pytest.mark.asyncio_cooperative
async def test_counterlock_concurrent_waiters():
@asyncio_cooperative
async def test_counterlock_concurrent_waiters() -> None:
counter = CounterLock(start_value=0)
results = []
results: list[int] = []

async def waiter(index):
async def waiter(index: int) -> None:
await counter.wait_for(5)
results.append(index)

Expand All @@ -79,32 +84,32 @@ async def waiter(index):
assert results == [0, 1, 2]


@pytest.mark.asyncio_cooperative
async def test_counterlock_increment_only():
@asyncio_cooperative
async def test_counterlock_increment_only() -> None:
counter = CounterLock(start_value=5)
with pytest.raises(ValueError):
counter.set(3)


@pytest.mark.asyncio_cooperative
async def test_counterlock_large_value():
@asyncio_cooperative
async def test_counterlock_large_value() -> None:
counter = CounterLock(start_value=0)
large_value = 10**6
counter.set(large_value)
assert counter.value == large_value


@pytest.mark.asyncio_cooperative
async def test_counterlock_zero_value():
@asyncio_cooperative
async def test_counterlock_zero_value() -> None:
counter = CounterLock(start_value=0)
assert counter.value == 0


@pytest.mark.asyncio_cooperative
async def test_counterlock_exception_handling():
@asyncio_cooperative
async def test_counterlock_exception_handling() -> None:
counter = CounterLock(start_value=0)

async def waiter():
async def waiter() -> str:
try:
await counter.wait_for(5)
raise ValueError("Intentional error")
Expand All @@ -116,12 +121,12 @@ async def waiter():
assert result == "Intentional error"


@pytest.mark.asyncio_cooperative
async def test_simultaneous_set_and_wait():
@asyncio_cooperative
async def test_simultaneous_set_and_wait() -> None:
counter = CounterLock(start_value=0)
results = []
results: list[int] = []

async def waiter(index):
async def waiter(index: int) -> None:
await counter.wait_for(5)
results.append(index)

Expand All @@ -131,35 +136,35 @@ async def waiter(index):
assert results == [0, 1, 2, 3, 4]


@pytest.mark.asyncio_cooperative
async def test_reentrant_set():
@asyncio_cooperative
async def test_reentrant_set() -> None:
counter = CounterLock(start_value=0)
counter.set(5)
counter.set(10) # Reentrant set
assert counter.value == 10


def test_counterlock_invalid_start_value():
def test_counterlock_invalid_start_value() -> None:
with pytest.raises(TypeError):
CounterLock(None)


@pytest.mark.asyncio_cooperative
async def test_immediate_set_and_wait():
@asyncio_cooperative
async def test_immediate_set_and_wait() -> None:
counter = CounterLock(start_value=5)

async def waiter():
return await counter.wait_for(5)
async def waiter() -> bool:
return bool(await counter.wait_for(5))

result = await waiter()
assert result is True


@pytest.mark.asyncio_cooperative
async def test_delayed_set():
@asyncio_cooperative
async def test_delayed_set() -> None:
counter = CounterLock(start_value=0)

async def waiter():
async def waiter() -> str:
await counter.wait_for(5)
return "done"

Expand All @@ -170,12 +175,12 @@ async def waiter():
assert result == "done"


@pytest.mark.asyncio_cooperative
async def test_multiple_sets():
@asyncio_cooperative
async def test_multiple_sets() -> None:
counter = CounterLock(start_value=0)
results = []
results: list[int] = []

async def waiter(index):
async def waiter(index: int) -> None:
await counter.wait_for(10)
results.append(index)

Expand All @@ -186,11 +191,11 @@ async def waiter(index):
assert results == [0, 1, 2]


@pytest.mark.asyncio_cooperative
async def test_custom_error_handling():
@asyncio_cooperative
async def test_custom_error_handling() -> None:
counter = CounterLock(start_value=0)

async def waiter():
async def waiter() -> str:
try:
await counter.wait_for(5)
raise RuntimeError("Custom error")
Expand All @@ -202,11 +207,11 @@ async def waiter():
assert result == "Custom error"


@pytest.mark.asyncio_cooperative
async def test_external_interruptions():
@asyncio_cooperative
async def test_external_interruptions() -> None:
counter = CounterLock(start_value=0)

async def waiter():
async def waiter() -> str:
await counter.wait_for(5)
return "done"

Expand Down
Loading