Skip to content

Commit 47e78c6

Browse files
committed
Refactor CABI: move common code from Subtask to canon_lower
1 parent 03cd36a commit 47e78c6

File tree

1 file changed

+55
-72
lines changed

1 file changed

+55
-72
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 55 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -626,36 +626,19 @@ class State(IntEnum):
626626
CANCELLED_BEFORE_RETURNED = 4
627627

628628
state: State
629-
supertask: Optional[Task]
629+
task: Task
630630
lenders: Optional[list[ResourceHandle]]
631631
request_cancel_begin: asyncio.Future
632632
request_cancel_end: asyncio.Future
633633

634-
def __init__(self, supertask):
634+
def __init__(self, task):
635635
Waitable.__init__(self)
636636
self.state = Subtask.State.STARTING
637-
self.supertask = supertask
637+
self.task = task
638638
self.lenders = []
639639
self.request_cancel_begin = asyncio.Future()
640640
self.request_cancel_end = asyncio.Future()
641641

642-
async def call_sync(self, callee, on_start, on_resolve):
643-
def sync_on_start():
644-
assert(self.state == Subtask.State.STARTING)
645-
self.state = Subtask.State.STARTED
646-
return on_start()
647-
648-
def sync_on_resolve(result):
649-
assert(result is not None)
650-
assert(self.state == Subtask.State.STARTED)
651-
self.state = Subtask.State.RETURNED
652-
on_resolve(result)
653-
654-
await Task.call_sync(self.supertask, callee, sync_on_start, sync_on_resolve)
655-
656-
def cancelled(self):
657-
return self.request_cancel_begin.done()
658-
659642
def resolved(self):
660643
match self.state:
661644
case (Subtask.State.STARTING |
@@ -667,32 +650,18 @@ def resolved(self):
667650
return True
668651

669652
async def request_cancel(self):
670-
assert(not self.cancelled() and not self.resolved())
653+
assert(not self.cancellation_requested() and not self.resolved())
671654
self.request_cancel_begin.set_result(None)
672655
await self.request_cancel_end
673656

657+
def cancellation_requested(self):
658+
return self.request_cancel_begin.done()
659+
674660
async def call_async(self, callee, on_start, on_resolve):
675661
async def do_call():
676-
await callee(self.supertask, async_on_start, async_on_resolve, async_on_block)
662+
await callee(self.task, on_start, on_resolve, async_on_block)
677663
relinquish_control()
678664

679-
def async_on_start():
680-
assert(self.state == Subtask.State.STARTING)
681-
self.state = Subtask.State.STARTED
682-
return on_start()
683-
684-
def async_on_resolve(result):
685-
if result is None:
686-
if self.state == Subtask.State.STARTING:
687-
self.state = Subtask.State.CANCELLED_BEFORE_STARTED
688-
else:
689-
assert(self.state == Subtask.State.STARTED)
690-
self.state = Subtask.State.CANCELLED_BEFORE_RETURNED
691-
else:
692-
assert(self.state == Subtask.State.STARTED)
693-
self.state = Subtask.State.RETURNED
694-
on_resolve(result)
695-
696665
async def async_on_block(awaitable):
697666
relinquish_control()
698667
if not self.request_cancel_end.done():
@@ -1990,52 +1959,65 @@ async def call_and_trap_on_throw(callee, task, args):
19901959
async def canon_lower(opts, ft, callee, task, flat_args):
19911960
trap_if(not task.inst.may_leave)
19921961
subtask = Subtask(task)
1962+
19931963
cx = LiftLowerContext(opts, task.inst, subtask)
19941964
flat_ft = flatten_functype(opts, ft, 'lower')
19951965
assert(types_match_values(flat_ft.params, flat_args))
19961966
flat_args = CoreValueIter(flat_args)
19971967

19981968
if opts.sync:
1999-
def on_start():
2000-
return lift_flat_values(cx, MAX_FLAT_PARAMS, flat_args, ft.param_types())
2001-
2002-
flat_results = None
2003-
def on_resolve(result):
2004-
nonlocal flat_results
2005-
flat_results = lower_flat_values(cx, MAX_FLAT_RESULTS, result, ft.result_type(), flat_args)
1969+
max_flat_params = MAX_FLAT_PARAMS
1970+
max_flat_results = MAX_FLAT_RESULTS
1971+
else:
1972+
max_flat_params = MAX_FLAT_ASYNC_PARAMS
1973+
max_flat_results = 0
20061974

2007-
await subtask.call_sync(callee, on_start, on_resolve)
2008-
assert(types_match_values(flat_ft.results, flat_results))
2009-
subtask.deliver_resolve()
2010-
return flat_results
1975+
on_progress = lambda:()
1976+
flat_results = None
20111977

20121978
def on_start():
20131979
on_progress()
2014-
return lift_flat_values(cx, MAX_FLAT_ASYNC_PARAMS, flat_args, ft.param_types())
1980+
assert(subtask.state == Subtask.State.STARTING)
1981+
subtask.state = Subtask.State.STARTED
1982+
return lift_flat_values(cx, max_flat_params, flat_args, ft.param_types())
20151983

20161984
def on_resolve(result):
20171985
on_progress()
2018-
if result is not None:
2019-
[] = lower_flat_values(cx, 0, result, ft.result_type(), flat_args)
2020-
2021-
subtaski = None
2022-
def on_progress():
2023-
if subtaski is not None:
2024-
def subtask_event():
2025-
if subtask.resolved():
2026-
subtask.deliver_resolve()
2027-
return (EventCode.SUBTASK, subtaski, subtask.state)
2028-
subtask.set_event(subtask_event)
2029-
2030-
await subtask.call_async(callee, on_start, on_resolve)
2031-
if subtask.resolved():
2032-
subtask.deliver_resolve()
2033-
return [Subtask.State.RETURNED]
1986+
if result is None:
1987+
assert(subtask.cancellation_requested())
1988+
if subtask.state == Subtask.State.STARTING:
1989+
subtask.state = Subtask.State.CANCELLED_BEFORE_STARTED
1990+
else:
1991+
assert(subtask.state == Subtask.State.STARTED)
1992+
subtask.state = Subtask.State.CANCELLED_BEFORE_RETURNED
1993+
else:
1994+
assert(subtask.state == Subtask.State.STARTED)
1995+
subtask.state = Subtask.State.RETURNED
1996+
nonlocal flat_results
1997+
flat_results = lower_flat_values(cx, max_flat_results, result, ft.result_type(), flat_args)
20341998

2035-
subtaski = task.inst.table.add(subtask)
2036-
assert(0 < subtaski <= Table.MAX_LENGTH < 2**28)
2037-
assert(0 <= subtask.state < 2**4)
2038-
return [subtask.state | (subtaski << 4)]
1999+
if opts.sync:
2000+
await task.call_sync(callee, on_start, on_resolve)
2001+
assert(types_match_values(flat_ft.results, flat_results))
2002+
subtask.deliver_resolve()
2003+
return flat_results
2004+
else:
2005+
await subtask.call_async(callee, on_start, on_resolve)
2006+
if subtask.resolved():
2007+
assert(flat_results == [])
2008+
subtask.deliver_resolve()
2009+
return [Subtask.State.RETURNED]
2010+
else:
2011+
subtaski = task.inst.table.add(subtask)
2012+
def on_progress():
2013+
def subtask_event():
2014+
if subtask.resolved():
2015+
subtask.deliver_resolve()
2016+
return (EventCode.SUBTASK, subtaski, subtask.state)
2017+
subtask.set_event(subtask_event)
2018+
assert(0 < subtaski <= Table.MAX_LENGTH < 2**28)
2019+
assert(0 <= subtask.state < 2**4)
2020+
return [subtask.state | (subtaski << 4)]
20392021

20402022
### `canon resource.new`
20412023

@@ -2199,7 +2181,8 @@ async def canon_subtask_cancel(sync, task, i):
21992181
trap_if(not task.inst.may_leave)
22002182
subtask = task.inst.table.get(i)
22012183
trap_if(not isinstance(subtask, Subtask))
2202-
trap_if(subtask.resolve_delivered() or subtask.cancelled())
2184+
trap_if(subtask.resolve_delivered())
2185+
trap_if(subtask.cancellation_requested())
22032186
if subtask.resolved():
22042187
assert(subtask.has_pending_event())
22052188
else:

0 commit comments

Comments
 (0)