Skip to content

Commit b268e72

Browse files
authored
Merge pull request #101 from stealthrocket/paramspec2
Improve type checking of call results
2 parents 78c0ced + a9148a1 commit b268e72

File tree

1 file changed

+107
-103
lines changed

1 file changed

+107
-103
lines changed

src/dispatch/function.py

Lines changed: 107 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,17 @@
44
import logging
55
from functools import wraps
66
from types import CoroutineType
7-
from typing import Any, Callable, Dict, Generic, ParamSpec, TypeAlias, TypeVar
7+
from typing import (
8+
Any,
9+
Callable,
10+
Coroutine,
11+
Dict,
12+
Generic,
13+
ParamSpec,
14+
TypeAlias,
15+
TypeVar,
16+
overload,
17+
)
818

919
import dispatch.coroutine
1020
from dispatch.client import Client
@@ -23,44 +33,20 @@
2333
"""
2434

2535

26-
P = ParamSpec("P")
27-
T = TypeVar("T")
28-
29-
30-
class Function(Generic[P, T]):
31-
"""Callable wrapper around a function meant to be used throughout the
32-
Dispatch Python SDK.
33-
"""
34-
35-
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func")
36+
class PrimitiveFunction:
37+
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func")
3638

3739
def __init__(
3840
self,
3941
endpoint: str,
4042
client: Client,
4143
name: str,
4244
primitive_func: PrimitiveFunctionType,
43-
func: Callable[P, T] | None,
44-
coroutine: bool = False,
4545
):
4646
self._endpoint = endpoint
4747
self._client = client
4848
self._name = name
4949
self._primitive_func = primitive_func
50-
if func:
51-
self._func: Callable[P, T] | None = (
52-
durable(self._call_async) if coroutine else func
53-
)
54-
else:
55-
self._func = None
56-
57-
def __call__(self, *args: P.args, **kwargs: P.kwargs):
58-
if self._func is None:
59-
raise ValueError("cannot call a primitive function directly")
60-
return self._func(*args, **kwargs)
61-
62-
def _primitive_call(self, input: Input) -> Output:
63-
return self._primitive_func(input)
6450

6551
@property
6652
def endpoint(self) -> str:
@@ -70,8 +56,62 @@ def endpoint(self) -> str:
7056
def name(self) -> str:
7157
return self._name
7258

59+
def _primitive_call(self, input: Input) -> Output:
60+
return self._primitive_func(input)
61+
62+
def _primitive_dispatch(self, input: Any = None) -> DispatchID:
63+
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
64+
return dispatch_id
65+
66+
def _build_primitive_call(
67+
self, input: Any, correlation_id: int | None = None
68+
) -> Call:
69+
return Call(
70+
correlation_id=correlation_id,
71+
endpoint=self.endpoint,
72+
function=self.name,
73+
input=input,
74+
)
75+
76+
77+
P = ParamSpec("P")
78+
T = TypeVar("T")
79+
80+
81+
class Function(PrimitiveFunction, Generic[P, T]):
82+
"""Callable wrapper around a function meant to be used throughout the
83+
Dispatch Python SDK.
84+
"""
85+
86+
__slots__ = ("_func_indirect",)
87+
88+
def __init__(
89+
self,
90+
endpoint: str,
91+
client: Client,
92+
name: str,
93+
primitive_func: PrimitiveFunctionType,
94+
func: Callable,
95+
):
96+
PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func)
97+
98+
self._func_indirect: Callable[P, Coroutine[Any, Any, T]] = durable(
99+
self._call_async
100+
)
101+
102+
async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
103+
return await dispatch.coroutine.call(
104+
self.build_call(*args, **kwargs, correlation_id=None)
105+
)
106+
107+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
108+
"""Call the function asynchronously (through Dispatch), and return a
109+
coroutine that can be awaited to retrieve the call result."""
110+
return self._func_indirect(*args, **kwargs)
111+
73112
def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
74-
"""Dispatch a call to the function.
113+
"""Dispatch an asynchronous call to the function without
114+
waiting for a result.
75115
76116
The Registry this function was registered with must be initialized
77117
with a Client / api_key for this call facility to be available.
@@ -88,16 +128,6 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
88128
"""
89129
return self._primitive_dispatch(Arguments(args, kwargs))
90130

91-
def _primitive_dispatch(self, input: Any = None) -> DispatchID:
92-
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
93-
return dispatch_id
94-
95-
async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
96-
"""Asynchronously call the function from a @dispatch.function."""
97-
return await dispatch.coroutine.call(
98-
self.build_call(*args, **kwargs, correlation_id=None)
99-
)
100-
101131
def build_call(
102132
self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs
103133
) -> Call:
@@ -117,16 +147,6 @@ def build_call(
117147
Arguments(args, kwargs), correlation_id=correlation_id
118148
)
119149

120-
def _build_primitive_call(
121-
self, input: Any, correlation_id: int | None = None
122-
) -> Call:
123-
return Call(
124-
correlation_id=correlation_id,
125-
endpoint=self.endpoint,
126-
function=self.name,
127-
input=input,
128-
)
129-
130150

131151
class Registry:
132152
"""Registry of local functions."""
@@ -141,89 +161,73 @@ def __init__(self, endpoint: str, client: Client):
141161
client: Client for the Dispatch API. Used to dispatch calls to
142162
local functions.
143163
"""
144-
self._functions: Dict[str, Function] = {}
164+
self._functions: Dict[str, PrimitiveFunction] = {}
145165
self._endpoint = endpoint
146166
self._client = client
147167

148-
def function(self, func: Callable[P, T]) -> Function[P, T]:
168+
@overload
169+
def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ...
170+
171+
@overload
172+
def function(self, func: Callable[P, T]) -> Function[P, T]: ...
173+
174+
def function(self, func):
149175
"""Decorator that registers functions."""
150-
if inspect.iscoroutinefunction(func):
151-
return self._register_coroutine(func)
152-
return self._register_function(func)
176+
if not inspect.iscoroutinefunction(func):
177+
logger.info("registering function: %s", func.__qualname__)
178+
return self._register_function(func)
153179

154-
def primitive_function(self, func: PrimitiveFunctionType) -> Function:
155-
"""Decorator that registers primitive functions."""
156-
return self._register_primitive_function(func)
180+
logger.info("registering coroutine: %s", func.__qualname__)
181+
return self._register_coroutine(func)
157182

158183
def _register_function(self, func: Callable[P, T]) -> Function[P, T]:
159-
logger.info("registering function: %s", func.__qualname__)
160-
161-
# Register the function with the experimental.durable package, in case
162-
# it's referenced from a coroutine.
163184
func = durable(func)
164185

165186
@wraps(func)
166-
def primitive_func(input: Input) -> Output:
167-
try:
168-
try:
169-
args, kwargs = input.input_arguments()
170-
except ValueError:
171-
raise ValueError("incorrect input for function")
172-
raw_output = func(*args, **kwargs)
173-
except Exception as e:
174-
logger.exception(
175-
f"@dispatch.function: '{func.__name__}' raised an exception"
176-
)
177-
return Output.error(Error.from_exception(e))
178-
else:
179-
return Output.value(raw_output)
180-
181-
primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
182-
primitive_func = durable(primitive_func)
187+
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
188+
return func(*args, **kwargs)
189+
190+
async_wrapper.__qualname__ = f"{func.__qualname__}_async"
183191

184-
return self._register(primitive_func, func, coroutine=False)
192+
return self._register_coroutine(async_wrapper)
185193

186194
def _register_coroutine(
187-
self, func: Callable[P, CoroutineType[Any, Any, T]]
195+
self, func: Callable[P, Coroutine[Any, Any, T]]
188196
) -> Function[P, T]:
189-
logger.info("registering coroutine: %s", func.__qualname__)
197+
name = func.__qualname__
198+
logger.info("registering coroutine: %s", name)
190199

191200
func = durable(func)
192201

193202
@wraps(func)
194203
def primitive_func(input: Input) -> Output:
195204
return OneShotScheduler(func).run(input)
196205

197-
primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
206+
primitive_func.__qualname__ = f"{name}_primitive"
198207
primitive_func = durable(primitive_func)
199208

200-
return self._register(primitive_func, func, coroutine=True)
209+
wrapped_func = Function[P, T](
210+
self._endpoint, self._client, name, primitive_func, func
211+
)
212+
self._register(name, wrapped_func)
213+
return wrapped_func
201214

202-
def _register_primitive_function(
215+
def primitive_function(
203216
self, primitive_func: PrimitiveFunctionType
204-
) -> Function[P, T]:
205-
logger.info("registering primitive function: %s", primitive_func.__qualname__)
206-
return self._register(primitive_func, func=None, coroutine=False)
217+
) -> PrimitiveFunction:
218+
"""Decorator that registers primitive functions."""
219+
name = primitive_func.__qualname__
220+
logger.info("registering primitive function: %s", name)
221+
wrapped_func = PrimitiveFunction(
222+
self._endpoint, self._client, name, primitive_func
223+
)
224+
self._register(name, wrapped_func)
225+
return wrapped_func
207226

208-
def _register(
209-
self,
210-
primitive_func: PrimitiveFunctionType,
211-
func: Callable[P, T] | None,
212-
coroutine: bool,
213-
) -> Function[P, T]:
214-
if func:
215-
name = func.__qualname__
216-
else:
217-
name = primitive_func.__qualname__
227+
def _register(self, name: str, wrapped_func: PrimitiveFunction):
218228
if name in self._functions:
219-
raise ValueError(
220-
f"function or coroutine already registered with name '{name}'"
221-
)
222-
wrapped_func = Function[P, T](
223-
self._endpoint, self._client, name, primitive_func, func, coroutine
224-
)
229+
raise ValueError(f"function already registered with name '{name}'")
225230
self._functions[name] = wrapped_func
226-
return wrapped_func
227231

228232
def set_client(self, client: Client):
229233
"""Set the Client instance used to dispatch calls to local functions."""

0 commit comments

Comments
 (0)