Skip to content

Commit 78c0ced

Browse files
authored
Merge pull request #100 from stealthrocket/paramspec
Improve type safety with ParamSpec
2 parents 0a518bc + 51f0023 commit 78c0ced

File tree

8 files changed

+67
-73
lines changed

8 files changed

+67
-73
lines changed

src/buf/validate/py.typed

Whitespace-only changes.

src/dispatch/experimental/durable/py.typed

Whitespace-only changes.

src/dispatch/function.py

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import inspect
44
import logging
55
from functools import wraps
6-
from types import FunctionType
7-
from typing import Any, Callable, Dict, TypeAlias
6+
from types import CoroutineType
7+
from typing import Any, Callable, Dict, Generic, ParamSpec, TypeAlias, TypeVar
88

99
import dispatch.coroutine
1010
from dispatch.client import Client
@@ -23,29 +23,11 @@
2323
"""
2424

2525

26-
# https://stackoverflow.com/questions/653368/how-to-create-a-decorator-that-can-be-used-either-with-or-without-parameters
27-
def decorator(f):
28-
"""This decorator is intended to declare decorators that can be used with
29-
or without parameters. If the decorated function is called with a single
30-
callable argument, it is assumed to be a function and the decorator is
31-
applied to it. Otherwise, the decorator is called with the arguments
32-
provided and the result is returned.
33-
"""
34-
35-
@wraps(f)
36-
def method(self, *args, **kwargs):
37-
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
38-
return f(self, args[0])
39-
40-
def wrapper(func):
41-
return f(self, func, *args, **kwargs)
42-
43-
return wrapper
26+
P = ParamSpec("P")
27+
T = TypeVar("T")
4428

45-
return method
4629

47-
48-
class Function:
30+
class Function(Generic[P, T]):
4931
"""Callable wrapper around a function meant to be used throughout the
5032
Dispatch Python SDK.
5133
"""
@@ -58,18 +40,23 @@ def __init__(
5840
client: Client,
5941
name: str,
6042
primitive_func: PrimitiveFunctionType,
61-
func: Callable,
43+
func: Callable[P, T] | None,
6244
coroutine: bool = False,
6345
):
6446
self._endpoint = endpoint
6547
self._client = client
6648
self._name = name
6749
self._primitive_func = primitive_func
68-
# FIXME: is there a way to decorate the function at the definition
69-
# without making it a class method?
70-
self._func = durable(self._call_async) if coroutine else func
50+
if func:
51+
self._func: Callable[P, T] | None = (
52+
durable(self._call_async) if coroutine else func
53+
)
54+
else:
55+
self._func = None
7156

72-
def __call__(self, *args, **kwargs):
57+
def __call__(self, *args: P.args, **kwargs: P.kwargs):
58+
if self._func is None:
59+
raise ValueError("cannot call a primitive function directly")
7360
return self._func(*args, **kwargs)
7461

7562
def _primitive_call(self, input: Input) -> Output:
@@ -83,7 +70,7 @@ def endpoint(self) -> str:
8370
def name(self) -> str:
8471
return self._name
8572

86-
def dispatch(self, *args: Any, **kwargs: Any) -> DispatchID:
73+
def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
8774
"""Dispatch a call to the function.
8875
8976
The Registry this function was registered with must be initialized
@@ -105,14 +92,14 @@ def _primitive_dispatch(self, input: Any = None) -> DispatchID:
10592
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
10693
return dispatch_id
10794

108-
async def _call_async(self, *args, **kwargs) -> Any:
95+
async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
10996
"""Asynchronously call the function from a @dispatch.function."""
11097
return await dispatch.coroutine.call(
11198
self.build_call(*args, **kwargs, correlation_id=None)
11299
)
113100

114101
def build_call(
115-
self, *args: Any, correlation_id: int | None = None, **kwargs: Any
102+
self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs
116103
) -> Call:
117104
"""Create a Call for this function with the provided input. Useful to
118105
generate calls when using the Client.
@@ -158,24 +145,21 @@ def __init__(self, endpoint: str, client: Client):
158145
self._endpoint = endpoint
159146
self._client = client
160147

161-
@decorator
162-
def function(self, func: Callable) -> Function:
163-
"""Returns a decorator that registers functions."""
148+
def function(self, func: Callable[P, T]) -> Function[P, T]:
149+
"""Decorator that registers functions."""
150+
if inspect.iscoroutinefunction(func):
151+
return self._register_coroutine(func)
164152
return self._register_function(func)
165153

166-
@decorator
167-
def primitive_function(self, func: Callable) -> Function:
168-
"""Returns a decorator that registers primitive functions."""
154+
def primitive_function(self, func: PrimitiveFunctionType) -> Function:
155+
"""Decorator that registers primitive functions."""
169156
return self._register_primitive_function(func)
170157

171-
def _register_function(self, func: Callable) -> Function:
172-
if inspect.iscoroutinefunction(func):
173-
return self._register_coroutine(func)
174-
158+
def _register_function(self, func: Callable[P, T]) -> Function[P, T]:
175159
logger.info("registering function: %s", func.__qualname__)
176160

177161
# Register the function with the experimental.durable package, in case
178-
# it's referenced from a @dispatch.coroutine.
162+
# it's referenced from a coroutine.
179163
func = durable(func)
180164

181165
@wraps(func)
@@ -199,7 +183,9 @@ def primitive_func(input: Input) -> Output:
199183

200184
return self._register(primitive_func, func, coroutine=False)
201185

202-
def _register_coroutine(self, func: Callable) -> Function:
186+
def _register_coroutine(
187+
self, func: Callable[P, CoroutineType[Any, Any, T]]
188+
) -> Function[P, T]:
203189
logger.info("registering coroutine: %s", func.__qualname__)
204190

205191
func = durable(func)
@@ -213,19 +199,27 @@ def primitive_func(input: Input) -> Output:
213199

214200
return self._register(primitive_func, func, coroutine=True)
215201

216-
def _register_primitive_function(self, func: PrimitiveFunctionType) -> Function:
217-
logger.info("registering primitive function: %s", func.__qualname__)
218-
return self._register(func, func, coroutine=inspect.iscoroutinefunction(func))
202+
def _register_primitive_function(
203+
self, primitive_func: PrimitiveFunctionType
204+
) -> Function[P, T]:
205+
logger.info("registering primitive function: %s", primitive_func.__qualname__)
206+
return self._register(primitive_func, func=None, coroutine=False)
219207

220208
def _register(
221-
self, primitive_func: PrimitiveFunctionType, func: Callable, coroutine: bool
222-
) -> Function:
223-
name = func.__qualname__
209+
self,
210+
primitive_func: PrimitiveFunctionType,
211+
func: Callable[P, T] | None,
212+
coroutine: bool,
213+
) -> Function[P, T]:
214+
if func:
215+
name = func.__qualname__
216+
else:
217+
name = primitive_func.__qualname__
224218
if name in self._functions:
225219
raise ValueError(
226220
f"function or coroutine already registered with name '{name}'"
227221
)
228-
wrapped_func = Function(
222+
wrapped_func = Function[P, T](
229223
self._endpoint, self._client, name, primitive_func, func, coroutine
230224
)
231225
self._functions[name] = wrapped_func

src/dispatch/integrations/py.typed

Whitespace-only changes.

src/dispatch/py.typed

Whitespace-only changes.

src/dispatch/sdk/v1/py.typed

Whitespace-only changes.

tests/dispatch/test_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def setUp(self):
1111
self.dispatch = Registry(endpoint="http://example.com", client=self.client)
1212

1313
def test_serializable(self):
14-
@self.dispatch.function()
14+
@self.dispatch.function
1515
def my_function():
1616
pass
1717

tests/test_fastapi.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_fastapi_simple_request(self):
7070
app = fastapi.FastAPI()
7171
dispatch = create_dispatch_instance(app, endpoint="http://127.0.0.1:9999/")
7272

73-
@dispatch.primitive_function()
73+
@dispatch.primitive_function
7474
def my_function(input: Input) -> Output:
7575
return Output.value(
7676
f"You told me: '{input.input}' ({len(input.input)} characters)"
@@ -159,7 +159,7 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult:
159159
return resp.exit.result
160160

161161
def test_no_input(self):
162-
@self.dispatch.primitive_function()
162+
@self.dispatch.primitive_function
163163
def my_function(input: Input) -> Output:
164164
return Output.value("Hello World!")
165165

@@ -178,7 +178,7 @@ def test_missing_coroutine(self):
178178
self.assertEqual(cm.exception.response.status_code, 404)
179179

180180
def test_string_input(self):
181-
@self.dispatch.primitive_function()
181+
@self.dispatch.primitive_function
182182
def my_function(input: Input) -> Output:
183183
return Output.value(f"You sent '{input.input}'")
184184

@@ -187,7 +187,7 @@ def my_function(input: Input) -> Output:
187187
self.assertEqual(out, "You sent 'cool stuff'")
188188

189189
def test_error_on_access_state_in_first_call(self):
190-
@self.dispatch.primitive_function()
190+
@self.dispatch.primitive_function
191191
def my_function(input: Input) -> Output:
192192
try:
193193
print(input.coroutine_state)
@@ -206,7 +206,7 @@ def my_function(input: Input) -> Output:
206206
)
207207

208208
def test_error_on_access_input_in_second_call(self):
209-
@self.dispatch.primitive_function()
209+
@self.dispatch.primitive_function
210210
def my_function(input: Input) -> Output:
211211
if input.is_first_call:
212212
return Output.poll(state=42)
@@ -230,22 +230,22 @@ def my_function(input: Input) -> Output:
230230
)
231231

232232
def test_duplicate_coro(self):
233-
@self.dispatch.primitive_function()
233+
@self.dispatch.primitive_function
234234
def my_function(input: Input) -> Output:
235235
return Output.value("Do one thing")
236236

237237
with self.assertRaises(ValueError):
238238

239-
@self.dispatch.primitive_function()
239+
@self.dispatch.primitive_function
240240
def my_function(input: Input) -> Output:
241241
return Output.value("Do something else")
242242

243243
def test_two_simple_coroutines(self):
244-
@self.dispatch.primitive_function()
244+
@self.dispatch.primitive_function
245245
def echoroutine(input: Input) -> Output:
246246
return Output.value(f"Echo: '{input.input}'")
247247

248-
@self.dispatch.primitive_function()
248+
@self.dispatch.primitive_function
249249
def len_coroutine(input: Input) -> Output:
250250
return Output.value(f"Length: {len(input.input)}")
251251

@@ -259,7 +259,7 @@ def len_coroutine(input: Input) -> Output:
259259
self.assertEqual(out, "Length: 10")
260260

261261
def test_coroutine_with_state(self):
262-
@self.dispatch.primitive_function()
262+
@self.dispatch.primitive_function
263263
def coroutine3(input: Input) -> Output:
264264
if input.is_first_call:
265265
counter = input.input
@@ -293,11 +293,11 @@ def coroutine3(input: Input) -> Output:
293293
self.assertEqual(out, "done")
294294

295295
def test_coroutine_poll(self):
296-
@self.dispatch.primitive_function()
296+
@self.dispatch.primitive_function
297297
def coro_compute_len(input: Input) -> Output:
298298
return Output.value(len(input.input))
299299

300-
@self.dispatch.primitive_function()
300+
@self.dispatch.primitive_function
301301
def coroutine_main(input: Input) -> Output:
302302
if input.is_first_call:
303303
text: str = input.input
@@ -333,11 +333,11 @@ def coroutine_main(input: Input) -> Output:
333333
self.assertEqual("length=10 text='cool stuff'", out)
334334

335335
def test_coroutine_poll_error(self):
336-
@self.dispatch.primitive_function()
336+
@self.dispatch.primitive_function
337337
def coro_compute_len(input: Input) -> Output:
338338
return Output.error(Error(Status.PERMANENT_ERROR, "type", "Dead"))
339339

340-
@self.dispatch.primitive_function()
340+
@self.dispatch.primitive_function
341341
def coroutine_main(input: Input) -> Output:
342342
if input.is_first_call:
343343
text: str = input.input
@@ -372,7 +372,7 @@ def coroutine_main(input: Input) -> Output:
372372
self.assertEqual(out, "msg=Dead type='type'")
373373

374374
def test_coroutine_error(self):
375-
@self.dispatch.primitive_function()
375+
@self.dispatch.primitive_function
376376
def mycoro(input: Input) -> Output:
377377
return Output.error(Error(Status.PERMANENT_ERROR, "sometype", "dead"))
378378

@@ -381,7 +381,7 @@ def mycoro(input: Input) -> Output:
381381
self.assertEqual("dead", resp.exit.result.error.message)
382382

383383
def test_coroutine_expected_exception(self):
384-
@self.dispatch.primitive_function()
384+
@self.dispatch.primitive_function
385385
def mycoro(input: Input) -> Output:
386386
try:
387387
1 / 0
@@ -395,7 +395,7 @@ def mycoro(input: Input) -> Output:
395395
self.assertEqual(Status.PERMANENT_ERROR, resp.status)
396396

397397
def test_coroutine_unexpected_exception(self):
398-
@self.dispatch.function()
398+
@self.dispatch.function
399399
def mycoro():
400400
1 / 0
401401
self.fail("should not reach here")
@@ -406,7 +406,7 @@ def mycoro():
406406
self.assertEqual(Status.PERMANENT_ERROR, resp.status)
407407

408408
def test_specific_status(self):
409-
@self.dispatch.primitive_function()
409+
@self.dispatch.primitive_function
410410
def mycoro(input: Input) -> Output:
411411
return Output.error(Error(Status.THROTTLED, "foo", "bar"))
412412

@@ -416,11 +416,11 @@ def mycoro(input: Input) -> Output:
416416
self.assertEqual(Status.THROTTLED, resp.status)
417417

418418
def test_tailcall(self):
419-
@self.dispatch.function()
419+
@self.dispatch.function
420420
def other_coroutine(value: Any) -> str:
421421
return f"Hello {value}"
422422

423-
@self.dispatch.primitive_function()
423+
@self.dispatch.primitive_function
424424
def mycoro(input: Input) -> Output:
425425
return Output.tail_call(other_coroutine._build_primitive_call(42))
426426

@@ -429,7 +429,7 @@ def mycoro(input: Input) -> Output:
429429
self.assertEqual(42, any_unpickle(resp.exit.tail_call.input))
430430

431431
def test_library_error_categorization(self):
432-
@self.dispatch.function()
432+
@self.dispatch.function
433433
def get(path: str) -> httpx.Response:
434434
http_response = self.http_client.get(path)
435435
http_response.raise_for_status()
@@ -445,7 +445,7 @@ def get(path: str) -> httpx.Response:
445445
self.assertEqual(Status.NOT_FOUND, Status(resp.status))
446446

447447
def test_library_output_categorization(self):
448-
@self.dispatch.function()
448+
@self.dispatch.function
449449
def get(path: str) -> httpx.Response:
450450
http_response = self.http_client.get(path)
451451
http_response.status_code = 429

0 commit comments

Comments
 (0)