diff --git a/invoke/loader.py b/invoke/loader.py index 801d16333..d6226f035 100644 --- a/invoke/loader.py +++ b/invoke/loader.py @@ -105,6 +105,8 @@ class FilesystemLoader(Loader): .. versionadded:: 1.0 """ + POSSIBLE_SUFFIXES = (".py", ".pyc", ".so", ".dll", ".dylib") + # TODO: could introduce config obj here for transmission to Collection # TODO: otherwise Loader has to know about specific bits to transmit, such # as auto-dashes, and has to grow one of those for every bit Collection @@ -120,33 +122,44 @@ def start(self) -> str: # Lazily determine default CWD if configured value is falsey return self._start or os.getcwd() + @classmethod + def _is_package(cls, path: str) -> bool: + for suffix in cls.POSSIBLE_SUFFIXES: + if os.path.exists(os.path.join(path, "__init__" + suffix)): + return True + return False + def find(self, name: str) -> Optional[ModuleSpec]: debug("FilesystemLoader find starting at {!r}".format(self.start)) spec = None - module = "{}.py".format(name) + modules = [ + "{}{}".format(name, suffix) for suffix in self.POSSIBLE_SUFFIXES + ] paths = self.start.split(os.sep) try: # walk the path upwards to check for dynamic import for x in reversed(range(len(paths) + 1)): path = os.sep.join(paths[0:x]) - if module in os.listdir(path): - spec = spec_from_file_location( - name, os.path.join(path, module) - ) - break - elif name in os.listdir(path) and os.path.exists( - os.path.join(path, name, "__init__.py") - ): - basepath = os.path.join(path, name) - spec = spec_from_file_location( - name, - os.path.join(basepath, "__init__.py"), - submodule_search_locations=[basepath], - ) - break - if spec: - debug("Found module: {!r}".format(spec)) - return spec + possible_modules = os.listdir(path) + for module in modules: + if module in possible_modules: + spec = spec_from_file_location( + name, os.path.join(path, module) + ) + break + elif name in possible_modules and self._is_package( + os.path.join(path, name) + ): + basepath = os.path.join(path, name) + spec = spec_from_file_location( + name, + os.path.join(basepath, "__init__.py"), + submodule_search_locations=[basepath], + ) + break + if spec: + debug("Found module: {!r}".format(spec)) + return spec except (FileNotFoundError, ModuleNotFoundError): msg = "ImportError loading {!r}, raising CollectionNotFound" debug(msg.format(name))