Skip to content

Commit 78d4d3d

Browse files
committed
refactor: setting auto_await for async autorun functions will make them return None, setting it to False will make them return the awaitable, the awaitable can be awaited multiple times, as it cashes the result if comparator is not changed, it can't be set for sync functions
1 parent 220af95 commit 78d4d3d

File tree

5 files changed

+174
-136
lines changed

5 files changed

+174
-136
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
- refactor: remove `WithState` as it wasn't doing anything beyond `functools.wraps`
66
- refactor: autorun doesn't inform subscribers when the output value is not changed
77
- refactor: add `autorun_class` and `side_effect_runner_class` to improve extensibility
8+
- refactor: setting `auto_await` for async autorun functions will make them return `None`, setting it to `False` will make them return the awaitable, the awaitable can be `await`ed multiple times, as it cashes the result if comparator is not changed, it can't be set for sync functions
89

910
## Version 0.22.2
1011

redux/autorun.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import asyncio
6-
import functools
76
import inspect
87
import weakref
98
from asyncio import Future, Task, iscoroutine, iscoroutinefunction
@@ -12,19 +11,21 @@
1211
Any,
1312
Concatenate,
1413
Generic,
15-
TypeVar,
14+
Literal,
1615
cast,
1716
)
1817

1918
from redux.basic_types import (
2019
Action,
2120
Args,
22-
AutorunOptions,
21+
AutoAwait,
22+
AutorunOptionsType,
2323
ComparatorOutput,
2424
Event,
2525
ReturnType,
2626
SelectorOutput,
2727
State,
28+
T,
2829
)
2930

3031
if TYPE_CHECKING:
@@ -33,26 +34,37 @@
3334
from redux.main import Store
3435

3536

36-
T = TypeVar('T')
37-
38-
3937
class AwaitableWrapper(Generic[T]):
4038
"""A wrapper for a coroutine to track if it has been awaited."""
4139

40+
_unawaited = object()
41+
value: tuple[Literal[False], None] | tuple[Literal[True], T]
42+
4243
def __init__(self, coro: Coroutine[None, None, T]) -> None:
4344
"""Initialize the AwaitableWrapper with a coroutine."""
4445
self.coro = coro
45-
self.awaited = False
46+
self.value = (False, None)
4647

4748
def __await__(self) -> Generator[None, None, T]:
4849
"""Await the coroutine and set the awaited flag to True."""
49-
self.awaited = True
50-
return self.coro.__await__()
50+
return self._wrap().__await__()
51+
52+
async def _wrap(self) -> T:
53+
"""Wrap the coroutine and set the awaited flag to True."""
54+
if self.value[0] is True:
55+
return self.value[1]
56+
self.value = (True, await self.coro)
57+
return self.value[1]
5158

5259
def close(self) -> None:
5360
"""Close the coroutine if it has not been awaited."""
5461
self.coro.close()
5562

63+
@property
64+
def awaited(self) -> bool:
65+
"""Check if the coroutine has been awaited."""
66+
return self.value[0] is True
67+
5668
def __repr__(self) -> str:
5769
"""Return a string representation of the AwaitableWrapper."""
5870
return f'AwaitableWrapper({self.coro}, awaited={self.awaited})'
@@ -71,7 +83,7 @@ class Autorun(
7183
):
7284
"""Run a wrapped function in response to specific state changes in the store."""
7385

74-
def __init__( # noqa: C901, PLR0912
86+
def __init__( # noqa: C901, PLR0912, PLR0915
7587
self: Autorun,
7688
*,
7789
store: Store[State, Action, Event],
@@ -81,7 +93,7 @@ def __init__( # noqa: C901, PLR0912
8193
Concatenate[SelectorOutput, Args],
8294
ReturnType,
8395
],
84-
options: AutorunOptions[ReturnType],
96+
options: AutorunOptionsType[ReturnType, AutoAwait],
8597
) -> None:
8698
"""Initialize the Autorun instance."""
8799
if hasattr(func, '__name__'):
@@ -121,7 +133,7 @@ def __init__( # noqa: C901, PLR0912
121133
self._func = weakref.ref(func, self.unsubscribe)
122134
self._is_coroutine = (
123135
asyncio.coroutines._is_coroutine # pyright: ignore [reportAttributeAccessIssue] # noqa: SLF001
124-
if asyncio.iscoroutinefunction(func)
136+
if asyncio.iscoroutinefunction(func) and options.auto_await is False
125137
else None
126138
)
127139
self._options = options
@@ -132,8 +144,16 @@ def __init__( # noqa: C901, PLR0912
132144
object(),
133145
)
134146
if iscoroutinefunction(func):
135-
self._latest_value = Future()
136-
self._latest_value.set_result(options.default_value)
147+
148+
async def default_value_wrapper() -> ReturnType | None:
149+
return options.default_value
150+
151+
create_task = self._store._create_task # noqa: SLF001
152+
default_value = default_value_wrapper()
153+
154+
if create_task:
155+
create_task(default_value)
156+
self._latest_value: ReturnType = default_value
137157
else:
138158
self._latest_value: ReturnType = options.default_value
139159
self._subscriptions: set[
@@ -145,11 +165,11 @@ def __init__( # noqa: C901, PLR0912
145165
self.call()
146166

147167
if self._options.reactive:
148-
self._unsubscribe = store._subscribe(self._react) # noqa: SLF001
168+
self._unsubscribe = store._subscribe(self.react) # noqa: SLF001
149169
else:
150170
self._unsubscribe = None
151171

152-
def _react(
172+
def react(
153173
self: Autorun,
154174
state: State,
155175
) -> None:
@@ -275,27 +295,17 @@ def call(
275295
create_task = self._store._create_task # noqa: SLF001
276296
previous_value = self._latest_value
277297
if iscoroutine(value) and create_task:
278-
if self._options.auto_await:
279-
future = Future()
280-
self._latest_value = cast('ReturnType', future)
281-
create_task(
282-
value,
283-
callback=functools.partial(
284-
self._task_callback,
285-
future=future,
286-
),
287-
)
288-
else:
298+
if self._options.auto_await is False:
289299
if (
290300
self._latest_value is not None
291301
and isinstance(self._latest_value, AwaitableWrapper)
292302
and not self._latest_value.awaited
293303
):
294304
self._latest_value.close()
295-
self._latest_value = cast(
296-
'ReturnType',
297-
AwaitableWrapper(value),
298-
)
305+
self._latest_value = cast('ReturnType', AwaitableWrapper(value))
306+
else:
307+
self._latest_value = cast('ReturnType', None)
308+
create_task(value)
299309
else:
300310
self._latest_value = value
301311
if self._latest_value is not previous_value:

redux/basic_types.py

Lines changed: 89 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
Any,
1010
Concatenate,
1111
Generic,
12-
Never,
12+
Literal,
1313
ParamSpec,
1414
Protocol,
1515
TypeAlias,
1616
TypeGuard,
17+
cast,
1718
overload,
1819
)
1920

@@ -144,10 +145,75 @@ class StoreOptions(Immutable, Generic[Action, Event]):
144145

145146
# Autorun
146147

148+
AutoAwait = TypeVar('AutoAwait', bound=Literal[True, False, None], infer_variance=True)
147149

148-
class AutorunOptions(Immutable, Generic[ReturnType]):
150+
151+
class AutorunOptionsType(Immutable, Generic[ReturnType, AutoAwait]):
152+
default_value: ReturnType | None = None
153+
auto_await: AutoAwait = cast('AutoAwait', val=None)
154+
initial_call: bool = True
155+
reactive: bool = True
156+
memoization: bool = True
157+
keep_ref: bool = True
158+
subscribers_initial_run: bool = True
159+
subscribers_keep_ref: bool = True
160+
161+
@overload
162+
def __init__(
163+
self: AutorunOptionsType[ReturnType, Literal[None]], # type: ignore[reportInvalidTypeVar]
164+
*,
165+
default_value: ReturnType | None = None,
166+
auto_await: Literal[None] | None = None,
167+
initial_call: bool = True,
168+
reactive: bool = True,
169+
memoization: bool = True,
170+
keep_ref: bool = True,
171+
subscribers_initial_run: bool = True,
172+
subscribers_keep_ref: bool = True,
173+
) -> None: ...
174+
@overload
175+
def __init__(
176+
self: AutorunOptionsType[ReturnType, Literal[True]], # type: ignore[reportInvalidTypeVar]
177+
*,
178+
default_value: ReturnType | None = None,
179+
auto_await: Literal[True],
180+
initial_call: bool = True,
181+
reactive: bool = True,
182+
memoization: bool = True,
183+
keep_ref: bool = True,
184+
subscribers_initial_run: bool = True,
185+
subscribers_keep_ref: bool = True,
186+
) -> None: ...
187+
@overload
188+
def __init__(
189+
self: AutorunOptionsType[ReturnType, Literal[False]], # type: ignore[reportInvalidTypeVar]
190+
*,
191+
default_value: ReturnType | None = None,
192+
auto_await: Literal[False],
193+
initial_call: bool = True,
194+
reactive: bool = True,
195+
memoization: bool = True,
196+
keep_ref: bool = True,
197+
subscribers_initial_run: bool = True,
198+
subscribers_keep_ref: bool = True,
199+
) -> None: ...
200+
def __init__( # noqa: PLR0913
201+
self: AutorunOptionsType,
202+
*,
203+
default_value: ReturnType | None = None,
204+
auto_await: bool | None = None,
205+
initial_call: bool = True,
206+
reactive: bool = True,
207+
memoization: bool = True,
208+
keep_ref: bool = True,
209+
subscribers_initial_run: bool = True,
210+
subscribers_keep_ref: bool = True,
211+
) -> None: ...
212+
213+
214+
class AutorunOptionsImplementation(Immutable, Generic[ReturnType, AutoAwait]):
149215
default_value: ReturnType | None = None
150-
auto_await: bool = True
216+
auto_await: AutoAwait = cast('AutoAwait', val=None)
151217
initial_call: bool = True
152218
reactive: bool = True
153219
memoization: bool = True
@@ -156,8 +222,7 @@ class AutorunOptions(Immutable, Generic[ReturnType]):
156222
subscribers_keep_ref: bool = True
157223

158224

159-
AutorunOptionsWithDefault = AutorunOptions[ReturnType]
160-
AutorunOptionsWithoutDefault = AutorunOptions[Never]
225+
AutorunOptions = cast('type[AutorunOptionsType]', AutorunOptionsImplementation)
161226

162227

163228
class AutorunReturnType(
@@ -186,34 +251,39 @@ def unsubscribe(self: AutorunReturnType) -> None: ...
186251
__name__: str
187252

188253

189-
class AutorunDecorator(Protocol, Generic[SelectorOutput, ReturnType]):
254+
class AutorunDecorator(Protocol, Generic[ReturnType, SelectorOutput, AutoAwait]):
190255
@overload
191256
def __call__(
192-
self: AutorunDecorator,
257+
self: AutorunDecorator[ReturnType, SelectorOutput, Literal[None]],
258+
func: Callable[
259+
Concatenate[SelectorOutput, Args],
260+
Awaitable[ReturnType],
261+
],
262+
) -> AutorunReturnType[None, Args]: ...
263+
@overload
264+
def __call__(
265+
self: AutorunDecorator[ReturnType, SelectorOutput, Literal[None]],
193266
func: Callable[
194267
Concatenate[SelectorOutput, Args],
195268
ReturnType,
196269
],
197270
) -> AutorunReturnType[ReturnType, Args]: ...
198-
199271
@overload
200272
def __call__(
201-
self: AutorunDecorator,
273+
self: AutorunDecorator[ReturnType, SelectorOutput, Literal[True]],
202274
func: Callable[
203275
Concatenate[SelectorOutput, Args],
204276
Awaitable[ReturnType],
205277
],
206-
) -> AutorunReturnType[Awaitable[ReturnType], Args]: ...
207-
208-
209-
class UnknownAutorunDecorator(Protocol, Generic[SelectorOutput]):
278+
) -> AutorunReturnType[None, Args]: ...
279+
@overload
210280
def __call__(
211-
self: UnknownAutorunDecorator,
281+
self: AutorunDecorator[ReturnType, SelectorOutput, Literal[False]],
212282
func: Callable[
213283
Concatenate[SelectorOutput, Args],
214-
ReturnType,
284+
Awaitable[ReturnType],
215285
],
216-
) -> AutorunReturnType[ReturnType, Args]: ...
286+
) -> AutorunReturnType[Awaitable[ReturnType], Args]: ...
217287

218288

219289
# View
@@ -227,10 +297,6 @@ class ViewOptions(Immutable, Generic[ReturnType]):
227297
subscribers_keep_ref: bool = True
228298

229299

230-
ViewOptionsWithDefault = ViewOptions[ReturnType]
231-
ViewOptionsWithoutDefault = ViewOptions[Never]
232-
233-
234300
class ViewReturnType(
235301
Protocol,
236302
Generic[ReturnType, Args],
@@ -257,17 +323,8 @@ def unsubscribe(self: ViewReturnType) -> None: ...
257323

258324
class ViewDecorator(
259325
Protocol,
260-
Generic[SelectorOutput, ReturnType],
326+
Generic[ReturnType, SelectorOutput],
261327
):
262-
@overload
263-
def __call__(
264-
self: ViewDecorator,
265-
func: Callable[
266-
Concatenate[SelectorOutput, Args],
267-
ReturnType,
268-
],
269-
) -> ViewReturnType[ReturnType, Args]: ...
270-
271328
@overload
272329
def __call__(
273330
self: ViewDecorator,
@@ -277,10 +334,9 @@ def __call__(
277334
],
278335
) -> ViewReturnType[Awaitable[ReturnType], Args]: ...
279336

280-
281-
class UnknownViewDecorator(Protocol, Generic[SelectorOutput]):
337+
@overload
282338
def __call__(
283-
self: UnknownViewDecorator,
339+
self: ViewDecorator,
284340
func: Callable[
285341
Concatenate[SelectorOutput, Args],
286342
ReturnType,

0 commit comments

Comments
 (0)