Skip to content

Commit 5aa45e0

Browse files
authored
Merge pull request #90 from stealthrocket/poll-result-error
Handle PollResult errors
2 parents 0a4391d + bba1353 commit 5aa45e0

File tree

6 files changed

+146
-26
lines changed

6 files changed

+146
-26
lines changed

src/dispatch/fastapi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ async def execute(request: fastapi.Request):
258258
)
259259

260260
logger.debug("finished handling run request with status %s", status.name)
261-
return fastapi.Response(content=response.SerializeToString())
261+
return fastapi.Response(
262+
content=response.SerializeToString(), media_type="application/proto"
263+
)
262264

263265
return app

src/dispatch/proto.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ class Input:
3636
This class is intended to be used as read-only.
3737
"""
3838

39-
__slots__ = ("_has_input", "_input", "_coroutine_state", "_call_results")
39+
__slots__ = (
40+
"_has_input",
41+
"_input",
42+
"_coroutine_state",
43+
"_call_results",
44+
"_poll_error",
45+
)
4046

4147
def __init__(self, req: function_pb.RunRequest):
4248
self._has_input = req.HasField("input")
@@ -54,6 +60,11 @@ def __init__(self, req: function_pb.RunRequest):
5460
self._call_results = [
5561
CallResult._from_proto(r) for r in req.poll_result.results
5662
]
63+
self._poll_error = (
64+
Error._from_proto(req.poll_result.error)
65+
if req.poll_result.HasField("error")
66+
else None
67+
)
5768

5869
@property
5970
def is_first_call(self) -> bool:
@@ -85,6 +96,11 @@ def call_results(self) -> list[CallResult]:
8596
self._assert_resume()
8697
return self._call_results
8798

99+
@property
100+
def poll_error(self) -> Error | None:
101+
self._assert_resume()
102+
return self._poll_error
103+
88104
def _assert_first_call(self):
89105
if self.is_resume:
90106
raise ValueError("This input is for a resumed coroutine")
@@ -105,14 +121,20 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
105121

106122
@classmethod
107123
def from_poll_results(
108-
cls, function: str, coroutine_state: Any, call_results: list[CallResult]
124+
cls,
125+
function: str,
126+
coroutine_state: Any,
127+
call_results: list[CallResult],
128+
error: Error | None = None,
109129
):
110130
return Input(
111131
req=function_pb.RunRequest(
132+
function=function,
112133
poll_result=poll_pb.PollResult(
113134
coroutine_state=coroutine_state,
114135
results=[result._as_proto() for result in call_results],
115-
)
136+
error=error._as_proto() if error else None,
137+
),
116138
)
117139
)
118140

src/dispatch/scheduler.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ class CallResult:
3737

3838

3939
class Future(Protocol):
40-
def add(self, result: CallResult | CoroutineResult): ...
40+
def add_result(self, result: CallResult | CoroutineResult): ...
41+
def add_error(self, error: Exception): ...
4142
def ready(self) -> bool: ...
4243
def error(self) -> Exception | None: ...
4344
def value(self) -> Any: ...
@@ -48,17 +49,25 @@ class CallFuture:
4849
"""A future result of a dispatch.coroutine.call() operation."""
4950

5051
result: CallResult | None = None
52+
first_error: Exception | None = None
5153

52-
def add(self, result: CallResult | CoroutineResult):
54+
def add_result(self, result: CallResult | CoroutineResult):
5355
assert isinstance(result, CallResult)
54-
self.result = result
56+
if self.result is None:
57+
self.result = result
58+
if result.error is not None and self.first_error is None:
59+
self.first_error = result.error
60+
61+
def add_error(self, error: Exception):
62+
if self.first_error is None:
63+
self.first_error = error
5564

5665
def ready(self) -> bool:
57-
return self.result is not None
66+
return self.first_error is not None or self.result is not None
5867

5968
def error(self) -> Exception | None:
60-
assert self.result is not None
61-
return self.result.error
69+
assert self.ready()
70+
return self.first_error
6271

6372
def value(self) -> Any:
6473
assert self.result is not None
@@ -74,7 +83,7 @@ class GatherFuture:
7483
results: dict[CoroutineID, CoroutineResult]
7584
first_error: Exception | None = None
7685

77-
def add(self, result: CallResult | CoroutineResult):
86+
def add_result(self, result: CallResult | CoroutineResult):
7887
assert isinstance(result, CoroutineResult)
7988

8089
try:
@@ -87,6 +96,10 @@ def add(self, result: CallResult | CoroutineResult):
8796

8897
self.results[result.coroutine_id] = result
8998

99+
def add_error(self, error: Exception):
100+
if self.first_error is not None:
101+
self.first_error = error
102+
90103
def ready(self) -> bool:
91104
return self.first_error is not None or len(self.waiting) == 0
92105

@@ -134,6 +147,8 @@ class State:
134147
next_coroutine_id: int
135148
next_call_id: int
136149

150+
prev_calls: list[Coroutine]
151+
137152

138153
class OneShotScheduler:
139154
"""Scheduler for local coroutines.
@@ -183,6 +198,7 @@ def _init_state(self, input: Input) -> State:
183198
ready=[Coroutine(id=0, parent_id=None, coroutine=main)],
184199
next_coroutine_id=1,
185200
next_call_id=1,
201+
prev_calls=[],
186202
)
187203

188204
def _rebuild_state(self, input: Input):
@@ -203,19 +219,37 @@ def _rebuild_state(self, input: Input):
203219
raise IncompatibleStateError from e
204220

205221
def _run(self, input: Input) -> Output:
222+
206223
if input.is_first_call:
207224
state = self._init_state(input)
208225
else:
209226
state = self._rebuild_state(input)
210227

228+
poll_error = input.poll_error
229+
if poll_error is not None:
230+
error = poll_error.to_exception()
231+
logger.debug("dispatching poll error: %s", error)
232+
for coroutine in state.prev_calls:
233+
future = coroutine.result
234+
assert future is not None
235+
future.add_error(error)
236+
if future.ready() and coroutine.id in state.suspended:
237+
state.ready.append(coroutine)
238+
del state.suspended[coroutine.id]
239+
logger.debug("coroutine %s is now ready", coroutine)
240+
241+
state.prev_calls = []
242+
211243
logger.debug("dispatching %d call result(s)", len(input.call_results))
212244
for cr in input.call_results:
213245
assert cr.correlation_id is not None
214246
coroutine_id = correlation_coroutine_id(cr.correlation_id)
215247
call_id = correlation_call_id(cr.correlation_id)
216248

217-
error = cr.error.to_exception() if cr.error is not None else None
218-
call_result = CallResult(call_id=call_id, value=cr.output, error=error)
249+
call_error = cr.error.to_exception() if cr.error is not None else None
250+
call_result = CallResult(
251+
call_id=call_id, value=cr.output, error=call_error
252+
)
219253

220254
try:
221255
owner = state.suspended[coroutine_id]
@@ -226,8 +260,8 @@ def _run(self, input: Input) -> Output:
226260
continue
227261

228262
logger.debug("dispatching %s to %s", call_result, owner)
229-
future.add(call_result)
230-
if future.ready():
263+
future.add_result(call_result)
264+
if future.ready() and owner.id in state.suspended:
231265
state.ready.append(owner)
232266
del state.suspended[owner.id]
233267
logger.debug("owner %s is now ready", owner)
@@ -284,8 +318,8 @@ def _run(self, input: Input) -> Output:
284318
except (KeyError, AssertionError):
285319
logger.warning("discarding %s", coroutine_result)
286320
else:
287-
future.add(coroutine_result)
288-
if future.ready():
321+
future.add_result(coroutine_result)
322+
if future.ready() and parent.id in state.suspended:
289323
state.ready.insert(0, parent)
290324
del state.suspended[parent.id]
291325
logger.debug("parent %s is now ready", parent)
@@ -308,6 +342,7 @@ def _run(self, input: Input) -> Output:
308342
pending_calls.append(call)
309343
coroutine.result = CallFuture()
310344
state.suspended[coroutine.id] = coroutine
345+
state.prev_calls.append(coroutine)
311346

312347
case Gather():
313348
gather = coroutine_yield

src/dispatch/sdk/v1/poll_pb2.py

Lines changed: 6 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/dispatch/sdk/v1/poll_pb2.pyi

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from google.protobuf.internal import containers as _containers
1111

1212
from buf.validate import validate_pb2 as _validate_pb2
1313
from dispatch.sdk.v1 import call_pb2 as _call_pb2
14+
from dispatch.sdk.v1 import error_pb2 as _error_pb2
1415

1516
DESCRIPTOR: _descriptor.FileDescriptor
1617

@@ -33,13 +34,16 @@ class Poll(_message.Message):
3334
) -> None: ...
3435

3536
class PollResult(_message.Message):
36-
__slots__ = ("coroutine_state", "results")
37+
__slots__ = ("coroutine_state", "results", "error")
3738
COROUTINE_STATE_FIELD_NUMBER: _ClassVar[int]
3839
RESULTS_FIELD_NUMBER: _ClassVar[int]
40+
ERROR_FIELD_NUMBER: _ClassVar[int]
3941
coroutine_state: bytes
4042
results: _containers.RepeatedCompositeFieldContainer[_call_pb2.CallResult]
43+
error: _error_pb2.Error
4144
def __init__(
4245
self,
4346
coroutine_state: _Optional[bytes] = ...,
4447
results: _Optional[_Iterable[_Union[_call_pb2.CallResult, _Mapping]]] = ...,
48+
error: _Optional[_Union[_error_pb2.Error, _Mapping]] = ...,
4549
) -> None: ...

tests/dispatch/test_scheduler.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import unittest
2-
from pprint import pprint
32
from typing import Any, Callable
43

54
from dispatch.coroutine import call, gather
@@ -221,6 +220,56 @@ async def main():
221220

222221
self.assertEqual(len(correlation_ids), 8)
223222

223+
def test_poll_error(self):
224+
# The purpose of the test is to ensure that when a poll error occurs,
225+
# we only abort the calls that were made on the previous yield. Any
226+
# other in-flight calls from previous yields are not affected.
227+
228+
@durable
229+
async def c_then_d():
230+
c_result = await call_one("c")
231+
try:
232+
# The poll error will affect this call only.
233+
d_result = await call_one("d")
234+
except RuntimeError as e:
235+
assert str(e) == "too many calls"
236+
d_result = 100
237+
return c_result + d_result
238+
239+
@durable
240+
async def main(c_then_d):
241+
return await gather(
242+
call_concurrently("a", "b"),
243+
c_then_d(),
244+
)
245+
246+
output = self.start(main, c_then_d)
247+
calls = self.assert_poll_call_functions(output, ["a", "b", "c"])
248+
249+
call_a, call_b, call_c = calls
250+
a_result, b_result, c_result = 10, 20, 30
251+
output = self.resume(
252+
main,
253+
output,
254+
[CallResult.from_value(c_result, correlation_id=call_c.correlation_id)],
255+
)
256+
self.assert_poll_call_functions(output, ["d"])
257+
258+
output = self.resume(
259+
main, output, [], poll_error=RuntimeError("too many calls")
260+
)
261+
self.assert_poll_call_functions(output, [])
262+
output = self.resume(
263+
main,
264+
output,
265+
[
266+
CallResult.from_value(a_result, correlation_id=call_a.correlation_id),
267+
CallResult.from_value(b_result, correlation_id=call_b.correlation_id),
268+
],
269+
)
270+
271+
self.assert_exit_result_value(output, [[a_result, b_result], c_result + 100])
272+
224273
def test_raise_indirect(self):
225274
@durable
226275
async def main():
@@ -234,11 +283,18 @@ def start(self, main: Callable, *args: Any, **kwargs: Any) -> Output:
234283
return OneShotScheduler(main).run(input)
235284

236285
def resume(
237-
self, main: Callable, prev_output: Output, call_results: list[CallResult]
286+
self,
287+
main: Callable,
288+
prev_output: Output,
289+
call_results: list[CallResult],
290+
poll_error: Exception | None = None,
238291
):
239292
poll = self.assert_poll(prev_output)
240293
input = Input.from_poll_results(
241-
main.__qualname__, poll.coroutine_state, call_results
294+
main.__qualname__,
295+
poll.coroutine_state,
296+
call_results,
297+
Error.from_exception(poll_error) if poll_error else None,
242298
)
243299
return OneShotScheduler(main).run(input)
244300

0 commit comments

Comments
 (0)