diff --git a/.github/prompts/update-user-messages.prompt.md b/.github/prompts/update-user-messages.prompt.md index 37eee04951b..29ababc2ef1 100644 --- a/.github/prompts/update-user-messages.prompt.md +++ b/.github/prompts/update-user-messages.prompt.md @@ -1,6 +1,7 @@ --- mode: 'edit' description: 'Update user messages' +model: Claude Sonnet 3.5 --- This prompt guide is for updating user-facing messages in ${file} or ${selection} @@ -43,7 +44,17 @@ When modifying user messages, follow **as close as possible** these rules: user_message("Unable to load project.", _version=1) ``` -3. **Message Style**: Follow **strictly** the guidelines in `${workspaceFolder}/docs/user-messages-guidelines.md` +3. **Message Style**: Follow **STRICTLY ALL 10 GUIDELINES** in `${workspaceFolder}/docs/user-messages-guidelines.md`: + - Be Clear and Concise + - Provide Specific and Actionable Information + - Avoid Technical Jargon + - Use a Polite and Non-Blaming Tone + - Avoid Negative Words and Phrases + - Place Messages Appropriately + - Use Inline Validation When Possible + - Avoid Using All-Caps and Excessive Punctuation + - **Use Humor Sparingly** - Avoid casual phrases like "Oops!", "Whoops!", or overly informal language + - Offer Alternative Solutions or Support 4. **Preserve Context**: Ensure the modified message conveys the same meaning and context as the original. @@ -56,8 +67,10 @@ When modifying user messages, follow **as close as possible** these rules: # After user_message("Your session has expired. Please log in again.", _version=3) ``` + 6. **Replace 'Study' by 'Project'**: If the message contains the word 'Study', replace it with 'Project' to align with our terminology. +7. **Professional Tone**: Maintain a professional, helpful tone. Avoid humor, casual expressions, or overly informal language that might not be appropriate for all users or situations. ## Examples @@ -91,4 +104,14 @@ return HttpErrorInfo(status.HTTP_404_NOT_FOUND, user_message("User not found.", return HttpErrorInfo(status.HTTP_404_NOT_FOUND, user_message("The requested user could not be found.", _version=2)) ``` -Remember: The goal is to improve clarity and helpfulness for end-users while maintaining accurate versioning for tracking changes. +### Example 4: Removing Humor (Guideline 9) + +```python +# Before +user_message("Oops! Something went wrong, but we've noted it down and we'll sort it out ASAP. Thanks for your patience!") + +# After +user_message("Something went wrong on our end. We've been notified and will resolve this issue as soon as possible. Thank you for your patience.", _version=1) +``` + +Remember: The goal is to improve clarity and helpfulness for end-users while maintaining accurate versioning for tracking changes. **Always check that your updated messages comply with ALL 10 guidelines, especially avoiding humor and maintaining a professional tone.** diff --git a/packages/celery-library/tests/unit/test_tasks.py b/packages/celery-library/tests/unit/test_tasks.py index a4edfb7540a..35da31aa180 100644 --- a/packages/celery-library/tests/unit/test_tasks.py +++ b/packages/celery-library/tests/unit/test_tasks.py @@ -11,8 +11,8 @@ from random import randint import pytest -from celery import Celery, Task -from celery.contrib.abortable import AbortableTask +from celery import Celery, Task # pylint: disable=no-name-in-module +from celery.contrib.abortable import AbortableTask # pylint: disable=no-name-in-module from celery_library.errors import TransferrableCeleryError from celery_library.task import register_task from celery_library.task_manager import CeleryTaskManager diff --git a/packages/models-library/src/models_library/api_schemas_long_running_tasks/base.py b/packages/models-library/src/models_library/api_schemas_long_running_tasks/base.py index d6e132c5361..38f2fa2f926 100644 --- a/packages/models-library/src/models_library/api_schemas_long_running_tasks/base.py +++ b/packages/models-library/src/models_library/api_schemas_long_running_tasks/base.py @@ -2,7 +2,8 @@ from collections.abc import Awaitable, Callable from typing import Annotated, TypeAlias -from pydantic import BaseModel, Field, field_validator, validate_call +from pydantic import BaseModel, ConfigDict, Field, field_validator, validate_call +from pydantic.config import JsonDict _logger = logging.getLogger(__name__) @@ -23,6 +24,22 @@ class TaskProgress(BaseModel): message: ProgressMessage = "" percent: ProgressPercent = 0.0 + @staticmethod + def _update_json_schema_extra(schema: JsonDict) -> None: + schema.update( + { + "examples": [ + { + "task_id": "3ac48b54-a48d-4c5e-a6ac-dcaddb9eaa59", + "message": "Halfway done", + "percent": 0.5, + } + ] + } + ) + + model_config = ConfigDict(json_schema_extra=_update_json_schema_extra) + # used to propagate progress updates internally _update_callback: Callable[["TaskProgress"], Awaitable[None]] | None = None diff --git a/packages/models-library/src/models_library/api_schemas_webserver/functions.py b/packages/models-library/src/models_library/api_schemas_webserver/functions.py index 66193388074..226db44f68d 100644 --- a/packages/models-library/src/models_library/api_schemas_webserver/functions.py +++ b/packages/models-library/src/models_library/api_schemas_webserver/functions.py @@ -1,7 +1,7 @@ import datetime from typing import Annotated, TypeAlias -from pydantic import Field, HttpUrl +from pydantic import ConfigDict, Field, HttpUrl from ..functions import ( Function, @@ -141,6 +141,45 @@ class RegisteredProjectFunctionGet(RegisteredProjectFunction, OutputSchema): modified_at: Annotated[datetime.datetime, Field(alias="lastChangeDate")] access_rights: dict[GroupID, FunctionGroupAccessRightsGet] thumbnail: HttpUrl | None = None + model_config = ConfigDict( + populate_by_name=True, + json_schema_extra={ + "examples": [ + { + "function_class": "PROJECT", + "title": "Example Project Function", + "description": "This is an example project function.", + "input_schema": { + "schema_content": { + "type": "object", + "properties": {"input1": {"type": "integer"}}, + }, + "schema_class": "application/schema+json", + }, + "output_schema": { + "schema_content": { + "type": "object", + "properties": {"output1": {"type": "string"}}, + }, + "schema_class": "application/schema+json", + }, + "default_inputs": None, + "project_id": "11111111-1111-1111-1111-111111111111", + "uid": "22222222-2222-2222-2222-222222222222", + "created_at": "2024-01-01T12:00:00", + "modified_at": "2024-01-02T12:00:00", + "access_rights": { + "5": { + "read": True, + "write": False, + "execute": True, + } + }, + "thumbnail": None, + }, + ] + }, + ) class SolverFunctionToRegister(SolverFunction, InputSchema): ... diff --git a/packages/models-library/src/models_library/functions.py b/packages/models-library/src/models_library/functions.py index 31434106c49..df99fcc4189 100644 --- a/packages/models-library/src/models_library/functions.py +++ b/packages/models-library/src/models_library/functions.py @@ -114,7 +114,37 @@ class ProjectFunction(FunctionBase): class RegisteredProjectFunction(ProjectFunction, RegisteredFunctionBase): - pass + model_config = ConfigDict( + populate_by_name=True, + json_schema_extra={ + "examples": [ + { + "function_class": "PROJECT", + "title": "Example Project Function", + "description": "This is an example project function.", + "input_schema": { + "schema_content": { + "type": "object", + "properties": {"input1": {"type": "integer"}}, + }, + "schema_class": "application/schema+json", + }, + "output_schema": { + "schema_content": { + "type": "object", + "properties": {"output1": {"type": "string"}}, + }, + "schema_class": "application/schema+json", + }, + "default_inputs": None, + "project_id": "11111111-1111-1111-1111-111111111111", + "uid": "22222222-2222-2222-2222-222222222222", + "created_at": "2024-01-01T12:00:00", + "modified_at": "2024-01-02T12:00:00", + }, + ] + }, + ) SolverJobID: TypeAlias = UUID diff --git a/packages/models-library/src/models_library/progress_bar.py b/packages/models-library/src/models_library/progress_bar.py index ad8130570e5..21fb158a0eb 100644 --- a/packages/models-library/src/models_library/progress_bar.py +++ b/packages/models-library/src/models_library/progress_bar.py @@ -1,6 +1,7 @@ from typing import Literal, TypeAlias from pydantic import BaseModel, ConfigDict +from pydantic.config import JsonDict # NOTE: keep a list of possible unit, and please use correct official unit names ProgressUnit: TypeAlias = Literal["Byte"] @@ -13,9 +14,10 @@ class ProgressStructuredMessage(BaseModel): unit: str | None = None sub: "ProgressStructuredMessage | None" = None - model_config = ConfigDict( - json_schema_extra={ - "examples": [ + @staticmethod + def _update_json_schema_extra(schema: JsonDict) -> None: + schema.update( + examples=[ { "description": "some description", "current": 12.2, @@ -39,8 +41,9 @@ class ProgressStructuredMessage(BaseModel): }, }, ] - } - ) + ) + + model_config = ConfigDict(json_schema_extra=_update_json_schema_extra) UNITLESS = None @@ -96,7 +99,17 @@ def composed_message(self) -> str: { "actual_value": 0.3, "total": 1.0, - "message": ProgressStructuredMessage.model_config["json_schema_extra"]["examples"][2], # type: ignore [index] + "message": { + "description": "downloading", + "current": 2.0, + "total": 5, + "sub": { + "description": "port 2", + "current": 12.2, + "total": 123, + "unit": "Byte", + }, + }, }, ] }, diff --git a/packages/pytest-simcore/src/pytest_simcore/celery_library_mocks.py b/packages/pytest-simcore/src/pytest_simcore/celery_library_mocks.py new file mode 100644 index 00000000000..c027bc0cbd4 --- /dev/null +++ b/packages/pytest-simcore/src/pytest_simcore/celery_library_mocks.py @@ -0,0 +1,96 @@ +# pylint: disable=redefined-outer-name + +from collections.abc import Callable + +import pytest +from faker import Faker +from pytest_mock import MockerFixture, MockType +from servicelib.celery.models import TaskStatus, TaskUUID +from servicelib.celery.task_manager import Task, TaskManager + +_faker = Faker() + + +@pytest.fixture +def submit_task_return_value() -> TaskUUID: + return TaskUUID(_faker.uuid4()) + + +@pytest.fixture +def cancel_task_return_value() -> None: + return None + + +@pytest.fixture +def get_task_result_return_value() -> dict: + return {"result": "example"} + + +@pytest.fixture +def get_task_status_return_value() -> TaskStatus: + example = TaskStatus.model_json_schema()["examples"][0] + return TaskStatus.model_validate(example) + + +@pytest.fixture +def list_tasks_return_value() -> list[Task]: + examples = Task.model_json_schema()["examples"] + assert len(examples) > 0 + return [Task.model_validate(example) for example in examples] + + +@pytest.fixture +def set_task_progress_return_value() -> None: + return None + + +@pytest.fixture +def mock_task_manager_object( + mocker: MockerFixture, + submit_task_return_value: TaskUUID, + cancel_task_return_value: None, + get_task_result_return_value: dict, + get_task_status_return_value: TaskStatus, + list_tasks_return_value: list[Task], + set_task_progress_return_value: None, +) -> MockType: + """ + Returns a TaskManager mock with overridable return values for each method. + If a return value is an Exception, the method will raise it. + """ + mock = mocker.Mock(spec=TaskManager) + + def _set_return_or_raise(method, value): + if isinstance(value, Exception): + method.side_effect = lambda *a, **kw: (_ for _ in ()).throw(value) + else: + method.return_value = value + + _set_return_or_raise(mock.submit_task, submit_task_return_value) + _set_return_or_raise(mock.cancel_task, cancel_task_return_value) + _set_return_or_raise(mock.get_task_result, get_task_result_return_value) + _set_return_or_raise(mock.get_task_status, get_task_status_return_value) + _set_return_or_raise(mock.list_tasks, list_tasks_return_value) + _set_return_or_raise(mock.set_task_progress, set_task_progress_return_value) + return mock + + +@pytest.fixture +def mock_task_manager_object_raising_factory( + mocker: MockerFixture, +) -> Callable[[Exception], MockType]: + def _factory(task_manager_exception: Exception) -> MockType: + mock = mocker.Mock(spec=TaskManager) + + def _raise_exc(*args, **kwargs): + raise task_manager_exception + + mock.submit_task.side_effect = _raise_exc + mock.cancel_task.side_effect = _raise_exc + mock.get_task_result.side_effect = _raise_exc + mock.get_task_status.side_effect = _raise_exc + mock.list_tasks.side_effect = _raise_exc + mock.set_task_progress.side_effect = _raise_exc + return mock + + return _factory diff --git a/packages/pytest-simcore/src/pytest_simcore/simcore_services.py b/packages/pytest-simcore/src/pytest_simcore/simcore_services.py index 77c607cbb09..274a8edb44a 100644 --- a/packages/pytest-simcore/src/pytest_simcore/simcore_services.py +++ b/packages/pytest-simcore/src/pytest_simcore/simcore_services.py @@ -29,6 +29,7 @@ _SERVICES_TO_SKIP: Final[set[str]] = { + "api-worker", "agent", # global mode deploy (NO exposed ports, has http API) "dask-sidecar", # global mode deploy (NO exposed ports, **NO** http API) "migration", diff --git a/packages/service-library/src/servicelib/celery/models.py b/packages/service-library/src/servicelib/celery/models.py index 40756553377..0c46e1716b1 100644 --- a/packages/service-library/src/servicelib/celery/models.py +++ b/packages/service-library/src/servicelib/celery/models.py @@ -4,7 +4,8 @@ from uuid import UUID from models_library.progress_bar import ProgressReport -from pydantic import BaseModel, StringConstraints +from pydantic import BaseModel, ConfigDict, StringConstraints +from pydantic.config import JsonDict TaskID: TypeAlias = str TaskName: TypeAlias = Annotated[ @@ -28,6 +29,7 @@ class TaskState(StrEnum): class TasksQueue(StrEnum): CPU_BOUND = "cpu_bound" DEFAULT = "default" + API_WORKER_QUEUE = "api_worker_queue" class TaskMetadata(BaseModel): @@ -40,6 +42,41 @@ class Task(BaseModel): uuid: TaskUUID metadata: TaskMetadata + @staticmethod + def _update_json_schema_extra(schema: JsonDict) -> None: + schema.update( + { + "examples": [ + { + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "metadata": { + "name": "task1", + "ephemeral": True, + "queue": "default", + }, + }, + { + "uuid": "223e4567-e89b-12d3-a456-426614174001", + "metadata": { + "name": "task2", + "ephemeral": False, + "queue": "cpu_bound", + }, + }, + { + "uuid": "323e4567-e89b-12d3-a456-426614174002", + "metadata": { + "name": "task3", + "ephemeral": True, + "queue": "default", + }, + }, + ] + } + ) + + model_config = ConfigDict(json_schema_extra=_update_json_schema_extra) + _TASK_DONE = {TaskState.SUCCESS, TaskState.FAILURE, TaskState.ABORTED} @@ -72,6 +109,33 @@ class TaskStatus(BaseModel): task_state: TaskState progress_report: ProgressReport + @staticmethod + def _update_json_schema_extra(schema: JsonDict) -> None: + + schema.update( + { + "examples": [ + { + "task_uuid": "123e4567-e89b-12d3-a456-426614174000", + "task_state": "SUCCESS", + "progress_report": { + "actual_value": 0.5, + "total": 1.0, + "attempts": 1, + "unit": "Byte", + "message": { + "description": "some description", + "current": 12.2, + "total": 123, + }, + }, + } + ] + } + ) + + model_config = ConfigDict(json_schema_extra=_update_json_schema_extra) + @property def is_done(self) -> bool: return self.task_state in _TASK_DONE diff --git a/packages/settings-library/src/settings_library/postgres.py b/packages/settings-library/src/settings_library/postgres.py index 325a3288414..90d456cbda0 100644 --- a/packages/settings-library/src/settings_library/postgres.py +++ b/packages/settings-library/src/settings_library/postgres.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Annotated +from typing import Annotated, Self from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse from pydantic import ( @@ -7,8 +7,7 @@ Field, PostgresDsn, SecretStr, - ValidationInfo, - field_validator, + model_validator, ) from pydantic.config import JsonDict from pydantic_settings import SettingsConfigDict @@ -50,13 +49,15 @@ class PostgresSettings(BaseCustomSettings): ), ] = None - @field_validator("POSTGRES_MAXSIZE") - @classmethod - def _check_size(cls, v, info: ValidationInfo): - if info.data["POSTGRES_MINSIZE"] > v: - msg = f"assert POSTGRES_MINSIZE={info.data['POSTGRES_MINSIZE']} <= POSTGRES_MAXSIZE={v}" + @model_validator(mode="after") + def validate_postgres_sizes(self) -> Self: + if self.POSTGRES_MINSIZE > self.POSTGRES_MAXSIZE: + msg = ( + f"assert POSTGRES_MINSIZE={self.POSTGRES_MINSIZE} <= " + f"POSTGRES_MAXSIZE={self.POSTGRES_MAXSIZE}" + ) raise ValueError(msg) - return v + return self @cached_property def dsn(self) -> str: @@ -135,3 +136,4 @@ def _update_json_schema_extra(schema: JsonDict) -> None: ) model_config = SettingsConfigDict(json_schema_extra=_update_json_schema_extra) + model_config = SettingsConfigDict(json_schema_extra=_update_json_schema_extra) diff --git a/services/api-server/.env-devel b/services/api-server/.env-devel index 29d4830d47f..a18401e3a5b 100644 --- a/services/api-server/.env-devel +++ b/services/api-server/.env-devel @@ -28,9 +28,12 @@ POSTGRES_PASSWORD=test POSTGRES_DB=test POSTGRES_HOST=127.0.0.1 -# Enables debug -SC_BOOT_MODE=debug - +# rabbit +RABBIT_HOST=rabbit +RABBIT_PASSWORD=adminadmin +RABBIT_PORT=5672 +RABBIT_SECURE=false +RABBIT_USER=admin # webserver WEBSERVER_HOST=webserver diff --git a/services/api-server/Makefile b/services/api-server/Makefile index 4db8527326b..555c88f6ec3 100644 --- a/services/api-server/Makefile +++ b/services/api-server/Makefile @@ -28,7 +28,9 @@ reqs: ## compiles pip requirements (.in -> .txt) define _create_and_validate_openapi # generating openapi specs file under $< (NOTE: Skips DEV FEATURES since this OAS is the 'offically released'!) - @source .env; \ + set -o allexport; \ + source .env; \ + set +o allexport; \ export API_SERVER_DEV_FEATURES_ENABLED=$1; \ python3 -c "import json; from $(APP_PACKAGE_NAME).main import *; print( json.dumps(app_factory().openapi(), indent=2) )" > $@ diff --git a/services/api-server/docker/boot.sh b/services/api-server/docker/boot.sh index 0f19b262c78..227be9c56b9 100755 --- a/services/api-server/docker/boot.sh +++ b/services/api-server/docker/boot.sh @@ -39,23 +39,51 @@ APP_LOG_LEVEL=${API_SERVER_LOGLEVEL:-${LOG_LEVEL:-${LOGLEVEL:-INFO}}} SERVER_LOG_LEVEL=$(echo "${APP_LOG_LEVEL}" | tr '[:upper:]' '[:lower:]') echo "$INFO" "Log-level app/server: $APP_LOG_LEVEL/$SERVER_LOG_LEVEL" -if [ "${SC_BOOT_MODE}" = "debug" ]; then - reload_dir_packages=$(fdfind src /devel/packages --exec echo '--reload-dir {} ' | tr '\n' ' ') +if [ "${API_SERVER_WORKER_MODE}" = "true" ]; then + if [ "${SC_BOOT_MODE}" = "debug" ]; then + exec watchmedo auto-restart \ + --directory /devel/packages \ + --directory services/api-server \ + --pattern "*.py" \ + --recursive \ + -- \ + celery \ + --app=boot_celery_worker:app \ + --workdir=services/api-server/docker \ + worker --pool=threads \ + --loglevel="${API_SERVER_LOGLEVEL}" \ + --concurrency="${CELERY_CONCURRENCY}" \ + --hostname="${API_SERVER_WORKER_NAME}" \ + --queues="${CELERY_QUEUES:-default}" + else + exec celery \ + --app=boot_celery_worker:app \ + --workdir=services/api-server/docker \ + worker --pool=threads \ + --loglevel="${API_SERVER_LOGLEVEL}" \ + --concurrency="${CELERY_CONCURRENCY}" \ + --hostname="${API_SERVER_WORKER_NAME}" \ + --queues="${CELERY_QUEUES:-default}" + fi +else + if [ "${SC_BOOT_MODE}" = "debug" ]; then + reload_dir_packages=$(fdfind src /devel/packages --exec echo '--reload-dir {} ' | tr '\n' ' ') - exec sh -c " - cd services/api-server/src/simcore_service_api_server && \ - python -Xfrozen_modules=off -m debugpy --listen 0.0.0.0:${API_SERVER_REMOTE_DEBUG_PORT} -m \ - uvicorn \ - --factory main:app_factory \ + exec sh -c " + cd services/api-server/src/simcore_service_api_server && \ + python -Xfrozen_modules=off -m debugpy --listen 0.0.0.0:${API_SERVER_REMOTE_DEBUG_PORT} -m \ + uvicorn \ + --factory main:app_factory \ + --host 0.0.0.0 \ + --reload \ + $reload_dir_packages \ + --reload-dir . \ + --log-level \"${SERVER_LOG_LEVEL}\" + " + else + exec uvicorn \ + --factory simcore_service_api_server.main:app_factory \ --host 0.0.0.0 \ - --reload \ - $reload_dir_packages \ - --reload-dir . \ - --log-level \"${SERVER_LOG_LEVEL}\" - " -else - exec uvicorn \ - --factory simcore_service_api_server.main:app_factory \ - --host 0.0.0.0 \ - --log-level "${SERVER_LOG_LEVEL}" + --log-level "${SERVER_LOG_LEVEL}" + fi fi diff --git a/services/api-server/docker/boot_celery_worker.py b/services/api-server/docker/boot_celery_worker.py new file mode 100644 index 00000000000..e0c7e119ced --- /dev/null +++ b/services/api-server/docker/boot_celery_worker.py @@ -0,0 +1,13 @@ +from celery.signals import worker_init, worker_shutdown # type: ignore[import-untyped] +from celery_library.signals import ( + on_worker_shutdown, +) +from simcore_service_api_server.celery_worker.worker_main import ( + get_app, + worker_init_wrapper, +) + +app = get_app() + +worker_init.connect(worker_init_wrapper) +worker_shutdown.connect(on_worker_shutdown) diff --git a/services/api-server/docker/healthcheck.py b/services/api-server/docker/healthcheck.py index 808782f3261..66ba806d0db 100755 --- a/services/api-server/docker/healthcheck.py +++ b/services/api-server/docker/healthcheck.py @@ -18,18 +18,49 @@ """ import os +import subprocess import sys from urllib.request import urlopen +from simcore_service_api_server.core.settings import ApplicationSettings + SUCCESS, UNHEALTHY = 0, 1 # Disabled if boots with debugger ok = os.environ.get("SC_BOOT_MODE", "").lower() == "debug" +app_settings = ApplicationSettings.create_from_envs() + + +def _is_celery_worker_healthy(): + assert app_settings.API_SERVER_CELERY + broker_url = app_settings.API_SERVER_CELERY.CELERY_RABBIT_BROKER.dsn + + try: + result = subprocess.run( + [ + "celery", + "--broker", + broker_url, + "inspect", + "ping", + "--destination", + "celery@" + os.getenv("API_SERVER_WORKER_NAME", "worker"), + ], + capture_output=True, + text=True, + check=True, + ) + return "pong" in result.stdout + except subprocess.CalledProcessError: + return False + + # Queries host # pylint: disable=consider-using-with ok = ( ok + or (app_settings.API_SERVER_WORKER_MODE and _is_celery_worker_healthy()) or urlopen( "{host}{baseurl}".format( host=sys.argv[1], baseurl=os.environ.get("SIMCORE_NODE_BASEPATH", "") diff --git a/services/api-server/requirements/_test.in b/services/api-server/requirements/_test.in index 432cb44b7f5..067ce3a7305 100644 --- a/services/api-server/requirements/_test.in +++ b/services/api-server/requirements/_test.in @@ -24,6 +24,7 @@ pact-python pyinstrument pytest pytest-asyncio +pytest-celery pytest-cov pytest-docker pytest-mock diff --git a/services/api-server/requirements/_test.txt b/services/api-server/requirements/_test.txt index 6711b99216a..d1909ce4a50 100644 --- a/services/api-server/requirements/_test.txt +++ b/services/api-server/requirements/_test.txt @@ -17,6 +17,10 @@ alembic==1.14.0 # via # -c requirements/_base.txt # -r requirements/_test.in +amqp==5.3.1 + # via + # -c requirements/_base.txt + # kombu annotated-types==0.7.0 # via # -c requirements/_base.txt @@ -45,6 +49,10 @@ aws-sam-translator==1.55.0 # cfn-lint aws-xray-sdk==2.14.0 # via moto +billiard==4.2.1 + # via + # -c requirements/_base.txt + # celery boto3==1.38.1 # via # aws-sam-translator @@ -57,6 +65,10 @@ botocore==1.38.1 # s3transfer botocore-stubs==1.37.4 # via types-boto3 +celery==5.5.3 + # via + # -c requirements/_base.txt + # pytest-celery certifi==2024.8.30 # via # -c requirements/../../../requirements/constraints.txt @@ -81,9 +93,25 @@ click==8.2.1 # via # -c requirements/_base.txt # -r requirements/_test.in + # celery + # click-didyoumean + # click-plugins + # click-repl # flask # pact-python # uvicorn +click-didyoumean==0.3.1 + # via + # -c requirements/_base.txt + # celery +click-plugins==1.1.1.2 + # via + # -c requirements/_base.txt + # celery +click-repl==0.3.0 + # via + # -c requirements/_base.txt + # celery coverage==7.6.12 # via pytest-cov cryptography==44.0.0 @@ -93,10 +121,14 @@ cryptography==44.0.0 # moto # python-jose # sshpubkeys +debugpy==1.8.16 + # via pytest-celery docker==7.1.0 # via # -r requirements/_test.in # moto + # pytest-celery + # pytest-docker-tools ecdsa==0.19.0 # via # moto @@ -185,6 +217,10 @@ jsonschema==3.2.0 # openapi-spec-validator junit-xml==1.9 # via cfn-lint +kombu==5.5.4 + # via + # -c requirements/_base.txt + # celery mako==1.3.10 # via # -c requirements/../../../requirements/constraints.txt @@ -219,6 +255,7 @@ packaging==24.2 # via # -c requirements/_base.txt # aioresponses + # kombu # pytest pact-python==2.3.1 # via -r requirements/_test.in @@ -232,6 +269,10 @@ pluggy==1.5.0 # via # pytest # pytest-cov +prompt-toolkit==3.0.51 + # via + # -c requirements/_base.txt + # click-repl propcache==0.2.1 # via # -c requirements/_base.txt @@ -241,6 +282,7 @@ psutil==6.1.0 # via # -c requirements/_base.txt # pact-python + # pytest-celery pyasn1==0.4.8 # via # python-jose @@ -278,13 +320,18 @@ pytest==8.4.1 # pytest-asyncio # pytest-cov # pytest-docker + # pytest-docker-tools # pytest-mock pytest-asyncio==1.0.0 # via -r requirements/_test.in +pytest-celery==1.1.3 + # via -r requirements/_test.in pytest-cov==6.2.1 # via -r requirements/_test.in pytest-docker==3.2.3 # via -r requirements/_test.in +pytest-docker-tools==3.1.9 + # via pytest-celery pytest-mock==3.14.1 # via -r requirements/_test.in pytest-runner==6.0.1 @@ -293,6 +340,7 @@ python-dateutil==2.9.0.post0 # via # -c requirements/_base.txt # botocore + # celery # moto python-jose==3.4.0 # via moto @@ -344,6 +392,7 @@ setuptools==80.9.0 # moto # openapi-spec-validator # pbr + # pytest-celery six==1.17.0 # via # -c requirements/_base.txt @@ -375,6 +424,10 @@ starlette==0.47.2 # -c requirements/../../../requirements/constraints.txt # -c requirements/_base.txt # fastapi +tenacity==9.0.0 + # via + # -c requirements/_base.txt + # pytest-celery types-aiofiles==24.1.0.20241221 # via -r requirements/_test.in types-awscrt==0.23.10 @@ -406,6 +459,7 @@ tzdata==2025.2 # via # -c requirements/_base.txt # faker + # kombu urllib3==2.5.0 # via # -c requirements/../../../requirements/constraints.txt @@ -418,6 +472,16 @@ uvicorn==0.34.2 # via # -c requirements/_base.txt # pact-python +vine==5.1.0 + # via + # -c requirements/_base.txt + # amqp + # celery + # kombu +wcwidth==0.2.13 + # via + # -c requirements/_base.txt + # prompt-toolkit werkzeug==2.1.2 # via # flask diff --git a/services/api-server/requirements/ci.txt b/services/api-server/requirements/ci.txt index cc1799cee07..9d4fff8972a 100644 --- a/services/api-server/requirements/ci.txt +++ b/services/api-server/requirements/ci.txt @@ -12,6 +12,7 @@ --requirement _tools.txt # installs this repo's packages +simcore-celery-library @ ../../packages/celery-library/ simcore-common-library @ ../../packages/common-library simcore-models-library @ ../../packages/models-library simcore-postgres-database @ ../../packages/postgres-database/ diff --git a/services/api-server/requirements/dev.txt b/services/api-server/requirements/dev.txt index 5afc552d753..85f3f1c428e 100644 --- a/services/api-server/requirements/dev.txt +++ b/services/api-server/requirements/dev.txt @@ -12,6 +12,7 @@ --requirement _tools.txt # installs this repo's packages +--editable ../../packages/celery-library/ --editable ../../packages/common-library --editable ../../packages/models-library --editable ../../packages/postgres-database diff --git a/services/api-server/requirements/prod.txt b/services/api-server/requirements/prod.txt index 9d4d747507e..9df71af1a63 100644 --- a/services/api-server/requirements/prod.txt +++ b/services/api-server/requirements/prod.txt @@ -10,6 +10,7 @@ --requirement _base.txt # installs this repo's packages +simcore-celery-library @ ../../packages/celery-library/ simcore-models-library @ ../../packages/models-library simcore-common-library @ ../../packages/common-library/ simcore-postgres-database @ ../../packages/postgres-database/ diff --git a/services/api-server/src/simcore_service_api_server/_constants.py b/services/api-server/src/simcore_service_api_server/_constants.py index 7bfbfd43907..512a987b640 100644 --- a/services/api-server/src/simcore_service_api_server/_constants.py +++ b/services/api-server/src/simcore_service_api_server/_constants.py @@ -1,9 +1,12 @@ from typing import Final -MSG_BACKEND_SERVICE_UNAVAILABLE: Final[ - str -] = "backend service is disabled or unreachable" +from common_library.user_messages import user_message -MSG_INTERNAL_ERROR_USER_FRIENDLY_TEMPLATE: Final[ - str -] = "Oops! Something went wrong, but we've noted it down and we'll sort it out ASAP. Thanks for your patience!" +MSG_BACKEND_SERVICE_UNAVAILABLE: Final[str] = user_message( + "The service is currently unavailable. Please try again later.", _version=1 +) + +MSG_INTERNAL_ERROR_USER_FRIENDLY_TEMPLATE: Final[str] = user_message( + "Something went wrong on our end. We've been notified and will resolve this issue as soon as possible. Thank you for your patience.", + _version=2, +) diff --git a/services/api-server/src/simcore_service_api_server/_service_function_jobs.py b/services/api-server/src/simcore_service_api_server/_service_function_jobs.py index ab88ea5d101..4038ab9b6fc 100644 --- a/services/api-server/src/simcore_service_api_server/_service_function_jobs.py +++ b/services/api-server/src/simcore_service_api_server/_service_function_jobs.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import overload import jsonschema from common_library.exclude import as_dict_exclude_none @@ -6,8 +7,6 @@ FunctionClass, FunctionID, FunctionInputs, - FunctionInputsList, - FunctionJobCollection, FunctionJobCollectionID, FunctionJobID, FunctionJobStatus, @@ -15,8 +14,12 @@ ProjectFunctionJob, RegisteredFunction, RegisteredFunctionJob, - RegisteredFunctionJobCollection, + RegisteredFunctionJobPatch, + RegisteredProjectFunctionJobPatch, + RegisteredSolverFunctionJobPatch, SolverFunctionJob, + SolverJobID, + TaskID, ) from models_library.functions_errors import ( FunctionExecuteAccessDeniedError, @@ -35,7 +38,12 @@ from pydantic import ValidationError from ._service_jobs import JobService +from .exceptions.function_errors import ( + FunctionJobCacheNotFoundError, + FunctionJobProjectMissingError, +) from .models.api_resources import JobLinks +from .models.domain.functions import PreRegisteredFunctionJobData from .models.schemas.jobs import ( JobInputs, JobPricingSpecification, @@ -120,7 +128,7 @@ async def validate_function_inputs( async def inspect_function_job( self, function: RegisteredFunction, function_job: RegisteredFunctionJob ) -> FunctionJobStatus: - + """Raises FunctionJobProjectNotRegisteredError if no project is associated with job""" stored_job_status = await self._web_rpc_client.get_function_job_status( function_job_id=function_job.uid, user_id=self.user_id, @@ -134,14 +142,16 @@ async def inspect_function_job( function.function_class == FunctionClass.PROJECT and function_job.function_class == FunctionClass.PROJECT ): - assert function_job.project_job_id is not None # nosec + if function_job.project_job_id is None: + raise FunctionJobProjectMissingError() job_status = await self._job_service.inspect_study_job( job_id=function_job.project_job_id, ) elif (function.function_class == FunctionClass.SOLVER) and ( function_job.function_class == FunctionClass.SOLVER ): - assert function_job.solver_job_id is not None # nosec + if function_job.solver_job_id is None: + raise FunctionJobProjectMissingError() job_status = await self._job_service.inspect_solver_job( solver_key=function.solver_key, version=function.solver_version, @@ -162,16 +172,33 @@ async def inspect_function_job( job_status=new_job_status, ) - async def run_function( + async def create_function_job_inputs( # pylint: disable=no-self-use self, *, function: RegisteredFunction, function_inputs: FunctionInputs, - pricing_spec: JobPricingSpecification | None, - job_links: JobLinks, - x_simcore_parent_project_uuid: NodeID | None, - x_simcore_parent_node_id: NodeID | None, + ) -> JobInputs: + joined_inputs = join_inputs( + function.default_inputs, + function_inputs, + ) + return JobInputs( + values=joined_inputs or {}, + ) + + async def get_cached_function_job( + self, + *, + function: RegisteredFunction, + job_inputs: JobInputs, ) -> RegisteredFunctionJob: + """ + N.B. this function checks access rights + + raises FunctionsExecuteApiAccessDeniedError if user cannot execute functions + raises FunctionJobCacheNotFoundError if no cached job is found + + """ user_api_access_rights = ( await self._web_rpc_client.get_functions_user_api_access_rights( @@ -195,22 +222,9 @@ async def run_function( function_id=function.uid, ) - joined_inputs = join_inputs( - function.default_inputs, - function_inputs, - ) - - if function.input_schema is not None: - is_valid, validation_str = await self.validate_function_inputs( - function_id=function.uid, - inputs=joined_inputs, - ) - if not is_valid: - raise FunctionInputsValidationError(error=validation_str) - if cached_function_jobs := await self._web_rpc_client.find_cached_function_jobs( function_id=function.uid, - inputs=joined_inputs, + inputs=job_inputs.values, user_id=self.user_id, product_name=self.product_name, ): @@ -222,10 +236,156 @@ async def run_function( if job_status.status == RunningState.SUCCESS: return cached_function_job + raise FunctionJobCacheNotFoundError() + + async def pre_register_function_job( + self, + *, + function: RegisteredFunction, + job_inputs: JobInputs, + ) -> PreRegisteredFunctionJobData: + + if function.input_schema is not None: + is_valid, validation_str = await self.validate_function_inputs( + function_id=function.uid, + inputs=job_inputs.values, + ) + if not is_valid: + raise FunctionInputsValidationError(error=validation_str) + + if function.function_class == FunctionClass.PROJECT: + job = await self._web_rpc_client.register_function_job( + function_job=ProjectFunctionJob( + function_uid=function.uid, + title=f"Function job of function {function.uid}", + description=function.description, + inputs=job_inputs.values, + outputs=None, + project_job_id=None, + job_creation_task_id=None, + ), + user_id=self.user_id, + product_name=self.product_name, + ) + + elif function.function_class == FunctionClass.SOLVER: + job = await self._web_rpc_client.register_function_job( + function_job=SolverFunctionJob( + function_uid=function.uid, + title=f"Function job of function {function.uid}", + description=function.description, + inputs=job_inputs.values, + outputs=None, + solver_job_id=None, + job_creation_task_id=None, + ), + user_id=self.user_id, + product_name=self.product_name, + ) + else: + raise UnsupportedFunctionClassError( + function_class=function.function_class, + ) + + return PreRegisteredFunctionJobData( + function_job_id=job.uid, + job_inputs=job_inputs, + ) + + @overload + async def patch_registered_function_job( + self, + *, + user_id: UserID, + product_name: ProductName, + function_job_id: FunctionJobID, + function_class: FunctionClass, + job_creation_task_id: TaskID | None, + ) -> RegisteredFunctionJob: ... + + @overload + async def patch_registered_function_job( + self, + *, + user_id: UserID, + product_name: ProductName, + function_job_id: FunctionJobID, + function_class: FunctionClass, + job_creation_task_id: TaskID | None, + project_job_id: ProjectID | None, + ) -> RegisteredFunctionJob: ... + + @overload + async def patch_registered_function_job( + self, + *, + user_id: UserID, + product_name: ProductName, + function_job_id: FunctionJobID, + function_class: FunctionClass, + job_creation_task_id: TaskID | None, + solver_job_id: SolverJobID | None, + ) -> RegisteredFunctionJob: ... + + async def patch_registered_function_job( + self, + *, + user_id: UserID, + product_name: ProductName, + function_job_id: FunctionJobID, + function_class: FunctionClass, + job_creation_task_id: TaskID | None, + project_job_id: ProjectID | None = None, + solver_job_id: SolverJobID | None = None, + ) -> RegisteredFunctionJob: + # Only allow one of project_job_id or solver_job_id depending on function_class + patch: RegisteredFunctionJobPatch + if function_class == FunctionClass.PROJECT: + patch = RegisteredProjectFunctionJobPatch( + title=None, + description=None, + inputs=None, + outputs=None, + job_creation_task_id=job_creation_task_id, + project_job_id=project_job_id, + ) + elif function_class == FunctionClass.SOLVER: + patch = RegisteredSolverFunctionJobPatch( + title=None, + description=None, + inputs=None, + outputs=None, + job_creation_task_id=job_creation_task_id, + solver_job_id=solver_job_id, + ) + else: + raise UnsupportedFunctionClassError( + function_class=function_class, + ) + return await self._web_rpc_client.patch_registered_function_job( + user_id=user_id, + product_name=product_name, + function_job_id=function_job_id, + registered_function_job_patch=patch, + ) + + async def run_function( + self, + *, + job_creation_task_id: TaskID | None, + function: RegisteredFunction, + pre_registered_function_job_data: PreRegisteredFunctionJobData, + pricing_spec: JobPricingSpecification | None, + job_links: JobLinks, + x_simcore_parent_project_uuid: NodeID | None, + x_simcore_parent_node_id: NodeID | None, + ) -> RegisteredFunctionJob: + """N.B. this function does not check access rights. Use get_cached_function_job for that""" + if function.function_class == FunctionClass.PROJECT: study_job = await self._job_service.create_studies_job( study_id=function.project_id, - job_inputs=JobInputs(values=joined_inputs or {}), + job_inputs=pre_registered_function_job_data.job_inputs, hidden=True, job_links=job_links, x_simcore_parent_project_uuid=x_simcore_parent_project_uuid, @@ -236,25 +396,20 @@ async def run_function( job_id=study_job.id, pricing_spec=pricing_spec, ) - return await self._web_rpc_client.register_function_job( - function_job=ProjectFunctionJob( - function_uid=function.uid, - title=f"Function job of function {function.uid}", - description=function.description, - inputs=joined_inputs, - outputs=None, - project_job_id=study_job.id, - job_creation_task_id=None, - ), + return await self.patch_registered_function_job( user_id=self.user_id, product_name=self.product_name, + function_job_id=pre_registered_function_job_data.function_job_id, + function_class=FunctionClass.PROJECT, + job_creation_task_id=job_creation_task_id, + project_job_id=study_job.id, ) if function.function_class == FunctionClass.SOLVER: solver_job = await self._job_service.create_solver_job( solver_key=function.solver_key, version=function.solver_version, - inputs=JobInputs(values=joined_inputs or {}), + inputs=pre_registered_function_job_data.job_inputs, job_links=job_links, hidden=True, x_simcore_parent_project_uuid=x_simcore_parent_project_uuid, @@ -266,18 +421,13 @@ async def run_function( job_id=solver_job.id, pricing_spec=pricing_spec, ) - return await self._web_rpc_client.register_function_job( - function_job=SolverFunctionJob( - function_uid=function.uid, - title=f"Function job of function {function.uid}", - description=function.description, - inputs=joined_inputs, - outputs=None, - solver_job_id=solver_job.id, - job_creation_task_id=None, - ), + return await self.patch_registered_function_job( user_id=self.user_id, product_name=self.product_name, + function_job_id=pre_registered_function_job_data.function_job_id, + function_class=FunctionClass.SOLVER, + job_creation_task_id=job_creation_task_id, + solver_job_id=solver_job.id, ) raise UnsupportedFunctionClassError( @@ -287,33 +437,22 @@ async def run_function( async def map_function( self, *, + job_creation_task_id: TaskID | None, function: RegisteredFunction, - function_inputs_list: FunctionInputsList, + pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData], job_links: JobLinks, pricing_spec: JobPricingSpecification | None, x_simcore_parent_project_uuid: ProjectID | None, x_simcore_parent_node_id: NodeID | None, - ) -> RegisteredFunctionJobCollection: + ) -> None: - function_jobs = [ + for data in pre_registered_function_job_data_list: await self.run_function( + job_creation_task_id=job_creation_task_id, function=function, - function_inputs=function_inputs, + pre_registered_function_job_data=data, pricing_spec=pricing_spec, job_links=job_links, x_simcore_parent_project_uuid=x_simcore_parent_project_uuid, x_simcore_parent_node_id=x_simcore_parent_node_id, ) - for function_inputs in function_inputs_list - ] - - function_job_collection_description = f"Function job collection of map of function {function.uid} with {len(function_inputs_list)} inputs" - return await self._web_rpc_client.register_function_job_collection( - function_job_collection=FunctionJobCollection( - title="Function job collection of function map", - description=function_job_collection_description, - job_ids=[function_job.uid for function_job in function_jobs], - ), - user_id=self.user_id, - product_name=self.product_name, - ) diff --git a/services/api-server/src/simcore_service_api_server/_service_studies.py b/services/api-server/src/simcore_service_api_server/_service_studies.py deleted file mode 100644 index 89fa5196e34..00000000000 --- a/services/api-server/src/simcore_service_api_server/_service_studies.py +++ /dev/null @@ -1,27 +0,0 @@ -from dataclasses import dataclass - -from models_library.products import ProductName -from models_library.rest_pagination import ( - MAXIMUM_NUMBER_OF_ITEMS_PER_PAGE, -) -from models_library.users import UserID - -from ._service_jobs import JobService -from ._service_utils import check_user_product_consistency - -DEFAULT_PAGINATION_LIMIT = MAXIMUM_NUMBER_OF_ITEMS_PER_PAGE - 1 - - -@dataclass(frozen=True, kw_only=True) -class StudyService: - job_service: JobService - user_id: UserID - product_name: ProductName - - def __post_init__(self): - check_user_product_consistency( - service_cls_name=self.__class__.__name__, - service_provider=self.job_service, - user_id=self.user_id, - product_name=self.product_name, - ) diff --git a/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py b/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py new file mode 100644 index 00000000000..1fa0ccfb3e4 --- /dev/null +++ b/services/api-server/src/simcore_service_api_server/api/dependencies/celery.py @@ -0,0 +1,13 @@ +from typing import Final + +from celery_library.task_manager import CeleryTaskManager +from fastapi import FastAPI + +ASYNC_JOB_CLIENT_NAME: Final[str] = "API_SERVER" + + +def get_task_manager(app: FastAPI) -> CeleryTaskManager: + assert hasattr(app.state, "task_manager") # nosec + task_manager = app.state.task_manager + assert isinstance(task_manager, CeleryTaskManager) # nosec + return task_manager diff --git a/services/api-server/src/simcore_service_api_server/api/dependencies/webserver_http.py b/services/api-server/src/simcore_service_api_server/api/dependencies/webserver_http.py index 377356f22c0..df4325dc1ff 100644 --- a/services/api-server/src/simcore_service_api_server/api/dependencies/webserver_http.py +++ b/services/api-server/src/simcore_service_api_server/api/dependencies/webserver_http.py @@ -4,7 +4,6 @@ from common_library.json_serialization import json_dumps from cryptography.fernet import Fernet from fastapi import Depends, FastAPI, HTTPException, status -from fastapi.requests import Request from ..._constants import MSG_BACKEND_SERVICE_UNAVAILABLE from ...core.settings import ApplicationSettings, WebServerSettings @@ -29,19 +28,16 @@ def _get_settings( return settings -def _get_encrypt(request: Request) -> Fernet | None: - e: Fernet | None = getattr(request.app.state, "webserver_fernet", None) - return e - - def get_session_cookie( identity: Annotated[str, Depends(get_active_user_email)], settings: Annotated[WebServerSettings, Depends(_get_settings)], - fernet: Annotated[Fernet | None, Depends(_get_encrypt)], + app: Annotated[FastAPI, Depends(get_app)], ) -> dict: # Based on aiohttp_session and aiohttp_security # SEE services/web/server/tests/unit/with_dbs/test_login.py + fernet: Fernet | None = getattr(app.state, "webserver_fernet", None) + if fernet is None: raise HTTPException( status.HTTP_503_SERVICE_UNAVAILABLE, detail=MSG_BACKEND_SERVICE_UNAVAILABLE diff --git a/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py b/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py index d34cc0da0c6..b43f5840ce0 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py @@ -1,5 +1,7 @@ +from logging import getLogger from typing import Annotated, Final +from common_library.error_codes import create_error_code from fastapi import APIRouter, Depends, FastAPI, HTTPException, status from fastapi_pagination.api import create_page from fastapi_pagination.bases import AbstractPage @@ -15,10 +17,13 @@ from models_library.functions import RegisteredFunction from models_library.functions_errors import ( UnsupportedFunctionClassError, + UnsupportedFunctionFunctionJobClassCombinationError, ) from models_library.products import ProductName from models_library.users import UserID +from servicelib.celery.models import TaskUUID from servicelib.fastapi.dependencies import get_app +from servicelib.logging_errors import create_troubleshootting_log_kwargs from simcore_service_api_server.models.schemas.functions_filters import ( FunctionJobsListFilters, ) @@ -26,12 +31,14 @@ from ..._service_function_jobs import FunctionJobService from ..._service_jobs import JobService +from ...exceptions.function_errors import FunctionJobProjectMissingError from ...models.pagination import Page, PaginationParams from ...models.schemas.errors import ErrorGet from ...services_http.storage import StorageApi from ...services_http.webserver import AuthSession from ...services_rpc.wb_api_server import WbApiRpcClient from ..dependencies.authentication import get_current_user_id, get_product_name +from ..dependencies.celery import get_task_manager from ..dependencies.database import get_db_asyncpg_engine from ..dependencies.functions import ( get_function_from_functionjob, @@ -52,6 +59,9 @@ FMSG_CHANGELOG_NEW_IN_VERSION, create_route_description, ) +from .tasks import _get_task_filter + +_logger = getLogger(__name__) # pylint: disable=too-many-arguments # pylint: disable=cyclic-import @@ -60,6 +70,8 @@ JOB_LIST_FILTER_PAGE_RELEASE_VERSION = "0.11.0" JOB_LOG_RELEASE_VERSION = "0.11.0" +_JOB_CREATION_TASK_STATUS_PREFIX: Final[str] = "JOB_CREATION_TASK_STATUS_" + function_job_router = APIRouter() _COMMON_FUNCTION_JOB_ERROR_RESPONSES: Final[dict] = { @@ -196,6 +208,9 @@ async def delete_function_job( ), ) async def function_job_status( + app: Annotated[FastAPI, Depends(get_app)], + user_id: Annotated[UserID, Depends(get_current_user_id)], + product_name: Annotated[ProductName, Depends(get_product_name)], function_job: Annotated[ RegisteredFunctionJob, Depends(get_function_job_dependency) ], @@ -204,10 +219,43 @@ async def function_job_status( FunctionJobService, Depends(get_function_job_service) ], ) -> FunctionJobStatus: + try: + return await function_job_service.inspect_function_job( + function=function, function_job=function_job + ) + except FunctionJobProjectMissingError as exc: + if ( + function.function_class == FunctionClass.PROJECT + and function_job.function_class == FunctionClass.PROJECT + ) or ( + function.function_class == FunctionClass.SOLVER + and function_job.function_class == FunctionClass.SOLVER + ): + if task_id := function_job.job_creation_task_id: + task_manager = get_task_manager(app) + task_filter = _get_task_filter(user_id, product_name) + task_status = await task_manager.get_task_status( + task_uuid=TaskUUID(task_id), task_filter=task_filter + ) + return FunctionJobStatus( + status=f"{_JOB_CREATION_TASK_STATUS_PREFIX}{task_status.task_state}" + ) + user_error_msg = f"The creation of job {function_job.uid} failed" + support_id = create_error_code(Exception()) + _logger.exception( + **create_troubleshootting_log_kwargs( + user_error_msg, + error=Exception(), + error_code=support_id, + tip="Initial call to run metamodeling function must have failed", + ) + ) + raise - return await function_job_service.inspect_function_job( - function=function, function_job=function_job - ) + raise UnsupportedFunctionFunctionJobClassCombinationError( + function_class=function.function_class, + function_job_class=function_job.function_class, + ) from exc async def get_function_from_functionjobid( @@ -266,7 +314,11 @@ async def function_job_outputs( function.function_class == FunctionClass.PROJECT and function_job.function_class == FunctionClass.PROJECT ): - assert function_job.project_job_id is not None # nosec + if function_job.project_job_id is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Function job outputs not found", + ) new_outputs = dict( ( await studies_jobs.get_study_job_outputs( @@ -282,7 +334,11 @@ async def function_job_outputs( function.function_class == FunctionClass.SOLVER and function_job.function_class == FunctionClass.SOLVER ): - assert function_job.solver_job_id is not None # nosec + if function_job.solver_job_id is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Function job outputs not found", + ) new_outputs = dict( ( await solvers_jobs_read.get_job_outputs( diff --git a/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py b/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py index e1d05e129c3..097fa1fc5b0 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py @@ -1,3 +1,5 @@ +import contextlib + # pylint: disable=too-many-positional-arguments from collections.abc import Callable from typing import Annotated, Final, Literal @@ -16,19 +18,34 @@ RegisteredFunctionJob, RegisteredFunctionJobCollection, ) +from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter +from models_library.functions import FunctionJobCollection, FunctionJobID from models_library.products import ProductName from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID from models_library.users import UserID +from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata, TasksQueue from servicelib.fastapi.dependencies import get_reverse_url_mapper from ..._service_function_jobs import FunctionJobService from ..._service_functions import FunctionService +from ...celery_worker.worker_tasks.functions_tasks import function_map as map_task +from ...celery_worker.worker_tasks.functions_tasks import ( + run_function as run_function_task, +) +from ...exceptions.function_errors import FunctionJobCacheNotFoundError +from ...models.domain.functions import PreRegisteredFunctionJobData from ...models.pagination import Page, PaginationParams from ...models.schemas.errors import ErrorGet from ...models.schemas.jobs import JobPricingSpecification from ...services_rpc.wb_api_server import WbApiRpcClient -from ..dependencies.authentication import get_current_user_id, get_product_name +from ..dependencies.authentication import ( + Identity, + get_current_identity, + get_current_user_id, + get_product_name, +) +from ..dependencies.celery import ASYNC_JOB_CLIENT_NAME, get_task_manager from ..dependencies.services import ( get_function_job_service, get_function_service, @@ -313,16 +330,19 @@ async def validate_function_inputs( ) async def run_function( # noqa: PLR0913 request: Request, + user_identity: Annotated[Identity, Depends(get_current_identity)], to_run_function: Annotated[RegisteredFunction, Depends(get_function)], url_for: Annotated[Callable, Depends(get_reverse_url_mapper)], function_inputs: FunctionInputs, function_service: Annotated[FunctionService, Depends(get_function_service)], - function_jobs_service: Annotated[ + function_job_service: Annotated[ FunctionJobService, Depends(get_function_job_service) ], x_simcore_parent_project_uuid: Annotated[ProjectID | Literal["null"], Header()], x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()], ) -> RegisteredFunctionJob: + # preprocess inputs + task_manager = get_task_manager(request.app) parent_project_uuid = ( x_simcore_parent_project_uuid if isinstance(x_simcore_parent_project_uuid, ProjectID) @@ -335,16 +355,57 @@ async def run_function( # noqa: PLR0913 ) pricing_spec = JobPricingSpecification.create_from_headers(request.headers) job_links = await function_service.get_function_job_links(to_run_function, url_for) + job_inputs = await function_job_service.create_function_job_inputs( + function=to_run_function, function_inputs=function_inputs + ) - return await function_jobs_service.run_function( + # check if results are cached + with contextlib.suppress(FunctionJobCacheNotFoundError): + return await function_job_service.get_cached_function_job( + function=to_run_function, + job_inputs=job_inputs, + ) + + pre_registered_function_job_data = ( + await function_job_service.pre_register_function_job( + function=to_run_function, + job_inputs=job_inputs, + ) + ) + + # run function in celery task + job_filter = AsyncJobFilter( + user_id=user_identity.user_id, + product_name=user_identity.product_name, + client_name=ASYNC_JOB_CLIENT_NAME, + ) + task_filter = TaskFilter.model_validate(job_filter.model_dump()) + task_name = run_function_task.__name__ + + task_uuid = await task_manager.submit_task( + TaskMetadata( + name=task_name, + ephemeral=True, + queue=TasksQueue.API_WORKER_QUEUE, + ), + task_filter=task_filter, + user_identity=user_identity, function=to_run_function, - function_inputs=function_inputs, + pre_registered_function_job_data=pre_registered_function_job_data, pricing_spec=pricing_spec, job_links=job_links, x_simcore_parent_project_uuid=parent_project_uuid, x_simcore_parent_node_id=parent_node_id, ) + return await function_job_service.patch_registered_function_job( + user_id=user_identity.user_id, + product_name=user_identity.product_name, + function_job_id=pre_registered_function_job_data.function_job_id, + function_class=to_run_function.function_class, + job_creation_task_id=TaskID(task_uuid), + ) + @function_router.delete( "/{function_id:uuid}", @@ -385,6 +446,7 @@ async def delete_function( ) async def map_function( # noqa: PLR0913 request: Request, + user_identity: Annotated[Identity, Depends(get_current_identity)], to_run_function: Annotated[RegisteredFunction, Depends(get_function)], function_inputs_list: FunctionInputsList, url_for: Annotated[Callable, Depends(get_reverse_url_mapper)], @@ -392,10 +454,12 @@ async def map_function( # noqa: PLR0913 FunctionJobService, Depends(get_function_job_service) ], function_service: Annotated[FunctionService, Depends(get_function_service)], + web_api_rpc_client: Annotated[WbApiRpcClient, Depends(get_wb_api_rpc_client)], x_simcore_parent_project_uuid: Annotated[ProjectID | Literal["null"], Header()], x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()], ) -> RegisteredFunctionJobCollection: + task_manager = get_task_manager(request.app) parent_project_uuid = ( x_simcore_parent_project_uuid if isinstance(x_simcore_parent_project_uuid, ProjectID) @@ -407,14 +471,75 @@ async def map_function( # noqa: PLR0913 else None ) pricing_spec = JobPricingSpecification.create_from_headers(request.headers) - job_links = await function_service.get_function_job_links(to_run_function, url_for) - return await function_jobs_service.map_function( + job_inputs_list = [ + await function_jobs_service.create_function_job_inputs( + function=to_run_function, function_inputs=function_inputs + ) + for function_inputs in function_inputs_list + ] + + job_ids: list[FunctionJobID] = [] + pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData] = [] + + for job_inputs in job_inputs_list: + try: + cached_job = await function_jobs_service.get_cached_function_job( + function=to_run_function, + job_inputs=job_inputs, + ) + job_ids.append(cached_job.uid) + except FunctionJobCacheNotFoundError: + data = await function_jobs_service.pre_register_function_job( + function=to_run_function, + job_inputs=job_inputs, + ) + pre_registered_function_job_data_list.append(data) + job_ids.append(data.function_job_id) + + # run map in celery task + job_filter = AsyncJobFilter( + user_id=user_identity.user_id, + product_name=user_identity.product_name, + client_name=ASYNC_JOB_CLIENT_NAME, + ) + task_filter = TaskFilter.model_validate(job_filter.model_dump()) + task_name = map_task.__name__ + + task_uuid = await task_manager.submit_task( + TaskMetadata( + name=task_name, + ephemeral=True, + queue=TasksQueue.API_WORKER_QUEUE, + ), + task_filter=task_filter, + user_identity=user_identity, function=to_run_function, - function_inputs_list=function_inputs_list, + pre_registered_function_job_data_list=pre_registered_function_job_data_list, pricing_spec=pricing_spec, job_links=job_links, x_simcore_parent_project_uuid=parent_project_uuid, x_simcore_parent_node_id=parent_node_id, ) + + # patch pre-registered function jobs + for data in pre_registered_function_job_data_list: + await function_jobs_service.patch_registered_function_job( + user_id=user_identity.user_id, + product_name=user_identity.product_name, + function_job_id=data.function_job_id, + function_class=to_run_function.function_class, + job_creation_task_id=TaskID(task_uuid), + ) + + function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(pre_registered_function_job_data_list)} inputs" + return await web_api_rpc_client.register_function_job_collection( + function_job_collection=FunctionJobCollection( + title="Function job collection of function map", + description=function_job_collection_description, + job_ids=job_ids, + ), + user_id=user_identity.user_id, + product_name=user_identity.product_name, + ) diff --git a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py index ff0f12f2d69..9837a5f625f 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/tasks.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/tasks.py @@ -1,7 +1,8 @@ import logging -from typing import Annotated, Any, Final +from typing import Annotated, Any -from fastapi import APIRouter, Depends, FastAPI, status +from common_library.error_codes import create_error_code +from fastapi import APIRouter, Depends, FastAPI, HTTPException, status from models_library.api_schemas_long_running_tasks.base import TaskProgress from models_library.api_schemas_long_running_tasks.tasks import ( TaskGet, @@ -14,28 +15,28 @@ ) from models_library.products import ProductName from models_library.users import UserID +from servicelib.celery.models import TaskFilter, TaskState, TaskUUID from servicelib.fastapi.dependencies import get_app +from servicelib.logging_errors import create_troubleshootting_log_kwargs from ...models.schemas.base import ApiServerEnvelope from ...models.schemas.errors import ErrorGet -from ...services_rpc.async_jobs import AsyncJobClient from ..dependencies.authentication import get_current_user_id, get_product_name -from ..dependencies.tasks import get_async_jobs_client +from ..dependencies.celery import ASYNC_JOB_CLIENT_NAME, get_task_manager from ._constants import ( FMSG_CHANGELOG_NEW_IN_VERSION, create_route_description, ) -_ASYNC_JOB_CLIENT_NAME: Final[str] = "API_SERVER" - router = APIRouter() _logger = logging.getLogger(__name__) -def _get_job_filter(user_id: UserID, product_name: ProductName) -> AsyncJobFilter: - return AsyncJobFilter( - user_id=user_id, product_name=product_name, client_name=_ASYNC_JOB_CLIENT_NAME +def _get_task_filter(user_id: UserID, product_name: ProductName) -> TaskFilter: + job_filter = AsyncJobFilter( + user_id=user_id, product_name=product_name, client_name=ASYNC_JOB_CLIENT_NAME ) + return TaskFilter.model_validate(job_filter.model_dump()) _DEFAULT_TASK_STATUS_CODES: dict[int | str, dict[str, Any]] = { @@ -62,26 +63,28 @@ async def list_tasks( app: Annotated[FastAPI, Depends(get_app)], user_id: Annotated[UserID, Depends(get_current_user_id)], product_name: Annotated[ProductName, Depends(get_product_name)], - async_jobs: Annotated[AsyncJobClient, Depends(get_async_jobs_client)], ): - user_async_jobs = await async_jobs.list_jobs( - job_filter=_get_job_filter(user_id, product_name), - filter_="", + + task_manager = get_task_manager(app) + + tasks = await task_manager.list_tasks( + task_filter=_get_task_filter(user_id, product_name), ) + app_router = app.router data = [ TaskGet( - task_id=f"{job.job_id}", - task_name=job.job_name, + task_id=f"{task.uuid}", + task_name=task.metadata.name, status_href=app_router.url_path_for( - "get_task_status", task_id=f"{job.job_id}" + "get_task_status", task_id=f"{task.uuid}" ), - abort_href=app_router.url_path_for("cancel_task", task_id=f"{job.job_id}"), + abort_href=app_router.url_path_for("cancel_task", task_id=f"{task.uuid}"), result_href=app_router.url_path_for( - "get_task_result", task_id=f"{job.job_id}" + "get_task_result", task_id=f"{task.uuid}" ), ) - for job in user_async_jobs + for task in tasks ] return ApiServerEnvelope(data=data) @@ -100,20 +103,23 @@ async def list_tasks( ) async def get_task_status( task_id: AsyncJobId, + app: Annotated[FastAPI, Depends(get_app)], user_id: Annotated[UserID, Depends(get_current_user_id)], product_name: Annotated[ProductName, Depends(get_product_name)], - async_jobs: Annotated[AsyncJobClient, Depends(get_async_jobs_client)], ): - async_job_rpc_status = await async_jobs.status( - job_id=task_id, - job_filter=_get_job_filter(user_id, product_name), + task_manager = get_task_manager(app) + + task_status = await task_manager.get_task_status( + task_filter=_get_task_filter(user_id, product_name), + task_uuid=TaskUUID(f"{task_id}"), ) - _task_id = f"{async_job_rpc_status.job_id}" + return TaskStatus( task_progress=TaskProgress( - task_id=_task_id, percent=async_job_rpc_status.progress.percent_value + task_id=f"{task_status.task_uuid}", + percent=task_status.progress_report.percent_value, ), - done=async_job_rpc_status.done, + done=task_status.is_done, started=None, ) @@ -132,13 +138,15 @@ async def get_task_status( ) async def cancel_task( task_id: AsyncJobId, + app: Annotated[FastAPI, Depends(get_app)], user_id: Annotated[UserID, Depends(get_current_user_id)], product_name: Annotated[ProductName, Depends(get_product_name)], - async_jobs: Annotated[AsyncJobClient, Depends(get_async_jobs_client)], ): - await async_jobs.cancel( - job_id=task_id, - job_filter=_get_job_filter(user_id, product_name), + task_manager = get_task_manager(app) + + await task_manager.cancel_task( + task_filter=_get_task_filter(user_id, product_name), + task_uuid=TaskUUID(f"{task_id}"), ) @@ -166,12 +174,49 @@ async def cancel_task( ) async def get_task_result( task_id: AsyncJobId, + app: Annotated[FastAPI, Depends(get_app)], user_id: Annotated[UserID, Depends(get_current_user_id)], product_name: Annotated[ProductName, Depends(get_product_name)], - async_jobs: Annotated[AsyncJobClient, Depends(get_async_jobs_client)], ): - async_job_rpc_result = await async_jobs.result( - job_id=task_id, - job_filter=_get_job_filter(user_id, product_name), + task_manager = get_task_manager(app) + task_filter = _get_task_filter(user_id, product_name) + + task_status = await task_manager.get_task_status( + task_filter=task_filter, + task_uuid=TaskUUID(f"{task_id}"), + ) + + if not task_status.is_done: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task result not available yet", + ) + if task_status.task_state == TaskState.ABORTED: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Task was cancelled", + ) + + task_result = await task_manager.get_task_result( + task_filter=task_filter, + task_uuid=TaskUUID(f"{task_id}"), ) - return TaskResult(result=async_job_rpc_result.result, error=None) + + if task_status.task_state == TaskState.FAILURE: + assert isinstance(task_result, Exception) + user_error_msg = f"The execution of task {task_id} failed" + support_id = create_error_code(task_result) + _logger.exception( + **create_troubleshootting_log_kwargs( + user_error_msg, + error=task_result, + error_code=support_id, + tip="Unexpected error in Celery", + ) + ) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=user_error_msg, + ) + + return TaskResult(result=task_result, error=None) diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/__init__.py b/services/api-server/src/simcore_service_api_server/celery_worker/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py b/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py new file mode 100644 index 00000000000..e70b7f79112 --- /dev/null +++ b/services/api-server/src/simcore_service_api_server/celery_worker/worker_main.py @@ -0,0 +1,42 @@ +"""Main application to be deployed in for example uvicorn.""" + +from functools import partial + +from celery_library.common import create_app as create_celery_app +from celery_library.signals import ( + on_worker_init, +) +from servicelib.fastapi.celery.app_server import FastAPIAppServer +from servicelib.logging_utils import setup_loggers + +from ..core.application import create_app +from ..core.settings import ApplicationSettings +from .worker_tasks.tasks import setup_worker_tasks + + +def get_app(): + _settings = ApplicationSettings.create_from_envs() + + setup_loggers( + log_format_local_dev_enabled=_settings.API_SERVER_LOG_FORMAT_LOCAL_DEV_ENABLED, + logger_filter_mapping=_settings.API_SERVER_LOG_FILTER_MAPPING, + tracing_settings=_settings.API_SERVER_TRACING, + log_base_level=_settings.log_level, + noisy_loggers=None, + ) + + assert _settings.API_SERVER_CELERY # nosec + app = create_celery_app(_settings.API_SERVER_CELERY) + setup_worker_tasks(app) + + return app + + +def worker_init_wrapper(sender, **_kwargs): + _settings = ApplicationSettings.create_from_envs() + assert _settings.API_SERVER_CELERY # nosec + app_server = FastAPIAppServer(app=create_app(_settings)) + + return partial(on_worker_init, app_server, _settings.API_SERVER_CELERY)( + sender, **_kwargs + ) diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/__init__.py b/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py b/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py new file mode 100644 index 00000000000..4c3697c1ca8 --- /dev/null +++ b/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py @@ -0,0 +1,147 @@ +from celery import ( # type: ignore[import-untyped] # pylint: disable=no-name-in-module + Task, +) +from celery_library.utils import get_app_server # pylint: disable=no-name-in-module +from fastapi import FastAPI +from models_library.functions import RegisteredFunction, RegisteredFunctionJob +from models_library.projects import ProjectID +from models_library.projects_nodes_io import NodeID +from servicelib.celery.models import TaskID +from simcore_service_api_server._service_function_jobs import FunctionJobService + +from ...api.dependencies.authentication import Identity +from ...api.dependencies.rabbitmq import get_rabbitmq_rpc_client +from ...api.dependencies.services import ( + get_catalog_service, + get_directorv2_service, + get_function_job_service, + get_job_service, + get_solver_service, + get_storage_service, +) +from ...api.dependencies.webserver_http import get_session_cookie, get_webserver_session +from ...api.dependencies.webserver_rpc import get_wb_api_rpc_client +from ...models.api_resources import JobLinks +from ...models.domain.functions import PreRegisteredFunctionJobData +from ...models.schemas.jobs import JobPricingSpecification +from ...services_http.director_v2 import DirectorV2Api +from ...services_http.storage import StorageApi + + +async def _assemble_function_job_service( + *, app: FastAPI, user_identity: Identity +) -> FunctionJobService: + # This should ideally be done by a dependency injection system (like it is done in the api-server). + # However, for that we would need to introduce a dependency injection system which is not coupled to, + # but compatible with FastAPI's Depends. One suggestion: https://github.com/ets-labs/python-dependency-injector. + # See also https://github.com/fastapi/fastapi/issues/1105#issuecomment-609919850. + settings = app.state.settings + assert settings.API_SERVER_WEBSERVER # nosec + session_cookie = get_session_cookie( + identity=user_identity.email, settings=settings.API_SERVER_WEBSERVER, app=app + ) + + rpc_client = get_rabbitmq_rpc_client(app=app) + web_server_rest_client = get_webserver_session( + app=app, session_cookies=session_cookie, identity=user_identity + ) + web_api_rpc_client = await get_wb_api_rpc_client(app=app) + director2_api = DirectorV2Api.get_instance(app=app) + assert isinstance(director2_api, DirectorV2Api) # nosec + storage_api = StorageApi.get_instance(app=app) + assert isinstance(storage_api, StorageApi) # nosec + catalog_service = get_catalog_service( + rpc_client=rpc_client, + user_id=user_identity.user_id, + product_name=user_identity.product_name, + ) + + storage_service = get_storage_service( + rpc_client=rpc_client, + user_id=user_identity.user_id, + product_name=user_identity.product_name, + ) + directorv2_service = get_directorv2_service(rpc_client=rpc_client) + + solver_service = get_solver_service( + catalog_service=catalog_service, + user_id=user_identity.user_id, + product_name=user_identity.product_name, + ) + + job_service = get_job_service( + web_rest_api=web_server_rest_client, + director2_api=director2_api, + storage_api=storage_api, + web_rpc_api=web_api_rpc_client, + storage_service=storage_service, + directorv2_service=directorv2_service, + user_id=user_identity.user_id, + product_name=user_identity.product_name, + solver_service=solver_service, + ) + + return get_function_job_service( + web_rpc_api=web_api_rpc_client, + job_service=job_service, + user_id=user_identity.user_id, + product_name=user_identity.product_name, + ) + + +async def run_function( + task: Task, + task_id: TaskID, + *, + user_identity: Identity, + function: RegisteredFunction, + pre_registered_function_job_data: PreRegisteredFunctionJobData, + pricing_spec: JobPricingSpecification | None, + job_links: JobLinks, + x_simcore_parent_project_uuid: ProjectID | None, + x_simcore_parent_node_id: NodeID | None, +) -> RegisteredFunctionJob: + assert task_id # nosec + app = get_app_server(task.app).app + function_job_service = await _assemble_function_job_service( + app=app, user_identity=user_identity + ) + + return await function_job_service.run_function( + job_creation_task_id=task_id, + function=function, + pre_registered_function_job_data=pre_registered_function_job_data, + pricing_spec=pricing_spec, + job_links=job_links, + x_simcore_parent_project_uuid=x_simcore_parent_project_uuid, + x_simcore_parent_node_id=x_simcore_parent_node_id, + ) + + +async def function_map( + task: Task, + task_id: TaskID, + *, + user_identity: Identity, + function: RegisteredFunction, + pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData], + job_links: JobLinks, + pricing_spec: JobPricingSpecification | None, + x_simcore_parent_project_uuid: ProjectID | None, + x_simcore_parent_node_id: NodeID | None, +) -> None: + assert task_id # nosec + app = get_app_server(task.app).app + function_job_service = await _assemble_function_job_service( + app=app, user_identity=user_identity + ) + + return await function_job_service.map_function( + job_creation_task_id=task_id, + function=function, + pre_registered_function_job_data_list=pre_registered_function_job_data_list, + pricing_spec=pricing_spec, + job_links=job_links, + x_simcore_parent_project_uuid=x_simcore_parent_project_uuid, + x_simcore_parent_node_id=x_simcore_parent_node_id, + ) diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py b/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py new file mode 100644 index 00000000000..d51d4a4cdba --- /dev/null +++ b/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py @@ -0,0 +1,46 @@ +import logging + +from celery import ( # type: ignore[import-untyped] # pylint: disable=no-name-in-module + Celery, +) +from celery_library.task import register_task +from celery_library.types import register_celery_types, register_pydantic_types +from models_library.functions import ( + RegisteredProjectFunction, + RegisteredProjectFunctionJob, + RegisteredPythonCodeFunction, + RegisteredSolverFunction, + RegisteredSolverFunctionJob, +) +from servicelib.logging_utils import log_context + +from ...api.dependencies.authentication import Identity +from ...models.api_resources import JobLinks +from ...models.domain.functions import PreRegisteredFunctionJobData +from ...models.schemas.jobs import JobInputs, JobPricingSpecification +from .functions_tasks import function_map, run_function + +_logger = logging.getLogger(__name__) + +pydantic_types_to_register = ( + Identity, + JobInputs, + JobLinks, + JobPricingSpecification, + PreRegisteredFunctionJobData, + RegisteredProjectFunction, + RegisteredProjectFunctionJob, + RegisteredPythonCodeFunction, + RegisteredProjectFunctionJob, + RegisteredSolverFunction, + RegisteredSolverFunctionJob, +) + + +def setup_worker_tasks(app: Celery) -> None: + register_celery_types() + register_pydantic_types(*pydantic_types_to_register) + + with log_context(_logger, logging.INFO, msg="worker task registration"): + register_task(app, run_function) + register_task(app, function_map) diff --git a/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py b/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py new file mode 100644 index 00000000000..0b4ac4c2f4e --- /dev/null +++ b/services/api-server/src/simcore_service_api_server/clients/celery_task_manager.py @@ -0,0 +1,18 @@ +from celery_library.common import create_app, create_task_manager +from celery_library.types import register_celery_types, register_pydantic_types +from fastapi import FastAPI +from settings_library.celery import CelerySettings + +from ..celery_worker.worker_tasks.tasks import pydantic_types_to_register + + +def setup_task_manager(app: FastAPI, celery_settings: CelerySettings) -> None: + async def on_startup() -> None: + app.state.task_manager = await create_task_manager( + create_app(celery_settings), celery_settings + ) + + register_celery_types() + register_pydantic_types(*pydantic_types_to_register) + + app.add_event_handler("startup", on_startup) diff --git a/services/api-server/src/simcore_service_api_server/core/application.py b/services/api-server/src/simcore_service_api_server/core/application.py index 33505c35c5f..572001ddc9d 100644 --- a/services/api-server/src/simcore_service_api_server/core/application.py +++ b/services/api-server/src/simcore_service_api_server/core/application.py @@ -15,6 +15,7 @@ from .._meta import API_VERSION, API_VTAG, APP_NAME from ..api.root import create_router from ..api.routes.health import router as health_router +from ..clients.celery_task_manager import setup_task_manager from ..clients.postgres import setup_postgres from ..services_http import director_v2, storage, webserver from ..services_http.rabbitmq import setup_rabbitmq @@ -88,6 +89,9 @@ def create_app(settings: ApplicationSettings | None = None) -> FastAPI: setup_rabbitmq(app) + if settings.API_SERVER_CELERY and not settings.API_SERVER_WORKER_MODE: + setup_task_manager(app, settings.API_SERVER_CELERY) + if app.state.settings.API_SERVER_PROMETHEUS_INSTRUMENTATION_ENABLED: setup_prometheus_instrumentation(app) diff --git a/services/api-server/src/simcore_service_api_server/core/settings.py b/services/api-server/src/simcore_service_api_server/core/settings.py index 59f6812b896..9b622a5ddc9 100644 --- a/services/api-server/src/simcore_service_api_server/core/settings.py +++ b/services/api-server/src/simcore_service_api_server/core/settings.py @@ -13,6 +13,7 @@ ) from servicelib.logging_utils_filtering import LoggerName, MessageSubstring from settings_library.base import BaseCustomSettings +from settings_library.celery import CelerySettings from settings_library.director_v2 import DirectorV2Settings from settings_library.postgres import PostgresSettings from settings_library.rabbit import RabbitSettings @@ -102,6 +103,10 @@ class ApplicationSettings(BasicSettings): # DOCKER BOOT SC_BOOT_MODE: BootModeEnum | None = None + API_SERVER_CELERY: Annotated[ + CelerySettings | None, Field(json_schema_extra={"auto_default_from_env": True}) + ] = None + API_SERVER_POSTGRES: Annotated[ PostgresSettings | None, Field(json_schema_extra={"auto_default_from_env": True}), @@ -142,6 +147,10 @@ class ApplicationSettings(BasicSettings): ), ] + API_SERVER_WORKER_MODE: Annotated[ + bool, Field(description="If True, the API server runs in worker mode") + ] = False + @cached_property def debug(self) -> bool: """If True, debug tracebacks should be returned on errors.""" diff --git a/services/api-server/src/simcore_service_api_server/exceptions/backend_errors.py b/services/api-server/src/simcore_service_api_server/exceptions/backend_errors.py index 33960e49f6b..5257bfad700 100644 --- a/services/api-server/src/simcore_service_api_server/exceptions/backend_errors.py +++ b/services/api-server/src/simcore_service_api_server/exceptions/backend_errors.py @@ -8,6 +8,7 @@ class BaseBackEndError(ApiServerBaseError): """status_code: the default return status which will be returned to the client calling the api-server (in case this exception is raised)""" + msg_template = "The api-server encountered an error when contacting the backend" status_code = status.HTTP_502_BAD_GATEWAY @classmethod diff --git a/services/api-server/src/simcore_service_api_server/exceptions/function_errors.py b/services/api-server/src/simcore_service_api_server/exceptions/function_errors.py new file mode 100644 index 00000000000..28d44a43556 --- /dev/null +++ b/services/api-server/src/simcore_service_api_server/exceptions/function_errors.py @@ -0,0 +1,17 @@ +from fastapi import status + +from .backend_errors import BaseBackEndError + + +class BaseFunctionBackendError(BaseBackEndError): + pass + + +class FunctionJobCacheNotFoundError(BaseBackEndError): + msg_template: str = "No cached function job found." + status_code: int = 404 # Not Found + + +class FunctionJobProjectMissingError(BaseBackEndError): + msg_template: str = "Could not process function job" + status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR # Not Found diff --git a/services/api-server/src/simcore_service_api_server/exceptions/handlers/__init__.py b/services/api-server/src/simcore_service_api_server/exceptions/handlers/__init__.py index 2385ea984f4..adecb5d7203 100644 --- a/services/api-server/src/simcore_service_api_server/exceptions/handlers/__init__.py +++ b/services/api-server/src/simcore_service_api_server/exceptions/handlers/__init__.py @@ -1,3 +1,6 @@ +from celery.exceptions import ( # type: ignore[import-untyped] #pylint: disable=no-name-in-module + CeleryError, +) from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from httpx import HTTPError as HttpxException @@ -37,6 +40,18 @@ def setup(app: FastAPI, *, is_debug: bool = False): error_message="This endpoint is still not implemented (under development)", ), ) + + app.add_exception_handler( + CeleryError, + make_handler_for_exception( + CeleryError, + status.HTTP_503_SERVICE_UNAVAILABLE, + error_message=MSG_INTERNAL_ERROR_USER_FRIENDLY_TEMPLATE, + add_exception_to_message=is_debug, + add_oec_to_message=True, + ), + ) + app.add_exception_handler( Exception, make_handler_for_exception( diff --git a/services/api-server/src/simcore_service_api_server/models/api_resources.py b/services/api-server/src/simcore_service_api_server/models/api_resources.py index b16a0414b83..f796eb0af86 100644 --- a/services/api-server/src/simcore_service_api_server/models/api_resources.py +++ b/services/api-server/src/simcore_service_api_server/models/api_resources.py @@ -4,7 +4,7 @@ from uuid import UUID import parse # type: ignore[import-untyped] -from pydantic import AfterValidator, BaseModel, Field, HttpUrl, TypeAdapter +from pydantic import AfterValidator, BaseModel, ConfigDict, Field, HttpUrl, TypeAdapter from pydantic.types import StringConstraints # RESOURCE NAMES https://google.aip.dev/122 @@ -103,6 +103,22 @@ def _url_missing_only_job_id(url: str | None) -> str | None: class JobLinks(BaseModel): + @staticmethod + def _update_json_schema_extra(schema: dict) -> None: + schema.update( + { + "examples": [ + { + "url_template": "https://api.osparc.io/v0/jobs/{job_id}", + "runner_url_template": "https://runner.osparc.io/dashboard", + "outputs_url_template": "https://api.osparc.io/v0/jobs/{job_id}/outputs", + } + ] + } + ) + + model_config = ConfigDict(json_schema_extra=_update_json_schema_extra) + url_template: Annotated[str | None, AfterValidator(_url_missing_only_job_id)] runner_url_template: str | None outputs_url_template: Annotated[ diff --git a/services/api-server/src/simcore_service_api_server/models/domain/functions.py b/services/api-server/src/simcore_service_api_server/models/domain/functions.py new file mode 100644 index 00000000000..ff4e56ba34b --- /dev/null +++ b/services/api-server/src/simcore_service_api_server/models/domain/functions.py @@ -0,0 +1,9 @@ +from models_library.functions import FunctionJobID +from pydantic import BaseModel + +from ...models.schemas.jobs import JobInputs + + +class PreRegisteredFunctionJobData(BaseModel): + function_job_id: FunctionJobID + job_inputs: JobInputs diff --git a/services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py b/services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py index 4fde8b5403c..75f7c5a7d21 100644 --- a/services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py +++ b/services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py @@ -28,6 +28,7 @@ FunctionOutputs, FunctionUserAccessRights, FunctionUserApiAccessRights, + RegisteredFunctionJobPatch, ) from models_library.licenses import LicensedItemID from models_library.products import ProductName @@ -489,6 +490,22 @@ async def register_function_job( function_job=function_job, ) + async def patch_registered_function_job( + self, + *, + user_id: UserID, + product_name: ProductName, + function_job_id: FunctionJobID, + registered_function_job_patch: RegisteredFunctionJobPatch, + ) -> RegisteredFunctionJob: + return await functions_rpc_interface.patch_registered_function_job( + self._client, + user_id=user_id, + product_name=product_name, + function_job_uuid=function_job_id, + registered_function_job_patch=registered_function_job_patch, + ) + async def get_function_input_schema( self, *, user_id: UserID, product_name: ProductName, function_id: FunctionID ) -> FunctionInputSchema: diff --git a/services/api-server/tests/conftest.py b/services/api-server/tests/conftest.py index 2fd59c2f626..f0c05db2d1f 100644 --- a/services/api-server/tests/conftest.py +++ b/services/api-server/tests/conftest.py @@ -29,6 +29,7 @@ "pytest_simcore.pydantic_models", "pytest_simcore.pytest_global_environs", "pytest_simcore.rabbit_service", + "pytest_simcore.redis_service", "pytest_simcore.repository_paths", "pytest_simcore.schemas", "pytest_simcore.services_api_mocks_for_aiohttp_clients", @@ -71,6 +72,10 @@ def default_app_env_vars( env_vars["API_SERVER_DEV_FEATURES_ENABLED"] = "1" env_vars["API_SERVER_LOG_FORMAT_LOCAL_DEV_ENABLED"] = "1" env_vars["API_SERVER_PROMETHEUS_INSTRUMENTATION_ENABLED"] = "0" + env_vars["POSTGRES_MINSIZE"] = "2" + env_vars["POSTGRES_MAXSIZE"] = "10" + env_vars["API_SERVER_CELERY"] = "null" + env_vars["API_SERVER_RABBITMQ"] = "null" return env_vars diff --git a/services/api-server/tests/unit/api_functions/celery/conftest.py b/services/api-server/tests/unit/api_functions/celery/conftest.py new file mode 100644 index 00000000000..993ba4b73ab --- /dev/null +++ b/services/api-server/tests/unit/api_functions/celery/conftest.py @@ -0,0 +1,150 @@ +# pylint: disable=unused-argument +# pylint: disable=redefined-outer-name +# pylint: disable=too-many-positional-arguments +# pylint: disable=no-name-in-module + + +import datetime +from collections.abc import AsyncIterator, Callable +from functools import partial +from typing import Any + +import pytest +from celery import Celery # pylint: disable=no-name-in-module +from celery.contrib.testing.worker import ( # pylint: disable=no-name-in-module + TestWorkController, + start_worker, +) +from celery.signals import ( # pylint: disable=no-name-in-module + worker_init, + worker_shutdown, +) +from celery.worker.worker import WorkController # pylint: disable=no-name-in-module +from celery_library.signals import on_worker_init, on_worker_shutdown +from pytest_mock import MockerFixture +from pytest_simcore.helpers.monkeypatch_envs import delenvs_from_dict, setenvs_from_dict +from pytest_simcore.helpers.typing_env import EnvVarsDict +from servicelib.fastapi.celery.app_server import FastAPIAppServer +from settings_library.redis import RedisSettings +from simcore_service_api_server.celery_worker.worker_main import setup_worker_tasks +from simcore_service_api_server.clients import celery_task_manager +from simcore_service_api_server.core.application import create_app +from simcore_service_api_server.core.settings import ApplicationSettings + + +@pytest.fixture(scope="session") +def celery_config() -> dict[str, Any]: + return { + "broker_connection_retry_on_startup": True, + "broker_url": "memory://localhost//", + "result_backend": "cache+memory://localhost//", + "result_expires": datetime.timedelta(days=7), + "result_extended": True, + "pool": "threads", + "task_default_queue": "default", + "task_send_sent_event": True, + "task_track_started": True, + "worker_send_task_events": True, + } + + +@pytest.fixture +async def mocked_log_streamer_setup(mocker: MockerFixture) -> MockerFixture: + # mock log streamer: He is looking for non-existent queues. Should be solved more elegantly + from simcore_service_api_server.services_http import rabbitmq + + mock_log_streamer = mocker.patch.object(rabbitmq, "LogDistributor", spec=True) + return mock_log_streamer + + +@pytest.fixture +def mock_celery_app(mocker: MockerFixture, celery_config: dict[str, Any]) -> Celery: + celery_app = Celery(**celery_config) + + mocker.patch.object( + celery_task_manager, + celery_task_manager.create_app.__name__, + lambda settings: celery_app, + ) + + return celery_app + + +@pytest.fixture +def app_environment( + mock_celery_app: Celery, + mocked_log_streamer_setup: MockerFixture, + use_in_memory_redis: RedisSettings, + monkeypatch: pytest.MonkeyPatch, + app_environment: EnvVarsDict, + rabbit_env_vars_dict: EnvVarsDict, +) -> EnvVarsDict: + # do not init other services + delenvs_from_dict(monkeypatch, ["API_SERVER_RABBITMQ", "API_SERVER_CELERY"]) + env_vars_dict = setenvs_from_dict( + monkeypatch, + { + **rabbit_env_vars_dict, + "API_SERVER_POSTGRES": "null", + "API_SERVER_HEALTH_CHECK_TASK_PERIOD_SECONDS": "3", + "API_SERVER_HEALTH_CHECK_TASK_TIMEOUT_SECONDS": "1", + }, + ) + + settings = ApplicationSettings.create_from_envs() + assert settings.API_SERVER_CELERY is not None + + return env_vars_dict + + +@pytest.fixture +def register_celery_tasks() -> Callable[[Celery], None]: + """override if tasks are needed""" + + def _(celery_app: Celery) -> None: ... + + return _ + + +@pytest.fixture +def add_worker_tasks() -> bool: + "override to not add default worker tasks" + return True + + +@pytest.fixture +async def with_api_server_celery_worker( + app_environment: EnvVarsDict, + celery_app: Celery, + monkeypatch: pytest.MonkeyPatch, + register_celery_tasks: Callable[[Celery], None], + add_worker_tasks: bool, +) -> AsyncIterator[TestWorkController]: + # Signals must be explicitily connected + monkeypatch.setenv("API_SERVER_WORKER_MODE", "true") + app_settings = ApplicationSettings.create_from_envs() + + app_server = FastAPIAppServer(app=create_app(app_settings)) + + def _on_worker_init_wrapper(sender: WorkController, **_kwargs): + assert app_settings.API_SERVER_CELERY # nosec + return partial(on_worker_init, app_server, app_settings.API_SERVER_CELERY)( + sender, **_kwargs + ) + + worker_init.connect(_on_worker_init_wrapper) + worker_shutdown.connect(on_worker_shutdown) + + if add_worker_tasks: + setup_worker_tasks(celery_app) + register_celery_tasks(celery_app) + + with start_worker( + celery_app, + pool="threads", + concurrency=1, + loglevel="info", + perform_ping_check=False, + queues="api_worker_queue", + ) as worker: + yield worker diff --git a/services/api-server/tests/unit/api_functions/celery/test_functions.py b/services/api-server/tests/unit/api_functions/celery/test_functions.py new file mode 100644 index 00000000000..130a05bebaa --- /dev/null +++ b/services/api-server/tests/unit/api_functions/celery/test_functions.py @@ -0,0 +1,558 @@ +# pylint: disable=unused-argument +# pylint: disable=redefined-outer-name +# pylint: disable=no-name-in-module +# pylint: disable=too-many-positional-arguments +# pylint: disable=too-many-arguments + + +import datetime +import inspect +from collections.abc import Callable +from functools import partial +from pathlib import Path +from typing import Any + +import httpx +import pytest +import respx +from celery import Celery, Task # pylint: disable=no-name-in-module +from celery.contrib.testing.worker import ( + TestWorkController, # pylint: disable=no-name-in-module +) +from celery_library.task import register_task +from celery_library.types import register_pydantic_types +from faker import Faker +from fastapi import FastAPI, status +from httpx import AsyncClient, BasicAuth, HTTPStatusError +from models_library.api_schemas_long_running_tasks.tasks import ( + TaskResult, + TaskStatus, +) +from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter +from models_library.functions import ( + FunctionClass, + FunctionID, + FunctionJobCollection, + FunctionJobID, + FunctionUserAccessRights, + FunctionUserApiAccessRights, + RegisteredFunction, + RegisteredFunctionJob, + RegisteredFunctionJobCollection, + RegisteredProjectFunction, + RegisteredProjectFunctionJob, + RegisteredProjectFunctionJobPatch, +) +from models_library.projects import ProjectID +from models_library.users import UserID +from pytest_mock import MockerFixture, MockType +from pytest_simcore.helpers.httpx_calls_capture_models import HttpApiCallCaptureModel +from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata, TasksQueue +from servicelib.common_headers import ( + X_SIMCORE_PARENT_NODE_ID, + X_SIMCORE_PARENT_PROJECT_UUID, +) +from simcore_service_api_server._meta import API_VTAG +from simcore_service_api_server.api.dependencies.authentication import Identity +from simcore_service_api_server.api.dependencies.celery import ( + ASYNC_JOB_CLIENT_NAME, + get_task_manager, +) +from simcore_service_api_server.celery_worker.worker_tasks.functions_tasks import ( + run_function as run_function_task, +) +from simcore_service_api_server.exceptions.backend_errors import BaseBackEndError +from simcore_service_api_server.models.api_resources import JobLinks +from simcore_service_api_server.models.domain.functions import ( + PreRegisteredFunctionJobData, +) +from simcore_service_api_server.models.schemas.jobs import ( + JobPricingSpecification, + NodeID, +) +from tenacity import ( + AsyncRetrying, + retry_if_exception_type, + stop_after_delay, + wait_fixed, +) + +pytest_simcore_core_services_selection = ["postgres", "rabbit"] +pytest_simcore_ops_services_selection = ["adminer"] + +_faker = Faker() + + +async def wait_for_task_result( + client: AsyncClient, + auth: BasicAuth, + task_id: str, + timeout: float = 30.0, +) -> TaskResult: + + async for attempt in AsyncRetrying( + stop=stop_after_delay(timeout), + wait=wait_fixed(wait=datetime.timedelta(seconds=1.0)), + reraise=True, + retry=retry_if_exception_type(AssertionError), + ): + with attempt: + + response = await client.get(f"/{API_VTAG}/tasks/{task_id}", auth=auth) + response.raise_for_status() + status = TaskStatus.model_validate(response.json()) + assert status.done is True + + assert status.done is True + response = await client.get(f"/{API_VTAG}/tasks/{task_id}/result", auth=auth) + response.raise_for_status() + return TaskResult.model_validate(response.json()) + + +def _register_fake_run_function_task() -> Callable[[Celery], None]: + + async def run_function( + task: Task, + task_id: TaskID, + *, + user_identity: Identity, + function: RegisteredFunction, + pre_registered_function_job_data: PreRegisteredFunctionJobData, + pricing_spec: JobPricingSpecification | None, + job_links: JobLinks, + x_simcore_parent_project_uuid: NodeID | None, + x_simcore_parent_node_id: NodeID | None, + ) -> RegisteredFunctionJob: + return RegisteredProjectFunctionJob( + title=_faker.sentence(), + description=_faker.paragraph(), + function_uid=FunctionID(_faker.uuid4()), + inputs=pre_registered_function_job_data.job_inputs.values, + outputs=None, + function_class=FunctionClass.PROJECT, + uid=FunctionJobID(_faker.uuid4()), + created_at=_faker.date_time(), + project_job_id=ProjectID(_faker.uuid4()), + job_creation_task_id=None, + ) + + # check our mock task is correct + assert run_function_task.__name__ == run_function.__name__ + assert inspect.signature(run_function_task) == inspect.signature( + run_function + ), f"Signature mismatch: {inspect.signature(run_function_task)} != {inspect.signature(run_function)}" + + def _(celery_app: Celery) -> None: + register_pydantic_types(RegisteredProjectFunctionJob) + register_task(celery_app, run_function) + + return _ + + +async def _patch_registered_function_job_side_effect( + mock_registered_project_function_job: RegisteredFunctionJob, *args, **kwargs +): + registered_function_job_patch = kwargs["registered_function_job_patch"] + assert isinstance(registered_function_job_patch, RegisteredProjectFunctionJobPatch) + job_creation_task_id = registered_function_job_patch.job_creation_task_id + assert job_creation_task_id is not None + return mock_registered_project_function_job.model_copy( + update={"job_creation_task_id": job_creation_task_id} + ) + + +@pytest.mark.parametrize("register_celery_tasks", [_register_fake_run_function_task()]) +@pytest.mark.parametrize("add_worker_tasks", [False]) +async def test_with_fake_run_function( + app: FastAPI, + client: AsyncClient, + auth: BasicAuth, + mocker: MockerFixture, + with_api_server_celery_worker: TestWorkController, + mock_handler_in_functions_rpc_interface: Callable[ + [str, Any, Exception | None, Callable | None], None + ], + mock_registered_project_function: RegisteredProjectFunction, + mock_registered_project_function_job: RegisteredFunctionJob, + user_id: UserID, +): + + body = { + "input_1": _faker.uuid4(), + "input_2": _faker.pyfloat(min_value=0, max_value=100), + "input_3": _faker.pyint(min_value=0, max_value=100), + "input_4": _faker.boolean(), + "input_5": _faker.sentence(), + "input_6": [ + _faker.pyfloat(min_value=0, max_value=100) + for _ in range(_faker.pyint(min_value=5, max_value=100)) + ], + } + + mock_handler_in_functions_rpc_interface( + "get_function_user_permissions", + FunctionUserAccessRights( + user_id=user_id, + execute=True, + read=True, + write=True, + ), + None, + None, + ) + mock_handler_in_functions_rpc_interface( + "get_functions_user_api_access_rights", + FunctionUserApiAccessRights( + user_id=user_id, + read_functions=True, + write_functions=True, + execute_functions=True, + read_function_jobs=True, + write_function_jobs=True, + execute_function_jobs=True, + read_function_job_collections=True, + write_function_job_collections=True, + execute_function_job_collections=True, + ), + None, + None, + ) + mock_handler_in_functions_rpc_interface( + "get_function", mock_registered_project_function, None, None + ) + mock_handler_in_functions_rpc_interface("find_cached_function_jobs", [], None, None) + mock_handler_in_functions_rpc_interface( + "register_function_job", mock_registered_project_function_job, None, None + ) + + mock_handler_in_functions_rpc_interface( + "patch_registered_function_job", + None, + None, + partial( + _patch_registered_function_job_side_effect, + mock_registered_project_function_job, + ), + ) + + headers = {} + headers[X_SIMCORE_PARENT_PROJECT_UUID] = "null" + headers[X_SIMCORE_PARENT_NODE_ID] = "null" + + response = await client.post( + f"/{API_VTAG}/functions/{_faker.uuid4()}:run", + auth=auth, + json=body, + headers=headers, + ) + + assert response.status_code == status.HTTP_200_OK + function_job = RegisteredProjectFunctionJob.model_validate(response.json()) + celery_task_id = function_job.job_creation_task_id + assert celery_task_id is not None + # Poll until task completion and get result + result = await wait_for_task_result(client, auth, celery_task_id) + RegisteredProjectFunctionJob.model_validate(result.result) + + +def _register_exception_task(exception: Exception) -> Callable[[Celery], None]: + + async def exception_task( + task: Task, + task_id: TaskID, + ): + raise exception + + def _(celery_app: Celery) -> None: + register_task(celery_app, exception_task) + + return _ + + +@pytest.mark.parametrize( + "register_celery_tasks", + [ + _register_exception_task(ValueError("Test error")), + _register_exception_task(Exception("Test error")), + _register_exception_task(BaseBackEndError()), + ], +) +@pytest.mark.parametrize("add_worker_tasks", [False]) +async def test_celery_error_propagation( + app: FastAPI, + client: AsyncClient, + auth: BasicAuth, + with_api_server_celery_worker: TestWorkController, +): + + user_identity = Identity( + user_id=_faker.pyint(), product_name=_faker.word(), email=_faker.email() + ) + job_filter = AsyncJobFilter( + user_id=user_identity.user_id, + product_name=user_identity.product_name, + client_name=ASYNC_JOB_CLIENT_NAME, + ) + task_manager = get_task_manager(app=app) + task_uuid = await task_manager.submit_task( + task_metadata=TaskMetadata( + name="exception_task", queue=TasksQueue.API_WORKER_QUEUE + ), + task_filter=TaskFilter.model_validate(job_filter.model_dump()), + ) + + with pytest.raises(HTTPStatusError) as exc_info: + await wait_for_task_result(client, auth, f"{task_uuid}") + + assert exc_info.value.response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + + +@pytest.mark.parametrize( + "parent_project_uuid, parent_node_uuid, expected_status_code", + [ + (None, None, status.HTTP_422_UNPROCESSABLE_ENTITY), + (f"{_faker.uuid4()}", None, status.HTTP_422_UNPROCESSABLE_ENTITY), + (None, f"{_faker.uuid4()}", status.HTTP_422_UNPROCESSABLE_ENTITY), + (f"{_faker.uuid4()}", f"{_faker.uuid4()}", status.HTTP_200_OK), + ("null", "null", status.HTTP_200_OK), + ], +) +@pytest.mark.parametrize("capture", ["run_study_function_parent_info.json"]) +@pytest.mark.parametrize("mocked_app_dependencies", [None]) +async def test_run_project_function_parent_info( + app: FastAPI, + with_api_server_celery_worker: TestWorkController, + client: AsyncClient, + mock_handler_in_functions_rpc_interface: Callable[ + [str, Any, Exception | None, Callable | None], None + ], + mock_registered_project_function: RegisteredProjectFunction, + mock_registered_project_function_job: RegisteredFunctionJob, + auth: httpx.BasicAuth, + user_id: UserID, + mocked_webserver_rest_api_base: respx.MockRouter, + mocked_directorv2_rest_api_base: respx.MockRouter, + mocked_webserver_rpc_api: dict[str, MockType], + create_respx_mock_from_capture, + project_tests_dir: Path, + parent_project_uuid: str | None, + parent_node_uuid: str | None, + expected_status_code: int, + capture: str, +) -> None: + def _default_side_effect( + request: httpx.Request, + path_params: dict[str, Any], + capture: HttpApiCallCaptureModel, + ) -> Any: + if request.method == "POST" and request.url.path.endswith("/projects"): + if parent_project_uuid and parent_project_uuid != "null": + _parent_uuid = request.headers.get(X_SIMCORE_PARENT_PROJECT_UUID) + assert _parent_uuid is not None + assert parent_project_uuid == _parent_uuid + if parent_node_uuid and parent_node_uuid != "null": + _parent_node_uuid = request.headers.get(X_SIMCORE_PARENT_NODE_ID) + assert _parent_node_uuid is not None + assert parent_node_uuid == _parent_node_uuid + return capture.response_body + + create_respx_mock_from_capture( + respx_mocks=[mocked_webserver_rest_api_base, mocked_directorv2_rest_api_base], + capture_path=project_tests_dir / "mocks" / capture, + side_effects_callbacks=[_default_side_effect] * 50, + ) + + mock_handler_in_functions_rpc_interface( + "get_function_user_permissions", + FunctionUserAccessRights( + user_id=user_id, + execute=True, + read=True, + write=True, + ), + None, + None, + ) + mock_handler_in_functions_rpc_interface( + "get_function", mock_registered_project_function, None, None + ) + mock_handler_in_functions_rpc_interface("find_cached_function_jobs", [], None, None) + mock_handler_in_functions_rpc_interface( + "register_function_job", mock_registered_project_function_job, None, None + ) + mock_handler_in_functions_rpc_interface( + "get_functions_user_api_access_rights", + FunctionUserApiAccessRights( + user_id=user_id, + execute_functions=True, + write_functions=True, + read_functions=True, + ), + None, + None, + ) + mock_handler_in_functions_rpc_interface( + "patch_registered_function_job", + None, + None, + partial( + _patch_registered_function_job_side_effect, + mock_registered_project_function_job, + ), + ) + + headers = {} + if parent_project_uuid: + headers[X_SIMCORE_PARENT_PROJECT_UUID] = parent_project_uuid + if parent_node_uuid: + headers[X_SIMCORE_PARENT_NODE_ID] = parent_node_uuid + + response = await client.post( + f"{API_VTAG}/functions/{mock_registered_project_function.uid}:run", + json={}, + auth=auth, + headers=headers, + ) + assert response.status_code == expected_status_code + if response.status_code == status.HTTP_200_OK: + function_job = RegisteredProjectFunctionJob.model_validate(response.json()) + celery_task_id = function_job.job_creation_task_id + assert celery_task_id is not None + # Poll until task completion and get result + result = await wait_for_task_result(client, auth, celery_task_id) + RegisteredProjectFunctionJob.model_validate(result.result) + + +@pytest.mark.parametrize( + "parent_project_uuid, parent_node_uuid, expected_status_code", + [ + (None, None, status.HTTP_422_UNPROCESSABLE_ENTITY), + (f"{_faker.uuid4()}", None, status.HTTP_422_UNPROCESSABLE_ENTITY), + (None, f"{_faker.uuid4()}", status.HTTP_422_UNPROCESSABLE_ENTITY), + (f"{_faker.uuid4()}", f"{_faker.uuid4()}", status.HTTP_200_OK), + ("null", "null", status.HTTP_200_OK), + ], +) +@pytest.mark.parametrize("capture", ["run_study_function_parent_info.json"]) +@pytest.mark.parametrize("mocked_app_dependencies", [None]) +async def test_map_function_parent_info( + app: FastAPI, + with_api_server_celery_worker: TestWorkController, + client: AsyncClient, + mock_handler_in_functions_rpc_interface: Callable[ + [str, Any, Exception | None, Callable | None], MockType + ], + mock_registered_project_function: RegisteredProjectFunction, + mock_registered_project_function_job: RegisteredFunctionJob, + auth: httpx.BasicAuth, + user_id: UserID, + mocked_webserver_rest_api_base: respx.MockRouter, + mocked_directorv2_rest_api_base: respx.MockRouter, + mocked_webserver_rpc_api: dict[str, MockType], + create_respx_mock_from_capture, + project_tests_dir: Path, + parent_project_uuid: str | None, + parent_node_uuid: str | None, + expected_status_code: int, + capture: str, +) -> None: + + side_effect_checks = {} + + def _default_side_effect( + side_effect_checks: dict, + request: httpx.Request, + path_params: dict[str, Any], + capture: HttpApiCallCaptureModel, + ) -> Any: + if request.method == "POST" and request.url.path.endswith("/projects"): + side_effect_checks["headers_checked"] = True + if parent_project_uuid and parent_project_uuid != "null": + _parent_uuid = request.headers.get(X_SIMCORE_PARENT_PROJECT_UUID) + assert _parent_uuid is not None + assert parent_project_uuid == _parent_uuid + if parent_node_uuid and parent_node_uuid != "null": + _parent_node_uuid = request.headers.get(X_SIMCORE_PARENT_NODE_ID) + assert _parent_node_uuid is not None + assert parent_node_uuid == _parent_node_uuid + return capture.response_body + + create_respx_mock_from_capture( + respx_mocks=[mocked_webserver_rest_api_base, mocked_directorv2_rest_api_base], + capture_path=project_tests_dir / "mocks" / capture, + side_effects_callbacks=[partial(_default_side_effect, side_effect_checks)] * 50, + ) + + mock_handler_in_functions_rpc_interface( + "get_function_user_permissions", + FunctionUserAccessRights( + user_id=user_id, + execute=True, + read=True, + write=True, + ), + None, + None, + ) + mock_handler_in_functions_rpc_interface( + "get_function", mock_registered_project_function, None, None + ) + mock_handler_in_functions_rpc_interface("find_cached_function_jobs", [], None, None) + mock_handler_in_functions_rpc_interface( + "register_function_job", mock_registered_project_function_job, None, None + ) + mock_handler_in_functions_rpc_interface( + "get_functions_user_api_access_rights", + FunctionUserApiAccessRights( + user_id=user_id, + execute_functions=True, + write_functions=True, + read_functions=True, + ), + None, + None, + ) + mock_handler_in_functions_rpc_interface( + "register_function_job_collection", + RegisteredFunctionJobCollection( + uid=FunctionJobID(_faker.uuid4()), + title="Test Collection", + description="A test function job collection", + job_ids=[], + created_at=datetime.datetime.now(datetime.UTC), + ), + None, + None, + ) + + patch_mock = mock_handler_in_functions_rpc_interface( + "patch_registered_function_job", + None, + None, + partial( + _patch_registered_function_job_side_effect, + mock_registered_project_function_job, + ), + ) + + headers = {} + if parent_project_uuid: + headers[X_SIMCORE_PARENT_PROJECT_UUID] = parent_project_uuid + if parent_node_uuid: + headers[X_SIMCORE_PARENT_NODE_ID] = parent_node_uuid + + response = await client.post( + f"{API_VTAG}/functions/{mock_registered_project_function.uid}:map", + json=[{}, {}], + auth=auth, + headers=headers, + ) + assert response.status_code == expected_status_code + + if expected_status_code == status.HTTP_200_OK: + FunctionJobCollection.model_validate(response.json()) + task_id = patch_mock.call_args.kwargs[ + "registered_function_job_patch" + ].job_creation_task_id + await wait_for_task_result(client, auth, f"{task_id}") + assert side_effect_checks["headers_checked"] is True diff --git a/services/api-server/tests/unit/api_functions/conftest.py b/services/api-server/tests/unit/api_functions/conftest.py index 70acd1244d7..fd60ede5b17 100644 --- a/services/api-server/tests/unit/api_functions/conftest.py +++ b/services/api-server/tests/unit/api_functions/conftest.py @@ -260,21 +260,24 @@ def mock_registered_function_job_collection( @pytest.fixture() def mock_handler_in_functions_rpc_interface( mock_wb_api_server_rpc: MockerFixture, -) -> Callable[[str, Any, Exception | None], None]: +) -> Callable[[str, Any, Exception | None, Callable | None], MockType]: def _mock( handler_name: str = "", return_value: Any = None, exception: Exception | None = None, + side_effect: Callable | None = None, ) -> MockType: from servicelib.rabbitmq.rpc_interfaces.webserver.functions import ( functions_rpc_interface, ) + assert exception is None or side_effect is None + return mock_wb_api_server_rpc.patch.object( functions_rpc_interface, handler_name, return_value=return_value, - side_effect=exception, + side_effect=exception or side_effect, ) return _mock diff --git a/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py b/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py index 6bdda8e4ecc..534051212a7 100644 --- a/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py +++ b/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py @@ -1,5 +1,8 @@ # pylint: disable=unused-argument +# pylint: disable=too-many-arguments +# pylint: disable=too-many-positional-arguments +import random import uuid from collections.abc import Callable from datetime import datetime @@ -8,21 +11,36 @@ import httpx import pytest +from celery_library.task_manager import CeleryTaskManager +from faker import Faker +from fastapi import FastAPI, status from httpx import AsyncClient from models_library.api_schemas_webserver.functions import ( ProjectFunctionJob, RegisteredProjectFunctionJob, ) -from models_library.functions import FunctionJobStatus, RegisteredProjectFunction +from models_library.functions import ( + FunctionJobStatus, + RegisteredProjectFunction, + TaskID, +) from models_library.products import ProductName +from models_library.progress_bar import ProgressReport, ProgressStructuredMessage +from models_library.projects import ProjectID from models_library.projects_state import RunningState from models_library.rest_pagination import PageMetaInfoLimitOffset from models_library.users import UserID from pytest_mock import MockerFixture, MockType -from servicelib.aiohttp import status +from servicelib.celery.models import TaskFilter, TaskState, TaskStatus, TaskUUID from simcore_service_api_server._meta import API_VTAG +from simcore_service_api_server.api.routes import function_jobs_routes +from simcore_service_api_server.api.routes.function_jobs_routes import ( + _JOB_CREATION_TASK_STATUS_PREFIX, +) from simcore_service_api_server.models.schemas.jobs import JobStatus +_faker = Faker() + async def test_delete_function_job( client: AsyncClient, @@ -179,19 +197,78 @@ def mocked_list_function_jobs(offset: int, limit: int): @pytest.mark.parametrize("job_status", ["SUCCESS", "FAILED", "STARTED"]) +@pytest.mark.parametrize( + "project_job_id, job_creation_task_id, celery_task_state", + [ + ( + ProjectID(_faker.uuid4()), + TaskID(_faker.uuid4()), + random.choice(list(TaskState)), + ), + (None, None, random.choice(list(TaskState))), + (None, TaskID(_faker.uuid4()), random.choice(list(TaskState))), + ], +) async def test_get_function_job_status( + app: FastAPI, mocked_app_dependencies: None, client: AsyncClient, + mocker: MockerFixture, mock_handler_in_functions_rpc_interface: Callable[[str, Any], None], mock_registered_project_function_job: RegisteredProjectFunctionJob, mock_registered_project_function: RegisteredProjectFunction, mock_method_in_jobs_service: Callable[[str, Any], None], auth: httpx.BasicAuth, job_status: str, + project_job_id: ProjectID, + job_creation_task_id: TaskID | None, + celery_task_state: TaskState, ) -> None: + _expected_return_status = ( + status.HTTP_500_INTERNAL_SERVER_ERROR + if job_status != "SUCCESS" + and job_status != "FAILED" + and (project_job_id is None and job_creation_task_id is None) + else status.HTTP_200_OK + ) + + def _mock_task_manager(*args, **kwargs) -> CeleryTaskManager: + async def _get_task_status( + task_uuid: TaskUUID, task_filter: TaskFilter + ) -> TaskStatus: + assert f"{task_uuid}" == job_creation_task_id + return TaskStatus( + task_uuid=task_uuid, + task_state=celery_task_state, + progress_report=ProgressReport( + actual_value=0.5, + total=1.0, + attempt=1, + unit=None, + message=ProgressStructuredMessage.model_validate( + ProgressStructuredMessage.model_json_schema()["$defs"][ + "ProgressStructuredMessage" + ]["examples"][0] + ), + ), + ) + + obj = mocker.Mock(spec=CeleryTaskManager) + obj.get_task_status = _get_task_status + return obj + + mocker.patch.object(function_jobs_routes, "get_task_manager", _mock_task_manager) + mock_handler_in_functions_rpc_interface( - "get_function_job", mock_registered_project_function_job + "get_function_job", + mock_registered_project_function_job.model_copy( + update={ + "user_id": ANY, + "project_job_id": project_job_id, + "job_creation_task_id": job_creation_task_id, + } + ), ) mock_handler_in_functions_rpc_interface( "get_function", mock_registered_project_function @@ -219,12 +296,29 @@ async def test_get_function_job_status( f"{API_VTAG}/function_jobs/{mock_registered_project_function_job.uid}/status", auth=auth, ) - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["status"] == job_status - - -@pytest.mark.parametrize("job_outputs", [{"X+Y": 42, "X-Y": 10}]) + assert response.status_code == _expected_return_status + if response.status_code == status.HTTP_200_OK: + data = response.json() + if ( + project_job_id is not None + or job_status == "SUCCESS" + or job_status == "FAILED" + ): + assert data["status"] == job_status + else: + assert ( + data["status"] + == f"{_JOB_CREATION_TASK_STATUS_PREFIX}{celery_task_state}" + ) + + +@pytest.mark.parametrize( + "job_outputs, project_job_id", + [ + (None, None), + ({"X+Y": 42, "X-Y": 10}, ProjectID(_faker.uuid4())), + ], +) async def test_get_function_job_outputs( client: AsyncClient, mock_handler_in_functions_rpc_interface: Callable[[str, Any], None], @@ -232,11 +326,25 @@ async def test_get_function_job_outputs( mock_registered_project_function: RegisteredProjectFunction, mocked_webserver_rpc_api: dict[str, MockType], auth: httpx.BasicAuth, - job_outputs: dict[str, Any], + job_outputs: dict[str, Any] | None, + project_job_id: ProjectID | None, ) -> None: + _expected_return_status = ( + status.HTTP_404_NOT_FOUND + if project_job_id is None and job_outputs is None + else status.HTTP_200_OK + ) + mock_handler_in_functions_rpc_interface( - "get_function_job", mock_registered_project_function_job + "get_function_job", + mock_registered_project_function_job.model_copy( + update={ + "user_id": ANY, + "project_job_id": project_job_id, + "job_creation_task_id": None, + } + ), ) mock_handler_in_functions_rpc_interface( "get_function", mock_registered_project_function @@ -247,6 +355,7 @@ async def test_get_function_job_outputs( f"{API_VTAG}/function_jobs/{mock_registered_project_function_job.uid}/outputs", auth=auth, ) - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data == job_outputs + assert response.status_code == _expected_return_status + if response.status_code == status.HTTP_200_OK: + data = response.json() + assert data == job_outputs diff --git a/services/api-server/tests/unit/api_functions/test_api_routers_functions.py b/services/api-server/tests/unit/api_functions/test_api_routers_functions.py index f4092d3afd2..c473d876cf5 100644 --- a/services/api-server/tests/unit/api_functions/test_api_routers_functions.py +++ b/services/api-server/tests/unit/api_functions/test_api_routers_functions.py @@ -3,18 +3,19 @@ # pylint: disable=too-many-positional-arguments # pylint: disable=redefined-outer-name -import datetime from collections.abc import Callable -from functools import partial from pathlib import Path from typing import Any from unittest.mock import MagicMock -from uuid import UUID, uuid4 +from uuid import uuid4 import httpx import pytest import respx +from celery import Task # pylint: disable=no-name-in-module +from celery_library.task_manager import CeleryTaskManager from faker import Faker +from fastapi import FastAPI from httpx import AsyncClient from models_library.api_schemas_long_running_tasks.tasks import TaskGet from models_library.functions import ( @@ -23,8 +24,8 @@ ProjectFunction, RegisteredFunction, RegisteredFunctionJob, - RegisteredFunctionJobCollection, RegisteredProjectFunction, + RegisteredProjectFunctionJob, ) from models_library.functions_errors import ( FunctionIDNotFoundError, @@ -32,14 +33,26 @@ ) from models_library.rest_pagination import PageMetaInfoLimitOffset from models_library.users import UserID +from pydantic import EmailStr from pytest_mock import MockerFixture, MockType from pytest_simcore.helpers.httpx_calls_capture_models import HttpApiCallCaptureModel from servicelib.aiohttp import status +from servicelib.celery.app_server import BaseAppServer +from servicelib.celery.models import TaskID from servicelib.common_headers import ( X_SIMCORE_PARENT_NODE_ID, X_SIMCORE_PARENT_PROJECT_UUID, ) +from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient from simcore_service_api_server._meta import API_VTAG +from simcore_service_api_server.api.dependencies.authentication import Identity +from simcore_service_api_server.celery_worker.worker_tasks import functions_tasks +from simcore_service_api_server.models.api_resources import JobLinks +from simcore_service_api_server.models.domain.functions import ( + PreRegisteredFunctionJobData, +) +from simcore_service_api_server.models.schemas.jobs import JobInputs +from simcore_service_api_server.services_rpc.wb_api_server import WbApiRpcClient _faker = Faker() @@ -316,6 +329,7 @@ async def test_delete_function( ) async def test_run_map_function_not_allowed( client: AsyncClient, + mocker: MockerFixture, mock_handler_in_functions_rpc_interface: Callable[[str, Any], None], mock_registered_project_function: RegisteredProjectFunction, auth: httpx.BasicAuth, @@ -328,6 +342,11 @@ async def test_run_map_function_not_allowed( ) -> None: """Test that running a function is not allowed.""" + mocker.patch( + "simcore_service_api_server.api.routes.functions_routes.get_task_manager", + return_value=mocker.MagicMock(spec=CeleryTaskManager), + ) + mock_handler_in_functions_rpc_interface( "get_function_user_permissions", FunctionUserAccessRights( @@ -379,155 +398,67 @@ async def async_magic(): ) -@pytest.mark.parametrize( - "parent_project_uuid, parent_node_uuid, expected_status_code", - [ - (None, None, status.HTTP_422_UNPROCESSABLE_ENTITY), - (f"{_faker.uuid4()}", None, status.HTTP_422_UNPROCESSABLE_ENTITY), - (None, f"{_faker.uuid4()}", status.HTTP_422_UNPROCESSABLE_ENTITY), - (f"{_faker.uuid4()}", f"{_faker.uuid4()}", status.HTTP_200_OK), - ("null", "null", status.HTTP_200_OK), - ], -) @pytest.mark.parametrize("capture", ["run_study_function_parent_info.json"]) -async def test_run_project_function_parent_info( +async def test_run_project_function( + mocker: MockerFixture, + mocked_webserver_rpc_api: dict[str, MockType], + app: FastAPI, client: AsyncClient, mock_handler_in_functions_rpc_interface: Callable[[str, Any], None], mock_registered_project_function: RegisteredProjectFunction, mock_registered_project_function_job: RegisteredFunctionJob, auth: httpx.BasicAuth, - user_id: UserID, + user_identity: Identity, + user_email: EmailStr, + job_links: JobLinks, mocked_webserver_rest_api_base: respx.MockRouter, mocked_directorv2_rest_api_base: respx.MockRouter, - mocked_webserver_rpc_api: dict[str, MockType], create_respx_mock_from_capture, project_tests_dir: Path, - parent_project_uuid: str | None, - parent_node_uuid: str | None, - expected_status_code: int, capture: str, ) -> None: - def _default_side_effect( - request: httpx.Request, - path_params: dict[str, Any], - capture: HttpApiCallCaptureModel, - ) -> Any: - if request.method == "POST" and request.url.path.endswith("/projects"): - if parent_project_uuid and parent_project_uuid != "null": - _parent_uuid = request.headers.get(X_SIMCORE_PARENT_PROJECT_UUID) - assert _parent_uuid is not None - assert parent_project_uuid == _parent_uuid - if parent_node_uuid and parent_node_uuid != "null": - _parent_node_uuid = request.headers.get(X_SIMCORE_PARENT_NODE_ID) - assert _parent_node_uuid is not None - assert parent_node_uuid == _parent_node_uuid - return capture.response_body - create_respx_mock_from_capture( - respx_mocks=[mocked_webserver_rest_api_base, mocked_directorv2_rest_api_base], - capture_path=project_tests_dir / "mocks" / capture, - side_effects_callbacks=[_default_side_effect] * 50, - ) + def _get_app_server(celery_app: Any) -> FastAPI: + app_server = mocker.Mock(spec=BaseAppServer) + app_server.app = app + return app_server - mock_handler_in_functions_rpc_interface( - "get_function_user_permissions", - FunctionUserAccessRights( - user_id=user_id, - execute=True, - read=True, - write=True, - ), - ) - mock_handler_in_functions_rpc_interface( - "get_function", mock_registered_project_function - ) - mock_handler_in_functions_rpc_interface("find_cached_function_jobs", []) - mock_handler_in_functions_rpc_interface( - "register_function_job", mock_registered_project_function_job - ) - mock_handler_in_functions_rpc_interface( - "get_functions_user_api_access_rights", - FunctionUserApiAccessRights( - user_id=user_id, - execute_functions=True, - write_functions=True, - read_functions=True, - ), - ) + mocker.patch.object(functions_tasks, "get_app_server", _get_app_server) - headers = {} - if parent_project_uuid: - headers[X_SIMCORE_PARENT_PROJECT_UUID] = parent_project_uuid - if parent_node_uuid: - headers[X_SIMCORE_PARENT_NODE_ID] = parent_node_uuid + def _get_rabbitmq_rpc_client(app: FastAPI) -> RabbitMQRPCClient: + return mocker.MagicMock(spec=RabbitMQRPCClient) - response = await client.post( - f"{API_VTAG}/functions/{mock_registered_project_function.uid}:run", - json={}, - auth=auth, - headers=headers, + mocker.patch.object( + functions_tasks, "get_rabbitmq_rpc_client", _get_rabbitmq_rpc_client ) - assert response.status_code == expected_status_code + async def _get_wb_api_rpc_client(app: FastAPI) -> WbApiRpcClient: + wb_api_rpc_client = WbApiRpcClient( + _client=mocker.MagicMock(spec=RabbitMQRPCClient) + ) + return wb_api_rpc_client -@pytest.mark.parametrize( - "parent_project_uuid, parent_node_uuid, expected_status_code", - [ - (None, None, status.HTTP_422_UNPROCESSABLE_ENTITY), - (f"{_faker.uuid4()}", None, status.HTTP_422_UNPROCESSABLE_ENTITY), - (None, f"{_faker.uuid4()}", status.HTTP_422_UNPROCESSABLE_ENTITY), - (f"{_faker.uuid4()}", f"{_faker.uuid4()}", status.HTTP_200_OK), - ("null", "null", status.HTTP_200_OK), - ], -) -@pytest.mark.parametrize("capture", ["run_study_function_parent_info.json"]) -async def test_map_function_parent_info( - client: AsyncClient, - mock_handler_in_functions_rpc_interface: Callable[[str, Any], None], - mock_registered_project_function: RegisteredProjectFunction, - mock_registered_project_function_job: RegisteredFunctionJob, - auth: httpx.BasicAuth, - user_id: UserID, - mocked_webserver_rest_api_base: respx.MockRouter, - mocked_directorv2_rest_api_base: respx.MockRouter, - mocked_webserver_rpc_api: dict[str, MockType], - create_respx_mock_from_capture, - project_tests_dir: Path, - parent_project_uuid: str | None, - parent_node_uuid: str | None, - expected_status_code: int, - capture: str, -) -> None: - side_effect_checks = {} + mocker.patch.object( + functions_tasks, "get_wb_api_rpc_client", _get_wb_api_rpc_client + ) def _default_side_effect( - side_effect_checks: dict, request: httpx.Request, path_params: dict[str, Any], capture: HttpApiCallCaptureModel, ) -> Any: - if request.method == "POST" and request.url.path.endswith("/projects"): - side_effect_checks["headers_checked"] = True - if parent_project_uuid and parent_project_uuid != "null": - _parent_uuid = request.headers.get(X_SIMCORE_PARENT_PROJECT_UUID) - assert _parent_uuid is not None - assert parent_project_uuid == _parent_uuid - if parent_node_uuid and parent_node_uuid != "null": - _parent_node_uuid = request.headers.get(X_SIMCORE_PARENT_NODE_ID) - assert _parent_node_uuid is not None - assert parent_node_uuid == _parent_node_uuid return capture.response_body create_respx_mock_from_capture( respx_mocks=[mocked_webserver_rest_api_base, mocked_directorv2_rest_api_base], capture_path=project_tests_dir / "mocks" / capture, - side_effects_callbacks=[partial(_default_side_effect, side_effect_checks)] * 50, + side_effects_callbacks=[_default_side_effect] * 50, ) mock_handler_in_functions_rpc_interface( "get_function_user_permissions", FunctionUserAccessRights( - user_id=user_id, + user_id=user_identity.user_id, execute=True, read=True, write=True, @@ -543,38 +474,33 @@ def _default_side_effect( mock_handler_in_functions_rpc_interface( "get_functions_user_api_access_rights", FunctionUserApiAccessRights( - user_id=user_id, + user_id=user_identity.user_id, execute_functions=True, write_functions=True, read_functions=True, ), ) mock_handler_in_functions_rpc_interface( - "register_function_job_collection", - RegisteredFunctionJobCollection( - uid=UUID(_faker.uuid4()), - title="Test Collection", - description="A test function job collection", - job_ids=[], - created_at=datetime.datetime.now(datetime.UTC), - ), + "patch_registered_function_job", mock_registered_project_function_job ) - headers = {} - if parent_project_uuid: - headers[X_SIMCORE_PARENT_PROJECT_UUID] = parent_project_uuid - if parent_node_uuid: - headers[X_SIMCORE_PARENT_NODE_ID] = parent_node_uuid + pre_registered_function_job_data = PreRegisteredFunctionJobData( + job_inputs=JobInputs(values={}), + function_job_id=mock_registered_project_function.uid, + ) - response = await client.post( - f"{API_VTAG}/functions/{mock_registered_project_function.uid}:map", - json=[{}, {}], - auth=auth, - headers=headers, + job = await functions_tasks.run_function( + task=MagicMock(spec=Task), + task_id=TaskID(_faker.uuid4()), + user_identity=user_identity, + function=mock_registered_project_function, + pre_registered_function_job_data=pre_registered_function_job_data, + pricing_spec=None, + job_links=job_links, + x_simcore_parent_project_uuid=None, + x_simcore_parent_node_id=None, ) - if expected_status_code == status.HTTP_200_OK: - assert side_effect_checks["headers_checked"] is True - assert response.status_code == expected_status_code + assert isinstance(job, RegisteredProjectFunctionJob) async def test_export_logs_project_function_job( diff --git a/services/api-server/tests/unit/conftest.py b/services/api-server/tests/unit/conftest.py index 51370bb7ce1..1865519bb1a 100644 --- a/services/api-server/tests/unit/conftest.py +++ b/services/api-server/tests/unit/conftest.py @@ -57,8 +57,10 @@ from pytest_simcore.simcore_webserver_projects_rest_api import GET_PROJECT from requests.auth import HTTPBasicAuth from respx import MockRouter +from simcore_service_api_server.api.dependencies.authentication import Identity from simcore_service_api_server.core.application import create_app from simcore_service_api_server.core.settings import ApplicationSettings +from simcore_service_api_server.models.api_resources import JobLinks from simcore_service_api_server.repository.api_keys import UserAndProductTuple from simcore_service_api_server.services_http.solver_job_outputs import ResultsTypes from simcore_service_api_server.services_rpc.wb_api_server import WbApiRpcClient @@ -69,6 +71,19 @@ def product_name() -> ProductName: return "osparc" +@pytest.fixture +def user_identity( + user_id: UserID, + user_email: EmailStr, + product_name: ProductName, +) -> Identity: + return Identity( + user_id=user_id, + product_name=product_name, + email=user_email, + ) + + @pytest.fixture def app_environment( monkeypatch: pytest.MonkeyPatch, @@ -114,7 +129,6 @@ def mock_missing_plugins(app_environment: EnvVarsDict, mocker: MockerFixture): "setup_prometheus_instrumentation", autospec=True, ) - return app_environment @@ -550,6 +564,12 @@ def project_job_rpc_get() -> ProjectJobRpcGet: return ProjectJobRpcGet.model_validate(example) +@pytest.fixture +def job_links() -> JobLinks: + example = JobLinks.model_json_schema()["examples"][0] + return JobLinks.model_validate(example) + + @pytest.fixture def mocked_webserver_rpc_api( mocked_app_dependencies: None, diff --git a/services/api-server/tests/unit/test_tasks.py b/services/api-server/tests/unit/test_tasks.py index 40f64eb31c4..8ec736b4727 100644 --- a/services/api-server/tests/unit/test_tasks.py +++ b/services/api-server/tests/unit/test_tasks.py @@ -1,189 +1,193 @@ # pylint: disable=redefined-outer-name # pylint: disable=unused-argument -from typing import Any + +from typing import Literal import pytest +from celery.exceptions import CeleryError # pylint: disable=no-name-in-module from faker import Faker from fastapi import status from httpx import AsyncClient, BasicAuth from models_library.api_schemas_long_running_tasks.tasks import TaskGet, TaskStatus -from models_library.api_schemas_rpc_async_jobs.exceptions import ( - BaseAsyncjobRpcError, - JobAbortedError, - JobError, - JobNotDoneError, - JobSchedulerError, -) +from models_library.progress_bar import ProgressReport, ProgressStructuredMessage from pytest_mock import MockerFixture, MockType -from pytest_simcore.helpers.async_jobs_server import AsyncJobSideEffects +from servicelib.celery.models import TaskState +from servicelib.celery.models import TaskStatus as CeleryTaskStatus +from servicelib.celery.models import TaskUUID +from simcore_service_api_server.api.routes import tasks as task_routes from simcore_service_api_server.models.schemas.base import ApiServerEnvelope +pytest_simcore_core_services_selection = ["postgres", "rabbit"] +pytest_plugins = [ + "pytest_simcore.celery_library_mocks", +] + _faker = Faker() @pytest.fixture -async def async_jobs_rpc_side_effects( - async_job_error: BaseAsyncjobRpcError | None, -) -> Any: - return AsyncJobSideEffects(exception=async_job_error) +def mock_task_manager( + mocker: MockerFixture, mock_task_manager_object: MockType +) -> MockType: + def _get_task_manager(app): + return mock_task_manager_object -@pytest.fixture -def mocked_async_jobs_rpc_api( - mocker: MockerFixture, - async_jobs_rpc_side_effects: Any, - mocked_app_dependencies: None, -) -> dict[str, MockType]: - """ - Mocks the catalog's simcore service RPC API for testing purposes. - """ - from servicelib.rabbitmq.rpc_interfaces.async_jobs import async_jobs - - mocks = {} - - # Get all callable methods from the side effects class that are not built-ins - side_effect_methods = [ - method_name - for method_name in dir(async_jobs_rpc_side_effects) - if not method_name.startswith("_") - and callable(getattr(async_jobs_rpc_side_effects, method_name)) - ] - - # Create mocks for each method in catalog_rpc that has a corresponding side effect - for method_name in side_effect_methods: - assert hasattr(async_jobs, method_name) - mocks[method_name] = mocker.patch.object( - async_jobs, - method_name, - autospec=True, - side_effect=getattr(async_jobs_rpc_side_effects, method_name), - ) - - return mocks + mocker.patch.object(task_routes, "get_task_manager", _get_task_manager) + return mock_task_manager_object -@pytest.mark.parametrize( - "async_job_error, expected_status_code", - [ - (None, status.HTTP_200_OK), - ( - JobSchedulerError( - exc=Exception("A very rare exception raised by the scheduler") - ), - status.HTTP_500_INTERNAL_SERVER_ERROR, - ), - ], -) -async def test_get_async_jobs( +async def test_list_celery_tasks( + mock_task_manager: MockType, client: AsyncClient, - mocked_async_jobs_rpc_api: dict[str, MockType], auth: BasicAuth, - expected_status_code: int, ): response = await client.get("/v0/tasks", auth=auth) - assert mocked_async_jobs_rpc_api["list_jobs"].called - assert response.status_code == expected_status_code + assert mock_task_manager.list_tasks.called + assert response.status_code == status.HTTP_200_OK - if response.status_code == status.HTTP_200_OK: - result = ApiServerEnvelope[list[TaskGet]].model_validate_json(response.text) - assert len(result.data) > 0 - assert all(isinstance(task, TaskGet) for task in result.data) - task = result.data[0] - assert task.abort_href == f"/v0/tasks/{task.task_id}:cancel" - assert task.result_href == f"/v0/tasks/{task.task_id}/result" - assert task.status_href == f"/v0/tasks/{task.task_id}" + result = ApiServerEnvelope[list[TaskGet]].model_validate_json(response.text) + assert len(result.data) > 0 + assert all(isinstance(task, TaskGet) for task in result.data) + task = result.data[0] + assert task.abort_href == f"/v0/tasks/{task.task_id}:cancel" + assert task.result_href == f"/v0/tasks/{task.task_id}/result" + assert task.status_href == f"/v0/tasks/{task.task_id}" -@pytest.mark.parametrize( - "async_job_error, expected_status_code", - [ - (None, status.HTTP_200_OK), - ( - JobSchedulerError( - exc=Exception("A very rare exception raised by the scheduler") - ), - status.HTTP_500_INTERNAL_SERVER_ERROR, - ), - ], -) -async def test_get_async_jobs_status( +async def test_get_task_status( + mock_task_manager: MockType, client: AsyncClient, - mocked_async_jobs_rpc_api: dict[str, MockType], auth: BasicAuth, - expected_status_code: int, ): task_id = f"{_faker.uuid4()}" response = await client.get(f"/v0/tasks/{task_id}", auth=auth) - assert mocked_async_jobs_rpc_api["status"].called - assert f"{mocked_async_jobs_rpc_api['status'].call_args[1]['job_id']}" == task_id - assert response.status_code == expected_status_code - if response.status_code == status.HTTP_200_OK: - TaskStatus.model_validate_json(response.text) + assert mock_task_manager.get_task_status.called + assert response.status_code == status.HTTP_200_OK + TaskStatus.model_validate_json(response.text) -@pytest.mark.parametrize( - "async_job_error, expected_status_code", - [ - (None, status.HTTP_204_NO_CONTENT), - ( - JobSchedulerError( - exc=Exception("A very rare exception raised by the scheduler") - ), - status.HTTP_500_INTERNAL_SERVER_ERROR, - ), - ], -) -async def test_cancel_async_job( +async def test_cancel_task( + mock_task_manager: MockType, client: AsyncClient, - mocked_async_jobs_rpc_api: dict[str, MockType], auth: BasicAuth, - expected_status_code: int, ): task_id = f"{_faker.uuid4()}" response = await client.post(f"/v0/tasks/{task_id}:cancel", auth=auth) - assert mocked_async_jobs_rpc_api["cancel"].called - assert f"{mocked_async_jobs_rpc_api['cancel'].call_args[1]['job_id']}" == task_id - assert response.status_code == expected_status_code + assert mock_task_manager.cancel_task.called + assert response.status_code == status.HTTP_204_NO_CONTENT + + +async def test_get_task_result( + mock_task_manager: MockType, + client: AsyncClient, + auth: BasicAuth, +): + task_id = f"{_faker.uuid4()}" + response = await client.get(f"/v0/tasks/{task_id}/result", auth=auth) + assert response.status_code == status.HTTP_200_OK + assert mock_task_manager.get_task_result.called + assert f"{mock_task_manager.get_task_result.call_args[1]['task_uuid']}" == task_id @pytest.mark.parametrize( - "async_job_error, expected_status_code", + "method, url, list_tasks_return_value, get_task_status_return_value, cancel_task_return_value, get_task_result_return_value, expected_status_code", [ - (None, status.HTTP_200_OK), ( - JobError( - job_id=_faker.uuid4(), - exc_type=Exception, - exc_message="An exception from inside the async job", - ), - status.HTTP_500_INTERNAL_SERVER_ERROR, + "GET", + "/v0/tasks", + CeleryError(), + None, + None, + None, + status.HTTP_503_SERVICE_UNAVAILABLE, ), ( - JobNotDoneError(job_id=_faker.uuid4()), - status.HTTP_404_NOT_FOUND, + "GET", + f"/v0/tasks/{_faker.uuid4()}", + None, + CeleryError(), + None, + None, + status.HTTP_503_SERVICE_UNAVAILABLE, ), ( - JobAbortedError(job_id=_faker.uuid4()), - status.HTTP_409_CONFLICT, + "POST", + f"/v0/tasks/{_faker.uuid4()}:cancel", + None, + None, + CeleryError(), + None, + status.HTTP_503_SERVICE_UNAVAILABLE, + ), + ( + "GET", + f"/v0/tasks/{_faker.uuid4()}/result", + None, + CeleryError(), + None, + None, + status.HTTP_503_SERVICE_UNAVAILABLE, + ), + ( + "GET", + f"/v0/tasks/{_faker.uuid4()}/result", + None, + CeleryTaskStatus( + task_uuid=TaskUUID("123e4567-e89b-12d3-a456-426614174000"), + task_state=TaskState.STARTED, + progress_report=ProgressReport( + actual_value=0.5, + total=1.0, + unit="Byte", + message=ProgressStructuredMessage.model_validate( + { + "description": "some description", + "current": 12.2, + "total": 123, + } + ), + ), + ), + None, + None, + status.HTTP_404_NOT_FOUND, ), ( - JobSchedulerError( - exc=Exception("A very rare exception raised by the scheduler") + "GET", + f"/v0/tasks/{_faker.uuid4()}/result", + None, + CeleryTaskStatus( + task_uuid=TaskUUID("123e4567-e89b-12d3-a456-426614174000"), + task_state=TaskState.ABORTED, + progress_report=ProgressReport( + actual_value=0.5, + total=1.0, + unit="Byte", + message=ProgressStructuredMessage.model_validate( + { + "description": "some description", + "current": 12.2, + "total": 123, + } + ), + ), ), - status.HTTP_500_INTERNAL_SERVER_ERROR, + None, + None, + status.HTTP_409_CONFLICT, ), ], ) -async def test_get_async_job_result( +async def test_celery_error_propagation( + mock_task_manager: MockType, client: AsyncClient, - mocked_async_jobs_rpc_api: dict[str, MockType], auth: BasicAuth, + method: Literal["GET", "POST"], + url: str, expected_status_code: int, ): - task_id = f"{_faker.uuid4()}" - response = await client.get(f"/v0/tasks/{task_id}/result", auth=auth) + response = await client.request(method=method, url=url, auth=auth) assert response.status_code == expected_status_code - assert mocked_async_jobs_rpc_api["result"].called - assert f"{mocked_async_jobs_rpc_api['result'].call_args[1]['job_id']}" == task_id diff --git a/services/docker-compose.devel.yml b/services/docker-compose.devel.yml index 085a78ef0c7..28e5a8bfa95 100644 --- a/services/docker-compose.devel.yml +++ b/services/docker-compose.devel.yml @@ -21,6 +21,16 @@ services: - ../packages:/devel/packages - ${HOST_UV_CACHE_DIR}:/home/scu/.cache/uv + api-worker: + environment: + <<: *common-environment + API_SERVER_PROFILING : ${API_SERVER_PROFILING} + API_SERVER_LOGLEVEL: DEBUG + volumes: + - ./api-server:/devel/services/api-server + - ../packages:/devel/packages + - ${HOST_UV_CACHE_DIR}:/home/scu/.cache/uv + autoscaling: environment: <<: *common-environment diff --git a/services/docker-compose.local.yml b/services/docker-compose.local.yml index 5f84ba9bf10..f1b1514ba55 100644 --- a/services/docker-compose.local.yml +++ b/services/docker-compose.local.yml @@ -149,6 +149,15 @@ services: ports: - "8080" - "3022:3000" + + api-worker: + environment: + <<: *common_environment + API_SERVER_REMOTE_DEBUG_PORT : 3000 + ports: + - "8080" + - "3025:3000" + webserver: environment: &webserver_environment_local <<: *common_environment diff --git a/services/docker-compose.yml b/services/docker-compose.yml index 7b78dd6df21..dbda14eb590 100644 --- a/services/docker-compose.yml +++ b/services/docker-compose.yml @@ -25,13 +25,14 @@ services: image: ${DOCKER_REGISTRY:-itisfoundation}/api-server:${DOCKER_IMAGE_TAG:-latest} init: true hostname: "{{.Node.Hostname}}-{{.Task.Slot}}" - environment: + environment: &api_server_environment <<: *tracing_open_telemetry_environs API_SERVER_DEV_FEATURES_ENABLED: ${API_SERVER_DEV_FEATURES_ENABLED} API_SERVER_LOG_FORMAT_LOCAL_DEV_ENABLED: ${LOG_FORMAT_LOCAL_DEV_ENABLED} API_SERVER_LOG_FILTER_MAPPING: ${LOG_FILTER_MAPPING} API_SERVER_LOGLEVEL: ${API_SERVER_LOGLEVEL} API_SERVER_PROFILING: ${API_SERVER_PROFILING} + API_SERVER_WORKER_MODE: "false" CATALOG_HOST: ${CATALOG_HOST} CATALOG_PORT: ${CATALOG_PORT} @@ -51,6 +52,12 @@ services: RABBIT_SECURE: ${RABBIT_SECURE} RABBIT_USER: ${RABBIT_USER} + REDIS_HOST: ${REDIS_HOST} + REDIS_PORT: ${REDIS_PORT} + REDIS_SECURE: ${REDIS_SECURE} + REDIS_USER: ${REDIS_USER} + REDIS_PASSWORD: ${REDIS_PASSWORD} + STORAGE_HOST: ${STORAGE_HOST} STORAGE_PORT: ${STORAGE_PORT} @@ -75,9 +82,23 @@ services: - traefik.http.routers.${SWARM_STACK_NAME}_api-server.entrypoints=simcore_api - traefik.http.routers.${SWARM_STACK_NAME}_api-server.priority=3 - traefik.http.routers.${SWARM_STACK_NAME}_api-server.middlewares=${SWARM_STACK_NAME}_gzip@swarm,ratelimit-${SWARM_STACK_NAME}_api-server,inflightreq-${SWARM_STACK_NAME}_api-server - networks: + networks: &api_server_networks - default + + api-worker: + image: ${DOCKER_REGISTRY:-itisfoundation}/api-server:${DOCKER_IMAGE_TAG:-latest} + init: true + hostname: "api-worker-{{.Node.Hostname}}-{{.Task.Slot}}" + environment: + <<: *api_server_environment + API_SERVER_WORKER_NAME: "api-worker-{{.Node.Hostname}}-{{.Task.Slot}}-{{.Task.ID}}" + API_SERVER_WORKER_MODE: "true" + CELERY_CONCURRENCY: 100 + CELERY_QUEUES: "api_worker_queue" + networks: *api_server_networks + + autoscaling: image: ${DOCKER_REGISTRY:-itisfoundation}/autoscaling:${DOCKER_IMAGE_TAG:-latest} init: true diff --git a/services/web/server/tests/integration/conftest.py b/services/web/server/tests/integration/conftest.py index edca3137527..5fc7ea7b893 100644 --- a/services/web/server/tests/integration/conftest.py +++ b/services/web/server/tests/integration/conftest.py @@ -64,6 +64,7 @@ def webserver_environ( # the test webserver is built-up in webserver_service fixture that runs # on the host. EXCLUDED_SERVICES = [ + "api-worker", "dask-scheduler", "director", "docker-api-proxy",