18
18
from typing import Optional
19
19
from typing import Set
20
20
from typing import Tuple
21
+ from typing import TYPE_CHECKING
22
+
23
+ import py .path
21
24
22
25
from _pytest ._io .saferepr import safeformat
23
26
from _pytest ._io .saferepr import saferepr
31
34
from _pytest .pathlib import Path
32
35
from _pytest .pathlib import PurePath
33
36
37
+ if TYPE_CHECKING :
38
+ from _pytest .assertion import AssertionState
39
+ from _pytest .main import Session
40
+
34
41
# pytest caches rewritten pycs in pycache dirs
35
42
PYTEST_TAG = "{}-pytest-{}" .format (sys .implementation .cache_tag , version )
36
43
PYC_EXT = ".py" + (__debug__ and "c" or "o" )
@@ -46,7 +53,7 @@ def __init__(self, config):
46
53
self .fnpats = config .getini ("python_files" )
47
54
except ValueError :
48
55
self .fnpats = ["test_*.py" , "*_test.py" ]
49
- self .session = None
56
+ self .session = None # type: Optional[Session]
50
57
self ._rewritten_names = set () # type: Set[str]
51
58
self ._must_rewrite = set () # type: Set[str]
52
59
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
@@ -56,7 +63,7 @@ def __init__(self, config):
56
63
self ._marked_for_rewrite_cache = {} # type: Dict[str, bool]
57
64
self ._session_paths_checked = False
58
65
59
- def set_session (self , session ) :
66
+ def set_session (self , session : "Optional[Session]" ) -> None :
60
67
self .session = session
61
68
self ._session_paths_checked = False
62
69
@@ -182,14 +189,14 @@ def _early_rewrite_bailout(self, name, state):
182
189
state .trace ("early skip of rewriting module: {}" .format (name ))
183
190
return True
184
191
185
- def _should_rewrite (self , name , fn , state ) :
192
+ def _should_rewrite (self , name : str , fn : str , state : "AssertionState" ) -> bool :
186
193
# always rewrite conftest files
187
194
if os .path .basename (fn ) == "conftest.py" :
188
195
state .trace ("rewriting conftest file: {!r}" .format (fn ))
189
196
return True
190
197
191
198
if self .session is not None :
192
- if self .session .isinitpath (fn ):
199
+ if self .session .isinitpath (py . path . local ( fn ) ):
193
200
state .trace (
194
201
"matched test file (was specified on cmdline): {!r}" .format (fn )
195
202
)
@@ -205,7 +212,7 @@ def _should_rewrite(self, name, fn, state):
205
212
206
213
return self ._is_marked_for_rewrite (name , state )
207
214
208
- def _is_marked_for_rewrite (self , name : str , state ) :
215
+ def _is_marked_for_rewrite (self , name : str , state : "AssertionState" ) -> bool :
209
216
try :
210
217
return self ._marked_for_rewrite_cache [name ]
211
218
except KeyError :
0 commit comments