diff --git a/sqlalchemy_bind_manager/_async_helpers.py b/sqlalchemy_bind_manager/_async_helpers.py new file mode 100644 index 0000000..98704c9 --- /dev/null +++ b/sqlalchemy_bind_manager/_async_helpers.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 Federico Busetti <729029+febus982@users.noreply.github.com> +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the "Software"), +# to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. +import asyncio +from typing import Coroutine + +# Reference: https://docs.astral.sh/ruff/rules/asyncio-dangling-task/ +_background_asyncio_tasks = set() + + +def run_async_from_sync(coro: Coroutine) -> None: + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + task = loop.create_task(coro) + # Add task to the set. This creates a strong reference. + _background_asyncio_tasks.add(task) + + # To prevent keeping references to finished tasks forever, + # make each task remove its own reference from the set after + # completion: + task.add_done_callback(_background_asyncio_tasks.discard) + else: + loop.run_until_complete(coro) + except RuntimeError: + asyncio.run(coro) diff --git a/sqlalchemy_bind_manager/_bind_manager.py b/sqlalchemy_bind_manager/_bind_manager.py index 0b9c80b..16113a0 100644 --- a/sqlalchemy_bind_manager/_bind_manager.py +++ b/sqlalchemy_bind_manager/_bind_manager.py @@ -32,6 +32,7 @@ from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm.decl_api import DeclarativeMeta, registry +from sqlalchemy_bind_manager._async_helpers import run_async_from_sync from sqlalchemy_bind_manager.exceptions import ( InvalidConfigError, NotInitializedBindError, @@ -87,6 +88,13 @@ def __init__( else: self.__init_bind(DEFAULT_BIND_NAME, config) + def __del__(self): + for bind in self.__binds.values(): + if isinstance(bind, SQLAlchemyAsyncBind): + run_async_from_sync(bind.engine.dispose()) + else: + bind.engine.dispose() + def __init_bind(self, name: str, config: SQLAlchemyConfig): if not isinstance(config, SQLAlchemyConfig): raise InvalidConfigError( diff --git a/sqlalchemy_bind_manager/_session_handler.py b/sqlalchemy_bind_manager/_session_handler.py index caba7ff..59dc82d 100644 --- a/sqlalchemy_bind_manager/_session_handler.py +++ b/sqlalchemy_bind_manager/_session_handler.py @@ -28,6 +28,7 @@ ) from sqlalchemy.orm import Session, scoped_session +from sqlalchemy_bind_manager._async_helpers import run_async_from_sync from sqlalchemy_bind_manager._bind_manager import ( SQLAlchemyAsyncBind, SQLAlchemyBind, @@ -73,10 +74,6 @@ def commit(self, session: Session) -> None: raise -# Reference: https://docs.astral.sh/ruff/rules/asyncio-dangling-task/ -_background_asyncio_tasks = set() - - class AsyncSessionHandler: scoped_session: async_scoped_session @@ -91,22 +88,7 @@ def __init__(self, bind: SQLAlchemyAsyncBind): def __del__(self): if not getattr(self, "scoped_session", None): return - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - task = loop.create_task(self.scoped_session.remove()) - # Add task to the set. This creates a strong reference. - _background_asyncio_tasks.add(task) - - # To prevent keeping references to finished tasks forever, - # make each task remove its own reference from the set after - # completion: - task.add_done_callback(_background_asyncio_tasks.discard) - else: - loop.run_until_complete(self.scoped_session.remove()) - except RuntimeError: - asyncio.run(self.scoped_session.remove()) + run_async_from_sync(self.scoped_session.remove()) @asynccontextmanager async def get_session(self, read_only: bool = False) -> AsyncIterator[AsyncSession]: diff --git a/tests/test_sqlalchemy_bind_manager.py b/tests/test_sqlalchemy_bind_manager.py index 98dc280..a53e014 100644 --- a/tests/test_sqlalchemy_bind_manager.py +++ b/tests/test_sqlalchemy_bind_manager.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from sqlalchemy import MetaData from sqlalchemy.ext.asyncio import AsyncSession @@ -71,3 +73,88 @@ def test_multiple_binds(multiple_config): assert async_bind is not None assert isinstance(sa_manager.get_mapper("async"), registry) assert isinstance(sa_manager.get_session("async"), AsyncSession) + + +async def test_engine_is_disposed_on_cleanup(multiple_config): + sa_manager = SQLAlchemyBindManager(multiple_config) + sync_engine = sa_manager.get_bind("default").engine + async_engine = sa_manager.get_bind("async").engine + + original_sync_dispose = sync_engine.dispose + original_async_dispose = async_engine.dispose + + with ( + patch.object( + sync_engine, + "dispose", + wraps=original_sync_dispose, + ) as mocked_dispose, + patch.object( + type(async_engine), + "dispose", + wraps=original_async_dispose, + ) as mocked_async_dispose, + ): + sa_manager = None + + mocked_dispose.assert_called_once() + mocked_async_dispose.assert_called() + + +def test_engine_is_disposed_on_cleanup_even_if_no_loop(multiple_config): + sa_manager = SQLAlchemyBindManager(multiple_config) + sync_engine = sa_manager.get_bind("default").engine + async_engine = sa_manager.get_bind("async").engine + + original_sync_dispose = sync_engine.dispose + original_async_dispose = async_engine.dispose + + with ( + patch.object( + sync_engine, + "dispose", + wraps=original_sync_dispose, + ) as mocked_dispose, + patch.object( + type(async_engine), + "dispose", + wraps=original_async_dispose, + ) as mocked_async_dispose, + ): + sa_manager = None + + mocked_dispose.assert_called_once() + mocked_async_dispose.assert_called() + + +def test_engine_is_disposed_on_cleanup_even_if_loop_search_errors_out( + multiple_config, +): + sa_manager = SQLAlchemyBindManager(multiple_config) + sync_engine = sa_manager.get_bind("default").engine + async_engine = sa_manager.get_bind("async").engine + + original_sync_dispose = sync_engine.dispose + original_async_dispose = async_engine.dispose + + with ( + patch.object( + sync_engine, + "dispose", + wraps=original_sync_dispose, + ) as mocked_dispose, + patch.object( + type(async_engine), + "dispose", + wraps=original_async_dispose, + ) as mocked_async_dispose, + patch( + "asyncio.get_event_loop", + side_effect=RuntimeError(), + ) as mocked_get_event_loop, + ): + sa_manager = None + + mocked_get_event_loop.assert_called_once() + mocked_dispose.assert_called_once() + mocked_async_dispose.assert_called()