Skip to content

Commit 4a6e107

Browse files
committed
Replace 'sync' with 'cancellable' in wait/poll/yield
1 parent 00f31f2 commit 4a6e107

File tree

2 files changed

+288
-144
lines changed

2 files changed

+288
-144
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 87 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ class CanonicalOptions(LiftLowerOptions):
211211
post_return: Optional[Callable] = None
212212
sync: bool = True # = !canonopt.async
213213
callback: Optional[Callable] = None
214+
cancellable: bool = False
214215

215216
### Runtime State
216217

@@ -221,22 +222,17 @@ class CanonicalOptions(LiftLowerOptions):
221222
class ComponentInstance:
222223
table: Table
223224
may_leave: bool
224-
backpressure: bool
225-
calling_sync_export: bool
226-
calling_sync_import: bool
227-
pending_tasks: list[tuple[Task, asyncio.Future]]
228-
starting_pending_task: bool
229-
async_waiting_tasks: asyncio.Condition
225+
no_backpressure: asyncio.Event
226+
num_backpressure_waiters: int
227+
lock: asyncio.Lock
230228

231229
def __init__(self):
232230
self.table = Table()
233231
self.may_leave = True
234-
self.backpressure = False
235-
self.calling_sync_export = False
236-
self.calling_sync_import = False
237-
self.pending_tasks = []
238-
self.starting_pending_task = False
239-
self.async_waiting_tasks = asyncio.Condition(scheduler)
232+
self.no_backpressure = asyncio.Event()
233+
self.no_backpressure.set()
234+
self.num_backpressure_waiters = 0
235+
self.lock = asyncio.Lock()
240236

241237
#### Table State
242238

@@ -497,67 +493,64 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
497493
async def enter(self):
498494
assert(scheduler.locked())
499495
self.trap_if_on_the_stack(self.inst)
500-
if not self.may_enter(self) or self.inst.pending_tasks:
501-
f = asyncio.Future()
502-
self.inst.pending_tasks.append((self, f))
503-
if await self.on_block(f) == Cancelled.TRUE:
504-
[i] = [i for i,(t,_) in enumerate(self.inst.pending_tasks) if t == self]
505-
self.inst.pending_tasks.pop(i)
496+
if self.opts.sync or self.opts.callback:
497+
if await self.acquire_instance_lock() == Cancelled.TRUE:
506498
self.on_resolve(None)
507-
return Cancelled.FALSE
508-
assert(self.may_enter(self) and self.inst.starting_pending_task)
509-
self.inst.starting_pending_task = False
510-
if self.opts.sync:
511-
self.inst.calling_sync_export = True
512-
return True
499+
return Cancelled.TRUE
500+
if not self.inst.no_backpressure.is_set() or self.inst.num_backpressure_waiters > 0:
501+
while True:
502+
self.inst.num_backpressure_waiters += 1
503+
cancelled = await self.wait_on(self.inst.no_backpressure.wait(),
504+
cancellable = True, for_callback = False)
505+
self.inst.num_backpressure_waiters -= 1
506+
if cancelled:
507+
self.on_resolve(None)
508+
return Cancelled.TRUE
509+
if self.inst.no_backpressure.is_set():
510+
break
511+
return Cancelled.FALSE
513512

514513
def trap_if_on_the_stack(self, inst):
515514
c = self.supertask
516515
while c is not None:
517516
trap_if(c.inst is inst)
518517
c = c.supertask
519518

520-
def may_enter(self, pending_task):
521-
return not self.inst.backpressure and \
522-
not self.inst.calling_sync_import and \
523-
not (self.inst.calling_sync_export and pending_task.opts.sync)
524-
525-
def maybe_start_pending_task(self):
526-
if self.inst.starting_pending_task:
527-
return
528-
for i,(pending_task,pending_future) in enumerate(self.inst.pending_tasks):
529-
if self.may_enter(pending_task):
530-
self.inst.pending_tasks.pop(i)
531-
self.inst.starting_pending_task = True
532-
pending_future.set_result(None)
533-
return
534-
535-
async def wait_on(self, awaitable, sync, cancellable = False) -> bool:
536-
if sync:
537-
assert(not self.inst.calling_sync_import)
538-
self.inst.calling_sync_import = True
519+
async def acquire_instance_lock(self) -> Cancelled:
520+
if not self.inst.lock.locked():
521+
await self.inst.lock.acquire()
539522
else:
540-
self.maybe_start_pending_task()
523+
acquired = asyncio.create_task(self.inst.lock.acquire())
524+
if await self.on_block(acquired) == Cancelled.TRUE:
525+
assert(self.on_block(acquired) == Cancelled.FALSE)
526+
return Cancelled.TRUE
527+
return Cancelled.FALSE
528+
529+
async def wait_on(self, awaitable, cancellable = False, for_callback = False) -> Cancelled:
530+
f = asyncio.ensure_future(awaitable)
531+
if f.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
532+
return Cancelled.FALSE
541533

542-
awaitable = asyncio.ensure_future(awaitable)
543-
if awaitable.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
544-
cancelled = Cancelled.FALSE
545-
else:
546-
cancelled = await self.on_block(awaitable)
547-
if cancelled and not cancellable:
548-
assert(self.state == Task.State.INITIAL)
549-
self.state = Task.State.PENDING_CANCEL
550-
cancelled = await self.on_block(awaitable)
551-
assert(not cancelled)
534+
if for_callback:
535+
self.inst.lock.release()
552536

553-
if sync:
554-
self.inst.calling_sync_import = False
555-
self.inst.async_waiting_tasks.notify_all()
556-
else:
557-
while self.inst.calling_sync_import:
558-
await self.inst.async_waiting_tasks.wait()
537+
cancelled = await self.on_block(f)
538+
if cancelled and not cancellable:
539+
assert(await self.on_block(f) == Cancelled.FALSE)
559540

560-
return cancelled
541+
if for_callback:
542+
cancelled |= await self.acquire_instance_lock()
543+
544+
if cancelled:
545+
assert(self.state == Task.State.INITIAL)
546+
if not cancellable:
547+
self.state = Task.State.PENDING_CANCEL
548+
return Cancelled.FALSE
549+
else:
550+
self.state = Task.State.CANCEL_DELIVERED
551+
return Cancelled.TRUE
552+
else:
553+
return Cancelled.FALSE
561554

562555
async def call_sync(self, callee, on_start, on_return):
563556
async def sync_on_block(awaitable):
@@ -567,42 +560,36 @@ async def sync_on_block(awaitable):
567560
assert(await self.on_block(awaitable) == Cancelled.FALSE)
568561
return Cancelled.FALSE
569562

570-
assert(not self.inst.calling_sync_import)
571-
self.inst.calling_sync_import = True
572563
await callee(self, on_start, on_return, sync_on_block)
573-
self.inst.calling_sync_import = False
574-
self.inst.async_waiting_tasks.notify_all()
575564

576-
async def wait_for_event(self, waitable_set, sync) -> EventTuple:
577-
if self.state == Task.State.PENDING_CANCEL:
565+
async def wait_for_event(self, waitable_set, cancellable, for_callback) -> EventTuple:
566+
if self.state == Task.State.PENDING_CANCEL and cancellable:
578567
self.state = Task.State.CANCEL_DELIVERED
579568
return (EventCode.TASK_CANCELLED, 0, 0)
580569
else:
581570
waitable_set.num_waiting += 1
582571
e = None
583572
while not e:
584573
maybe_event = waitable_set.maybe_has_pending_event.wait()
585-
if await self.wait_on(maybe_event, sync, cancellable = True):
586-
assert(self.state == Task.State.INITIAL)
587-
self.state = Task.State.CANCEL_DELIVERED
574+
if await self.wait_on(maybe_event, cancellable, for_callback) == Cancelled.TRUE:
588575
return (EventCode.TASK_CANCELLED, 0, 0)
589576
e = waitable_set.poll()
590577
waitable_set.num_waiting -= 1
591578
return e
592579

593-
async def yield_(self, sync) -> EventTuple:
594-
if self.state == Task.State.PENDING_CANCEL:
580+
async def yield_(self, cancellable, for_callback) -> EventTuple:
581+
if self.state == Task.State.PENDING_CANCEL and for_callback:
595582
self.state = Task.State.CANCEL_DELIVERED
596583
return (EventCode.TASK_CANCELLED, 0, 0)
597-
elif await self.wait_on(asyncio.sleep(0), sync, cancellable = True):
598-
assert(self.state == Task.State.INITIAL)
599-
self.state = Task.State.CANCEL_DELIVERED
584+
elif await self.wait_on(asyncio.sleep(0), cancellable, for_callback) == Cancelled.TRUE:
600585
return (EventCode.TASK_CANCELLED, 0, 0)
601586
else:
602587
return (EventCode.NONE, 0, 0)
603588

604-
async def poll_for_event(self, waitable_set, sync) -> Optional[EventTuple]:
605-
event_code,_,_ = e = await self.yield_(sync)
589+
async def poll_for_event(self, waitable_set, cancellable, for_callback) -> Optional[EventTuple]:
590+
waitable_set.num_waiting += 1
591+
event_code,_,_ = e = await self.yield_(cancellable, for_callback)
592+
waitable_set.num_waiting -= 1
606593
if event_code == EventCode.TASK_CANCELLED:
607594
return e
608595
elif (e := waitable_set.poll()):
@@ -624,13 +611,10 @@ def cancel(self):
624611
self.state = Task.State.RESOLVED
625612

626613
def exit(self):
627-
assert(scheduler.locked())
628614
trap_if(self.state != Task.State.RESOLVED)
629615
assert(self.num_borrows == 0)
630-
if self.opts.sync:
631-
assert(self.inst.calling_sync_export)
632-
self.inst.calling_sync_export = False
633-
self.maybe_start_pending_task()
616+
if self.opts.sync or self.opts.callback:
617+
self.inst.lock.release()
634618

635619
#### Subtask State
636620

@@ -1932,7 +1916,7 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
19321916

19331917
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_block):
19341918
task = Task(opts, inst, ft, caller, on_resolve, on_block)
1935-
if not await task.enter():
1919+
if await task.enter() == Cancelled.TRUE:
19361920
return
19371921

19381922
cx = LiftLowerContext(opts, inst, task)
@@ -1967,15 +1951,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
19671951
task.exit()
19681952
return
19691953
case CallbackCode.YIELD:
1970-
e = await task.yield_(sync = False)
1954+
e = await task.yield_(cancellable = True, for_callback = True)
19711955
case CallbackCode.WAIT:
19721956
s = task.inst.table.get(si)
19731957
trap_if(not isinstance(s, WaitableSet))
1974-
e = await task.wait_for_event(s, sync = False)
1958+
e = await task.wait_for_event(s, cancellable = True, for_callback = True)
19751959
case CallbackCode.POLL:
19761960
s = task.inst.table.get(si)
19771961
trap_if(not isinstance(s, WaitableSet))
1978-
e = await task.poll_for_event(s, sync = False)
1962+
e = await task.poll_for_event(s, cancellable = True, for_callback = True)
19791963
event_code, p1, p2 = e
19801964
[packed] = await call_and_trap_on_throw(opts.callback, task, [event_code, p1, p2])
19811965

@@ -2114,8 +2098,11 @@ async def canon_context_set(t, i, task, v):
21142098
### 🔀 `canon backpressure.set`
21152099

21162100
async def canon_backpressure_set(task, flat_args):
2117-
trap_if(task.opts.sync)
2118-
task.inst.backpressure = bool(flat_args[0])
2101+
assert(len(flat_args) == 1)
2102+
if flat_args[0] == 0:
2103+
task.inst.no_backpressure.set()
2104+
else:
2105+
task.inst.no_backpressure.clear()
21192106
return []
21202107

21212108
### 🔀 `canon task.return`
@@ -2140,9 +2127,9 @@ async def canon_task_cancel(task):
21402127

21412128
### 🔀 `canon yield`
21422129

2143-
async def canon_yield(sync, task):
2130+
async def canon_yield(opts, task):
21442131
trap_if(not task.inst.may_leave)
2145-
event_code,_,_ = await task.yield_(sync)
2132+
event_code,_,_ = await task.yield_(opts.cancellable, for_callback = False)
21462133
match event_code:
21472134
case EventCode.NONE:
21482135
return [0]
@@ -2157,12 +2144,12 @@ async def canon_waitable_set_new(task):
21572144

21582145
### 🔀 `canon waitable-set.wait`
21592146

2160-
async def canon_waitable_set_wait(sync, mem, task, si, ptr):
2147+
async def canon_waitable_set_wait(opts, task, si, ptr):
21612148
trap_if(not task.inst.may_leave)
21622149
s = task.inst.table.get(si)
21632150
trap_if(not isinstance(s, WaitableSet))
2164-
e = await task.wait_for_event(s, sync)
2165-
return unpack_event(mem, task, ptr, e)
2151+
e = await task.wait_for_event(s, opts.cancellable, for_callback = False)
2152+
return unpack_event(opts.memory, task, ptr, e)
21662153

21672154
def unpack_event(mem, task, ptr, e: EventTuple):
21682155
event, p1, p2 = e
@@ -2173,12 +2160,12 @@ def unpack_event(mem, task, ptr, e: EventTuple):
21732160

21742161
### 🔀 `canon waitable-set.poll`
21752162

2176-
async def canon_waitable_set_poll(sync, mem, task, si, ptr):
2163+
async def canon_waitable_set_poll(opts, task, si, ptr):
21772164
trap_if(not task.inst.may_leave)
21782165
s = task.inst.table.get(si)
21792166
trap_if(not isinstance(s, WaitableSet))
2180-
e = await task.poll_for_event(s, sync)
2181-
return unpack_event(mem, task, ptr, e)
2167+
e = await task.poll_for_event(s, opts.cancellable, for_callback = False)
2168+
return unpack_event(opts.memory, task, ptr, e)
21822169

21832170
### 🔀 `canon waitable-set.drop`
21842171

@@ -2220,7 +2207,7 @@ async def canon_subtask_cancel(sync, task, i):
22202207
while not subtask.resolved():
22212208
if subtask.has_pending_event():
22222209
_ = subtask.get_event()
2223-
await task.wait_on(subtask.wait_for_pending_event(), sync = True)
2210+
await task.wait_on(subtask.wait_for_pending_event())
22242211
else:
22252212
if not subtask.resolved():
22262213
return [BLOCKED]
@@ -2296,7 +2283,7 @@ def on_copy_done(result):
22962283
e.copy(task.inst, buffer, on_copy, on_copy_done)
22972284

22982285
if opts.sync and not e.has_pending_event():
2299-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2286+
await task.wait_on(e.wait_for_pending_event())
23002287

23012288
if e.has_pending_event():
23022289
code,index,payload = e.get_event()
@@ -2342,7 +2329,7 @@ def on_copy_done(result):
23422329
e.copy(task.inst, buffer, on_copy_done)
23432330

23442331
if opts.sync and not e.has_pending_event():
2345-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2332+
await task.wait_on(e.wait_for_pending_event())
23462333

23472334
if e.has_pending_event():
23482335
code,index,payload = e.get_event()
@@ -2375,7 +2362,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
23752362
e.shared.cancel()
23762363
if not e.has_pending_event():
23772364
if sync:
2378-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2365+
await task.wait_on(e.wait_for_pending_event())
23792366
else:
23802367
return [BLOCKED]
23812368
code,index,payload = e.get_event()

0 commit comments

Comments
 (0)