Skip to content

Commit 3cf33ab

Browse files
authored
refactor: @scoped
1 parent c5552c8 commit 3cf33ab

File tree

3 files changed

+53
-37
lines changed

3 files changed

+53
-37
lines changed

injection/_core/common/type.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,39 +28,39 @@
2828
)
2929

3030

31-
def get_return_types(*args: TypeInfo[Any]) -> Iterator[InputType[Any]]:
32-
for arg in args:
33-
if isinstance(arg, Collection) and not isclass(arg):
34-
inner_args = arg
35-
36-
elif isfunction(arg) and (return_type := get_return_hint(arg)):
37-
inner_args = (return_type,)
38-
39-
else:
40-
yield arg # type: ignore[misc]
41-
continue
42-
43-
yield from get_return_types(*inner_args)
44-
45-
4631
def get_return_hint[T](function: Callable[..., T]) -> InputType[T] | None:
4732
return get_type_hints(function).get("return")
4833

4934

50-
def get_yield_hint[T](
35+
def get_yield_hints[T](
5136
function: Callable[..., Iterator[T]] | Callable[..., AsyncIterator[T]],
5237
) -> tuple[InputType[T]] | tuple[()]:
5338
return_type = get_return_hint(function)
5439

55-
if get_origin(return_type) in {
40+
if get_origin(return_type) in (
5641
AsyncGenerator,
5742
AsyncIterable,
5843
AsyncIterator,
5944
Generator,
6045
Iterable,
6146
Iterator,
62-
}:
47+
):
6348
for arg in get_args(return_type):
6449
return (arg,)
6550

6651
return ()
52+
53+
54+
def iter_return_types(*args: TypeInfo[Any]) -> Iterator[InputType[Any]]:
55+
for arg in args:
56+
if isinstance(arg, Collection) and not isclass(arg):
57+
inner_args = arg
58+
59+
elif isfunction(arg) and (return_type := get_return_hint(arg)):
60+
inner_args = (return_type,)
61+
62+
else:
63+
yield arg # type: ignore[misc]
64+
continue
65+
66+
yield from iter_return_types(*inner_args)

injection/_core/injectables.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ContextManager,
1111
NoReturn,
1212
Protocol,
13+
Self,
1314
runtime_checkable,
1415
)
1516

@@ -166,6 +167,10 @@ def unlock(self) -> None:
166167
def __get_scope(self) -> Scope:
167168
return get_scope(self.scope_name)
168169

170+
@classmethod
171+
def bind_scope_name(cls, name: str) -> Callable[[Caller[..., R]], Self]:
172+
return partial(cls, scope_name=name)
173+
169174

170175
class AsyncCMScopedInjectable[T](ScopedInjectable[AsyncContextManager[T], T]):
171176
__slots__ = ()

injection/_core/module.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from contextlib import asynccontextmanager, contextmanager, suppress
1818
from dataclasses import dataclass, field
1919
from enum import StrEnum
20-
from functools import partial, partialmethod, singledispatchmethod, update_wrapper
20+
from functools import partialmethod, singledispatchmethod, update_wrapper
2121
from inspect import (
2222
BoundArguments,
2323
Signature,
@@ -57,8 +57,8 @@
5757
from injection._core.common.type import (
5858
InputType,
5959
TypeInfo,
60-
get_return_types,
61-
get_yield_hint,
60+
get_yield_hints,
61+
iter_return_types,
6262
)
6363
from injection._core.injectables import (
6464
AsyncCMScopedInjectable,
@@ -169,14 +169,21 @@ def get_default(cls) -> Priority:
169169

170170
type PriorityStr = Literal["low", "high"]
171171

172-
type ContextManagerLikeRecipe[**P, T] = (
172+
type ContextManagerRecipe[**P, T] = (
173173
Callable[P, ContextManager[T]] | Callable[P, AsyncContextManager[T]]
174174
)
175175
type GeneratorRecipe[**P, T] = (
176176
Callable[P, Generator[T, Any, Any]] | Callable[P, AsyncGenerator[T, Any]]
177177
)
178178

179179

180+
@dataclass(repr=False, eq=False, frozen=True, slots=True)
181+
class _ScopedContext[**P, T]:
182+
cls: type[ScopedInjectable[Any, T]]
183+
hints: TypeInfo[T]
184+
wrapper: Recipe[P, T] | ContextManagerRecipe[P, T]
185+
186+
180187
@dataclass(eq=False, frozen=True, slots=True)
181188
class Module(EventListener, InjectionProvider): # type: ignore[misc]
182189
name: str = field(default_factory=lambda: f"anonymous@{new_short_key()}")
@@ -266,29 +273,33 @@ def scoped[**P, T](
266273
def decorator(
267274
wrapped: Recipe[P, T] | GeneratorRecipe[P, T],
268275
) -> Recipe[P, T] | GeneratorRecipe[P, T]:
269-
injectable_class: type[ScopedInjectable[Any, T]]
270-
wrapper: Recipe[P, T] | ContextManagerLikeRecipe[P, T]
271-
272276
if isasyncgenfunction(wrapped):
273-
hint = get_yield_hint(wrapped)
274-
injectable_class = AsyncCMScopedInjectable
275-
wrapper = asynccontextmanager(wrapped)
277+
ctx = _ScopedContext(
278+
cls=AsyncCMScopedInjectable,
279+
hints=get_yield_hints(wrapped),
280+
wrapper=asynccontextmanager(wrapped),
281+
)
276282

277283
elif isgeneratorfunction(wrapped):
278-
hint = get_yield_hint(wrapped)
279-
injectable_class = CMScopedInjectable
280-
wrapper = contextmanager(wrapped)
284+
ctx = _ScopedContext(
285+
cls=CMScopedInjectable,
286+
hints=get_yield_hints(wrapped),
287+
wrapper=contextmanager(wrapped),
288+
)
281289

282290
else:
283-
injectable_class = SimpleScopedInjectable
284-
hint = wrapper = wrapped # type: ignore[assignment]
291+
ctx = _ScopedContext(
292+
cls=SimpleScopedInjectable,
293+
hints=(wrapped,),
294+
wrapper=wrapped,
295+
)
285296

286297
self.injectable(
287-
wrapper,
288-
cls=partial(injectable_class, scope_name=scope_name),
298+
ctx.wrapper,
299+
cls=ctx.cls.bind_scope_name(scope_name),
289300
ignore_type_hint=True,
290301
inject=inject,
291-
on=(hint, on),
302+
on=(ctx.hints, on),
292303
mode=mode,
293304
)
294305
return wrapped
@@ -715,7 +726,7 @@ def __build_key_types(input_cls: Any) -> frozenset[Any]:
715726
config = MatchingTypesConfig(ignore_none=True)
716727
return frozenset(
717728
itertools.chain.from_iterable(
718-
iter_matching_types(cls, config) for cls in get_return_types(input_cls)
729+
iter_matching_types(cls, config) for cls in iter_return_types(input_cls)
719730
)
720731
)
721732

0 commit comments

Comments
 (0)