From 9cfb98d00e81bdc35cd7075fba2906c3bbfe82b8 Mon Sep 17 00:00:00 2001 From: Nate Mortensen Date: Wed, 20 Aug 2025 15:11:11 -0700 Subject: [PATCH] Add basic pollers --- cadence/_internal/rpc/__init__.py | 0 cadence/_internal/rpc/metadata.py | 36 +++++++ cadence/client.py | 37 +++++++ cadence/sample/client_example.py | 22 ++++ cadence/worker/__init__.py | 11 ++ cadence/worker/_activity.py | 43 ++++++++ cadence/worker/_decision.py | 46 +++++++++ cadence/worker/_poller.py | 60 +++++++++++ cadence/worker/_types.py | 25 +++++ cadence/worker/_worker.py | 42 ++++++++ tests/cadence/worker/__init__.py | 0 tests/cadence/worker/test_poller.py | 154 ++++++++++++++++++++++++++++ tests/cadence/worker/test_worker.py | 52 ++++++++++ 13 files changed, 528 insertions(+) create mode 100644 cadence/_internal/rpc/__init__.py create mode 100644 cadence/_internal/rpc/metadata.py create mode 100644 cadence/client.py create mode 100644 cadence/sample/client_example.py create mode 100644 cadence/worker/__init__.py create mode 100644 cadence/worker/_activity.py create mode 100644 cadence/worker/_decision.py create mode 100644 cadence/worker/_poller.py create mode 100644 cadence/worker/_types.py create mode 100644 cadence/worker/_worker.py create mode 100644 tests/cadence/worker/__init__.py create mode 100644 tests/cadence/worker/test_poller.py create mode 100644 tests/cadence/worker/test_worker.py diff --git a/cadence/_internal/rpc/__init__.py b/cadence/_internal/rpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cadence/_internal/rpc/metadata.py b/cadence/_internal/rpc/metadata.py new file mode 100644 index 0000000..e4c1fe3 --- /dev/null +++ b/cadence/_internal/rpc/metadata.py @@ -0,0 +1,36 @@ +import collections + +from grpc.aio import Metadata +from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails + + +class _ClientCallDetails( + collections.namedtuple( + "_ClientCallDetails", ("method", "timeout", "metadata", "credentials", "wait_for_ready") + ), + ClientCallDetails, +): + pass + +class MetadataInterceptor(UnaryUnaryClientInterceptor): + def __init__(self, metadata: Metadata): + self._metadata = metadata + + async def intercept_unary_unary(self, continuation, client_call_details: ClientCallDetails, request): + return await continuation(self._replace_details(client_call_details), request) + + + def _replace_details(self, client_call_details: ClientCallDetails) -> ClientCallDetails: + metadata = client_call_details.metadata + if metadata is None: + metadata = self._metadata + else: + metadata += self._metadata + + return _ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready, + ) diff --git a/cadence/client.py b/cadence/client.py new file mode 100644 index 0000000..0eccd17 --- /dev/null +++ b/cadence/client.py @@ -0,0 +1,37 @@ +import os +import socket +from typing import TypedDict + +from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub +from grpc.aio import Channel + + +class ClientOptions(TypedDict, total=False): + domain: str + identity: str + +class Client: + def __init__(self, channel: Channel, options: ClientOptions): + self._channel = channel + self._worker_stub = WorkerAPIStub(channel) + self._options = options + self._identity = options["identity"] if "identity" in options else f"{os.getpid()}@{socket.gethostname()}" + + + @property + def domain(self) -> str: + return self._options["domain"] + + @property + def identity(self) -> str: + return self._identity + + @property + def worker_stub(self) -> WorkerAPIStub: + return self._worker_stub + + + async def close(self): + await self._channel.close() + + diff --git a/cadence/sample/client_example.py b/cadence/sample/client_example.py new file mode 100644 index 0000000..64b9be2 --- /dev/null +++ b/cadence/sample/client_example.py @@ -0,0 +1,22 @@ +import asyncio + +from grpc.aio import insecure_channel, Metadata + +from cadence.client import Client, ClientOptions +from cadence._internal.rpc.metadata import MetadataInterceptor +from cadence.worker import Worker + + +async def main(): + # TODO - Hide all this + metadata = Metadata() + metadata["rpc-service"] = "cadence-frontend" + metadata["rpc-encoding"] = "proto" + metadata["rpc-caller"] = "nate" + async with insecure_channel("localhost:7833", interceptors=[MetadataInterceptor(metadata)]) as channel: + client = Client(channel, ClientOptions(domain="foo")) + worker = Worker(client, "task_list") + await worker.run() + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/cadence/worker/__init__.py b/cadence/worker/__init__.py new file mode 100644 index 0000000..c2959b6 --- /dev/null +++ b/cadence/worker/__init__.py @@ -0,0 +1,11 @@ + + +from ._worker import ( + Worker, + WorkerOptions +) + +__all__ = [ + "Worker", + "WorkerOptions" +] \ No newline at end of file diff --git a/cadence/worker/_activity.py b/cadence/worker/_activity.py new file mode 100644 index 0000000..0ae24e1 --- /dev/null +++ b/cadence/worker/_activity.py @@ -0,0 +1,43 @@ +import asyncio +from typing import Optional + +from cadence.api.v1.common_pb2 import Failure +from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, PollForActivityTaskRequest, \ + RespondActivityTaskFailedRequest +from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind +from cadence.client import Client +from cadence.worker._types import WorkerOptions, _LONG_POLL_TIMEOUT +from cadence.worker._poller import Poller + + +class ActivityWorker: + def __init__(self, client: Client, task_list: str, options: WorkerOptions): + self._client = client + self._task_list = task_list + self._identity = options["identity"] + permits = asyncio.Semaphore(options["max_concurrent_activity_execution_size"]) + self._poller = Poller[PollForActivityTaskResponse](options["activity_task_pollers"], permits, self._poll, self._execute) + # TODO: Local dispatch, local activities, actually running activities, etc + + async def run(self): + await self._poller.run() + + async def _poll(self) -> Optional[PollForActivityTaskResponse]: + task: PollForActivityTaskResponse = await self._client.worker_stub.PollForActivityTask(PollForActivityTaskRequest( + domain=self._client.domain, + task_list=TaskList(name=self._task_list,kind=TaskListKind.TASK_LIST_KIND_NORMAL), + identity=self._identity, + ), timeout=_LONG_POLL_TIMEOUT) + + if task.task_token: + return task + else: + return None + + async def _execute(self, task: PollForActivityTaskResponse): + await self._client.worker_stub.RespondActivityTaskFailed(RespondActivityTaskFailedRequest( + task_token=task.task_token, + identity=self._identity, + failure=Failure(reason='error', details=b'not implemented'), + )) + diff --git a/cadence/worker/_decision.py b/cadence/worker/_decision.py new file mode 100644 index 0000000..0510f61 --- /dev/null +++ b/cadence/worker/_decision.py @@ -0,0 +1,46 @@ +import asyncio +from typing import Optional + +from cadence.api.v1.common_pb2 import Payload +from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskRequest, PollForDecisionTaskResponse, \ + RespondDecisionTaskFailedRequest +from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind +from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause +from cadence.client import Client +from cadence.worker._poller import Poller +from cadence.worker._types import WorkerOptions, _LONG_POLL_TIMEOUT + + +class DecisionWorker: + def __init__(self, client: Client, task_list: str, options: WorkerOptions): + self._client = client + self._task_list = task_list + self._identity = options["identity"] + permits = asyncio.Semaphore(options["max_concurrent_decision_task_execution_size"]) + self._poller = Poller[PollForDecisionTaskResponse](options["decision_task_pollers"], permits, self._poll, self._execute) + # TODO: Sticky poller, actually running workflows, etc. + + async def run(self): + await self._poller.run() + + async def _poll(self) -> Optional[PollForDecisionTaskResponse]: + task: PollForDecisionTaskResponse = await self._client.worker_stub.PollForDecisionTask(PollForDecisionTaskRequest( + domain=self._client.domain, + task_list=TaskList(name=self._task_list,kind=TaskListKind.TASK_LIST_KIND_NORMAL), + identity=self._identity, + ), timeout=_LONG_POLL_TIMEOUT) + + if task.task_token: + return task + else: + return None + + + async def _execute(self, task: PollForDecisionTaskResponse): + await self._client.worker_stub.RespondDecisionTaskFailed(RespondDecisionTaskFailedRequest( + task_token=task.task_token, + cause=DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION, + identity=self._identity, + details=Payload(data=b'not implemented') + )) + diff --git a/cadence/worker/_poller.py b/cadence/worker/_poller.py new file mode 100644 index 0000000..3b2889b --- /dev/null +++ b/cadence/worker/_poller.py @@ -0,0 +1,60 @@ +import asyncio +import logging +from typing import Callable, TypeVar, Generic, Awaitable, Optional + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + +class Poller(Generic[T]): + def __init__(self, num_tasks: int, permits: asyncio.Semaphore, poll: Callable[[], Awaitable[Optional[T]]], callback: Callable[[T], Awaitable[None]]): + self._num_tasks = num_tasks + self._permits = permits + self._poll = poll + self._callback = callback + self._background_tasks: set[asyncio.Task[None]] = set() + pass + + async def run(self): + try: + async with asyncio.TaskGroup() as tg: + for i in range(self._num_tasks): + tg.create_task(self._poll_loop()) + except asyncio.CancelledError: + pass + + + async def _poll_loop(self): + while True: + try: + await self._poll_and_dispatch() + except asyncio.CancelledError as e: + raise e + except Exception: + logger.exception('Exception while polling') + + + async def _poll_and_dispatch(self): + await self._permits.acquire() + try: + task = await self._poll() + except Exception as e: + self._permits.release() + raise e + + if task is None: + self._permits.release() + return + + # Need to store a reference to the async task or it may be garbage collected + scheduled = asyncio.create_task(self._execute_callback(task)) + self._background_tasks.add(scheduled) + scheduled.add_done_callback(self._background_tasks.remove) + + async def _execute_callback(self, task: T): + try: + await self._callback(task) + except Exception: + logger.exception('Exception during callback') + finally: + self._permits.release() \ No newline at end of file diff --git a/cadence/worker/_types.py b/cadence/worker/_types.py new file mode 100644 index 0000000..8b16fed --- /dev/null +++ b/cadence/worker/_types.py @@ -0,0 +1,25 @@ +from typing import TypedDict + + +class WorkerOptions(TypedDict, total=False): + max_concurrent_activity_execution_size: int + max_concurrent_decision_task_execution_size: int + task_list_activities_per_second: float + # Remove these in favor of introducing automatic scaling prior to release + activity_task_pollers: int + decision_task_pollers: int + disable_workflow_worker: bool + disable_activity_worker: bool + identity: str + +_DEFAULT_WORKER_OPTIONS: WorkerOptions = { + "max_concurrent_activity_execution_size": 1000, + "max_concurrent_decision_task_execution_size": 1000, + "task_list_activities_per_second": 0.0, + "activity_task_pollers": 2, + "decision_task_pollers": 2, + "disable_workflow_worker": False, + "disable_activity_worker": False, +} + +_LONG_POLL_TIMEOUT = 60.0 diff --git a/cadence/worker/_worker.py b/cadence/worker/_worker.py new file mode 100644 index 0000000..bb3ccc3 --- /dev/null +++ b/cadence/worker/_worker.py @@ -0,0 +1,42 @@ +import asyncio +import uuid +from typing import Unpack + +from cadence.client import Client +from cadence.worker._activity import ActivityWorker +from cadence.worker._decision import DecisionWorker +from cadence.worker._types import WorkerOptions, _DEFAULT_WORKER_OPTIONS + + +class Worker: + + def __init__(self, client: Client, task_list: str, **kwargs: Unpack[WorkerOptions]): + self._client = client + self._task_list = task_list + + options = WorkerOptions(**kwargs) + _validate_and_copy_defaults(client, task_list, options) + self._options = options + self._activity_worker = ActivityWorker(client, task_list, options) + self._decision_worker = DecisionWorker(client, task_list, options) + + + async def run(self): + async with asyncio.TaskGroup() as tg: + if not self._options["disable_workflow_worker"]: + tg.create_task(self._decision_worker.run()) + if not self._options["disable_activity_worker"]: + tg.create_task(self._activity_worker.run()) + + + +def _validate_and_copy_defaults(client: Client, task_list: str, options: WorkerOptions): + if "identity" not in options: + options["identity"] = f"{client.identity}@{task_list}@{uuid.uuid4()}" + + # TODO: More validation + + for (key, value) in _DEFAULT_WORKER_OPTIONS.items(): + if key not in options: + # noinspection PyTypedDict + options[key] = value diff --git a/tests/cadence/worker/__init__.py b/tests/cadence/worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cadence/worker/test_poller.py b/tests/cadence/worker/test_poller.py new file mode 100644 index 0000000..2cb3e00 --- /dev/null +++ b/tests/cadence/worker/test_poller.py @@ -0,0 +1,154 @@ +import asyncio + +import pytest + +from cadence.worker._poller import Poller + + +@pytest.mark.asyncio +async def test_poller(): + permits = asyncio.Semaphore(1) + incoming = asyncio.Queue() + outgoing = asyncio.Queue() + poller = Poller(1, permits, incoming.get, outgoing.put) + + task = asyncio.create_task(poller.run()) + await incoming.put("foo") + result = await outgoing.get() + + assert result == "foo" + task.cancel() + assert incoming.empty() is True + assert outgoing.empty() is True + +@pytest.mark.asyncio +async def test_poller_empty_task(): + permits = asyncio.Semaphore(1) + incoming = asyncio.Queue() + outgoing = asyncio.Queue() + poller = Poller(1, permits, incoming.get, outgoing.put) + + task = asyncio.create_task(poller.run()) + await incoming.put(None) + await incoming.put("foo") + result = await outgoing.get() + + assert result == "foo" + task.cancel() + +@pytest.mark.asyncio +async def test_poller_num_tasks(): + permits = asyncio.Semaphore(10) + + count = 0 + all_waiting = asyncio.Event() + done = asyncio.Event() + + async def poll_func(): + nonlocal count + + count += 1 + if count == 5 and not all_waiting.is_set(): + all_waiting.set() + + await done.wait() + return "foo" + + outgoing = asyncio.Queue() + poller = Poller(5, permits, poll_func, outgoing.put) + task = asyncio.create_task(poller.run()) + + async with asyncio.timeout(1): + await all_waiting.wait() + + assert outgoing.empty() is True + + task.cancel() + +@pytest.mark.asyncio +async def test_poller_concurrency(): + permits = asyncio.Semaphore(5) + + poll_count = 0 + count = 0 + all_waiting = asyncio.Event() + done = asyncio.Event() + + async def infinite_tasks() -> str: + nonlocal poll_count + poll_count += 1 + return "foo" + + async def never_complete(_: str): + nonlocal count, all_waiting, done + count += 1 + if count == 5 and not all_waiting.is_set(): + all_waiting.set() + + await done.wait() + + poller = Poller(10, permits, infinite_tasks, never_complete) + task = asyncio.create_task(poller.run()) + + # Ensure we receive all 5 + async with asyncio.timeout(1): + await all_waiting.wait() + + # Ensure no extra polls were issued + assert poll_count == 5 + done.set() + task.cancel() + + +@pytest.mark.asyncio +async def test_poller_poll_error(): + permits = asyncio.Semaphore(1) + + done = asyncio.Event() + call_count = 0 + async def poll_func(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("oh no") + elif call_count == 2: + return "foo" + else: + await done.wait() + return "bar" + + outgoing = asyncio.Queue() + poller = Poller(1, permits, poll_func, outgoing.put) + + task = asyncio.create_task(poller.run()) + result = await outgoing.get() + + assert result == "foo" + task.cancel() + done.set() + +@pytest.mark.asyncio +async def test_poller_execute_error(): + permits = asyncio.Semaphore(1) + + outgoing = asyncio.Queue() + call_count = 0 + async def execute(item: str): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("oh no") + await outgoing.put(item) + + incoming = asyncio.Queue() + poller = Poller(1, permits, incoming.get, execute) + + task = asyncio.create_task(poller.run()) + await incoming.put("first") + await incoming.put("second") + result = await outgoing.get() + + assert result == "second" + task.cancel() + + diff --git a/tests/cadence/worker/test_worker.py b/tests/cadence/worker/test_worker.py new file mode 100644 index 0000000..5a3667e --- /dev/null +++ b/tests/cadence/worker/test_worker.py @@ -0,0 +1,52 @@ +import asyncio + +import pytest + +from unittest.mock import AsyncMock, Mock, PropertyMock + +from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskRequest, PollForActivityTaskRequest +from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind +from cadence.client import Client +from cadence.worker import Worker + + +@pytest.mark.asyncio +async def test_worker(): + client = Mock(spec=Client) + done = asyncio.Event() + both_waited = asyncio.Barrier(3) + + async def poll(_, timeout=0.0): + await both_waited.wait() + await done.wait() + return None + + worker_stub = Mock() + worker_stub.PollForDecisionTask = AsyncMock(side_effect=poll) + worker_stub.PollForActivityTask = AsyncMock(side_effect=poll) + + client.worker_stub = worker_stub + type(client).domain = PropertyMock(return_value="domain") + type(client).identity = PropertyMock(return_value="identity") + + worker = Worker(client, "task_list", activity_task_pollers=1, decision_task_pollers=1, identity="identity") + + task = asyncio.create_task(worker.run()) + + # Wait until both polled + await both_waited.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + worker_stub.PollForDecisionTask.assert_called_once_with(PollForDecisionTaskRequest( + domain="domain", + identity="identity", + task_list=TaskList(name="task_list", kind=TaskListKind.TASK_LIST_KIND_NORMAL), + ), timeout=60.0) + + worker_stub.PollForActivityTask.assert_called_once_with(PollForActivityTaskRequest( + domain="domain", + identity="identity", + task_list=TaskList(name="task_list", kind=TaskListKind.TASK_LIST_KIND_NORMAL), + ), timeout=60.0) \ No newline at end of file