Skip to content

Commit 79c572c

Browse files
Some logic improvements + attempt to revert on failed import
1 parent 90a82df commit 79c572c

File tree

2 files changed

+60
-35
lines changed

2 files changed

+60
-35
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- [BasicLocaliser.set_variants][tanjun.dependencies.BasicLocaliser.set_variants]
1515
now tries to normalise the locale keys of variants passed as keyword arguments
1616
to match the [hikari.Locale][hikari.locales.Locale] values.
17-
- The hot-reloader no-longer keeps a module's old components loaded after
18-
a top-level error (e.g. syntax error) is raised while re-importing the module.
17+
- The hot-reloader no-longer reverts to a module's old components loaded after
18+
the new module's loaders fail or is found to have no loaders.
1919

2020
### Fixed
2121
- Module unloaders will now be called with the correct old module's state

tanjun/clients.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2627,12 +2627,8 @@ def load_modules(self, *modules: typing.Union[str, pathlib.Path]) -> Self:
26272627
with _WrapLoadError(errors.FailedModuleImport, module_path):
26282628
module = load_module()
26292629

2630-
try:
2630+
with _EXPECT_ITER_END:
26312631
generator.send(module)
2632-
except StopIteration:
2633-
pass
2634-
else:
2635-
raise RuntimeError("Generator didn't finish")
26362632

26372633
return self
26382634

@@ -2648,12 +2644,8 @@ async def load_modules_async(self, *modules: typing.Union[str, pathlib.Path]) ->
26482644
with _WrapLoadError(errors.FailedModuleImport, module_path):
26492645
module = await loop.run_in_executor(None, load_module)
26502646

2651-
try:
2647+
with _EXPECT_ITER_END:
26522648
generator.send(module)
2653-
except StopIteration:
2654-
pass
2655-
else:
2656-
raise RuntimeError("Generator didn't finish")
26572649

26582650
def unload_modules(self, *modules: typing.Union[str, pathlib.Path]) -> Self:
26592651
# <<inherited docstring from tanjun.ab.Client>>.
@@ -2679,7 +2671,7 @@ def unload_modules(self, *modules: typing.Union[str, pathlib.Path]) -> Self:
26792671

26802672
def _reload_module(
26812673
self, module_path: typing.Union[str, pathlib.Path], /
2682-
) -> collections.Generator[collections.Callable[[], types.ModuleType], types.ModuleType, None]:
2674+
) -> collections.Generator[collections.Callable[[], types.ModuleType], types.ModuleType | None, None]:
26832675
if isinstance(module_path, str):
26842676
old_module = self._modules.get(module_path)
26852677
load_module: typing.Optional[_ReloadModule] = None
@@ -2697,25 +2689,33 @@ def _reload_module(
26972689
_LOGGER.info("Reloading %s", module_path)
26982690

26992691
old_loaders = _get_loaders(old_module, module_path)
2700-
modules_dict.pop(module_path)
27012692
with _WrapLoadError(errors.FailedModuleUnload, module_path):
2702-
# This will never raise MissingLoaders as we assert this earlier
27032693
self._call_unloaders(module_path, old_loaders)
27042694

27052695
module = yield load_module
27062696

2707-
loaders = _get_loaders(module, module_path)
2697+
if not module:
2698+
# Indicates the newer version of this module couldn't be imported
2699+
try:
2700+
self._call_loaders(module_path, old_loaders)
2701+
2702+
except Exception as exc:
2703+
# TODO: exc here is annoyingly already chained with another
2704+
# error which would've already been logged
2705+
_LOGGER.debug("Failed to revert %s", module_path, exc_info=exc)
2706+
modules_dict.pop(module_path)
2707+
2708+
return
27082709

2709-
# We assert that the new module has loaders early to avoid unnecessarily
2710-
# unloading then rolling back when we know it's going to fail to load.
2711-
if not any(loader.has_load for loader in loaders):
2712-
raise errors.ModuleMissingLoaders(f"Didn't find any loaders in new {module_path}", module_path)
2710+
loaders = _get_loaders(module, module_path)
27132711

27142712
try:
2715-
# This will never raise MissingLoaders as we assert this earlier
27162713
self._call_loaders(module_path, loaders)
2714+
except errors.ModuleMissingLoaders:
2715+
modules_dict.pop(module_path)
2716+
raise
27172717
except Exception as exc:
2718-
self._call_loaders(module_path, old_loaders)
2718+
modules_dict.pop(module_path)
27192719
raise errors.FailedModuleLoad(module_path) from exc
27202720
else:
27212721
modules_dict[module_path] = module
@@ -2728,15 +2728,18 @@ def reload_modules(self, *modules: typing.Union[str, pathlib.Path]) -> Self:
27282728

27292729
generator = self._reload_module(module_path)
27302730
load_module = next(generator)
2731-
with _WrapLoadError(errors.FailedModuleLoad, module_path):
2732-
module = load_module()
27332731

27342732
try:
2733+
module = load_module()
2734+
2735+
except Exception as exc:
2736+
with _EXPECT_ITER_END:
2737+
generator.send(None)
2738+
2739+
raise errors.FailedModuleLoad(module_path) from exc
2740+
2741+
with _EXPECT_ITER_END:
27352742
generator.send(module)
2736-
except StopIteration:
2737-
pass
2738-
else:
2739-
raise RuntimeError("Generator didn't finish")
27402743

27412744
return self
27422745

@@ -2749,17 +2752,18 @@ async def reload_modules_async(self, *modules: typing.Union[str, pathlib.Path])
27492752

27502753
generator = self._reload_module(module_path)
27512754
load_module = next(generator)
2752-
with _WrapLoadError(errors.FailedModuleLoad, module_path):
2753-
module = await loop.run_in_executor(None, load_module)
27542755

27552756
try:
2756-
generator.send(module)
2757+
module = await loop.run_in_executor(None, load_module)
27572758

2758-
except StopIteration:
2759-
pass
2759+
except Exception as exc:
2760+
with _EXPECT_ITER_END:
2761+
generator.send(None)
27602762

2761-
else:
2762-
raise RuntimeError("Generator didn't finish")
2763+
raise errors.FailedModuleLoad(module_path) from exc
2764+
2765+
with _EXPECT_ITER_END:
2766+
generator.send(module)
27632767

27642768
def set_type_dependency(self, type_: type[_T], value: _T, /) -> Self:
27652769
# <<inherited docstring from tanjun.abc.Client>>.
@@ -3185,6 +3189,27 @@ def __exit__(
31853189
raise self._error(*self._args, **self._kwargs) from exc
31863190

31873191

3192+
class _ExpectIterEnd:
3193+
__slots__ = ()
3194+
3195+
def __enter__(self) -> None:
3196+
pass
3197+
3198+
def __exit__(
3199+
self,
3200+
exc_type: typing.Optional[type[BaseException]],
3201+
exc: typing.Optional[BaseException],
3202+
exc_tb: typing.Optional[types.TracebackType],
3203+
) -> bool:
3204+
if exc_type is None:
3205+
raise RuntimeError("Generator didn't finish") from None
3206+
3207+
return exc_type is StopIteration
3208+
3209+
3210+
_EXPECT_ITER_END = _ExpectIterEnd()
3211+
3212+
31883213
def _try_deregister_listener(
31893214
interaction_server: hikari.api.InteractionServer,
31903215
interaction_type: typing.Any,

0 commit comments

Comments
 (0)