6
6
import inspect
7
7
import pkgutil
8
8
import sys
9
+ from contextlib import suppress
10
+ from inspect import isbuiltin , isclass
9
11
from types import ModuleType
10
12
from typing import (
11
13
TYPE_CHECKING ,
15
17
Dict ,
16
18
Iterable ,
17
19
Iterator ,
20
+ List ,
18
21
Optional ,
19
22
Protocol ,
20
23
Set ,
24
27
Union ,
25
28
cast ,
26
29
)
30
+ from warnings import warn
27
31
28
32
try :
29
33
from typing import Self
@@ -59,13 +63,11 @@ def get_origin(tp):
59
63
return None
60
64
61
65
62
- MARKER_EXTRACTORS = []
66
+ MARKER_EXTRACTORS : List [Callable [[Any ], Any ]] = []
67
+ INSPECT_EXCLUSION_FILTERS : List [Callable [[Any ], bool ]] = [isbuiltin ]
63
68
64
- try :
69
+ with suppress ( ImportError ) :
65
70
from fastapi .params import Depends as FastAPIDepends
66
- except ImportError :
67
- pass
68
- else :
69
71
70
72
def extract_marker_from_fastapi (param : Any ) -> Any :
71
73
if isinstance (param , FastAPIDepends ):
@@ -74,11 +76,8 @@ def extract_marker_from_fastapi(param: Any) -> Any:
74
76
75
77
MARKER_EXTRACTORS .append (extract_marker_from_fastapi )
76
78
77
- try :
79
+ with suppress ( ImportError ) :
78
80
from fast_depends .dependencies import Depends as FastDepends
79
- except ImportError :
80
- pass
81
- else :
82
81
83
82
def extract_marker_from_fast_depends (param : Any ) -> Any :
84
83
if isinstance (param , FastDepends ):
@@ -88,16 +87,22 @@ def extract_marker_from_fast_depends(param: Any) -> Any:
88
87
MARKER_EXTRACTORS .append (extract_marker_from_fast_depends )
89
88
90
89
91
- try :
92
- import starlette .requests
93
- except ImportError :
94
- starlette = None
90
+ with suppress (ImportError ):
91
+ from starlette .requests import Request as StarletteRequest
95
92
93
+ def is_starlette_request_cls (obj : Any ) -> bool :
94
+ return isclass (obj ) and _safe_is_subclass (obj , StarletteRequest )
96
95
97
- try :
98
- import werkzeug .local
99
- except ImportError :
100
- werkzeug = None
96
+ INSPECT_EXCLUSION_FILTERS .append (is_starlette_request_cls )
97
+
98
+
99
+ with suppress (ImportError ):
100
+ from werkzeug .local import LocalProxy as WerkzeugLocalProxy
101
+
102
+ def is_werkzeug_local_proxy (obj : Any ) -> bool :
103
+ return isinstance (obj , WerkzeugLocalProxy )
104
+
105
+ INSPECT_EXCLUSION_FILTERS .append (is_werkzeug_local_proxy )
101
106
102
107
from . import providers # noqa: E402
103
108
@@ -130,6 +135,10 @@ def extract_marker_from_fast_depends(param: Any) -> Any:
130
135
Container = Any
131
136
132
137
138
+ class DIWiringWarning (RuntimeWarning ):
139
+ """Base class for all warnings raised by the wiring module."""
140
+
141
+
133
142
class PatchedRegistry :
134
143
135
144
def __init__ (self ) -> None :
@@ -411,30 +420,11 @@ def _create_providers_map(
411
420
return providers_map
412
421
413
422
414
- class InspectFilter :
415
-
416
- def is_excluded (self , instance : object ) -> bool :
417
- if self ._is_werkzeug_local_proxy (instance ):
418
- return True
419
- elif self ._is_starlette_request_cls (instance ):
423
+ def is_excluded_from_inspect (obj : Any ) -> bool :
424
+ for is_excluded in INSPECT_EXCLUSION_FILTERS :
425
+ if is_excluded (obj ):
420
426
return True
421
- elif self ._is_builtin (instance ):
422
- return True
423
- else :
424
- return False
425
-
426
- def _is_werkzeug_local_proxy (self , instance : object ) -> bool :
427
- return werkzeug and isinstance (instance , werkzeug .local .LocalProxy )
428
-
429
- def _is_starlette_request_cls (self , instance : object ) -> bool :
430
- return (
431
- starlette
432
- and isinstance (instance , type )
433
- and _safe_is_subclass (instance , starlette .requests .Request )
434
- )
435
-
436
- def _is_builtin (self , instance : object ) -> bool :
437
- return inspect .isbuiltin (instance )
427
+ return False
438
428
439
429
440
430
def wire ( # noqa: C901
@@ -455,7 +445,7 @@ def wire( # noqa: C901
455
445
456
446
for module in modules :
457
447
for member_name , member in _get_members_and_annotated (module ):
458
- if _inspect_filter . is_excluded (member ):
448
+ if is_excluded_from_inspect (member ):
459
449
continue
460
450
461
451
if _is_marker (member ):
@@ -520,6 +510,11 @@ def unwire( # noqa: C901
520
510
def inject (fn : F ) -> F :
521
511
"""Decorate callable with injecting decorator."""
522
512
reference_injections , reference_closing = _fetch_reference_injections (fn )
513
+
514
+ if not reference_injections :
515
+ warn ("@inject is not required here" , DIWiringWarning , stacklevel = 2 )
516
+ return fn
517
+
523
518
patched = _get_patched (fn , reference_injections , reference_closing )
524
519
return cast (F , patched )
525
520
@@ -1054,7 +1049,6 @@ def is_loader_installed() -> bool:
1054
1049
1055
1050
1056
1051
_patched_registry = PatchedRegistry ()
1057
- _inspect_filter = InspectFilter ()
1058
1052
_loader = AutoLoader ()
1059
1053
1060
1054
# Optimizations
0 commit comments