Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3174045
feat: move redis client lifecycle to app server's one
giancarloromeo Aug 20, 2025
ab99e13
fix: remove unused param
giancarloromeo Aug 20, 2025
0d26e0c
fix: fake server Redis client's lifecycle
giancarloromeo Aug 20, 2025
f2b84eb
fix: typecheck
giancarloromeo Aug 20, 2025
064f8ac
fix: remove unuseful setup
giancarloromeo Aug 20, 2025
f8290e8
fix: celery task manager fixture
giancarloromeo Aug 20, 2025
8999602
fix: task manager property
giancarloromeo Aug 20, 2025
dd0a684
fix: typecheck
giancarloromeo Aug 20, 2025
a81aafc
fix: absolute import
giancarloromeo Aug 20, 2025
d50dbb4
tests: use in-memory Redis
giancarloromeo Aug 20, 2025
86bbe39
fix: rename
giancarloromeo Aug 20, 2025
70bf868
fix: shutdown
giancarloromeo Aug 21, 2025
1f626a2
fix: worker shutdown
giancarloromeo Aug 21, 2025
246d695
fix: use threads
giancarloromeo Aug 21, 2025
41446dc
fix: explicity stop worker
giancarloromeo Aug 21, 2025
ee52317
fix: shutdown
giancarloromeo Aug 21, 2025
97ded5f
fix: raise timeout
giancarloromeo Aug 21, 2025
abeab74
fix: force worker stop
giancarloromeo Aug 21, 2025
8d2c423
fix: remove rabbit
giancarloromeo Aug 21, 2025
1bb6dad
fix: rabbit
giancarloromeo Aug 21, 2025
9327579
fix: use separate Celery app
giancarloromeo Aug 22, 2025
5198a37
fix: add wait after submit
giancarloromeo Aug 22, 2025
f223f79
fix: test indent
giancarloromeo Aug 22, 2025
f347698
fix: loglevel
giancarloromeo Aug 22, 2025
585dc9f
fix: wait
giancarloromeo Aug 22, 2025
ac35b82
Merge branch 'master' into is8159/fix-redis-client-lifecycle
giancarloromeo Aug 22, 2025
ceae3c3
fix: remove partials
giancarloromeo Aug 22, 2025
81161e9
fix: remove unused asserts
giancarloromeo Aug 22, 2025
1e02fe7
fix: assert
giancarloromeo Aug 22, 2025
ad4cc8c
fix: remove partial
giancarloromeo Aug 22, 2025
46f5034
remove unused
giancarloromeo Aug 26, 2025
fec6e98
rename
giancarloromeo Aug 26, 2025
8d4cf9b
Merge branch 'master' into is8159/fix-redis-client-lifecycle
giancarloromeo Aug 26, 2025
8270bb1
Merge remote-tracking branch 'upstream/master' into is8159/fix-redis-…
giancarloromeo Aug 26, 2025
8125b5d
fix: remove unused
giancarloromeo Aug 26, 2025
59c33ae
Merge branch 'master' into is8159/fix-redis-client-lifecycle
giancarloromeo Sep 8, 2025
7618c5e
fix: remove sleep
giancarloromeo Sep 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions packages/celery-library/src/celery_library/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@
from typing import Any

from celery import Celery # type: ignore[import-untyped]
from servicelib.redis import RedisClientSDK
from settings_library.celery import CelerySettings
from settings_library.redis import RedisDatabase

from .backends._redis import RedisTaskInfoStore
from .task_manager import CeleryTaskManager


def _celery_configure(celery_settings: CelerySettings) -> dict[str, Any]:
base_config = {
Expand Down Expand Up @@ -36,22 +32,3 @@ def create_app(settings: CelerySettings) -> Celery:
),
**_celery_configure(settings),
)


async def create_task_manager(
app: Celery, settings: CelerySettings
) -> CeleryTaskManager:
redis_client_sdk = RedisClientSDK(
settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
RedisDatabase.CELERY_TASKS
),
client_name="celery_tasks",
)
await redis_client_sdk.setup()
# GCR please address https://github.com/ITISFoundation/osparc-simcore/issues/8159

return CeleryTaskManager(
app,
settings,
RedisTaskInfoStore(redis_client_sdk),
)
18 changes: 4 additions & 14 deletions packages/celery-library/src/celery_library/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@
from celery.worker.worker import WorkController # type: ignore[import-untyped]
from servicelib.celery.app_server import BaseAppServer
from servicelib.logging_utils import log_context
from settings_library.celery import CelerySettings

from .common import create_task_manager
from .utils import get_app_server, set_app_server

_logger = logging.getLogger(__name__)


def on_worker_init(
app_server: BaseAppServer,
celery_settings: CelerySettings,
sender: WorkController,
**_kwargs,
) -> None:
Expand All @@ -26,21 +23,14 @@ def _init(startup_complete_event: threading.Event) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

async def _setup_task_manager():
assert sender.app # nosec
assert isinstance(sender.app, Celery) # nosec
assert sender.app # nosec
assert isinstance(sender.app, Celery) # nosec

app_server.task_manager = await create_task_manager(
sender.app,
celery_settings,
)

set_app_server(sender.app, app_server)
set_app_server(sender.app, app_server)

app_server.event_loop = loop

loop.run_until_complete(_setup_task_manager())
loop.run_until_complete(app_server.lifespan(startup_complete_event))
loop.run_until_complete(app_server.start_and_hold(startup_complete_event))

thread = threading.Thread(
group=None,
Expand Down
86 changes: 65 additions & 21 deletions packages/celery-library/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,31 @@
# pylint: disable=unused-argument

import datetime
import logging
import threading
from collections.abc import AsyncIterator, Callable
from functools import partial
from typing import Any

import pytest
from celery import Celery # type: ignore[import-untyped]
from celery.contrib.testing.worker import TestWorkController, start_worker
from celery.contrib.testing.worker import (
TestWorkController,
start_worker,
)
from celery.signals import worker_init, worker_shutdown
from celery.worker.worker import WorkController
from celery_library.common import create_task_manager
from celery_library.backends._redis import RedisTaskInfoStore
from celery_library.signals import on_worker_init, on_worker_shutdown
from celery_library.task_manager import CeleryTaskManager
from celery_library.types import register_celery_types
from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict
from pytest_simcore.helpers.typing_env import EnvVarsDict
from servicelib.celery.app_server import BaseAppServer
from servicelib.celery.task_manager import TaskManager
from servicelib.redis import RedisClientSDK
from settings_library.celery import CelerySettings
from settings_library.redis import RedisSettings
from settings_library.redis import RedisDatabase, RedisSettings

pytest_plugins = [
"pytest_simcore.docker_compose",
Expand All @@ -33,11 +39,40 @@
]


_logger = logging.getLogger(__name__)


class FakeAppServer(BaseAppServer):
async def lifespan(self, startup_completed_event: threading.Event) -> None:
def __init__(self, app: Celery, settings: CelerySettings):
super().__init__(app)
self._settings = settings
self._task_manager: CeleryTaskManager | None = None

@property
def task_manager(self) -> TaskManager:
assert self._task_manager, "Task manager is not initialized"
return self._task_manager

async def start_and_hold(self, startup_completed_event: threading.Event) -> None:
redis_client_sdk = RedisClientSDK(
self._settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
RedisDatabase.CELERY_TASKS
),
client_name="pytest_celery_tasks",
)
await redis_client_sdk.setup()

self._task_manager = CeleryTaskManager(
self._app,
self._settings,
RedisTaskInfoStore(redis_client_sdk),
)

startup_completed_event.set()
await self.shutdown_event.wait() # wait for shutdown

await redis_client_sdk.shutdown()


@pytest.fixture
def register_celery_tasks() -> Callable[[Celery], None]:
Expand All @@ -51,17 +86,12 @@ def _(celery_app: Celery) -> None: ...
@pytest.fixture
def app_environment(
monkeypatch: pytest.MonkeyPatch,
redis_service: RedisSettings,
env_devel_dict: EnvVarsDict,
) -> EnvVarsDict:
return setenvs_from_dict(
monkeypatch,
{
**env_devel_dict,
"REDIS_SECURE": redis_service.REDIS_SECURE,
"REDIS_HOST": redis_service.REDIS_HOST,
"REDIS_PORT": f"{redis_service.REDIS_PORT}",
"REDIS_PASSWORD": redis_service.REDIS_PASSWORD.get_secret_value(),
},
)

Expand All @@ -74,8 +104,8 @@ def celery_settings(


@pytest.fixture
def app_server() -> BaseAppServer:
return FakeAppServer(app=None)
def app_server(celery_app: Celery, celery_settings: CelerySettings) -> BaseAppServer:
return FakeAppServer(app=celery_app, settings=celery_settings)


@pytest.fixture(scope="session")
Expand All @@ -98,11 +128,10 @@ def celery_config() -> dict[str, Any]:
async def with_celery_worker(
celery_app: Celery,
app_server: BaseAppServer,
celery_settings: CelerySettings,
register_celery_tasks: Callable[[Celery], None],
) -> AsyncIterator[TestWorkController]:
def _on_worker_init_wrapper(sender: WorkController, **_kwargs):
return partial(on_worker_init, app_server, celery_settings)(sender, **_kwargs)
return partial(on_worker_init, app_server)(sender, **_kwargs)

worker_init.connect(_on_worker_init_wrapper)
worker_shutdown.connect(on_worker_shutdown)
Expand All @@ -111,24 +140,39 @@ def _on_worker_init_wrapper(sender: WorkController, **_kwargs):

with start_worker(
celery_app,
pool="threads",
concurrency=1,
pool="threads",
loglevel="info",
perform_ping_check=False,
queues="default",
) as worker:
yield worker


@pytest.fixture
async def mock_celery_app(celery_config: dict[str, Any]) -> Celery:
return Celery(**celery_config)


@pytest.fixture
async def celery_task_manager(
celery_app: Celery,
mock_celery_app: Celery,
celery_settings: CelerySettings,
with_celery_worker: TestWorkController,
) -> CeleryTaskManager:
use_in_memory_redis: RedisSettings,
) -> AsyncIterator[CeleryTaskManager]:
register_celery_types()

return await create_task_manager(
celery_app,
celery_settings,
)
try:
redis_client_sdk = RedisClientSDK(
use_in_memory_redis.build_redis_dsn(RedisDatabase.CELERY_TASKS),
client_name="pytest_celery_tasks",
)
await redis_client_sdk.setup()

yield CeleryTaskManager(
mock_celery_app,
celery_settings,
RedisTaskInfoStore(redis_client_sdk),
)
finally:
await redis_client_sdk.shutdown()
2 changes: 2 additions & 0 deletions packages/celery-library/tests/unit/test_async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ async def test_async_jobs_cancel(
payload=60 * 10, # test hangs if not cancelled properly
)

await asyncio.sleep(3)

await async_jobs.cancel(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
Expand Down
9 changes: 7 additions & 2 deletions packages/celery-library/tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
from celery import Celery, Task
from celery.contrib.abortable import AbortableTask
from celery.worker.worker import WorkController
from celery_library.errors import TransferrableCeleryError
from celery_library.task import register_task
from celery_library.task_manager import CeleryTaskManager
Expand Down Expand Up @@ -92,6 +93,7 @@ def _(celery_app: Celery) -> None:

async def test_submitting_task_calling_async_function_results_with_success_state(
celery_task_manager: CeleryTaskManager,
with_celery_worker: WorkController,
):
task_filter = TaskFilter(user_id=42)

Expand Down Expand Up @@ -122,6 +124,7 @@ async def test_submitting_task_calling_async_function_results_with_success_state

async def test_submitting_task_with_failure_results_with_error(
celery_task_manager: CeleryTaskManager,
with_celery_worker: WorkController,
):
task_filter = TaskFilter(user_id=42)

Expand Down Expand Up @@ -150,6 +153,7 @@ async def test_submitting_task_with_failure_results_with_error(

async def test_cancelling_a_running_task_aborts_and_deletes(
celery_task_manager: CeleryTaskManager,
with_celery_worker: WorkController,
):
task_filter = TaskFilter(user_id=42)

Expand Down Expand Up @@ -182,6 +186,7 @@ async def test_cancelling_a_running_task_aborts_and_deletes(

async def test_listing_task_uuids_contains_submitted_task(
celery_task_manager: CeleryTaskManager,
with_celery_worker: WorkController,
):
task_filter = TaskFilter(user_id=42)

Expand All @@ -201,5 +206,5 @@ async def test_listing_task_uuids_contains_submitted_task(
tasks = await celery_task_manager.list_tasks(task_filter)
assert any(task.uuid == task_uuid for task in tasks)

tasks = await celery_task_manager.list_tasks(task_filter)
assert any(task.uuid == task_uuid for task in tasks)
tasks = await celery_task_manager.list_tasks(task_filter)
assert any(task.uuid == task_uuid for task in tasks)
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ def shutdown_event(self) -> asyncio.Event:
return self._shutdown_event

@property
@abstractmethod
def task_manager(self) -> TaskManager:
return self._task_manager

@task_manager.setter
def task_manager(self, manager: TaskManager) -> None:
self._task_manager = manager
raise NotImplementedError

@abstractmethod
async def lifespan(
async def start_and_hold(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check my other comment about renaming this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for interfaces, plaease add some doc about what is expected, specially
when the name does not reveals all details

self,
startup_completed_event: threading.Event,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion packages/service-library/src/servicelib/celery/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: ...

async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: ...

async def list_tasks(self, task_context: TaskFilter) -> list[Task]: ...
async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: ...

async def remove_task(self, task_id: TaskID) -> None: ...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi import FastAPI

from ...celery.app_server import BaseAppServer
from ...celery.task_manager import TaskManager

_SHUTDOWN_TIMEOUT: Final[float] = datetime.timedelta(seconds=10).total_seconds()

Expand All @@ -18,7 +19,13 @@ def __init__(self, app: FastAPI):
super().__init__(app)
self._lifespan_manager: LifespanManager | None = None

async def lifespan(self, startup_completed_event: threading.Event) -> None:
@property
def task_manager(self) -> TaskManager:
assert self.app.state.task_manager, "Task manager is not initialized" # nosec
task_manager: TaskManager = self.app.state.task_manager
return task_manager

async def start_and_hold(self, startup_completed_event: threading.Event) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a lifespan and the one problem I see here is the returned type that is wrong. It should be AsyncIterator[None] which would remove the confusion I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't yield anything here. This is the place in which the initialized FastAPI instance stays parked waiting for the shutdown event.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the primary entrypoint for this service, i would call it run_until_shutdown that emphasizes the lifecycle clearly (and reminds the
naming from asyncio library).

Regarding @sanderegg comment.

In other parts of the code our approach is to provide a context-manager like function that includes setup&tear-down parts in one place (see https://github.com/ITISFoundation/osparc-simcore/blob/master/packages/service-library/src/servicelib/fastapi/postgres_lifespan.py#L31C11-L31C37).

This approach here is difference since this member function encapsulates the setup&tear-down parts AND runs it. That reduces the flexibility but I guess you do not need it here.

I understand this function also can only be called once. Therefore I would add a protection for it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIP: use log_context(INFO,...) instead of _logger.info

async with LifespanManager(
self.app,
startup_timeout=None, # waits for full app initialization (DB migrations, etc.)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ..dsm import setup_dsm
from ..dsm_cleaner import setup_dsm_cleaner
from ..exceptions.handlers import set_exception_handlers
from ..modules.celery import setup_task_manager
from ..modules.celery import setup_celery
from ..modules.db import setup_db
from ..modules.long_running_tasks import setup_rest_api_long_running_tasks_for_uploads
from ..modules.rabbitmq import setup as setup_rabbitmq
Expand Down Expand Up @@ -71,12 +71,13 @@ def create_app(settings: ApplicationSettings) -> FastAPI: # noqa: C901
setup_s3(app)
setup_client_session(app, tracing_settings=settings.STORAGE_TRACING)

if settings.STORAGE_CELERY and not settings.STORAGE_WORKER_MODE:
setup_rabbitmq(app)
if settings.STORAGE_CELERY:
setup_celery(app, settings=settings.STORAGE_CELERY)

setup_task_manager(app, celery_settings=settings.STORAGE_CELERY)
if not settings.STORAGE_WORKER_MODE:
setup_rabbitmq(app)
setup_rpc_routes(app)

setup_rpc_routes(app)
setup_rest_api_long_running_tasks_for_uploads(app)
setup_rest_api_routes(app, API_VTAG)
set_exception_handlers(app)
Expand Down
Loading
Loading