Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3174045
feat: move redis client lifecycle to app server's one
giancarloromeo Aug 20, 2025
ab99e13
fix: remove unused param
giancarloromeo Aug 20, 2025
0d26e0c
fix: fake server Redis client's lifecycle
giancarloromeo Aug 20, 2025
f2b84eb
fix: typecheck
giancarloromeo Aug 20, 2025
064f8ac
fix: remove unuseful setup
giancarloromeo Aug 20, 2025
f8290e8
fix: celery task manager fixture
giancarloromeo Aug 20, 2025
8999602
fix: task manager property
giancarloromeo Aug 20, 2025
dd0a684
fix: typecheck
giancarloromeo Aug 20, 2025
a81aafc
fix: absolute import
giancarloromeo Aug 20, 2025
d50dbb4
tests: use in-memory Redis
giancarloromeo Aug 20, 2025
86bbe39
fix: rename
giancarloromeo Aug 20, 2025
70bf868
fix: shutdown
giancarloromeo Aug 21, 2025
1f626a2
fix: worker shutdown
giancarloromeo Aug 21, 2025
246d695
fix: use threads
giancarloromeo Aug 21, 2025
41446dc
fix: explicity stop worker
giancarloromeo Aug 21, 2025
ee52317
fix: shutdown
giancarloromeo Aug 21, 2025
97ded5f
fix: raise timeout
giancarloromeo Aug 21, 2025
abeab74
fix: force worker stop
giancarloromeo Aug 21, 2025
8d2c423
fix: remove rabbit
giancarloromeo Aug 21, 2025
1bb6dad
fix: rabbit
giancarloromeo Aug 21, 2025
9327579
fix: use separate Celery app
giancarloromeo Aug 22, 2025
5198a37
fix: add wait after submit
giancarloromeo Aug 22, 2025
f223f79
fix: test indent
giancarloromeo Aug 22, 2025
f347698
fix: loglevel
giancarloromeo Aug 22, 2025
585dc9f
fix: wait
giancarloromeo Aug 22, 2025
ac35b82
Merge branch 'master' into is8159/fix-redis-client-lifecycle
giancarloromeo Aug 22, 2025
ceae3c3
fix: remove partials
giancarloromeo Aug 22, 2025
81161e9
fix: remove unused asserts
giancarloromeo Aug 22, 2025
1e02fe7
fix: assert
giancarloromeo Aug 22, 2025
ad4cc8c
fix: remove partial
giancarloromeo Aug 22, 2025
46f5034
remove unused
giancarloromeo Aug 26, 2025
fec6e98
rename
giancarloromeo Aug 26, 2025
8d4cf9b
Merge branch 'master' into is8159/fix-redis-client-lifecycle
giancarloromeo Aug 26, 2025
8270bb1
Merge remote-tracking branch 'upstream/master' into is8159/fix-redis-…
giancarloromeo Aug 26, 2025
8125b5d
fix: remove unused
giancarloromeo Aug 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions packages/celery-library/src/celery_library/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@
from typing import Any

from celery import Celery # type: ignore[import-untyped]
from servicelib.redis import RedisClientSDK
from settings_library.celery import CelerySettings
from settings_library.redis import RedisDatabase

from .backends._redis import RedisTaskInfoStore
from .task_manager import CeleryTaskManager


def _celery_configure(celery_settings: CelerySettings) -> dict[str, Any]:
base_config = {
Expand Down Expand Up @@ -36,22 +32,3 @@ def create_app(settings: CelerySettings) -> Celery:
),
**_celery_configure(settings),
)


async def create_task_manager(
app: Celery, settings: CelerySettings
) -> CeleryTaskManager:
redis_client_sdk = RedisClientSDK(
settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
RedisDatabase.CELERY_TASKS
),
client_name="celery_tasks",
)
await redis_client_sdk.setup()
# GCR please address https://github.com/ITISFoundation/osparc-simcore/issues/8159

return CeleryTaskManager(
app,
settings,
RedisTaskInfoStore(redis_client_sdk),
)
18 changes: 4 additions & 14 deletions packages/celery-library/src/celery_library/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@
from celery.worker.worker import WorkController # type: ignore[import-untyped]
from servicelib.celery.app_server import BaseAppServer
from servicelib.logging_utils import log_context
from settings_library.celery import CelerySettings

from .common import create_task_manager
from .utils import get_app_server, set_app_server

_logger = logging.getLogger(__name__)


def on_worker_init(
app_server: BaseAppServer,
celery_settings: CelerySettings,
sender: WorkController,
**_kwargs,
) -> None:
Expand All @@ -26,21 +23,14 @@ def _init(startup_complete_event: threading.Event) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

async def _setup_task_manager():
assert sender.app # nosec
assert isinstance(sender.app, Celery) # nosec
assert sender.app # nosec
assert isinstance(sender.app, Celery) # nosec

app_server.task_manager = await create_task_manager(
sender.app,
celery_settings,
)

set_app_server(sender.app, app_server)
set_app_server(sender.app, app_server)

app_server.event_loop = loop

loop.run_until_complete(_setup_task_manager())
loop.run_until_complete(app_server.lifespan(startup_complete_event))
loop.run_until_complete(app_server.run_until_shutdown(startup_complete_event))

thread = threading.Thread(
group=None,
Expand Down
89 changes: 67 additions & 22 deletions packages/celery-library/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,30 @@
# pylint: disable=unused-argument

import datetime
import logging
import threading
from collections.abc import AsyncIterator, Callable
from functools import partial
from typing import Any

import pytest
from celery import Celery # type: ignore[import-untyped]
from celery.contrib.testing.worker import TestWorkController, start_worker
from celery.contrib.testing.worker import (
TestWorkController,
start_worker,
)
from celery.signals import worker_init, worker_shutdown
from celery.worker.worker import WorkController
from celery_library.common import create_task_manager
from celery_library.backends._redis import RedisTaskInfoStore
from celery_library.signals import on_worker_init, on_worker_shutdown
from celery_library.task_manager import CeleryTaskManager
from celery_library.types import register_celery_types
from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict
from pytest_simcore.helpers.typing_env import EnvVarsDict
from servicelib.celery.app_server import BaseAppServer
from servicelib.celery.task_manager import TaskManager
from servicelib.redis import RedisClientSDK
from settings_library.celery import CelerySettings
from settings_library.redis import RedisSettings
from settings_library.redis import RedisDatabase, RedisSettings

pytest_plugins = [
"pytest_simcore.docker_compose",
Expand All @@ -33,11 +38,42 @@
]


_logger = logging.getLogger(__name__)


class FakeAppServer(BaseAppServer):
async def lifespan(self, startup_completed_event: threading.Event) -> None:
def __init__(self, app: Celery, settings: CelerySettings):
super().__init__(app)
self._settings = settings
self._task_manager: CeleryTaskManager | None = None

@property
def task_manager(self) -> TaskManager:
assert self._task_manager, "Task manager is not initialized"
return self._task_manager

async def run_until_shutdown(
self, startup_completed_event: threading.Event
) -> None:
redis_client_sdk = RedisClientSDK(
self._settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
RedisDatabase.CELERY_TASKS
),
client_name="pytest_celery_tasks",
)
await redis_client_sdk.setup()

self._task_manager = CeleryTaskManager(
self._app,
self._settings,
RedisTaskInfoStore(redis_client_sdk),
)

startup_completed_event.set()
await self.shutdown_event.wait() # wait for shutdown

await redis_client_sdk.shutdown()


@pytest.fixture
def register_celery_tasks() -> Callable[[Celery], None]:
Expand All @@ -51,17 +87,12 @@ def _(celery_app: Celery) -> None: ...
@pytest.fixture
def app_environment(
monkeypatch: pytest.MonkeyPatch,
redis_service: RedisSettings,
env_devel_dict: EnvVarsDict,
) -> EnvVarsDict:
return setenvs_from_dict(
monkeypatch,
{
**env_devel_dict,
"REDIS_SECURE": redis_service.REDIS_SECURE,
"REDIS_HOST": redis_service.REDIS_HOST,
"REDIS_PORT": f"{redis_service.REDIS_PORT}",
"REDIS_PASSWORD": redis_service.REDIS_PASSWORD.get_secret_value(),
},
)

Expand All @@ -74,8 +105,8 @@ def celery_settings(


@pytest.fixture
def app_server() -> BaseAppServer:
return FakeAppServer(app=None)
def app_server(celery_app: Celery, celery_settings: CelerySettings) -> BaseAppServer:
return FakeAppServer(app=celery_app, settings=celery_settings)


@pytest.fixture(scope="session")
Expand All @@ -98,11 +129,10 @@ def celery_config() -> dict[str, Any]:
async def with_celery_worker(
celery_app: Celery,
app_server: BaseAppServer,
celery_settings: CelerySettings,
register_celery_tasks: Callable[[Celery], None],
) -> AsyncIterator[TestWorkController]:
def _on_worker_init_wrapper(sender: WorkController, **_kwargs):
return partial(on_worker_init, app_server, celery_settings)(sender, **_kwargs)
return on_worker_init(app_server, sender, **_kwargs)

worker_init.connect(_on_worker_init_wrapper)
worker_shutdown.connect(on_worker_shutdown)
Expand All @@ -111,24 +141,39 @@ def _on_worker_init_wrapper(sender: WorkController, **_kwargs):

with start_worker(
celery_app,
pool="threads",
concurrency=1,
pool="threads",
loglevel="info",
perform_ping_check=False,
queues="default",
) as worker:
yield worker


@pytest.fixture
async def mock_celery_app(celery_config: dict[str, Any]) -> Celery:
return Celery(**celery_config)


@pytest.fixture
async def celery_task_manager(
celery_app: Celery,
mock_celery_app: Celery,
celery_settings: CelerySettings,
with_celery_worker: TestWorkController,
) -> CeleryTaskManager:
use_in_memory_redis: RedisSettings,
) -> AsyncIterator[CeleryTaskManager]:
register_celery_types()

return await create_task_manager(
celery_app,
celery_settings,
)
try:
redis_client_sdk = RedisClientSDK(
use_in_memory_redis.build_redis_dsn(RedisDatabase.CELERY_TASKS),
client_name="pytest_celery_tasks",
)
await redis_client_sdk.setup()

yield CeleryTaskManager(
mock_celery_app,
celery_settings,
RedisTaskInfoStore(redis_client_sdk),
)
finally:
await redis_client_sdk.shutdown()
2 changes: 2 additions & 0 deletions packages/celery-library/tests/unit/test_async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ async def test_async_jobs_cancel(
payload=60 * 10, # test hangs if not cancelled properly
)

await asyncio.sleep(3.0) # wait a bit before cancelling

await async_jobs.cancel(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
Expand Down
9 changes: 7 additions & 2 deletions packages/celery-library/tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
from celery import Celery, Task
from celery.contrib.abortable import AbortableTask
from celery.worker.worker import WorkController
from celery_library.errors import TransferrableCeleryError
from celery_library.task import register_task
from celery_library.task_manager import CeleryTaskManager
Expand Down Expand Up @@ -92,6 +93,7 @@ def _(celery_app: Celery) -> None:

async def test_submitting_task_calling_async_function_results_with_success_state(
celery_task_manager: CeleryTaskManager,
with_celery_worker: WorkController,
):
task_filter = TaskFilter(user_id=42)

Expand Down Expand Up @@ -122,6 +124,7 @@ async def test_submitting_task_calling_async_function_results_with_success_state

async def test_submitting_task_with_failure_results_with_error(
celery_task_manager: CeleryTaskManager,
with_celery_worker: WorkController,
):
task_filter = TaskFilter(user_id=42)

Expand Down Expand Up @@ -150,6 +153,7 @@ async def test_submitting_task_with_failure_results_with_error(

async def test_cancelling_a_running_task_aborts_and_deletes(
celery_task_manager: CeleryTaskManager,
with_celery_worker: WorkController,
):
task_filter = TaskFilter(user_id=42)

Expand Down Expand Up @@ -182,6 +186,7 @@ async def test_cancelling_a_running_task_aborts_and_deletes(

async def test_listing_task_uuids_contains_submitted_task(
celery_task_manager: CeleryTaskManager,
with_celery_worker: WorkController,
):
task_filter = TaskFilter(user_id=42)

Expand All @@ -201,5 +206,5 @@ async def test_listing_task_uuids_contains_submitted_task(
tasks = await celery_task_manager.list_tasks(task_filter)
assert any(task.uuid == task_uuid for task in tasks)

tasks = await celery_task_manager.list_tasks(task_filter)
assert any(task.uuid == task_uuid for task in tasks)
tasks = await celery_task_manager.list_tasks(task_filter)
assert any(task.uuid == task_uuid for task in tasks)
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ def shutdown_event(self) -> asyncio.Event:
return self._shutdown_event

@property
@abstractmethod
def task_manager(self) -> TaskManager:
return self._task_manager

@task_manager.setter
def task_manager(self, manager: TaskManager) -> None:
self._task_manager = manager
raise NotImplementedError

@abstractmethod
async def lifespan(
async def run_until_shutdown(
self,
startup_completed_event: threading.Event,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion packages/service-library/src/servicelib/celery/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None: ...

async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None: ...

async def list_tasks(self, task_context: TaskFilter) -> list[Task]: ...
async def list_tasks(self, task_filter: TaskFilter) -> list[Task]: ...

async def remove_task(self, task_id: TaskID) -> None: ...

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Protocol
from typing import Any, Protocol, runtime_checkable

from models_library.progress_bar import ProgressReport

Expand All @@ -12,6 +12,7 @@
)


@runtime_checkable
class TaskManager(Protocol):
async def submit_task(
self, task_metadata: TaskMetadata, *, task_filter: TaskFilter, **task_param
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,24 @@
from fastapi import FastAPI

from ...celery.app_server import BaseAppServer
from ...celery.task_manager import TaskManager

_SHUTDOWN_TIMEOUT: Final[float] = datetime.timedelta(seconds=10).total_seconds()

_logger = logging.getLogger(__name__)


class FastAPIAppServer(BaseAppServer[FastAPI]):
def __init__(self, app: FastAPI):
super().__init__(app)
self._lifespan_manager: LifespanManager | None = None
@property
def task_manager(self) -> TaskManager:
task_manager = self.app.state.task_manager
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order in which the app state is setup is very important and here I do not see how this is guaranteed. Can you please show me offline how the workflow works?

assert task_manager, "Task manager is not initialized" # nosec
assert isinstance(task_manager, TaskManager)
return task_manager

async def lifespan(self, startup_completed_event: threading.Event) -> None:
async def run_until_shutdown(
self, startup_completed_event: threading.Event
) -> None:
async with LifespanManager(
self.app,
startup_timeout=None, # waits for full app initialization (DB migrations, etc.)
Expand Down
Loading
Loading