Skip to content

Commit a45912e

Browse files
committed
Collect "from imports" as referencable names for docstrings
This makes it possible to use symbols made available via "from imports" within the same module. This lessens the need to declare types manually.
1 parent 2bd7076 commit a45912e

File tree

6 files changed

+69
-1
lines changed

6 files changed

+69
-1
lines changed

examples/example_pkg-stubs/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ __all__ = [
1010

1111
class CustomException(Exception):
1212
pass
13+
14+
class AnotherType:
15+
pass

examples/example_pkg-stubs/_basic.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ from . import CustomException
1111

1212
logger: Incomplete
1313

14+
from configparser import ConfigParser
15+
16+
from . import AnotherType
17+
1418
__all__ = [
1519
"func_empty",
1620
"ExampleClass",
@@ -39,6 +43,7 @@ def func_use_from_elsewhere(
3943
a3: ExampleClass.NestedClass,
4044
a4: ExampleClass.NestedClass,
4145
) -> tuple[CustomException, ExampleClass.NestedClass]: ...
46+
def func_use_from_import(a1: AnotherType, a2: ConfigParser) -> None: ...
4247

4348
class ExampleClass:
4449

examples/example_pkg/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@
1111

1212
class CustomException(Exception):
1313
pass
14+
15+
16+
class AnotherType:
17+
pass

examples/example_pkg/_basic.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ def func_use_from_elsewhere(a1, a2, a3, a4):
8888
"""
8989

9090

91+
def func_use_from_import(a1, a2):
92+
"""Check using symbols made available in this module with from imports.
93+
94+
Parameters
95+
----------
96+
a1 : AnotherType
97+
a2 : Cfg
98+
"""
99+
100+
91101
class ExampleClass:
92102
"""Dummy.
93103

src/docstub/_analysis.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,18 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
388388
self._collect_type_annotation(stack)
389389
return False
390390

391+
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
392+
"""Collect "from import" targets as usable types within each module."""
393+
for import_alias in node.names:
394+
if cstm.matches(import_alias, cstm.ImportStar()):
395+
continue
396+
name = import_alias.evaluated_alias
397+
if name is None:
398+
name = import_alias.evaluated_name
399+
assert isinstance(name, str)
400+
stack = [*self._stack, name]
401+
self._collect_type_annotation(stack)
402+
391403
def _collect_type_annotation(self, stack):
392404
"""Collect an importable type annotation.
393405

tests/test_analysis.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def _module_factory(src, module_name):
4949

5050

5151
class Test_TypeCollector:
52-
5352
def test_classes(self, module_factory):
5453
module_path = module_factory(
5554
src=dedent(
@@ -100,6 +99,41 @@ def test_ignores_assigns(self, module_factory, src):
10099
imports = TypeCollector.collect(file=module_path)
101100
assert len(imports) == 0
102101

102+
@pytest.mark.parametrize(
103+
"src",
104+
[
105+
"from calendar import Aug",
106+
"from . import Aug",
107+
"from calendar import August as Aug",
108+
"from . import Agust as Aug",
109+
],
110+
)
111+
def test_from_import(self, module_factory, src):
112+
module_path = module_factory(src=src, module_name="sub.module")
113+
imports = TypeCollector.collect(file=module_path)
114+
assert len(imports) == 1
115+
assert imports == {
116+
"sub.module.Aug": KnownImport(import_path="sub.module", import_name="Aug")
117+
}
118+
119+
@pytest.mark.parametrize(
120+
"src",
121+
[
122+
"from calendar import Aug, Dec",
123+
"from . import Aug, Dec",
124+
"from calendar import August as Aug, December as Dec",
125+
"from . import August as Aug, December as Dec",
126+
],
127+
)
128+
def test_from_import_multiple(self, module_factory, src):
129+
module_path = module_factory(src=src, module_name="sub.module")
130+
imports = TypeCollector.collect(file=module_path)
131+
assert len(imports) == 2
132+
assert imports == {
133+
"sub.module.Aug": KnownImport(import_path="sub.module", import_name="Aug"),
134+
"sub.module.Dec": KnownImport(import_path="sub.module", import_name="Dec"),
135+
}
136+
103137

104138
class Test_TypeMatcher:
105139
type_prefixes = { # noqa: RUF012

0 commit comments

Comments
 (0)