Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions a_sync/iter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ class ASyncIterator(_AwaitableAsyncIterableMixin, Iterator[T]):
__slots__ = ("__anext")


class ASyncGeneratorFunction(Generic[P, T]):
cdef class ASyncGeneratorFunction:
"""
Encapsulates an asynchronous generator function, providing a mechanism to use it as an asynchronous iterator with enhanced capabilities. This class wraps an async generator function, allowing it to be called with parameters and return an :class:`~ASyncIterator` object. It is particularly useful for situations where an async generator function needs to be used in a manner that is consistent with both synchronous and asynchronous execution contexts.

Expand All @@ -416,13 +416,20 @@ class ASyncGeneratorFunction(Generic[P, T]):
- :class:`ASyncIterator`
- :class:`ASyncIterable`
"""
cdef readonly str field_name
cdef readonly object _cache_handle
cdef readonly object __wrapped__
cdef object __weakself__

_cache_handle: TimerHandle
"An asyncio handle used to pop the bound method from `instance.__dict__` 5 minutes after its last use."

__weakself__: "ref[object]" = None
"A weak reference to the instance the function is bound to, if any."

def __cinit__(self):
self.__weakself__ = None

def __init__(
self, async_gen_func: AsyncGenFunc[P, T], instance: Any = None
) -> None:
Expand All @@ -439,11 +446,12 @@ class ASyncGeneratorFunction(Generic[P, T]):

self.__wrapped__ = async_gen_func
"The actual async generator function."

update_wrapper(self, async_gen_func)

if instance is not None:
self._cache_handle = self.__get_cache_handle(instance)
self.__weakself__ = ref(instance, self.__cancel_cache_handle)
update_wrapper(self, self.__wrapped__)

def __repr__(self) -> str:
return "<{} for {} at {}>".format(
Expand All @@ -462,7 +470,7 @@ class ASyncGeneratorFunction(Generic[P, T]):
"""
if self.__weakself__ is None:
return ASyncIterator(self.__wrapped__(*args, **kwargs))
return ASyncIterator(self.__wrapped__(self.__self__, *args, **kwargs))
return ASyncIterator(self.__wrapped__(self.__get_self(), *args, **kwargs))

def __get__(self, instance: V, owner: Type[V]) -> "ASyncGeneratorFunction[P, T]":
"Descriptor method to make the function act like a non-data descriptor."
Expand All @@ -471,16 +479,19 @@ class ASyncGeneratorFunction(Generic[P, T]):

cdef object gen_func
try:
gen_func = instance.__dict__[self.field_name]
gen_func = (<dict>instance.__dict__)[self.field_name]
except KeyError:
gen_func = ASyncGeneratorFunction(self.__wrapped__, instance)
instance.__dict__[self.field_name] = gen_func
(<dict>instance.__dict__)[self.field_name] = gen_func
gen_func._cache_handle.cancel()
gen_func._cache_handle = self.__get_cache_handle(instance)
return gen_func

@property
def __self__(self) -> object:
return self.__get_self()

cdef object __get_self(self):
cdef object instance
try:
instance = self.__weakself__()
Expand All @@ -490,13 +501,13 @@ class ASyncGeneratorFunction(Generic[P, T]):
raise ReferenceError(self)
return instance

def __get_cache_handle(self, instance: object) -> TimerHandle:
cdef object __get_cache_handle(self, instance: object):
# NOTE: we create a strong reference to instance here. I'm not sure if this is good or not but its necessary for now.
return get_event_loop().call_later(
300, delattr, instance, self.field_name
)

def __cancel_cache_handle(self, instance: object) -> None:
cdef void __cancel_cache_handle(self, object instance):
self._cache_handle.cancel()


Expand Down
Loading