diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index a5ff9f3..f8cbd5b 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -5,14 +5,17 @@ import json import logging import re -from dataclasses import asdict, dataclass +import dataclasses as dc +from itertools import pairwise from functools import cache from pathlib import Path +from typing import Self, ClassVar import libcst as cst import libcst.matchers as cstm from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum +from . import __version__ logger = logging.getLogger(__name__) @@ -49,7 +52,7 @@ def _shared_leading_qualname(*qualnames): return ".".join(shared) -@dataclass(slots=True, frozen=True) +@dc.dataclass(slots=True, frozen=True) class KnownImport: """Import information associated with a single known type annotation. @@ -208,6 +211,96 @@ def __str__(self) -> str: return out +@dc.dataclass(slots=True, kw_only=True) +class PyNode: + _TYPE_KINDS: ClassVar[set[str]] = { + "builtin", + "class", + "type_alias", + "ann_assign", + "import_from", + "generic_type", + } + _KINDS: ClassVar[set[str]] = {"module"} | _TYPE_KINDS + + name: str + kind: str + loc: str | None = None + parent: Self | None = None + children: list[Self] = dc.field(default_factory=list) + + @property + def is_leaf(self): + return not self.children + + @property + def fullname(self): + names = [node.name for node in self.walk_parents()][::-1] + return ".".join(names + [self.name]) + + @property + def is_type(self): + return self.kind in self._TYPE_KINDS + + @property + def import_statement(self): + module = [] + qualname = [self.name] + for parent in self.walk_parents(): + if parent.kind == "module": + module.insert(0, parent.name) + else: + qualname.insert(0, parent.name) + + if module: + return f"from {'.'.join(module)} import {'.'.join(qualname)}" + else: + return None + + def add_child(self, child): + assert child.parent is None + child.parent = self + self.children.append(child) + + def _walk_tree(self, names=()): + names = names + (self.name,) + yield names, self + for child in self.children: + yield from child._walk_tree(names) + + def walk_tree(self): + yield from self._walk_tree() + + def walk_parents(self): + current = self.parent + while current is not None: + yield current + current = current.parent + + def serialize_tree(self): + raw = {field.name: getattr(self, field.name) for field in dc.fields(self)} + del raw["parent"] + raw["children"] = [child.serialize_tree() for child in self.children] + return raw + + @classmethod + def from_serialized_tree(cls, primitives): + self = cls(**primitives) + if self.parent: + self.parent = cls.from_serialized_tree(self.parent) + self.children = [cls.from_serialized_tree(child) for child in self.children] + return self + + def __repr__(self): + return f"{type(self).__name__}({self.name!r}, kind={self.kind!r})" + + def __post_init__(self): + unsupported_kind = {self.kind} - self._KINDS + if unsupported_kind: + msg = f"unsupported kind {unsupported_kind}, supported are {self._KINDS}" + raise ValueError(msg) + + def _is_type(value): """Check if value is a type. @@ -227,28 +320,33 @@ def _is_type(value): def _builtin_types(): - """Return known imports for all builtins (in the current runtime). + """Builtin types in the current runtime. Returns ------- - known_imports : dict[str, KnownImport] + types : dict[str, PyNode] """ - known_builtins = set(dir(builtins)) + builtins_names = set(dir(builtins)) - known_imports = {} - for name in known_builtins: + types = {} + for name in builtins_names: if name.startswith("_"): continue value = getattr(builtins, name) if not _is_type(value): continue - known_imports[name] = KnownImport(builtin_name=name) + types[name] = PyNode(name=name, kind="builtin") - return known_imports + return types def _runtime_types_in_module(module_name): module = importlib.import_module(module_name) + + modules = [PyNode(name=name, kind="module") for name in module_name.split(".")] + for parent, child in pairwise(modules): + parent.add_child(child) + types = {} for name in module.__all__: if name.startswith("_"): @@ -257,14 +355,14 @@ def _runtime_types_in_module(module_name): if not _is_type(value): continue - import_ = KnownImport(import_path=module_name, import_name=name) - types[name] = import_ - types[f"{module_name}.{name}"] = import_ + pynode = PyNode(name=name, kind="generic_type") + modules[-1].add_child(pynode) + types[pynode.fullname] = pynode return types -def common_known_types(): +def common_types_nicknames(): """Return known imports for commonly supported types. This includes builtin types, and types from the `typing` or @@ -272,23 +370,165 @@ def common_known_types(): Returns ------- - known_imports : dict[str, KnownImport] + types : list[PyNode] + type_nicknames : dict[str, str] Examples -------- >>> types = common_known_types() >>> types["str"] - + PyNode('str', kind='builtin') >>> types["Iterable"] - + PyNode('Iterable', kind='generic_type') + >>> types["Iterable"].fullname + 'collections.abc.Iterable' >>> types["collections.abc.Iterable"] - + PyNode('Iterable', kind='generic_type') + """ + pynodes = _builtin_types() + pynodes |= _runtime_types_in_module("typing") + collections_abc = _runtime_types_in_module("collections.abc") + pynodes |= collections_abc + + type_nicknames = {node.name: fullname for fullname, node in collections_abc.items()} + + return pynodes, type_nicknames + + +class PythonCollector(cst.CSTVisitor): + """Collect types from a given Python file. + + Examples + -------- + >>> types = PythonCollector.collect(__file__) + >>> types[f"{__name__}.TypeCollector"] + """ - known_imports = _builtin_types() - known_imports |= _runtime_types_in_module("typing") - # Overrides containers from typing - known_imports |= _runtime_types_in_module("collections.abc") - return known_imports + + METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) + + class ImportSerializer: + """Implements the `FuncSerializer` protocol to cache `TypeCollector.collect`.""" + + suffix = ".json" + encoding = "utf-8" + + def hash_args(self, path: Path) -> str: + """Compute a unique hash from the path passed to `TypeCollector.collect`.""" + key = pyfile_checksum(path, salt=__version__) + return key + + def serialize(self, pynode: PyNode) -> bytes: + """Serialize results from `TypeCollector.collect`.""" + primitives = pynode.serialize_tree() + raw = json.dumps(primitives, separators=(",", ":")).encode(self.encoding) + return raw + + def deserialize(self, raw: bytes) -> PyNode: + """Deserialize results from `TypeCollector.collect`.""" + primitives = json.loads(raw.decode(self.encoding)) + pynode = PyNode.from_serialized_tree(primitives) + return pynode + + @classmethod + def collect(cls, file_path): + """Collect importable type annotations in given file. + + Parameters + ---------- + file_path : Path + + Returns + ------- + module_tree : PyNode + """ + file_path = Path(file_path) + with file_path.open("r") as fo: + source = fo.read() + + tree = cst.parse_module(source) + meta_tree = cst.metadata.MetadataWrapper(tree) + collector = cls(file_path=file_path) + meta_tree.visit(collector) + + return collector._root_pynode + + def __init__(self, *, file_path): + """Initialize type collector. + + Parameters + ---------- + module_name : str + """ + full_module_name = module_name_from_path(file_path) + current_module, *parent_modules = full_module_name.split(".")[::-1] + + self._file_path = file_path + self._root_pynode = PyNode( + name=current_module, kind="module", loc=str(file_path) + ) + self._current_pynode = self._root_pynode + + for name in parent_modules: + # TODO set location for parent modules too + parent = PyNode(name=name, kind="module") + parent.add_child(self._root_pynode) + self._root_pynode = parent + + def _get_loc(self, node): + pos = self.get_metadata(cst.metadata.PositionProvider, node).start + loc = f"{self._file_path}:{pos.line}:{pos.column}" + return loc + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + pynode = PyNode(name=node.name.value, kind="class", loc=self._get_loc(node)) + self._current_pynode.add_child(pynode) + self._current_pynode = pynode + return True + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self._current_pynode = self._current_pynode.parent + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + return False + + def visit_TypeAlias(self, node: cst.TypeAlias) -> bool: + """Collect type alias with 3.12 syntax.""" + pynode = PyNode( + name=node.name.value, kind="type_alias", loc=self._get_loc(node) + ) + self._current_pynode.add_child(pynode) + return False + + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: + """Collect type alias annotated with `TypeAlias`.""" + is_type_alias = cstm.matches( + node, + cstm.AnnAssign( + annotation=cstm.Annotation(annotation=cstm.Name(value="TypeAlias")) + ), + ) + if is_type_alias and node.value is not None: + names = cstm.findall(node.target, cstm.Name()) + assert len(names) == 1 + pynode = PyNode( + name=names[0].value, kind="ann_assign", loc=self._get_loc(node) + ) + self._current_pynode.add_child(pynode) + return False + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + """Collect "from import" targets as usable types within each module.""" + for import_alias in node.names: + if cstm.matches(import_alias, cstm.ImportStar()): + continue + name = import_alias.evaluated_alias + if name is None: + name = import_alias.evaluated_name + assert isinstance(name, str) + + pynode = PyNode(name=name, kind="import_from", loc=self._get_loc(node)) + self._current_pynode.add_child(pynode) class TypeCollector(cst.CSTVisitor): @@ -406,7 +646,7 @@ class TypeMatcher: Attributes ---------- - types : dict[str, KnownImport] + types : dict[str, PyNode] type_prefixes : dict[str, KnownImport] type_nicknames : dict[str, str] successful_queries : int @@ -415,7 +655,7 @@ class TypeMatcher: Examples -------- - >>> from docstub._analysis import TypeMatcher, common_known_types + >>> from docstub._analysis import TypeMatcher >>> db = TypeMatcher() >>> db.match("Any") ('Any', ) @@ -435,12 +675,11 @@ def __init__( type_prefixes : dict[str, KnownImport] type_nicknames : dict[str, str] """ - self.types = types or common_known_types() + self.types = types or {} self.type_prefixes = type_prefixes or {} self.type_nicknames = type_nicknames or {} self.successful_queries = 0 self.unknown_qualnames = [] - self.current_module = None def match(self, search_name): @@ -453,11 +692,9 @@ def match(self, search_name): Returns ------- - type_name : str | None - type_origin : KnownImport | None + type : pynode | None """ - type_name = None - type_origin = None + pynode = None if search_name.startswith("~."): # Sphinx like matching with abbreviated name @@ -470,8 +707,7 @@ def match(self, search_name): } if len(matches) > 1: shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0] - type_origin = matches[shortest_key] - type_name = shortest_key + pynode = matches[shortest_key] logger.warning( "%r in %s matches multiple types %r, using %r", search_name, @@ -480,7 +716,7 @@ def match(self, search_name): shortest_key, ) elif len(matches) == 1: - type_name, type_origin = matches.popitem() + _, pynode = matches.popitem() else: search_name = search_name[2:] logger.debug( @@ -492,38 +728,25 @@ def match(self, search_name): # Replace alias search_name = self.type_nicknames.get(search_name, search_name) - if type_origin is None and self.current_module: + if pynode is None and self.current_module: # Try scope of current module module_name = module_name_from_path(self.current_module) try_qualname = f"{module_name}.{search_name}" - type_origin = self.types.get(try_qualname) - if type_origin: - type_name = search_name + pynode = self.types.get(try_qualname) - if type_origin is None and search_name in self.types: - type_name = search_name - type_origin = self.types[search_name] + if pynode is None and search_name in self.types: + pynode = self.types[search_name] - if type_origin is None: + if pynode is None: # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a') for partial_qualname in reversed(accumulate_qualname(search_name)): - type_origin = self.type_prefixes.get(partial_qualname) - if type_origin: - type_name = search_name + pynode = self.type_prefixes.get(partial_qualname) + if pynode: break - if ( - type_origin is not None - and type_name is not None - and type_name != type_origin.target - and not type_name.startswith(type_origin.target) - ): - # Ensure that the annotation matches the import target - type_name = type_name[type_name.find(type_origin.target) :] - - if type_name is not None: + if pynode is not None: self.successful_queries += 1 else: self.unknown_qualnames.append(search_name) - return type_name, type_origin + return pynode diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 3bcbf76..d1aef03 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -9,9 +9,9 @@ from ._analysis import ( KnownImport, - TypeCollector, + PythonCollector, TypeMatcher, - common_known_types, + common_types_nicknames, ) from ._cache import FileCache from ._config import Config @@ -93,20 +93,26 @@ def _collect_types(root_path, *, ignore=()): Returns ------- - types : dict[str, ~.KnownImport] + types : dict[str, ~.PyNode] """ - types = common_known_types() + types = {} collect_cached_types = FileCache( - func=TypeCollector.collect, - serializer=TypeCollector.ImportSerializer(), + func=PythonCollector.collect, + serializer=PythonCollector.ImportSerializer(), cache_dir=Path.cwd() / ".docstub_cache", name=f"{__version__}/collected_types", ) if root_path.is_dir(): for source_path in walk_python_package(root_path, ignore=ignore): logger.info("collecting types in %s", source_path) - types_in_source = collect_cached_types(source_path) + + module_tree = collect_cached_types(source_path) + types_in_source = { + ".".join(fullname): pynode + for fullname, pynode in module_tree.walk_tree() + if pynode.is_type + } types.update(types_in_source) return types @@ -228,7 +234,7 @@ def run(root_path, out_dir, config_paths, ignore, group_errors, allow_errors, ve config = _load_configuration(config_paths) config = config.merge(Config(ignore_files=list(ignore))) - types = common_known_types() + types, type_nicknames = common_types_nicknames() types |= _collect_types(root_path, ignore=config.ignore_files) types |= { type_name: KnownImport(import_path=module, import_name=type_name) @@ -244,9 +250,11 @@ def run(root_path, out_dir, config_paths, ignore, group_errors, allow_errors, ve for prefix, module in config.type_prefixes.items() } + type_nicknames |= config.type_nicknames + reporter = GroupedErrorReporter() if group_errors else ErrorReporter() matcher = TypeMatcher( - types=types, type_prefixes=type_prefixes, type_nicknames=config.type_nicknames + types=types, type_prefixes=type_prefixes, type_nicknames=type_nicknames ) stub_transformer = Py2StubTransformer(matcher=matcher, reporter=reporter) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 1591cd4..3ee20e4 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -524,13 +524,14 @@ def _match_import(self, qualname, *, meta): Possibly modified or normalized qualname. """ if self.matcher is not None: - annotation_name, known_import = self.matcher.match(qualname) + pynode = self.matcher.match(qualname) + annotation_name = pynode.fullname else: annotation_name = None - known_import = None + pynode = None - if known_import and known_import.has_import: - self._collected_imports.add(known_import) + if pynode and pynode.import_statement: + self._collected_imports.add(pynode.import_statement) if annotation_name: matched_qualname = annotation_name diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index bbd55bd..0ed2a42 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -106,7 +106,7 @@ def module_name_from_path(path): return name -def pyfile_checksum(path): +def pyfile_checksum(path, salt=""): """Compute a unique key for a Python file. The key takes into account the given `path`, the relative position if the @@ -124,7 +124,7 @@ def pyfile_checksum(path): absolute_path = str(path.resolve()).encode() with open(path, "rb") as fp: content = fp.read() - key = crc32(content + module_name + absolute_path) + key = crc32(content + module_name + absolute_path + salt.encode()) return key