diff --git a/docs/user_guide.md b/docs/user_guide.md index ac3a5da..1f46e37 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -111,8 +111,9 @@ To translate a type from a docstring into a valid type annotation, docstub needs Out of the box, docstub will know about builtin types such as `int` or `bool` that don't need an import, and types in `typing`, `collections.abc` from Python's standard library. It will source these from the Python environment it is installed in. In addition to that, docstub will collect all types in the package directory you are running it on. +This also includes imported types, which you can then use within the scope of the module that imports them. -However, if you want to use types from third-party libraries you can tell docstub about them in a configuration file. +However, you can also tell docstub directly about external types in a configuration file. Docstub will look for a `pyproject.toml` or `docstub.toml` in the current working directory. Or, you can point docstub at TOML file(s) explicitly using the `--config` option. In these configuration file(s) you can declare external types directly with @@ -134,8 +135,9 @@ ski = "skimage" which will enable any type that is prefixed with `ski.` or `sklearn.tree.`, e.g. `ski.transform.AffineTransform` or `sklearn.tree.DecisionTreeClassifier`. -In both of these cases, docstub doesn't check that these types actually exist. -Testing the generated stubs with a type checker is recommended. +> [!IMPORTANT] +> Docstub doesn't check that types actually exist or if a symbol is a valid type. +> We always recommend validating the generated stubs with a full type checker! > [!TIP] > Docstub currently collects types statically. diff --git a/examples/example_pkg-stubs/__init__.pyi b/examples/example_pkg-stubs/__init__.pyi index 5a5bef1..c39e273 100644 --- a/examples/example_pkg-stubs/__init__.pyi +++ b/examples/example_pkg-stubs/__init__.pyi @@ -10,3 +10,6 @@ __all__ = [ class CustomException(Exception): pass + +class AnotherType: + pass diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 7cdc627..ac83e75 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -3,11 +3,12 @@ import configparser import logging from collections.abc import Sequence +from configparser import ConfigParser as Cfg from typing import Any, Literal, Self, Union from _typeshed import Incomplete -from . import CustomException +from . import AnotherType, CustomException logger: Incomplete @@ -39,6 +40,7 @@ def func_use_from_elsewhere( a3: ExampleClass.NestedClass, a4: ExampleClass.NestedClass, ) -> tuple[CustomException, ExampleClass.NestedClass]: ... +def func_use_from_import(a1: AnotherType, a2: Cfg) -> None: ... class ExampleClass: diff --git a/examples/example_pkg/__init__.py b/examples/example_pkg/__init__.py index ac61e3d..f32d938 100644 --- a/examples/example_pkg/__init__.py +++ b/examples/example_pkg/__init__.py @@ -11,3 +11,7 @@ class CustomException(Exception): pass + + +class AnotherType: + pass diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 4f12dd0..a25a49a 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -5,8 +5,11 @@ # Existing imports are preserved import logging +from configparser import ConfigParser as Cfg # noqa: F401 from typing import Literal +from . import AnotherType # noqa: F401 + # Assign-statements are preserved logger = logging.getLogger(__name__) # Inline comments are stripped @@ -88,6 +91,16 @@ def func_use_from_elsewhere(a1, a2, a3, a4): """ +def func_use_from_import(a1, a2): + """Check using symbols made available in this module with from imports. + + Parameters + ---------- + a1 : AnotherType + a2 : Cfg + """ + + class ExampleClass: """Dummy. diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index a5ff9f3..92bcff6 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -388,6 +388,18 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: self._collect_type_annotation(stack) return False + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + """Collect "from import" targets as usable types within each module.""" + for import_alias in node.names: + if cstm.matches(import_alias, cstm.ImportStar()): + continue + name = import_alias.evaluated_alias + if name is None: + name = import_alias.evaluated_name + assert isinstance(name, str) + stack = [*self._stack, name] + self._collect_type_annotation(stack) + def _collect_type_annotation(self, stack): """Collect an importable type annotation. diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 3188495..53a1c5f 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -49,7 +49,6 @@ def _module_factory(src, module_name): class Test_TypeCollector: - def test_classes(self, module_factory): module_path = module_factory( src=dedent( @@ -100,6 +99,41 @@ def test_ignores_assigns(self, module_factory, src): imports = TypeCollector.collect(file=module_path) assert len(imports) == 0 + @pytest.mark.parametrize( + "src", + [ + "from calendar import Aug", + "from . import Aug", + "from calendar import August as Aug", + "from . import Agust as Aug", + ], + ) + def test_from_import(self, module_factory, src): + module_path = module_factory(src=src, module_name="sub.module") + imports = TypeCollector.collect(file=module_path) + assert len(imports) == 1 + assert imports == { + "sub.module.Aug": KnownImport(import_path="sub.module", import_name="Aug") + } + + @pytest.mark.parametrize( + "src", + [ + "from calendar import Aug, Dec", + "from . import Aug, Dec", + "from calendar import August as Aug, December as Dec", + "from . import August as Aug, December as Dec", + ], + ) + def test_from_import_multiple(self, module_factory, src): + module_path = module_factory(src=src, module_name="sub.module") + imports = TypeCollector.collect(file=module_path) + assert len(imports) == 2 + assert imports == { + "sub.module.Aug": KnownImport(import_path="sub.module", import_name="Aug"), + "sub.module.Dec": KnownImport(import_path="sub.module", import_name="Dec"), + } + class Test_TypeMatcher: type_prefixes = { # noqa: RUF012