33
33
"""
34
34
35
35
36
- P = ParamSpec ("P" )
37
- T = TypeVar ("T" )
38
-
39
-
40
- class Function (Generic [P , T ]):
41
- """Callable wrapper around a function meant to be used throughout the
42
- Dispatch Python SDK.
43
- """
44
-
45
- __slots__ = ("_endpoint" , "_client" , "_name" , "_primitive_func" , "_func" )
36
+ class PrimitiveFunction :
37
+ __slots__ = ("_endpoint" , "_client" , "_name" , "_primitive_func" )
46
38
47
39
def __init__ (
48
40
self ,
49
41
endpoint : str ,
50
42
client : Client ,
51
43
name : str ,
52
44
primitive_func : PrimitiveFunctionType ,
53
- func : Callable [..., Any ] | None ,
54
45
):
55
46
self ._endpoint = endpoint
56
47
self ._client = client
57
48
self ._name = name
58
49
self ._primitive_func = primitive_func
59
- self ._func : Callable [P , Coroutine [Any , Any , T ]] | None = (
60
- durable (self ._call_async ) if func else None
61
- )
62
-
63
- def __call__ (self , * args : P .args , ** kwargs : P .kwargs ) -> Coroutine [Any , Any , T ]:
64
- if self ._func is None :
65
- raise ValueError ("cannot call a primitive function directly" )
66
- return self ._func (* args , ** kwargs )
67
-
68
- def _primitive_call (self , input : Input ) -> Output :
69
- return self ._primitive_func (input )
70
50
71
51
@property
72
52
def endpoint (self ) -> str :
@@ -76,8 +56,62 @@ def endpoint(self) -> str:
76
56
def name (self ) -> str :
77
57
return self ._name
78
58
59
+ def _primitive_call (self , input : Input ) -> Output :
60
+ return self ._primitive_func (input )
61
+
62
+ def _primitive_dispatch (self , input : Any = None ) -> DispatchID :
63
+ [dispatch_id ] = self ._client .dispatch ([self ._build_primitive_call (input )])
64
+ return dispatch_id
65
+
66
+ def _build_primitive_call (
67
+ self , input : Any , correlation_id : int | None = None
68
+ ) -> Call :
69
+ return Call (
70
+ correlation_id = correlation_id ,
71
+ endpoint = self .endpoint ,
72
+ function = self .name ,
73
+ input = input ,
74
+ )
75
+
76
+
77
+ P = ParamSpec ("P" )
78
+ T = TypeVar ("T" )
79
+
80
+
81
+ class Function (PrimitiveFunction , Generic [P , T ]):
82
+ """Callable wrapper around a function meant to be used throughout the
83
+ Dispatch Python SDK.
84
+ """
85
+
86
+ __slots__ = ("_func_indirect" ,)
87
+
88
+ def __init__ (
89
+ self ,
90
+ endpoint : str ,
91
+ client : Client ,
92
+ name : str ,
93
+ primitive_func : PrimitiveFunctionType ,
94
+ func : Callable ,
95
+ ):
96
+ PrimitiveFunction .__init__ (self , endpoint , client , name , primitive_func )
97
+
98
+ self ._func_indirect : Callable [P , Coroutine [Any , Any , T ]] = durable (
99
+ self ._call_async
100
+ )
101
+
102
+ async def _call_async (self , * args : P .args , ** kwargs : P .kwargs ) -> T :
103
+ return await dispatch .coroutine .call (
104
+ self .build_call (* args , ** kwargs , correlation_id = None )
105
+ )
106
+
107
+ def __call__ (self , * args : P .args , ** kwargs : P .kwargs ) -> Coroutine [Any , Any , T ]:
108
+ """Call the function asynchronously (through Dispatch), and return a
109
+ coroutine that can be awaited to retrieve the call result."""
110
+ return self ._func_indirect (* args , ** kwargs )
111
+
79
112
def dispatch (self , * args : P .args , ** kwargs : P .kwargs ) -> DispatchID :
80
- """Dispatch a call to the function.
113
+ """Dispatch an asynchronous call to the function without
114
+ waiting for a result.
81
115
82
116
The Registry this function was registered with must be initialized
83
117
with a Client / api_key for this call facility to be available.
@@ -94,16 +128,6 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
94
128
"""
95
129
return self ._primitive_dispatch (Arguments (args , kwargs ))
96
130
97
- def _primitive_dispatch (self , input : Any = None ) -> DispatchID :
98
- [dispatch_id ] = self ._client .dispatch ([self ._build_primitive_call (input )])
99
- return dispatch_id
100
-
101
- async def _call_async (self , * args : P .args , ** kwargs : P .kwargs ) -> T :
102
- """Asynchronously call the function from a @dispatch.function."""
103
- return await dispatch .coroutine .call (
104
- self .build_call (* args , ** kwargs , correlation_id = None )
105
- )
106
-
107
131
def build_call (
108
132
self , * args : P .args , correlation_id : int | None = None , ** kwargs : P .kwargs
109
133
) -> Call :
@@ -123,16 +147,6 @@ def build_call(
123
147
Arguments (args , kwargs ), correlation_id = correlation_id
124
148
)
125
149
126
- def _build_primitive_call (
127
- self , input : Any , correlation_id : int | None = None
128
- ) -> Call :
129
- return Call (
130
- correlation_id = correlation_id ,
131
- endpoint = self .endpoint ,
132
- function = self .name ,
133
- input = input ,
134
- )
135
-
136
150
137
151
class Registry :
138
152
"""Registry of local functions."""
@@ -147,7 +161,7 @@ def __init__(self, endpoint: str, client: Client):
147
161
client: Client for the Dispatch API. Used to dispatch calls to
148
162
local functions.
149
163
"""
150
- self ._functions : Dict [str , Function ] = {}
164
+ self ._functions : Dict [str , PrimitiveFunction ] = {}
151
165
self ._endpoint = endpoint
152
166
self ._client = client
153
167
@@ -166,10 +180,6 @@ def function(self, func):
166
180
logger .info ("registering coroutine: %s" , func .__qualname__ )
167
181
return self ._register_coroutine (func )
168
182
169
- def primitive_function (self , func : PrimitiveFunctionType ) -> Function :
170
- """Decorator that registers primitive functions."""
171
- return self ._register_primitive_function (func )
172
-
173
183
def _register_function (self , func : Callable [P , T ]) -> Function [P , T ]:
174
184
func = durable (func )
175
185
@@ -184,40 +194,40 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
184
194
def _register_coroutine (
185
195
self , func : Callable [P , Coroutine [Any , Any , T ]]
186
196
) -> Function [P , T ]:
187
- logger .info ("registering coroutine: %s" , func .__qualname__ )
197
+ name = func .__qualname__
198
+ logger .info ("registering coroutine: %s" , name )
188
199
189
200
func = durable (func )
190
201
191
202
@wraps (func )
192
203
def primitive_func (input : Input ) -> Output :
193
204
return OneShotScheduler (func ).run (input )
194
205
195
- primitive_func .__qualname__ = f"{ func . __qualname__ } _primitive"
206
+ primitive_func .__qualname__ = f"{ name } _primitive"
196
207
primitive_func = durable (primitive_func )
197
208
198
- return self ._register (primitive_func , func )
209
+ wrapped_func = Function [P , T ](
210
+ self ._endpoint , self ._client , name , primitive_func , func
211
+ )
212
+ self ._register (name , wrapped_func )
213
+ return wrapped_func
199
214
200
- def _register_primitive_function (
215
+ def primitive_function (
201
216
self , primitive_func : PrimitiveFunctionType
202
- ) -> Function [P , T ]:
203
- logger .info ("registering primitive function: %s" , primitive_func .__qualname__ )
204
- return self ._register (primitive_func , func = None )
217
+ ) -> PrimitiveFunction :
218
+ """Decorator that registers primitive functions."""
219
+ name = primitive_func .__qualname__
220
+ logger .info ("registering primitive function: %s" , name )
221
+ wrapped_func = PrimitiveFunction (
222
+ self ._endpoint , self ._client , name , primitive_func
223
+ )
224
+ self ._register (name , wrapped_func )
225
+ return wrapped_func
205
226
206
- def _register (
207
- self ,
208
- primitive_func : PrimitiveFunctionType ,
209
- func : Callable [P , Coroutine [Any , Any , T ]] | None ,
210
- ) -> Function [P , T ]:
211
- name = func .__qualname__ if func else primitive_func .__qualname__
227
+ def _register (self , name : str , wrapped_func : PrimitiveFunction ):
212
228
if name in self ._functions :
213
- raise ValueError (
214
- f"function or coroutine already registered with name '{ name } '"
215
- )
216
- wrapped_func = Function [P , T ](
217
- self ._endpoint , self ._client , name , primitive_func , func
218
- )
229
+ raise ValueError (f"function already registered with name '{ name } '" )
219
230
self ._functions [name ] = wrapped_func
220
- return wrapped_func
221
231
222
232
def set_client (self , client : Client ):
223
233
"""Set the Client instance used to dispatch calls to local functions."""
0 commit comments