Skip to content

Commit a9148a1

Browse files
committed
Extract a base class to better handle primitive functions
1 parent 8fe9486 commit a9148a1

File tree

1 file changed

+78
-68
lines changed

1 file changed

+78
-68
lines changed

src/dispatch/function.py

Lines changed: 78 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -33,40 +33,20 @@
3333
"""
3434

3535

36-
P = ParamSpec("P")
37-
T = TypeVar("T")
38-
39-
40-
class Function(Generic[P, T]):
41-
"""Callable wrapper around a function meant to be used throughout the
42-
Dispatch Python SDK.
43-
"""
44-
45-
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func")
36+
class PrimitiveFunction:
37+
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func")
4638

4739
def __init__(
4840
self,
4941
endpoint: str,
5042
client: Client,
5143
name: str,
5244
primitive_func: PrimitiveFunctionType,
53-
func: Callable[..., Any] | None,
5445
):
5546
self._endpoint = endpoint
5647
self._client = client
5748
self._name = name
5849
self._primitive_func = primitive_func
59-
self._func: Callable[P, Coroutine[Any, Any, T]] | None = (
60-
durable(self._call_async) if func else None
61-
)
62-
63-
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
64-
if self._func is None:
65-
raise ValueError("cannot call a primitive function directly")
66-
return self._func(*args, **kwargs)
67-
68-
def _primitive_call(self, input: Input) -> Output:
69-
return self._primitive_func(input)
7050

7151
@property
7252
def endpoint(self) -> str:
@@ -76,8 +56,62 @@ def endpoint(self) -> str:
7656
def name(self) -> str:
7757
return self._name
7858

59+
def _primitive_call(self, input: Input) -> Output:
60+
return self._primitive_func(input)
61+
62+
def _primitive_dispatch(self, input: Any = None) -> DispatchID:
63+
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
64+
return dispatch_id
65+
66+
def _build_primitive_call(
67+
self, input: Any, correlation_id: int | None = None
68+
) -> Call:
69+
return Call(
70+
correlation_id=correlation_id,
71+
endpoint=self.endpoint,
72+
function=self.name,
73+
input=input,
74+
)
75+
76+
77+
P = ParamSpec("P")
78+
T = TypeVar("T")
79+
80+
81+
class Function(PrimitiveFunction, Generic[P, T]):
82+
"""Callable wrapper around a function meant to be used throughout the
83+
Dispatch Python SDK.
84+
"""
85+
86+
__slots__ = ("_func_indirect",)
87+
88+
def __init__(
89+
self,
90+
endpoint: str,
91+
client: Client,
92+
name: str,
93+
primitive_func: PrimitiveFunctionType,
94+
func: Callable,
95+
):
96+
PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func)
97+
98+
self._func_indirect: Callable[P, Coroutine[Any, Any, T]] = durable(
99+
self._call_async
100+
)
101+
102+
async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
103+
return await dispatch.coroutine.call(
104+
self.build_call(*args, **kwargs, correlation_id=None)
105+
)
106+
107+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
108+
"""Call the function asynchronously (through Dispatch), and return a
109+
coroutine that can be awaited to retrieve the call result."""
110+
return self._func_indirect(*args, **kwargs)
111+
79112
def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
80-
"""Dispatch a call to the function.
113+
"""Dispatch an asynchronous call to the function without
114+
waiting for a result.
81115
82116
The Registry this function was registered with must be initialized
83117
with a Client / api_key for this call facility to be available.
@@ -94,16 +128,6 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
94128
"""
95129
return self._primitive_dispatch(Arguments(args, kwargs))
96130

97-
def _primitive_dispatch(self, input: Any = None) -> DispatchID:
98-
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
99-
return dispatch_id
100-
101-
async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
102-
"""Asynchronously call the function from a @dispatch.function."""
103-
return await dispatch.coroutine.call(
104-
self.build_call(*args, **kwargs, correlation_id=None)
105-
)
106-
107131
def build_call(
108132
self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs
109133
) -> Call:
@@ -123,16 +147,6 @@ def build_call(
123147
Arguments(args, kwargs), correlation_id=correlation_id
124148
)
125149

126-
def _build_primitive_call(
127-
self, input: Any, correlation_id: int | None = None
128-
) -> Call:
129-
return Call(
130-
correlation_id=correlation_id,
131-
endpoint=self.endpoint,
132-
function=self.name,
133-
input=input,
134-
)
135-
136150

137151
class Registry:
138152
"""Registry of local functions."""
@@ -147,7 +161,7 @@ def __init__(self, endpoint: str, client: Client):
147161
client: Client for the Dispatch API. Used to dispatch calls to
148162
local functions.
149163
"""
150-
self._functions: Dict[str, Function] = {}
164+
self._functions: Dict[str, PrimitiveFunction] = {}
151165
self._endpoint = endpoint
152166
self._client = client
153167

@@ -166,10 +180,6 @@ def function(self, func):
166180
logger.info("registering coroutine: %s", func.__qualname__)
167181
return self._register_coroutine(func)
168182

169-
def primitive_function(self, func: PrimitiveFunctionType) -> Function:
170-
"""Decorator that registers primitive functions."""
171-
return self._register_primitive_function(func)
172-
173183
def _register_function(self, func: Callable[P, T]) -> Function[P, T]:
174184
func = durable(func)
175185

@@ -184,40 +194,40 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
184194
def _register_coroutine(
185195
self, func: Callable[P, Coroutine[Any, Any, T]]
186196
) -> Function[P, T]:
187-
logger.info("registering coroutine: %s", func.__qualname__)
197+
name = func.__qualname__
198+
logger.info("registering coroutine: %s", name)
188199

189200
func = durable(func)
190201

191202
@wraps(func)
192203
def primitive_func(input: Input) -> Output:
193204
return OneShotScheduler(func).run(input)
194205

195-
primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
206+
primitive_func.__qualname__ = f"{name}_primitive"
196207
primitive_func = durable(primitive_func)
197208

198-
return self._register(primitive_func, func)
209+
wrapped_func = Function[P, T](
210+
self._endpoint, self._client, name, primitive_func, func
211+
)
212+
self._register(name, wrapped_func)
213+
return wrapped_func
199214

200-
def _register_primitive_function(
215+
def primitive_function(
201216
self, primitive_func: PrimitiveFunctionType
202-
) -> Function[P, T]:
203-
logger.info("registering primitive function: %s", primitive_func.__qualname__)
204-
return self._register(primitive_func, func=None)
217+
) -> PrimitiveFunction:
218+
"""Decorator that registers primitive functions."""
219+
name = primitive_func.__qualname__
220+
logger.info("registering primitive function: %s", name)
221+
wrapped_func = PrimitiveFunction(
222+
self._endpoint, self._client, name, primitive_func
223+
)
224+
self._register(name, wrapped_func)
225+
return wrapped_func
205226

206-
def _register(
207-
self,
208-
primitive_func: PrimitiveFunctionType,
209-
func: Callable[P, Coroutine[Any, Any, T]] | None,
210-
) -> Function[P, T]:
211-
name = func.__qualname__ if func else primitive_func.__qualname__
227+
def _register(self, name: str, wrapped_func: PrimitiveFunction):
212228
if name in self._functions:
213-
raise ValueError(
214-
f"function or coroutine already registered with name '{name}'"
215-
)
216-
wrapped_func = Function[P, T](
217-
self._endpoint, self._client, name, primitive_func, func
218-
)
229+
raise ValueError(f"function already registered with name '{name}'")
219230
self._functions[name] = wrapped_func
220-
return wrapped_func
221231

222232
def set_client(self, client: Client):
223233
"""Set the Client instance used to dispatch calls to local functions."""

0 commit comments

Comments
 (0)