Skip to content

Commit 5df494f

Browse files
committed
feat: support with_state to be applied to methods of classes, not just functions
1 parent 492c4db commit 5df494f

File tree

5 files changed

+94
-14
lines changed

5 files changed

+94
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
- chore: add badges in `README.md` and classifiers in `pyproject.toml`
66
- refactor: move the common code for manipulating the signature of the wrapped functions in `WithStore` and `Autorun` to a utility function
7+
- feat: support `with_state` to be applied to methods of classes, not just functions
78

89
## Version 0.23.0
910

redux/basic_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,15 +361,25 @@ def __call__(
361361
# With Store
362362

363363

364+
Self = TypeVar('Self', bound=object, infer_variance=True)
365+
366+
364367
class WithStateDecorator(
365368
Protocol,
366369
Generic[SelectorOutput],
367370
):
371+
@overload
368372
def __call__(
369373
self: WithStateDecorator,
370374
func: Callable[Concatenate[SelectorOutput, Args], ReturnType],
371375
) -> Callable[Args, ReturnType]: ...
372376

377+
@overload
378+
def __call__(
379+
self: WithStateDecorator,
380+
func: Callable[Concatenate[Self, SelectorOutput, Args], ReturnType],
381+
) -> Callable[Concatenate[Self, Args], ReturnType]: ...
382+
373383

374384
class EventSubscriber(Protocol):
375385
def __call__(

redux/main.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
ReducerType,
4343
ReturnType,
4444
SelectorOutput,
45+
Self,
4546
SnapshotAtom,
4647
State,
4748
StoreOptions,
@@ -384,20 +385,67 @@ def with_state(
384385
`store._state` is also possible.
385386
"""
386387

388+
@overload
389+
def with_state_decorator(
390+
func: Callable[
391+
Concatenate[SelectorOutput, Args],
392+
ReturnType,
393+
],
394+
) -> Callable[Args, ReturnType]: ...
395+
@overload
396+
def with_state_decorator(
397+
func: Callable[
398+
Concatenate[Self, SelectorOutput, Args],
399+
ReturnType,
400+
],
401+
) -> Callable[Concatenate[Self, Args], ReturnType]: ...
387402
def with_state_decorator(
388403
func: Callable[
389404
Concatenate[SelectorOutput, Args],
390405
ReturnType,
406+
]
407+
| Callable[
408+
Concatenate[Self, SelectorOutput, Args],
409+
ReturnType,
391410
],
392-
) -> Callable[Args, ReturnType]:
411+
) -> Callable[Args, ReturnType] | Callable[Concatenate[Self, Args], ReturnType]:
412+
signature = drop_with_store_parameter(func)
413+
414+
if (
415+
signature.parameters
416+
and next(iter(signature.parameters.values())).name == 'self'
417+
):
418+
func_ = cast(
419+
'Callable[Concatenate[Self, SelectorOutput, Args], ReturnType]',
420+
func,
421+
)
422+
423+
def wrapper(*args: Args.args, **kwargs: Args.kwargs) -> ReturnType:
424+
if self._state is None:
425+
msg = 'Store has not been initialized yet.'
426+
raise RuntimeError(msg)
427+
self_ = cast('Self', args[0])
428+
args_ = cast('Any', args[1:])
429+
return func_(self_, selector(self._state), *args_, **kwargs)
430+
431+
wrapped = wraps(func_)(wrapper)
432+
wrapped.__signature__ = signature # pyright: ignore [reportAttributeAccessIssue]
433+
434+
return wrapped
435+
436+
func_ = cast(
437+
'Callable[Concatenate[SelectorOutput, Args], ReturnType]',
438+
func,
439+
)
440+
393441
def wrapper(*args: Args.args, **kwargs: Args.kwargs) -> ReturnType:
394442
if self._state is None:
395443
msg = 'Store has not been initialized yet.'
396444
raise RuntimeError(msg)
397-
return func(selector(self._state), *args, **kwargs)
445+
return func_(selector(self._state), *args, **kwargs)
398446

399-
wrapped = wraps(func)(wrapper)
400-
wrapped.__signature__ = drop_with_store_parameter(func) # pyright: ignore [reportAttributeAccessIssue]
447+
wrapped = wraps(func_)(wrapper)
448+
wrapped.__signature__ = signature # pyright: ignore [reportAttributeAccessIssue]
401449

402450
return wrapped
403451

tests/test_async.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ async def sync_mirror(value: int) -> int:
101101
async def _(mirrored_value: int) -> None:
102102
if mirrored_value < INCREMENTS:
103103
return
104-
await asyncio.sleep(0.1)
104+
await asyncio.sleep(0.001)
105105
store.dispatch(FinishAction())
106106

107107
dispatch_actions(store)
@@ -124,7 +124,7 @@ async def _(values: tuple[int, int]) -> None:
124124
elif value < INCREMENTS:
125125
store.dispatch(IncrementAction())
126126
else:
127-
await asyncio.sleep(0.1)
127+
await asyncio.sleep(0.001)
128128
store.dispatch(FinishAction())
129129

130130

@@ -139,6 +139,7 @@ async def sync_mirror(value: int) -> int:
139139
@store.autorun(lambda state: (state.value, state.mirrored_value))
140140
async def _(values: tuple[int, int]) -> None:
141141
value, mirrored_value = values
142+
await asyncio.sleep(0.001)
142143
if mirrored_value != value:
143144
sync_mirror_returned_value = sync_mirror()
144145
assert 'awaited=False' in str(sync_mirror_returned_value)
@@ -148,7 +149,6 @@ async def _(values: tuple[int, int]) -> None:
148149
elif value < INCREMENTS:
149150
store.dispatch(IncrementAction())
150151
else:
151-
await asyncio.sleep(0.1)
152152
store.dispatch(FinishAction())
153153

154154

@@ -165,7 +165,7 @@ async def _(value: int) -> int:
165165
async def _(mirrored_value: int) -> None:
166166
if mirrored_value < INCREMENTS:
167167
return
168-
await asyncio.sleep(0.1)
168+
await asyncio.sleep(0.001)
169169
store.dispatch(FinishAction())
170170

171171
dispatch_actions(store)
@@ -184,7 +184,7 @@ async def _(value: int) -> None:
184184
assert await doubled() == value * 2
185185
for _ in range(10):
186186
await doubled()
187-
await asyncio.sleep(0.01)
187+
await asyncio.sleep(0.001)
188188
if value < INCREMENTS:
189189
store.dispatch(IncrementAction())
190190
else:
@@ -209,7 +209,7 @@ async def _(value: int) -> None:
209209
calls_length = len(calls)
210210
assert await doubled() == value * 2
211211
assert len(calls) == calls_length + 1
212-
await asyncio.sleep(0.01)
212+
await asyncio.sleep(0.001)
213213

214214
if value < INCREMENTS:
215215
store.dispatch(IncrementAction())
@@ -230,7 +230,7 @@ async def multiplied(value: int, factor: int) -> int:
230230
async def _(value: int) -> None:
231231
assert await multiplied(factor=2) == value * 2
232232
assert await multiplied(factor=3) == value * 3
233-
await asyncio.sleep(0.01)
233+
await asyncio.sleep(0.001)
234234
if value < INCREMENTS:
235235
store.dispatch(IncrementAction())
236236
else:
@@ -251,7 +251,7 @@ async def doubled(value: int) -> int:
251251
@store.autorun(lambda state: state.value)
252252
async def _(value: int) -> None:
253253
assert await doubled() == value * 2
254-
await asyncio.sleep(0.01)
254+
await asyncio.sleep(0.001)
255255
if value < INCREMENTS:
256256
store.dispatch(IncrementAction())
257257
else:
@@ -265,7 +265,7 @@ def test_subscription(store: StoreType) -> None:
265265
async def render(state: StateType) -> None:
266266
if state.value == INCREMENTS:
267267
unsubscribe()
268-
await asyncio.sleep(0.1)
268+
await asyncio.sleep(0.001)
269269
store.dispatch(FinishAction())
270270

271271
unsubscribe = store._subscribe(render) # noqa: SLF001
@@ -277,7 +277,7 @@ def test_event_subscription(store: StoreType) -> None:
277277
async def handler(event: IncrementEvent) -> None:
278278
if event.post_value == INCREMENTS:
279279
unsubscribe()
280-
await asyncio.sleep(0.1)
280+
await asyncio.sleep(0.001)
281281
store.dispatch(FinishAction())
282282

283283
unsubscribe = store.subscribe_event(IncrementEvent, handler)

tests/test_with_state.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,24 @@ def check(self: X, value: int) -> None:
198198
store.dispatch(FinishAction())
199199

200200
check_spy.assert_called_once_with(0)
201+
202+
203+
def test_methods(store: StoreType) -> None:
204+
"""Test `with_state` decorator with methods."""
205+
206+
class SomeClass:
207+
@store.with_state(lambda state: state.value)
208+
def some_method(self, value: int) -> int:
209+
return value
210+
211+
instance = SomeClass()
212+
213+
store.dispatch(InitAction())
214+
215+
assert instance.some_method() == 0
216+
217+
store.dispatch(_IncrementAction())
218+
219+
assert instance.some_method() == 1
220+
221+
store.dispatch(FinishAction())

0 commit comments

Comments
 (0)