diff --git a/packages/celery-library/src/celery_library/backends/_redis.py b/packages/celery-library/src/celery_library/backends/redis.py similarity index 79% rename from packages/celery-library/src/celery_library/backends/_redis.py rename to packages/celery-library/src/celery_library/backends/redis.py index 2bd80ed76d8..9878cd5e063 100644 --- a/packages/celery-library/src/celery_library/backends/_redis.py +++ b/packages/celery-library/src/celery_library/backends/redis.py @@ -12,7 +12,7 @@ TaskMetadata, TaskUUID, ) -from servicelib.redis import RedisClientSDK +from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types from ..utils import build_task_id_prefix @@ -41,18 +41,24 @@ async def create_task( expiry: timedelta, ) -> None: task_key = _build_key(task_id) - await self._redis_client_sdk.redis.hset( - name=task_key, - key=_CELERY_TASK_METADATA_KEY, - value=task_metadata.model_dump_json(), - ) # type: ignore + await handle_redis_returns_union_types( + self._redis_client_sdk.redis.hset( + name=task_key, + key=_CELERY_TASK_METADATA_KEY, + value=task_metadata.model_dump_json(), + ) + ) await self._redis_client_sdk.redis.expire( task_key, expiry, ) async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: - raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore + raw_result = await handle_redis_returns_union_types( + self._redis_client_sdk.redis.hget( + _build_key(task_id), _CELERY_TASK_METADATA_KEY + ) + ) if not raw_result: return None @@ -65,7 +71,11 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: return None async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: - raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore + raw_result = await handle_redis_returns_union_types( + self._redis_client_sdk.redis.hget( + _build_key(task_id), _CELERY_TASK_PROGRESS_KEY + ) + ) if not raw_result: return None @@ -121,11 +131,13 @@ async def remove_task(self, task_id: TaskID) -> None: await self._redis_client_sdk.redis.delete(_build_key(task_id)) async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None: - await self._redis_client_sdk.redis.hset( - name=_build_key(task_id), - key=_CELERY_TASK_PROGRESS_KEY, - value=report.model_dump_json(), - ) # type: ignore + await handle_redis_returns_union_types( + self._redis_client_sdk.redis.hset( + name=_build_key(task_id), + key=_CELERY_TASK_PROGRESS_KEY, + value=report.model_dump_json(), + ) + ) async def task_exists(self, task_id: TaskID) -> bool: n = await self._redis_client_sdk.redis.exists(_build_key(task_id)) diff --git a/packages/celery-library/src/celery_library/common.py b/packages/celery-library/src/celery_library/common.py index d50e75597c6..ef45ef4c8b9 100644 --- a/packages/celery-library/src/celery_library/common.py +++ b/packages/celery-library/src/celery_library/common.py @@ -2,13 +2,9 @@ from typing import Any from celery import Celery # type: ignore[import-untyped] -from servicelib.redis import RedisClientSDK from settings_library.celery import CelerySettings from settings_library.redis import RedisDatabase -from .backends._redis import RedisTaskInfoStore -from .task_manager import CeleryTaskManager - def _celery_configure(celery_settings: CelerySettings) -> dict[str, Any]: base_config = { @@ -36,22 +32,3 @@ def create_app(settings: CelerySettings) -> Celery: ), **_celery_configure(settings), ) - - -async def create_task_manager( - app: Celery, settings: CelerySettings -) -> CeleryTaskManager: - redis_client_sdk = RedisClientSDK( - settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn( - RedisDatabase.CELERY_TASKS - ), - client_name="celery_tasks", - ) - await redis_client_sdk.setup() - # GCR please address https://github.com/ITISFoundation/osparc-simcore/issues/8159 - - return CeleryTaskManager( - app, - settings, - RedisTaskInfoStore(redis_client_sdk), - ) diff --git a/packages/celery-library/src/celery_library/signals.py b/packages/celery-library/src/celery_library/signals.py index dd5bf047e65..02f1a56f0ec 100644 --- a/packages/celery-library/src/celery_library/signals.py +++ b/packages/celery-library/src/celery_library/signals.py @@ -6,18 +6,15 @@ from celery.worker.worker import WorkController # type: ignore[import-untyped] from servicelib.celery.app_server import BaseAppServer from servicelib.logging_utils import log_context -from settings_library.celery import CelerySettings -from .common import create_task_manager from .utils import get_app_server, set_app_server _logger = logging.getLogger(__name__) def on_worker_init( - app_server: BaseAppServer, - celery_settings: CelerySettings, sender: WorkController, + app_server: BaseAppServer, **_kwargs, ) -> None: startup_complete_event = threading.Event() @@ -26,21 +23,14 @@ def _init(startup_complete_event: threading.Event) -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - async def _setup_task_manager(): - assert sender.app # nosec - assert isinstance(sender.app, Celery) # nosec - - app_server.task_manager = await create_task_manager( - sender.app, - celery_settings, - ) + assert sender.app # nosec + assert isinstance(sender.app, Celery) # nosec - set_app_server(sender.app, app_server) + set_app_server(sender.app, app_server) app_server.event_loop = loop - loop.run_until_complete(_setup_task_manager()) - loop.run_until_complete(app_server.lifespan(startup_complete_event)) + loop.run_until_complete(app_server.run_until_shutdown(startup_complete_event)) thread = threading.Thread( group=None, diff --git a/packages/celery-library/tests/conftest.py b/packages/celery-library/tests/conftest.py index e9fc599136a..8e8bc976820 100644 --- a/packages/celery-library/tests/conftest.py +++ b/packages/celery-library/tests/conftest.py @@ -2,25 +2,30 @@ # pylint: disable=unused-argument import datetime +import logging import threading from collections.abc import AsyncIterator, Callable -from functools import partial from typing import Any import pytest from celery import Celery # type: ignore[import-untyped] -from celery.contrib.testing.worker import TestWorkController, start_worker +from celery.contrib.testing.worker import ( + TestWorkController, + start_worker, +) from celery.signals import worker_init, worker_shutdown from celery.worker.worker import WorkController -from celery_library.common import create_task_manager +from celery_library.backends.redis import RedisTaskInfoStore from celery_library.signals import on_worker_init, on_worker_shutdown from celery_library.task_manager import CeleryTaskManager from celery_library.types import register_celery_types from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.celery.app_server import BaseAppServer +from servicelib.celery.task_manager import TaskManager +from servicelib.redis import RedisClientSDK from settings_library.celery import CelerySettings -from settings_library.redis import RedisSettings +from settings_library.redis import RedisDatabase, RedisSettings pytest_plugins = [ "pytest_simcore.docker_compose", @@ -33,11 +38,42 @@ ] +_logger = logging.getLogger(__name__) + + class FakeAppServer(BaseAppServer): - async def lifespan(self, startup_completed_event: threading.Event) -> None: + def __init__(self, app: Celery, settings: CelerySettings): + super().__init__(app) + self._settings = settings + self._task_manager: CeleryTaskManager | None = None + + @property + def task_manager(self) -> TaskManager: + assert self._task_manager, "Task manager is not initialized" + return self._task_manager + + async def run_until_shutdown( + self, startup_completed_event: threading.Event + ) -> None: + redis_client_sdk = RedisClientSDK( + self._settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn( + RedisDatabase.CELERY_TASKS + ), + client_name="pytest_celery_tasks", + ) + await redis_client_sdk.setup() + + self._task_manager = CeleryTaskManager( + self._app, + self._settings, + RedisTaskInfoStore(redis_client_sdk), + ) + startup_completed_event.set() await self.shutdown_event.wait() # wait for shutdown + await redis_client_sdk.shutdown() + @pytest.fixture def register_celery_tasks() -> Callable[[Celery], None]: @@ -51,17 +87,12 @@ def _(celery_app: Celery) -> None: ... @pytest.fixture def app_environment( monkeypatch: pytest.MonkeyPatch, - redis_service: RedisSettings, env_devel_dict: EnvVarsDict, ) -> EnvVarsDict: return setenvs_from_dict( monkeypatch, { **env_devel_dict, - "REDIS_SECURE": redis_service.REDIS_SECURE, - "REDIS_HOST": redis_service.REDIS_HOST, - "REDIS_PORT": f"{redis_service.REDIS_PORT}", - "REDIS_PASSWORD": redis_service.REDIS_PASSWORD.get_secret_value(), }, ) @@ -74,8 +105,8 @@ def celery_settings( @pytest.fixture -def app_server() -> BaseAppServer: - return FakeAppServer(app=None) +def app_server(celery_app: Celery, celery_settings: CelerySettings) -> BaseAppServer: + return FakeAppServer(app=celery_app, settings=celery_settings) @pytest.fixture(scope="session") @@ -98,11 +129,10 @@ def celery_config() -> dict[str, Any]: async def with_celery_worker( celery_app: Celery, app_server: BaseAppServer, - celery_settings: CelerySettings, register_celery_tasks: Callable[[Celery], None], ) -> AsyncIterator[TestWorkController]: def _on_worker_init_wrapper(sender: WorkController, **_kwargs): - return partial(on_worker_init, app_server, celery_settings)(sender, **_kwargs) + return on_worker_init(sender, app_server, **_kwargs) worker_init.connect(_on_worker_init_wrapper) worker_shutdown.connect(on_worker_shutdown) @@ -111,8 +141,8 @@ def _on_worker_init_wrapper(sender: WorkController, **_kwargs): with start_worker( celery_app, - pool="threads", concurrency=1, + pool="threads", loglevel="info", perform_ping_check=False, queues="default", @@ -120,15 +150,30 @@ def _on_worker_init_wrapper(sender: WorkController, **_kwargs): yield worker +@pytest.fixture +async def mock_celery_app(celery_config: dict[str, Any]) -> Celery: + return Celery(**celery_config) + + @pytest.fixture async def celery_task_manager( - celery_app: Celery, + mock_celery_app: Celery, celery_settings: CelerySettings, - with_celery_worker: TestWorkController, -) -> CeleryTaskManager: + use_in_memory_redis: RedisSettings, +) -> AsyncIterator[CeleryTaskManager]: register_celery_types() - return await create_task_manager( - celery_app, - celery_settings, - ) + try: + redis_client_sdk = RedisClientSDK( + use_in_memory_redis.build_redis_dsn(RedisDatabase.CELERY_TASKS), + client_name="pytest_celery_tasks", + ) + await redis_client_sdk.setup() + + yield CeleryTaskManager( + mock_celery_app, + celery_settings, + RedisTaskInfoStore(redis_client_sdk), + ) + finally: + await redis_client_sdk.shutdown() diff --git a/packages/celery-library/tests/unit/test_tasks.py b/packages/celery-library/tests/unit/test_tasks.py index d3e768c9ff1..a9a5ad5c6a1 100644 --- a/packages/celery-library/tests/unit/test_tasks.py +++ b/packages/celery-library/tests/unit/test_tasks.py @@ -12,6 +12,7 @@ import pytest from celery import Celery, Task # pylint: disable=no-name-in-module +from celery.worker.worker import WorkController # pylint: disable=no-name-in-module from celery_library.errors import TaskNotFoundError, TransferrableCeleryError from celery_library.task import register_task from celery_library.task_manager import CeleryTaskManager @@ -33,6 +34,10 @@ pytest_simcore_ops_services_selection = [] +class MyTaskFilter(TaskFilter): + user_id: int + + async def _fake_file_processor( celery_app: Celery, task_name: str, task_id: str, files: list[str] ) -> str: @@ -91,8 +96,9 @@ def _(celery_app: Celery) -> None: async def test_submitting_task_calling_async_function_results_with_success_state( celery_task_manager: CeleryTaskManager, + with_celery_worker: WorkController, ): - task_filter = TaskFilter(user_id=42) + task_filter = MyTaskFilter(user_id=42) task_uuid = await celery_task_manager.submit_task( TaskMetadata( @@ -121,8 +127,9 @@ async def test_submitting_task_calling_async_function_results_with_success_state async def test_submitting_task_with_failure_results_with_error( celery_task_manager: CeleryTaskManager, + with_celery_worker: WorkController, ): - task_filter = TaskFilter(user_id=42) + task_filter = MyTaskFilter(user_id=42) task_uuid = await celery_task_manager.submit_task( TaskMetadata( @@ -149,8 +156,9 @@ async def test_submitting_task_with_failure_results_with_error( async def test_cancelling_a_running_task_aborts_and_deletes( celery_task_manager: CeleryTaskManager, + with_celery_worker: WorkController, ): - task_filter = TaskFilter(user_id=42) + task_filter = MyTaskFilter(user_id=42) task_uuid = await celery_task_manager.submit_task( TaskMetadata( @@ -171,8 +179,9 @@ async def test_cancelling_a_running_task_aborts_and_deletes( async def test_listing_task_uuids_contains_submitted_task( celery_task_manager: CeleryTaskManager, + with_celery_worker: WorkController, ): - task_filter = TaskFilter(user_id=42) + task_filter = MyTaskFilter(user_id=42) task_uuid = await celery_task_manager.submit_task( TaskMetadata( @@ -190,5 +199,5 @@ async def test_listing_task_uuids_contains_submitted_task( tasks = await celery_task_manager.list_tasks(task_filter) assert any(task.uuid == task_uuid for task in tasks) - tasks = await celery_task_manager.list_tasks(task_filter) - assert any(task.uuid == task_uuid for task in tasks) + tasks = await celery_task_manager.list_tasks(task_filter) + assert any(task.uuid == task_uuid for task in tasks) diff --git a/packages/service-library/src/servicelib/celery/app_server.py b/packages/service-library/src/servicelib/celery/app_server.py index 9312497aa31..c11d4b46acd 100644 --- a/packages/service-library/src/servicelib/celery/app_server.py +++ b/packages/service-library/src/servicelib/celery/app_server.py @@ -31,16 +31,14 @@ def shutdown_event(self) -> asyncio.Event: return self._shutdown_event @property + @abstractmethod def task_manager(self) -> TaskManager: - return self._task_manager - - @task_manager.setter - def task_manager(self, manager: TaskManager) -> None: - self._task_manager = manager + raise NotImplementedError @abstractmethod - async def lifespan( + async def run_until_shutdown( self, startup_completed_event: threading.Event, ) -> None: + """Used to initialize the app server until shutdown event is set.""" raise NotImplementedError diff --git a/packages/service-library/src/servicelib/celery/models.py b/packages/service-library/src/servicelib/celery/models.py index c35fc98504e..db1a07c80ee 100644 --- a/packages/service-library/src/servicelib/celery/models.py +++ b/packages/service-library/src/servicelib/celery/models.py @@ -97,7 +97,7 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: ... async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: ... - async def list_tasks(self, task_context: TaskFilter) -> list[Task]: ... + async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: ... async def remove_task(self, task_id: TaskID) -> None: ... diff --git a/packages/service-library/src/servicelib/celery/task_manager.py b/packages/service-library/src/servicelib/celery/task_manager.py index 68a62edbb8a..94c4019e027 100644 --- a/packages/service-library/src/servicelib/celery/task_manager.py +++ b/packages/service-library/src/servicelib/celery/task_manager.py @@ -1,4 +1,4 @@ -from typing import Any, Protocol +from typing import Any, Protocol, runtime_checkable from models_library.progress_bar import ProgressReport @@ -12,6 +12,7 @@ ) +@runtime_checkable class TaskManager(Protocol): async def submit_task( self, task_metadata: TaskMetadata, *, task_filter: TaskFilter, **task_param diff --git a/packages/service-library/src/servicelib/fastapi/celery/app_server.py b/packages/service-library/src/servicelib/fastapi/celery/app_server.py index e1a1d3255ac..3c42aa9144d 100644 --- a/packages/service-library/src/servicelib/fastapi/celery/app_server.py +++ b/packages/service-library/src/servicelib/fastapi/celery/app_server.py @@ -7,6 +7,7 @@ from fastapi import FastAPI from ...celery.app_server import BaseAppServer +from ...celery.task_manager import TaskManager _SHUTDOWN_TIMEOUT: Final[float] = datetime.timedelta(seconds=10).total_seconds() @@ -14,11 +15,16 @@ class FastAPIAppServer(BaseAppServer[FastAPI]): - def __init__(self, app: FastAPI): - super().__init__(app) - self._lifespan_manager: LifespanManager | None = None + @property + def task_manager(self) -> TaskManager: + task_manager = self.app.state.task_manager + assert task_manager, "Task manager is not initialized" # nosec + assert isinstance(task_manager, TaskManager) + return task_manager - async def lifespan(self, startup_completed_event: threading.Event) -> None: + async def run_until_shutdown( + self, startup_completed_event: threading.Event + ) -> None: async with LifespanManager( self.app, startup_timeout=None, # waits for full app initialization (DB migrations, etc.) diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py b/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py index e70b7f79112..82881b6af69 100644 --- a/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py +++ b/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py @@ -37,6 +37,4 @@ def worker_init_wrapper(sender, **_kwargs): assert _settings.API_SERVER_CELERY # nosec app_server = FastAPIAppServer(app=create_app(_settings)) - return partial(on_worker_init, app_server, _settings.API_SERVER_CELERY)( - sender, **_kwargs - ) + return partial(on_worker_init, app_server=app_server)(sender, **_kwargs) diff --git a/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py b/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py index 0b4ac4c2f4e..8f9f002e1d5 100644 --- a/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py +++ b/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py @@ -1,18 +1,48 @@ -from celery_library.common import create_app, create_task_manager +import logging + +from celery_library.backends.redis import RedisTaskInfoStore +from celery_library.common import create_app +from celery_library.task_manager import CeleryTaskManager from celery_library.types import register_celery_types, register_pydantic_types from fastapi import FastAPI +from servicelib.logging_utils import log_context +from servicelib.redis import RedisClientSDK from settings_library.celery import CelerySettings +from settings_library.redis import RedisDatabase from ..celery_worker.worker_tasks.tasks import pydantic_types_to_register +_logger = logging.getLogger(__name__) + -def setup_task_manager(app: FastAPI, celery_settings: CelerySettings) -> None: +def setup_task_manager(app: FastAPI, settings: CelerySettings) -> None: async def on_startup() -> None: - app.state.task_manager = await create_task_manager( - create_app(celery_settings), celery_settings - ) + with log_context(_logger, logging.INFO, "Setting up Celery"): + redis_client_sdk = RedisClientSDK( + settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn( + RedisDatabase.CELERY_TASKS + ), + client_name="api_server_celery_tasks", + ) + app.state.celery_tasks_redis_client_sdk = redis_client_sdk + await redis_client_sdk.setup() + + app.state.task_manager = CeleryTaskManager( + create_app(settings), + settings, + RedisTaskInfoStore(redis_client_sdk), + ) + + register_celery_types() + register_pydantic_types(*pydantic_types_to_register) - register_celery_types() - register_pydantic_types(*pydantic_types_to_register) + async def on_shutdown() -> None: + with log_context(_logger, logging.INFO, "Shutting down Celery"): + redis_client_sdk: RedisClientSDK | None = ( + app.state.celery_tasks_redis_client_sdk + ) + if redis_client_sdk: + await redis_client_sdk.shutdown() app.add_event_handler("startup", on_startup) + app.add_event_handler("shutdown", on_shutdown) diff --git a/services/api-server/tests/unit/api_functions/celery/conftest.py b/services/api-server/tests/unit/api_functions/celery/conftest.py index 993ba4b73ab..0a5c933a728 100644 --- a/services/api-server/tests/unit/api_functions/celery/conftest.py +++ b/services/api-server/tests/unit/api_functions/celery/conftest.py @@ -6,7 +6,6 @@ import datetime from collections.abc import AsyncIterator, Callable -from functools import partial from typing import Any import pytest @@ -126,11 +125,8 @@ async def with_api_server_celery_worker( app_server = FastAPIAppServer(app=create_app(app_settings)) - def _on_worker_init_wrapper(sender: WorkController, **_kwargs): - assert app_settings.API_SERVER_CELERY # nosec - return partial(on_worker_init, app_server, app_settings.API_SERVER_CELERY)( - sender, **_kwargs - ) + def _on_worker_init_wrapper(sender: WorkController, **kwargs): + return on_worker_init(sender, app_server=app_server, **kwargs) worker_init.connect(_on_worker_init_wrapper) worker_shutdown.connect(on_worker_shutdown) diff --git a/services/storage/src/simcore_service_storage/core/application.py b/services/storage/src/simcore_service_storage/core/application.py index 305e13bb3ea..cf3bb4546fc 100644 --- a/services/storage/src/simcore_service_storage/core/application.py +++ b/services/storage/src/simcore_service_storage/core/application.py @@ -70,11 +70,11 @@ def create_app(settings: ApplicationSettings) -> FastAPI: # noqa: C901 setup_s3(app) setup_client_session(app, tracing_settings=settings.STORAGE_TRACING) - if settings.STORAGE_CELERY and not settings.STORAGE_WORKER_MODE: - setup_rabbitmq(app) - - setup_task_manager(app, celery_settings=settings.STORAGE_CELERY) + if settings.STORAGE_CELERY: + setup_task_manager(app, settings=settings.STORAGE_CELERY) + if not settings.STORAGE_WORKER_MODE: + setup_rabbitmq(app) setup_rpc_routes(app) setup_rest_api_routes(app, API_VTAG) diff --git a/services/storage/src/simcore_service_storage/modules/celery/__init__.py b/services/storage/src/simcore_service_storage/modules/celery/__init__.py index 684262d3f9b..0dcb3a2ea5e 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/__init__.py +++ b/services/storage/src/simcore_service_storage/modules/celery/__init__.py @@ -1,4 +1,7 @@ -from celery_library.common import create_app, create_task_manager +import logging + +from celery_library.backends.redis import RedisTaskInfoStore +from celery_library.common import create_app from celery_library.task_manager import CeleryTaskManager from celery_library.types import register_celery_types, register_pydantic_types from fastapi import FastAPI @@ -6,21 +9,47 @@ FileUploadCompletionBody, FoldersBody, ) +from servicelib.logging_utils import log_context +from servicelib.redis import RedisClientSDK from settings_library.celery import CelerySettings +from settings_library.redis import RedisDatabase from ...models import FileMetaData +_logger = logging.getLogger(__name__) + -def setup_task_manager(app: FastAPI, celery_settings: CelerySettings) -> None: +def setup_task_manager(app: FastAPI, settings: CelerySettings) -> None: async def on_startup() -> None: - app.state.task_manager = await create_task_manager( - create_app(celery_settings), celery_settings - ) + with log_context(_logger, logging.INFO, "Setting up Celery"): + redis_client_sdk = RedisClientSDK( + settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn( + RedisDatabase.CELERY_TASKS + ), + client_name="storage_celery_tasks", + ) + app.state.celery_tasks_redis_client_sdk = redis_client_sdk + await redis_client_sdk.setup() + + app.state.task_manager = CeleryTaskManager( + create_app(settings), + settings, + RedisTaskInfoStore(redis_client_sdk), + ) + + register_celery_types() + register_pydantic_types(FileUploadCompletionBody, FileMetaData, FoldersBody) - register_celery_types() - register_pydantic_types(FileUploadCompletionBody, FileMetaData, FoldersBody) + async def on_shutdown() -> None: + with log_context(_logger, logging.INFO, "Shutting down Celery"): + redis_client_sdk: RedisClientSDK | None = ( + app.state.celery_tasks_redis_client_sdk + ) + if redis_client_sdk: + await redis_client_sdk.shutdown() app.add_event_handler("startup", on_startup) + app.add_event_handler("shutdown", on_shutdown) def get_task_manager_from_app(app: FastAPI) -> CeleryTaskManager: diff --git a/services/storage/src/simcore_service_storage/modules/celery/worker_main.py b/services/storage/src/simcore_service_storage/modules/celery/worker_main.py index 396ed37accf..f2e90e90024 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/worker_main.py +++ b/services/storage/src/simcore_service_storage/modules/celery/worker_main.py @@ -1,7 +1,5 @@ """Main application to be deployed in for example uvicorn.""" -from functools import partial - from celery.signals import worker_init, worker_shutdown # type: ignore[import-untyped] from celery_library.common import create_app as create_celery_app from celery_library.signals import ( @@ -32,11 +30,8 @@ app_server = FastAPIAppServer(app=create_app(_settings)) -def worker_init_wrapper(sender, **_kwargs): - assert _settings.STORAGE_CELERY # nosec - return partial(on_worker_init, app_server, _settings.STORAGE_CELERY)( - sender, **_kwargs - ) +def worker_init_wrapper(sender, **kwargs): + return on_worker_init(sender, app_server, **kwargs) worker_init.connect(worker_init_wrapper) diff --git a/services/storage/tests/conftest.py b/services/storage/tests/conftest.py index 32813640197..802c3fab387 100644 --- a/services/storage/tests/conftest.py +++ b/services/storage/tests/conftest.py @@ -12,7 +12,6 @@ import random import sys from collections.abc import AsyncIterator, Awaitable, Callable -from functools import partial from pathlib import Path from typing import Any, Final, cast @@ -1018,10 +1017,7 @@ async def with_storage_celery_worker( app_server = FastAPIAppServer(app=create_app(app_settings)) def _on_worker_init_wrapper(sender: WorkController, **_kwargs): - assert app_settings.STORAGE_CELERY # nosec - return partial(on_worker_init, app_server, app_settings.STORAGE_CELERY)( - sender, **_kwargs - ) + return on_worker_init(sender, app_server, **_kwargs) worker_init.connect(_on_worker_init_wrapper) worker_shutdown.connect(on_worker_shutdown)