Skip to content

Commit 26e64dd

Browse files
committed
fix _should_rewrite: pass py.path.local
1 parent 682fc18 commit 26e64dd

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

src/_pytest/assertion/rewrite.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from typing import Optional
1919
from typing import Set
2020
from typing import Tuple
21+
from typing import TYPE_CHECKING
22+
23+
import py.path
2124

2225
from _pytest._io.saferepr import safeformat
2326
from _pytest._io.saferepr import saferepr
@@ -31,6 +34,10 @@
3134
from _pytest.pathlib import Path
3235
from _pytest.pathlib import PurePath
3336

37+
if TYPE_CHECKING:
38+
from _pytest.assertion import AssertionState
39+
from _pytest.main import Session
40+
3441
# pytest caches rewritten pycs in pycache dirs
3542
PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version)
3643
PYC_EXT = ".py" + (__debug__ and "c" or "o")
@@ -46,7 +53,7 @@ def __init__(self, config):
4653
self.fnpats = config.getini("python_files")
4754
except ValueError:
4855
self.fnpats = ["test_*.py", "*_test.py"]
49-
self.session = None
56+
self.session = None # type: Optional[Session]
5057
self._rewritten_names = set() # type: Set[str]
5158
self._must_rewrite = set() # type: Set[str]
5259
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
@@ -56,7 +63,7 @@ def __init__(self, config):
5663
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
5764
self._session_paths_checked = False
5865

59-
def set_session(self, session):
66+
def set_session(self, session: "Optional[Session]") -> None:
6067
self.session = session
6168
self._session_paths_checked = False
6269

@@ -182,14 +189,14 @@ def _early_rewrite_bailout(self, name, state):
182189
state.trace("early skip of rewriting module: {}".format(name))
183190
return True
184191

185-
def _should_rewrite(self, name, fn, state):
192+
def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
186193
# always rewrite conftest files
187194
if os.path.basename(fn) == "conftest.py":
188195
state.trace("rewriting conftest file: {!r}".format(fn))
189196
return True
190197

191198
if self.session is not None:
192-
if self.session.isinitpath(fn):
199+
if self.session.isinitpath(py.path.local(fn)):
193200
state.trace(
194201
"matched test file (was specified on cmdline): {!r}".format(fn)
195202
)
@@ -205,7 +212,7 @@ def _should_rewrite(self, name, fn, state):
205212

206213
return self._is_marked_for_rewrite(name, state)
207214

208-
def _is_marked_for_rewrite(self, name: str, state):
215+
def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
209216
try:
210217
return self._marked_for_rewrite_cache[name]
211218
except KeyError:

src/_pytest/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def pytest_runtest_logreport(self, report):
458458

459459
pytest_collectreport = pytest_runtest_logreport
460460

461-
def isinitpath(self, path):
461+
def isinitpath(self, path: "py.path.local") -> bool:
462462
return path in self._initialpaths
463463

464464
def gethookproxy(self, fspath: py.path.local):

testing/test_assertrewrite.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,15 @@
2121
from _pytest.assertion.rewrite import PYC_TAIL
2222
from _pytest.assertion.rewrite import PYTEST_TAG
2323
from _pytest.assertion.rewrite import rewrite_asserts
24+
from _pytest.compat import TYPE_CHECKING
2425
from _pytest.config import ExitCode
2526
from _pytest.pathlib import Path
2627
from _pytest.pytester import Testdir
2728

29+
if TYPE_CHECKING:
30+
from typing import List
31+
from typing import Set
32+
2833

2934
def setup_module(mod):
3035
mod._old_reprcompare = util._reprcompare
@@ -1250,14 +1255,14 @@ def spy_write_pyc(*args, **kwargs):
12501255

12511256
class TestEarlyRewriteBailout:
12521257
@pytest.fixture
1253-
def hook(self, pytestconfig, monkeypatch, testdir):
1258+
def hook(self, pytestconfig, monkeypatch, testdir) -> AssertionRewritingHook:
12541259
"""Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
12551260
if PathFinder.find_spec has been called.
12561261
"""
12571262
import importlib.machinery
12581263

1259-
self.find_spec_calls = []
1260-
self.initial_paths = set()
1264+
self.find_spec_calls = [] # type: List[str]
1265+
self.initial_paths = set() # type: Set[py.path.local]
12611266

12621267
class StubSession:
12631268
_initialpaths = self.initial_paths
@@ -1281,7 +1286,7 @@ def spy_find_spec(name, path):
12811286
testdir.syspathinsert()
12821287
return hook
12831288

1284-
def test_basic(self, testdir, hook):
1289+
def test_basic(self, testdir: "Testdir", hook: AssertionRewritingHook) -> None:
12851290
"""
12861291
Ensure we avoid calling PathFinder.find_spec when we know for sure a certain
12871292
module will not be rewritten to optimize assertion rewriting (#3918).

0 commit comments

Comments
 (0)