Skip to content

Commit b3a0ca3

Browse files
Merge pull request #79 from stealthrocket/remove-coroutine-decorator
Remove coroutine decorator
2 parents 4624d5f + 2955162 commit b3a0ca3

File tree

7 files changed

+115
-49
lines changed

7 files changed

+115
-49
lines changed

README.md

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ This package implements the Dispatch SDK for Python.
2020
- [Configuration](#configuration)
2121
- [Integration with FastAPI](#integration-with-fastapi)
2222
- [Local testing with ngrok](#local-testing-with-ngrok)
23+
- [Durable coroutines for Python](#durable-coroutines-for-python)
24+
- [Examples](#examples)
2325
- [Contributing](#contributing)
2426

2527
## What is Dispatch?
@@ -45,7 +47,7 @@ The SDK allows Python applications to declare *Stateful Functions* that the
4547
Dispatch scheduler can orchestrate. This is the bare minimum structure used
4648
to declare stateful functions:
4749
```python
48-
@dispatch.function()
50+
@dispatch.function
4951
def action(msg):
5052
...
5153
```
@@ -94,7 +96,7 @@ import requests
9496
app = FastAPI()
9597
dispatch = Dispatch(app)
9698

97-
@dispatch.function()
99+
@dispatch.function
98100
def publish(url, payload):
99101
r = requests.post(url, data=payload)
100102
r.raise_for_status()
@@ -144,7 +146,59 @@ different value, but in this example it would be:
144146
export DISPATCH_ENDPOINT_URL="https://f441-2600-1700-2802-e01f-6861-dbc9-d551-ecfb.ngrok-free.app"
145147
```
146148

147-
### Examples
149+
### Durable coroutines for Python
150+
151+
The `@dispatch.function` decorator can also be applied to Python coroutines
152+
(a.k.a. *async* functions), in which case each await point on another
153+
stateful function becomes a durability step in the execution: if the awaited
154+
operation fails, it is automatically retried and the parent function is paused
155+
until the result becomes available, or a permanent error is raised.
156+
157+
```python
158+
@dispatch.function
159+
async def pipeline(msg):
160+
# Each await point is a durability step, the functions can be run across the
161+
# fleet of service instances and retried as needed without losing track of
162+
# progress through the function execution.
163+
msg = await transform1(msg)
164+
msg = await transform2(msg)
165+
await publish(msg)
166+
167+
@dispatch.function
168+
async def publish(msg):
169+
# Each dispatch function runs concurrently to the others, even if it does
170+
# blocking operations like this POST request, it does not prevent other
171+
# concurrent operations from carrying on in the program.
172+
r = requests.post("https://somewhere.com/", data=msg)
173+
r.raise_for_status()
174+
175+
@dispatch.function
176+
async def transform1(msg):
177+
...
178+
179+
@dispatch.function
180+
async def transform2(msg):
181+
...
182+
```
183+
184+
This model is composable and can be used to create fan-out/fan-in control flows.
185+
`gather` can be used to wait on multiple concurrent calls to stateful functions,
186+
for example:
187+
188+
```python
189+
from dispatch import gather
190+
191+
@dispatch.function
192+
async def process(msgs):
193+
concurrent_calls = [transform(msg) for msg in msgs]
194+
return await gather(*concurrent_calls)
195+
196+
@dispatch.function
197+
async def transform(msg):
198+
...
199+
```
200+
201+
## Examples
148202

149203
Check out the [examples](examples/) directory for code samples to help you get
150204
started with the SDK.

examples/auto_retry/app.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def third_party_api_call(x):
4444
return "SUCCESS"
4545

4646

47-
# Use the `dispatch.function` decorator to mark a function as durable.
48-
@dispatch.function()
47+
# Use the `dispatch.function` decorator to declare a stateful function.
48+
@dispatch.function
4949
def some_logic():
5050
print("Executing some logic")
5151
x = rng.randint(0, 5)
@@ -56,8 +56,9 @@ def some_logic():
5656
# This is a normal FastAPI route that handles regular traffic.
5757
@app.get("/")
5858
def root():
59-
# Use the `dispatch` method to call the durable function. This call is
60-
# non-blocking and returns immediately.
59+
# Use the `dispatch` method to call the stateful function. This call is
60+
# returns immediately after scheduling the function call, which happens in
61+
# the background.
6162
some_logic.dispatch()
62-
# Sending an unrelated response immediately.
63+
# Sending a response now that the HTTP handler has completed.
6364
return "OK"

examples/getting_started/app.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@
6767
dispatch = Dispatch(app)
6868

6969

70-
# Use the `dispatch.function` decorator to mark a function as durable.
71-
@dispatch.function()
70+
# Use the `dispatch.function` decorator declare a stateful function.
71+
@dispatch.function
7272
def publish(url, payload):
7373
r = requests.post(url, data=payload)
7474
r.raise_for_status()
@@ -77,8 +77,9 @@ def publish(url, payload):
7777
# This is a normal FastAPI route that handles regular traffic.
7878
@app.get("/")
7979
def root():
80-
# Use the `dispatch` method to call the durable function. This call is
81-
# non-blocking and returns immediately.
80+
# Use the `dispatch` method to call the stateful function. This call is
81+
# returns immediately after scheduling the function call, which happens in
82+
# the background.
8283
publish.dispatch("https://httpstat.us/200", {"hello": "world"})
83-
# Sending an unrelated response immediately.
84+
# Sending a response now that the HTTP handler has completed.
8485
return "OK"

src/dispatch/experimental/durable/function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def durable(fn: Callable) -> Callable:
5656
elif isinstance(fn, FunctionType):
5757
return DurableFunction(fn)
5858
else:
59-
raise TypeError("unsupported callable")
59+
raise TypeError(
60+
f"cannot create a durable function from value of type {fn.__qualname__}"
61+
)
6062

6163

6264
class Serializable:

src/dispatch/function.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

3-
import functools
43
import inspect
54
import logging
5+
from functools import wraps
66
from types import FunctionType
77
from typing import Any, Callable, Dict, TypeAlias
88

@@ -23,12 +23,34 @@
2323
"""
2424

2525

26+
# https://stackoverflow.com/questions/653368/how-to-create-a-decorator-that-can-be-used-either-with-or-without-parameters
27+
def decorator(f):
28+
"""This decorator is intended to declare decorators that can be used with
29+
or without parameters. If the decorated function is called with a single
30+
callable argument, it is assumed to be a function and the decorator is
31+
applied to it. Otherwise, the decorator is called with the arguments
32+
provided and the result is returned.
33+
"""
34+
35+
@wraps(f)
36+
def method(self, *args, **kwargs):
37+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
38+
return f(self, args[0])
39+
40+
def wrapper(func):
41+
return f(self, func, *args, **kwargs)
42+
43+
return wrapper
44+
45+
return method
46+
47+
2648
class Function:
2749
"""Callable wrapper around a function meant to be used throughout the
2850
Dispatch Python SDK.
2951
"""
3052

31-
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func", "call")
53+
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func")
3254

3355
def __init__(
3456
self,
@@ -42,11 +64,12 @@ def __init__(
4264
self._client = client
4365
self._name = name
4466
self._primitive_func = primitive_func
45-
self._func = func
46-
4767
# FIXME: is there a way to decorate the function at the definition
4868
# without making it a class method?
49-
self.call = durable(self._call_async)
69+
if inspect.iscoroutinefunction(func):
70+
self._func = durable(self._call_async)
71+
else:
72+
self._func = func
5073

5174
def __call__(self, *args, **kwargs):
5275
return self._func(*args, **kwargs)
@@ -90,7 +113,7 @@ def _primitive_dispatch(self, input: Any = None) -> DispatchID:
90113
return dispatch_id
91114

92115
async def _call_async(self, *args, **kwargs) -> Any:
93-
"""Asynchronously call the function from a @dispatch.coroutine."""
116+
"""Asynchronously call the function from a @dispatch.function."""
94117
return await dispatch.coroutine.call(
95118
self.build_call(*args, **kwargs, correlation_id=None)
96119
)
@@ -142,39 +165,27 @@ def __init__(self, endpoint: str, client: Client | None):
142165
self._endpoint = endpoint
143166
self._client = client
144167

145-
def function(self) -> Callable[[FunctionType], Function]:
168+
@decorator
169+
def function(self, func: Callable) -> Function:
146170
"""Returns a decorator that registers functions."""
171+
return self._register_function(func)
147172

148-
# Note: the indirection here means that we can add parameters
149-
# to the decorator later without breaking existing apps.
150-
return self._register_function
151-
152-
def coroutine(self) -> Callable[[FunctionType], Function | FunctionType]:
153-
"""Returns a decorator that registers coroutines."""
154-
155-
# Note: the indirection here means that we can add parameters
156-
# to the decorator later without breaking existing apps.
157-
return self._register_coroutine
158-
159-
def primitive_function(self) -> Callable[[PrimitiveFunctionType], Function]:
173+
@decorator
174+
def primitive_function(self, func: Callable) -> Function:
160175
"""Returns a decorator that registers primitive functions."""
161-
162-
# Note: the indirection here means that we can add parameters
163-
# to the decorator later without breaking existing apps.
164-
return self._register_primitive_function
176+
return self._register_primitive_function(func)
165177

166178
def _register_function(self, func: Callable) -> Function:
167179
if inspect.iscoroutinefunction(func):
168-
raise TypeError(
169-
"async functions must be registered via @dispatch.coroutine"
170-
)
180+
return self._register_coroutine(func)
171181

172182
logger.info("registering function: %s", func.__qualname__)
173183

174184
# Register the function with the experimental.durable package, in case
175185
# it's referenced from a @dispatch.coroutine.
176186
func = durable(func)
177187

188+
@wraps(func)
178189
def primitive_func(input: Input) -> Output:
179190
try:
180191
try:
@@ -196,14 +207,11 @@ def primitive_func(input: Input) -> Output:
196207
return self._register(func, primitive_func)
197208

198209
def _register_coroutine(self, func: Callable) -> Function:
199-
if not inspect.iscoroutinefunction(func):
200-
raise TypeError(f"{func.__qualname__} must be an async function")
201-
202210
logger.info("registering coroutine: %s", func.__qualname__)
203211

204212
func = durable(func)
205213

206-
@functools.wraps(func)
214+
@wraps(func)
207215
def primitive_func(input: Input) -> Output:
208216
return OneShotScheduler(func).run(input)
209217

src/dispatch/scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _init_state(self, input: Input) -> State:
175175

176176
main = self.entry_point(*args, **kwargs)
177177
if not isinstance(main, DurableCoroutine):
178-
raise ValueError("entry point is not a @dispatch.coroutine")
178+
raise ValueError("entry point is not a @dispatch.function")
179179

180180
return State(
181181
version=sys.version,
@@ -255,7 +255,7 @@ def _run(self, input: Input) -> Output:
255255
)
256256
except Exception as e:
257257
logger.exception(
258-
f"@dispatch.coroutine: '{coroutine}' raised an exception"
258+
f"@dispatch.function: '{coroutine}' raised an exception"
259259
)
260260
coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e)
261261

@@ -317,7 +317,7 @@ def _run(self, input: Input) -> Output:
317317
g = awaitable.__await__()
318318
if not isinstance(g, DurableGenerator):
319319
raise ValueError(
320-
"gather awaitable is not a @dispatch.coroutine"
320+
"gather awaitable is not a @dispatch.function"
321321
)
322322
child_id = state.next_coroutine_id
323323
state.next_coroutine_id += 1

tests/test_full.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def execute(self):
5858

5959
def test_simple_end_to_end(self):
6060
# The FastAPI server.
61-
@self.dispatch.function()
61+
@self.dispatch.function
6262
def my_function(name: str) -> str:
6363
return f"Hello world: {name}"
6464

@@ -73,7 +73,7 @@ def my_function(name: str) -> str:
7373
self.assertEqual(any_unpickle(resp.exit.result.output), "Hello world: 52")
7474

7575
def test_simple_missing_signature(self):
76-
@self.dispatch.function()
76+
@self.dispatch.function
7777
def my_function(name: str) -> str:
7878
return f"Hello world: {name}"
7979

0 commit comments

Comments
 (0)