4
4
import logging
5
5
from functools import wraps
6
6
from types import CoroutineType
7
- from typing import Any , Callable , Dict , Generic , ParamSpec , TypeAlias , TypeVar
7
+ from typing import (
8
+ Any ,
9
+ Callable ,
10
+ Coroutine ,
11
+ Dict ,
12
+ Generic ,
13
+ ParamSpec ,
14
+ TypeAlias ,
15
+ TypeVar ,
16
+ overload ,
17
+ )
8
18
9
19
import dispatch .coroutine
10
20
from dispatch .client import Client
23
33
"""
24
34
25
35
26
- P = ParamSpec ("P" )
27
- T = TypeVar ("T" )
28
-
29
-
30
- class Function (Generic [P , T ]):
31
- """Callable wrapper around a function meant to be used throughout the
32
- Dispatch Python SDK.
33
- """
34
-
35
- __slots__ = ("_endpoint" , "_client" , "_name" , "_primitive_func" , "_func" )
36
+ class PrimitiveFunction :
37
+ __slots__ = ("_endpoint" , "_client" , "_name" , "_primitive_func" )
36
38
37
39
def __init__ (
38
40
self ,
39
41
endpoint : str ,
40
42
client : Client ,
41
43
name : str ,
42
44
primitive_func : PrimitiveFunctionType ,
43
- func : Callable [P , T ] | None ,
44
- coroutine : bool = False ,
45
45
):
46
46
self ._endpoint = endpoint
47
47
self ._client = client
48
48
self ._name = name
49
49
self ._primitive_func = primitive_func
50
- if func :
51
- self ._func : Callable [P , T ] | None = (
52
- durable (self ._call_async ) if coroutine else func
53
- )
54
- else :
55
- self ._func = None
56
-
57
- def __call__ (self , * args : P .args , ** kwargs : P .kwargs ):
58
- if self ._func is None :
59
- raise ValueError ("cannot call a primitive function directly" )
60
- return self ._func (* args , ** kwargs )
61
-
62
- def _primitive_call (self , input : Input ) -> Output :
63
- return self ._primitive_func (input )
64
50
65
51
@property
66
52
def endpoint (self ) -> str :
@@ -70,8 +56,62 @@ def endpoint(self) -> str:
70
56
def name (self ) -> str :
71
57
return self ._name
72
58
59
+ def _primitive_call (self , input : Input ) -> Output :
60
+ return self ._primitive_func (input )
61
+
62
+ def _primitive_dispatch (self , input : Any = None ) -> DispatchID :
63
+ [dispatch_id ] = self ._client .dispatch ([self ._build_primitive_call (input )])
64
+ return dispatch_id
65
+
66
+ def _build_primitive_call (
67
+ self , input : Any , correlation_id : int | None = None
68
+ ) -> Call :
69
+ return Call (
70
+ correlation_id = correlation_id ,
71
+ endpoint = self .endpoint ,
72
+ function = self .name ,
73
+ input = input ,
74
+ )
75
+
76
+
77
+ P = ParamSpec ("P" )
78
+ T = TypeVar ("T" )
79
+
80
+
81
+ class Function (PrimitiveFunction , Generic [P , T ]):
82
+ """Callable wrapper around a function meant to be used throughout the
83
+ Dispatch Python SDK.
84
+ """
85
+
86
+ __slots__ = ("_func_indirect" ,)
87
+
88
+ def __init__ (
89
+ self ,
90
+ endpoint : str ,
91
+ client : Client ,
92
+ name : str ,
93
+ primitive_func : PrimitiveFunctionType ,
94
+ func : Callable ,
95
+ ):
96
+ PrimitiveFunction .__init__ (self , endpoint , client , name , primitive_func )
97
+
98
+ self ._func_indirect : Callable [P , Coroutine [Any , Any , T ]] = durable (
99
+ self ._call_async
100
+ )
101
+
102
+ async def _call_async (self , * args : P .args , ** kwargs : P .kwargs ) -> T :
103
+ return await dispatch .coroutine .call (
104
+ self .build_call (* args , ** kwargs , correlation_id = None )
105
+ )
106
+
107
+ def __call__ (self , * args : P .args , ** kwargs : P .kwargs ) -> Coroutine [Any , Any , T ]:
108
+ """Call the function asynchronously (through Dispatch), and return a
109
+ coroutine that can be awaited to retrieve the call result."""
110
+ return self ._func_indirect (* args , ** kwargs )
111
+
73
112
def dispatch (self , * args : P .args , ** kwargs : P .kwargs ) -> DispatchID :
74
- """Dispatch a call to the function.
113
+ """Dispatch an asynchronous call to the function without
114
+ waiting for a result.
75
115
76
116
The Registry this function was registered with must be initialized
77
117
with a Client / api_key for this call facility to be available.
@@ -88,16 +128,6 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
88
128
"""
89
129
return self ._primitive_dispatch (Arguments (args , kwargs ))
90
130
91
- def _primitive_dispatch (self , input : Any = None ) -> DispatchID :
92
- [dispatch_id ] = self ._client .dispatch ([self ._build_primitive_call (input )])
93
- return dispatch_id
94
-
95
- async def _call_async (self , * args : P .args , ** kwargs : P .kwargs ) -> T :
96
- """Asynchronously call the function from a @dispatch.function."""
97
- return await dispatch .coroutine .call (
98
- self .build_call (* args , ** kwargs , correlation_id = None )
99
- )
100
-
101
131
def build_call (
102
132
self , * args : P .args , correlation_id : int | None = None , ** kwargs : P .kwargs
103
133
) -> Call :
@@ -117,16 +147,6 @@ def build_call(
117
147
Arguments (args , kwargs ), correlation_id = correlation_id
118
148
)
119
149
120
- def _build_primitive_call (
121
- self , input : Any , correlation_id : int | None = None
122
- ) -> Call :
123
- return Call (
124
- correlation_id = correlation_id ,
125
- endpoint = self .endpoint ,
126
- function = self .name ,
127
- input = input ,
128
- )
129
-
130
150
131
151
class Registry :
132
152
"""Registry of local functions."""
@@ -141,89 +161,73 @@ def __init__(self, endpoint: str, client: Client):
141
161
client: Client for the Dispatch API. Used to dispatch calls to
142
162
local functions.
143
163
"""
144
- self ._functions : Dict [str , Function ] = {}
164
+ self ._functions : Dict [str , PrimitiveFunction ] = {}
145
165
self ._endpoint = endpoint
146
166
self ._client = client
147
167
148
- def function (self , func : Callable [P , T ]) -> Function [P , T ]:
168
+ @overload
169
+ def function (self , func : Callable [P , Coroutine [Any , Any , T ]]) -> Function [P , T ]: ...
170
+
171
+ @overload
172
+ def function (self , func : Callable [P , T ]) -> Function [P , T ]: ...
173
+
174
+ def function (self , func ):
149
175
"""Decorator that registers functions."""
150
- if inspect .iscoroutinefunction (func ):
151
- return self . _register_coroutine ( func )
152
- return self ._register_function (func )
176
+ if not inspect .iscoroutinefunction (func ):
177
+ logger . info ( "registering function: %s" , func . __qualname__ )
178
+ return self ._register_function (func )
153
179
154
- def primitive_function (self , func : PrimitiveFunctionType ) -> Function :
155
- """Decorator that registers primitive functions."""
156
- return self ._register_primitive_function (func )
180
+ logger .info ("registering coroutine: %s" , func .__qualname__ )
181
+ return self ._register_coroutine (func )
157
182
158
183
def _register_function (self , func : Callable [P , T ]) -> Function [P , T ]:
159
- logger .info ("registering function: %s" , func .__qualname__ )
160
-
161
- # Register the function with the experimental.durable package, in case
162
- # it's referenced from a coroutine.
163
184
func = durable (func )
164
185
165
186
@wraps (func )
166
- def primitive_func (input : Input ) -> Output :
167
- try :
168
- try :
169
- args , kwargs = input .input_arguments ()
170
- except ValueError :
171
- raise ValueError ("incorrect input for function" )
172
- raw_output = func (* args , ** kwargs )
173
- except Exception as e :
174
- logger .exception (
175
- f"@dispatch.function: '{ func .__name__ } ' raised an exception"
176
- )
177
- return Output .error (Error .from_exception (e ))
178
- else :
179
- return Output .value (raw_output )
180
-
181
- primitive_func .__qualname__ = f"{ func .__qualname__ } _primitive"
182
- primitive_func = durable (primitive_func )
187
+ async def async_wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T :
188
+ return func (* args , ** kwargs )
189
+
190
+ async_wrapper .__qualname__ = f"{ func .__qualname__ } _async"
183
191
184
- return self ._register ( primitive_func , func , coroutine = False )
192
+ return self ._register_coroutine ( async_wrapper )
185
193
186
194
def _register_coroutine (
187
- self , func : Callable [P , CoroutineType [Any , Any , T ]]
195
+ self , func : Callable [P , Coroutine [Any , Any , T ]]
188
196
) -> Function [P , T ]:
189
- logger .info ("registering coroutine: %s" , func .__qualname__ )
197
+ name = func .__qualname__
198
+ logger .info ("registering coroutine: %s" , name )
190
199
191
200
func = durable (func )
192
201
193
202
@wraps (func )
194
203
def primitive_func (input : Input ) -> Output :
195
204
return OneShotScheduler (func ).run (input )
196
205
197
- primitive_func .__qualname__ = f"{ func . __qualname__ } _primitive"
206
+ primitive_func .__qualname__ = f"{ name } _primitive"
198
207
primitive_func = durable (primitive_func )
199
208
200
- return self ._register (primitive_func , func , coroutine = True )
209
+ wrapped_func = Function [P , T ](
210
+ self ._endpoint , self ._client , name , primitive_func , func
211
+ )
212
+ self ._register (name , wrapped_func )
213
+ return wrapped_func
201
214
202
- def _register_primitive_function (
215
+ def primitive_function (
203
216
self , primitive_func : PrimitiveFunctionType
204
- ) -> Function [P , T ]:
205
- logger .info ("registering primitive function: %s" , primitive_func .__qualname__ )
206
- return self ._register (primitive_func , func = None , coroutine = False )
217
+ ) -> PrimitiveFunction :
218
+ """Decorator that registers primitive functions."""
219
+ name = primitive_func .__qualname__
220
+ logger .info ("registering primitive function: %s" , name )
221
+ wrapped_func = PrimitiveFunction (
222
+ self ._endpoint , self ._client , name , primitive_func
223
+ )
224
+ self ._register (name , wrapped_func )
225
+ return wrapped_func
207
226
208
- def _register (
209
- self ,
210
- primitive_func : PrimitiveFunctionType ,
211
- func : Callable [P , T ] | None ,
212
- coroutine : bool ,
213
- ) -> Function [P , T ]:
214
- if func :
215
- name = func .__qualname__
216
- else :
217
- name = primitive_func .__qualname__
227
+ def _register (self , name : str , wrapped_func : PrimitiveFunction ):
218
228
if name in self ._functions :
219
- raise ValueError (
220
- f"function or coroutine already registered with name '{ name } '"
221
- )
222
- wrapped_func = Function [P , T ](
223
- self ._endpoint , self ._client , name , primitive_func , func , coroutine
224
- )
229
+ raise ValueError (f"function already registered with name '{ name } '" )
225
230
self ._functions [name ] = wrapped_func
226
- return wrapped_func
227
231
228
232
def set_client (self , client : Client ):
229
233
"""Set the Client instance used to dispatch calls to local functions."""
0 commit comments