Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion a_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
from a_sync.a_sync import ASyncGenericBase, ASyncGenericSingleton, a_sync
from a_sync.a_sync.modifiers.semaphores import apply_semaphore
from a_sync.a_sync.property import ASyncCachedPropertyDescriptor
from a_sync.a_sync.property import ASyncCachedPropertyDescriptor as cached_property
from a_sync.a_sync.property import \
ASyncCachedPropertyDescriptor as cached_property
from a_sync.a_sync.property import ASyncPropertyDescriptor
from a_sync.a_sync.property import ASyncPropertyDescriptor as property
from a_sync.asyncio import as_completed, cgather, create_task, gather, igather
Expand Down
27 changes: 14 additions & 13 deletions a_sync/_smart.pyi
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from asyncio import AbstractEventLoop, Future, Task
from typing import Any, Awaitable, Generator, Generic, TypeVar
from typing import (TYPE_CHECKING, Any, Awaitable, Generic, Optional, Tuple,
TypeVar, Union)
from weakref import WeakSet

from a_sync._typing import T
from a_sync.primitives.queue import SmartProcessingQueue
if TYPE_CHECKING:
from a_sync.primitives.queue import SmartProcessingQueue

_T = TypeVar("_T")

_Args = tuple[Any, ...]
_Kwargs = tuple[tuple[str, Any], ...]
_Key = tuple[_Args, _Kwargs]
_Args = Tuple[Any]
_Kwargs = Tuple[Tuple[str, Any]]
_Key = Tuple[_Args, _Kwargs]

def shield(arg: Awaitable[_T]) -> SmartFuture[_T] | Future[_T]:
def shield(arg: Awaitable[_T]) -> Union[SmartFuture[_T], "Future[_T]"]:
"""
Wait for a future, shielding it from cancellation.

Expand Down Expand Up @@ -83,9 +84,9 @@ class _SmartFutureMixin(Generic[_T]):
- :class:`SmartTask`
"""

_queue: SmartProcessingQueue[Any, Any, _T] | None = None
_queue: Optional["SmartProcessingQueue[Any, Any, _T]"] = None
_key: _Key
_waiters: WeakSet[SmartTask[_T]]
_waiters: "WeakSet[SmartTask[_T]]"

class SmartFuture(_SmartFutureMixin[_T], Future):
"""
Expand Down Expand Up @@ -123,9 +124,9 @@ class SmartFuture(_SmartFutureMixin[_T], Future):

def create_future(
*,
queue: SmartProcessingQueue | None = None,
key: _Key | None = None,
loop: AbstractEventLoop | None = None,
queue: Optional["SmartProcessingQueue"] = None,
key: Optional[_Key] = None,
loop: Optional[AbstractEventLoop] = None,
) -> SmartFuture[_T]:
"""
Create a :class:`~SmartFuture` instance.
Expand Down Expand Up @@ -183,7 +184,7 @@ class SmartTask(_SmartFutureMixin[_T], Task):
```
"""

def set_smart_task_factory(loop: AbstractEventLoop | None = None) -> None:
def set_smart_task_factory(loop: AbstractEventLoop = None) -> None:
"""
Set the event loop's task factory to :func:`~smart_task_factory` so all tasks will be SmartTask instances.

Expand Down
72 changes: 27 additions & 45 deletions a_sync/_smart.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ to protect tasks from cancellation.
import asyncio
import typing
import weakref
from collections.abc import Awaitable, Generator
from logging import getLogger
from types import TracebackType
from typing import Any, Optional
from typing import Awaitable, Generator, Optional, Set

cimport cython
from cpython.object cimport PyObject
Expand Down Expand Up @@ -51,10 +50,17 @@ cdef bint _DEBUG_LOGS_ENABLED = logger.isEnabledFor(DEBUG)
cdef object _logger_log = logger._log
del getLogger

# cdef typing
cdef object Any = typing.Any
cdef object Generic = typing.Generic
cdef object Tuple = typing.Tuple
cdef object Union = typing.Union
del typing

cdef object Args = tuple[Any]
cdef object Kwargs = tuple[tuple[str, Any]]
_Key = tuple[Args, Kwargs]

cdef object Args = Tuple[Any]
cdef object Kwargs = Tuple[Tuple[str, Any]]
_Key = Tuple[Args, Kwargs]
cdef object Key = _Key


Expand All @@ -68,7 +74,7 @@ cdef void log_await(object arg):


@cython.linetrace(False)
cdef Py_ssize_t count_waiters(fut: SmartFuture[Any] | SmartTask[Any]):
cdef Py_ssize_t count_waiters(fut: Union["SmartFuture", "SmartTask"]):
if _is_done(fut):
return ZERO
try:
Expand Down Expand Up @@ -172,7 +178,7 @@ cdef inline bint _is_cancelled(fut: Future):


@cython.linetrace(False)
cdef object _get_result(fut: SmartFuture[Any] | SmartTask[Any]):
cdef object _get_result(fut: Union["SmartFuture", "SmartTask"]):
"""Return the result this future represents.

If the future has been cancelled, raises CancelledError. If the
Expand Down Expand Up @@ -220,7 +226,7 @@ cdef object _get_exception(fut: Future):
raise InvalidStateError('Exception is not set.')


class SmartFuture(Future[T]):
class SmartFuture(Future, Generic[T]):
"""
A smart future that tracks waiters and integrates with a smart processing queue.

Expand All @@ -238,18 +244,18 @@ class SmartFuture(Future[T]):
- :class:`asyncio.Future`
"""
_queue: Optional["SmartProcessingQueue[Any, Any, T]"] = None
_key: Key | None = None
_key: Optional[Key] = None

_waiters: "weakref.WeakSet[SmartTask[T]]"

__traceback__: TracebackType | None = None
__traceback__: Optional[TracebackType] = None

def __init__(
self,
*,
queue: Optional["SmartProcessingQueue[Any, Any, T]"] = None,
key: Key | None = None,
loop: AbstractEventLoop | None = None,
key: Optional[Key] = None,
loop: Optional[AbstractEventLoop] = None,
) -> None:
"""
Initialize the SmartFuture with an optional queue and key.
Expand Down Expand Up @@ -377,8 +383,8 @@ cdef inline object current_task(object loop):
@cython.linetrace(False)
cpdef inline object create_future(
queue: Optional["SmartProcessingQueue"] = None,
key: Key | None = None,
loop: AbstractEventLoop | None = None,
key: Optional[Key] = None,
loop: Optional[AbstractEventLoop] = None,
):
"""
Create a :class:`~SmartFuture` instance.
Expand All @@ -404,7 +410,7 @@ cpdef inline object create_future(
return _SmartFuture(queue=queue, key=key, loop=loop or get_event_loop())


class SmartTask(Task[T]):
class SmartTask(Task, Generic[T]):
"""
A smart task that tracks waiters and integrates with a smart processing queue.

Expand All @@ -422,17 +428,17 @@ class SmartTask(Task[T]):
- :class:`asyncio.Task`
"""

_waiters: set["Task[T]"]
_waiters: Set["Task[T]"]

__traceback__: TracebackType | None = None
__traceback__: Optional[TracebackType] = None

@cython.linetrace(False)
def __init__(
self,
coro: Awaitable[T],
*,
loop: AbstractEventLoop | None = None,
name: str | None = None,
loop: Optional[AbstractEventLoop] = None,
name: Optional[str] = None,
) -> None:
"""
Initialize the SmartTask with a coroutine and optional event loop.
Expand Down Expand Up @@ -623,12 +629,7 @@ cpdef object shield(arg: Awaitable[T]):
outer = _SmartFuture(loop=loop)

# special handling to connect SmartFutures to SmartTasks if enabled
waiters = getattr(inner, "_waiters", None)
if isinstance(waiters, WeakSet):
# SmartFuture._waiters is a WeakSet
(<WeakSet>waiters).add(outer)
elif waiters is not None:
# SmartTask _waiters is a builtins.set
if (waiters := getattr(inner, "_waiters", None)) is not None:
waiters.add(outer)

_inner_done_callback, _outer_done_callback = _get_done_callbacks(inner, outer)
Expand All @@ -648,7 +649,7 @@ cdef tuple _get_done_callbacks(inner: Task, outer: Future):
return

if _is_cancelled(inner):
outer.cancel(CancelMessage(inner))
outer.cancel()
else:
exc = _get_exception(inner)
if exc is not None:
Expand All @@ -663,22 +664,6 @@ cdef tuple _get_done_callbacks(inner: Task, outer: Future):
return _inner_done_callback, _outer_done_callback


cdef class CancelMessage:
"""
This class wraps a cancelled task for an asyncio cancel message so that
we can pass it around freely as one object but only construct the string
representation when required by something downstream.
"""
cdef object task
def __cinit__(self, object task) -> None:
self.task = task
def __repr__(self) -> str:
return f"CancelMessage('{str(self)}')"
def __str__(self) -> str:
return f"[a_sync.shield] inner task is cancelled: {self.task!r}"



__all__ = [
"create_future",
"shield",
Expand All @@ -687,6 +672,3 @@ __all__ = [
"smart_task_factory",
"set_smart_task_factory",
]


del Any, Awaitable, Generator, Optional, TracebackType
66 changes: 18 additions & 48 deletions a_sync/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ def sync_function(x: int) -> str:
"""

import asyncio
from collections.abc import AsyncIterable, Awaitable, Callable, Coroutine, Iterable
from collections.abc import (AsyncIterable, Awaitable, Callable, Coroutine,
Iterable)
from concurrent.futures._base import Executor
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, TypeVar, runtime_checkable
from typing import (TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict,
TypeVar, Union, runtime_checkable)

from typing_extensions import ParamSpec

Expand All @@ -91,13 +93,13 @@ def sync_function(x: int) -> str:
P = ParamSpec("P")
"""A :class:`ParamSpec` used everywhere in the lib."""

Numeric = int | float | Decimal
Numeric = Union[int, float, Decimal]
"""Type alias for numeric values of types int, float, or Decimal."""

MaybeAwaitable = Awaitable[T] | T
MaybeAwaitable = Union[Awaitable[T], T]
"""Type alias for values that may or may not be awaitable. Useful for functions that can return either an awaitable or a direct value."""

MaybeCoro = Coroutine[Any, Any, T] | T
MaybeCoro = Union[Coroutine[Any, Any, T], T]
"Type alias for values that may or may not be coroutine."

CoroFn = Callable[P, Awaitable[T]]
Expand All @@ -106,7 +108,7 @@ def sync_function(x: int) -> str:
SyncFn = Callable[P, T]
"""Type alias for synchronous functions."""

AnyFn = CoroFn[P, T] | SyncFn[P, T]
AnyFn = Union[CoroFn[P, T], SyncFn[P, T]]
"Type alias for any function, whether synchronous or asynchronous."


Expand Down Expand Up @@ -144,7 +146,7 @@ def my_method(self, x: int) -> str:
__call__: Callable[P, T]


AnyBoundMethod = CoroBoundMethod[Any, P, T] | SyncBoundMethod[Any, P, T]
AnyBoundMethod = Union[CoroBoundMethod[Any, P, T], SyncBoundMethod[Any, P, T]]
"Type alias for any bound method, whether synchronous or asynchronous."


Expand Down Expand Up @@ -172,7 +174,7 @@ class SyncUnboundMethod(Protocol[I, P, T]):
__get__: Callable[[I, type], SyncBoundMethod[I, P, T]]


AnyUnboundMethod = AsyncUnboundMethod[I, P, T] | SyncUnboundMethod[I, P, T]
AnyUnboundMethod = Union[AsyncUnboundMethod[I, P, T], SyncUnboundMethod[I, P, T]]
"Type alias for any unbound method, whether synchronous or asynchronous."

AsyncGetterFunction = Callable[[I], Awaitable[T]]
Expand All @@ -181,13 +183,13 @@ class SyncUnboundMethod(Protocol[I, P, T]):
SyncGetterFunction = Callable[[I], T]
"Type alias for synchronous getter functions."

AnyGetterFunction = AsyncGetterFunction[I, T] | SyncGetterFunction[I, T]
AnyGetterFunction = Union[AsyncGetterFunction[I, T], SyncGetterFunction[I, T]]
"Type alias for any getter function, whether synchronous or asynchronous."

AsyncDecorator = Callable[[CoroFn[P, T]], CoroFn[P, T]]
"Type alias for decorators for coroutine functions."

AsyncDecoratorOrCoroFn = AsyncDecorator[P, T] | CoroFn[P, T]
AsyncDecoratorOrCoroFn = Union[AsyncDecorator[P, T], CoroFn[P, T]]
"Type alias for either an asynchronous decorator or a coroutine function."

DefaultMode = Literal["sync", "async", None]
Expand All @@ -196,62 +198,30 @@ class SyncUnboundMethod(Protocol[I, P, T]):
CacheType = Literal["memory", None]
"Type alias for cache types."

SemaphoreSpec = asyncio.Semaphore | int | None
SemaphoreSpec = Optional[Union[asyncio.Semaphore, int]]
"Type alias for semaphore specifications."


class _ModifierKwargsBase(TypedDict, total=False):
class ModifierKwargs(TypedDict, total=False):
"""
TypedDict for keyword arguments that modify the behavior of asynchronous operations,
excluding the default mode and executor.
TypedDict for keyword arguments that modify the behavior of asynchronous operations.
"""

default: DefaultMode
cache_type: CacheType
cache_typed: bool
ram_cache_maxsize: int | None
ram_cache_ttl: Numeric | None
runs_per_minute: int | None
semaphore: SemaphoreSpec


class _ModifierKwargsNoDefault(_ModifierKwargsBase, total=False):
"""
TypedDict for keyword arguments that modify the behavior of asynchronous operations,
excluding the default mode.
"""

# sync modifiers
executor: Executor


class _ModifierKwargsNoExecutor(_ModifierKwargsBase, total=False):
"""
TypedDict for keyword arguments that modify the behavior of asynchronous operations,
excluding the executor.
"""

default: DefaultMode


class _ModifierKwargsNoDefaultExecutor(_ModifierKwargsBase, total=False):
"""
TypedDict for keyword arguments that modify the behavior of asynchronous operations,
excluding the default mode and executor.
"""


class ModifierKwargs(_ModifierKwargsNoDefault, total=False):
"""
TypedDict for keyword arguments that modify the behavior of asynchronous operations.
"""

default: DefaultMode


AnyIterable = AsyncIterable[K] | Iterable[K]
AnyIterable = Union[AsyncIterable[K], Iterable[K]]
"Type alias for any iterable, whether synchronous or asynchronous."

AnyIterableOrAwaitableIterable = AnyIterable[K] | Awaitable[AnyIterable[K]]
AnyIterableOrAwaitableIterable = Union[AnyIterable[K], Awaitable[AnyIterable[K]]]
"""
Type alias for any iterable, whether synchronous or asynchronous,
or an awaitable that resolves to any iterable, whether synchronous or asynchronous.
Expand Down
Loading
Loading