Skip to content

Commit 5875713

Browse files
committed
Collect "from imports" as referencable names for docstrings
This makes it possible to use symbols made available via "from imports" within the same module. This lessens the need to declare types manually.
1 parent 2bd7076 commit 5875713

File tree

6 files changed

+70
-2
lines changed

6 files changed

+70
-2
lines changed

examples/example_pkg-stubs/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ __all__ = [
1010

1111
class CustomException(Exception):
1212
pass
13+
14+
class AnotherType:
15+
pass

examples/example_pkg-stubs/_basic.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import configparser
44
import logging
55
from collections.abc import Sequence
6+
from configparser import ConfigParser as Cfg
67
from typing import Any, Literal, Self, Union
78

89
from _typeshed import Incomplete
910

10-
from . import CustomException
11+
from . import AnotherType, CustomException
1112

1213
logger: Incomplete
1314

@@ -39,6 +40,7 @@ def func_use_from_elsewhere(
3940
a3: ExampleClass.NestedClass,
4041
a4: ExampleClass.NestedClass,
4142
) -> tuple[CustomException, ExampleClass.NestedClass]: ...
43+
def func_use_from_import(a1: AnotherType, a2: Cfg) -> None: ...
4244

4345
class ExampleClass:
4446

examples/example_pkg/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@
1111

1212
class CustomException(Exception):
1313
pass
14+
15+
16+
class AnotherType:
17+
pass

examples/example_pkg/_basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
# Existing imports are preserved
77
import logging
8+
from configparser import ConfigParser as Cfg # noqa: F401
89
from typing import Literal
910

11+
from . import AnotherType # noqa: F401
12+
1013
# Assign-statements are preserved
1114
logger = logging.getLogger(__name__) # Inline comments are stripped
1215

@@ -88,6 +91,16 @@ def func_use_from_elsewhere(a1, a2, a3, a4):
8891
"""
8992

9093

94+
def func_use_from_import(a1, a2):
95+
"""Check using symbols made available in this module with from imports.
96+
97+
Parameters
98+
----------
99+
a1 : AnotherType
100+
a2 : Cfg
101+
"""
102+
103+
91104
class ExampleClass:
92105
"""Dummy.
93106

src/docstub/_analysis.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,18 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
388388
self._collect_type_annotation(stack)
389389
return False
390390

391+
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
392+
"""Collect "from import" targets as usable types within each module."""
393+
for import_alias in node.names:
394+
if cstm.matches(import_alias, cstm.ImportStar()):
395+
continue
396+
name = import_alias.evaluated_alias
397+
if name is None:
398+
name = import_alias.evaluated_name
399+
assert isinstance(name, str)
400+
stack = [*self._stack, name]
401+
self._collect_type_annotation(stack)
402+
391403
def _collect_type_annotation(self, stack):
392404
"""Collect an importable type annotation.
393405

tests/test_analysis.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def _module_factory(src, module_name):
4949

5050

5151
class Test_TypeCollector:
52-
5352
def test_classes(self, module_factory):
5453
module_path = module_factory(
5554
src=dedent(
@@ -100,6 +99,41 @@ def test_ignores_assigns(self, module_factory, src):
10099
imports = TypeCollector.collect(file=module_path)
101100
assert len(imports) == 0
102101

102+
@pytest.mark.parametrize(
103+
"src",
104+
[
105+
"from calendar import Aug",
106+
"from . import Aug",
107+
"from calendar import August as Aug",
108+
"from . import Agust as Aug",
109+
],
110+
)
111+
def test_from_import(self, module_factory, src):
112+
module_path = module_factory(src=src, module_name="sub.module")
113+
imports = TypeCollector.collect(file=module_path)
114+
assert len(imports) == 1
115+
assert imports == {
116+
"sub.module.Aug": KnownImport(import_path="sub.module", import_name="Aug")
117+
}
118+
119+
@pytest.mark.parametrize(
120+
"src",
121+
[
122+
"from calendar import Aug, Dec",
123+
"from . import Aug, Dec",
124+
"from calendar import August as Aug, December as Dec",
125+
"from . import August as Aug, December as Dec",
126+
],
127+
)
128+
def test_from_import_multiple(self, module_factory, src):
129+
module_path = module_factory(src=src, module_name="sub.module")
130+
imports = TypeCollector.collect(file=module_path)
131+
assert len(imports) == 2
132+
assert imports == {
133+
"sub.module.Aug": KnownImport(import_path="sub.module", import_name="Aug"),
134+
"sub.module.Dec": KnownImport(import_path="sub.module", import_name="Dec"),
135+
}
136+
103137

104138
class Test_TypeMatcher:
105139
type_prefixes = { # noqa: RUF012

0 commit comments

Comments
 (0)