Skip to content

Commit 3ee1238

Browse files
committed
Replace 'sync' with 'cancellable' in wait/poll/yield
1 parent 00f31f2 commit 3ee1238

File tree

2 files changed

+291
-145
lines changed

2 files changed

+291
-145
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 90 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ class CanonicalOptions(LiftLowerOptions):
211211
post_return: Optional[Callable] = None
212212
sync: bool = True # = !canonopt.async
213213
callback: Optional[Callable] = None
214+
cancellable: bool = False
214215

215216
### Runtime State
216217

@@ -221,22 +222,17 @@ class CanonicalOptions(LiftLowerOptions):
221222
class ComponentInstance:
222223
table: Table
223224
may_leave: bool
224-
backpressure: bool
225-
calling_sync_export: bool
226-
calling_sync_import: bool
227-
pending_tasks: list[tuple[Task, asyncio.Future]]
228-
starting_pending_task: bool
229-
async_waiting_tasks: asyncio.Condition
225+
no_backpressure: asyncio.Event
226+
num_backpressure_waiters: int
227+
lock: asyncio.Lock
230228

231229
def __init__(self):
232230
self.table = Table()
233231
self.may_leave = True
234-
self.backpressure = False
235-
self.calling_sync_export = False
236-
self.calling_sync_import = False
237-
self.pending_tasks = []
238-
self.starting_pending_task = False
239-
self.async_waiting_tasks = asyncio.Condition(scheduler)
232+
self.no_backpressure = asyncio.Event()
233+
self.no_backpressure.set()
234+
self.num_backpressure_waiters = 0
235+
self.lock = asyncio.Lock()
240236

241237
#### Table State
242238

@@ -497,67 +493,64 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
497493
async def enter(self):
498494
assert(scheduler.locked())
499495
self.trap_if_on_the_stack(self.inst)
500-
if not self.may_enter(self) or self.inst.pending_tasks:
501-
f = asyncio.Future()
502-
self.inst.pending_tasks.append((self, f))
503-
if await self.on_block(f) == Cancelled.TRUE:
504-
[i] = [i for i,(t,_) in enumerate(self.inst.pending_tasks) if t == self]
505-
self.inst.pending_tasks.pop(i)
506-
self.on_resolve(None)
507-
return Cancelled.FALSE
508-
assert(self.may_enter(self) and self.inst.starting_pending_task)
509-
self.inst.starting_pending_task = False
510-
if self.opts.sync:
511-
self.inst.calling_sync_export = True
512-
return True
496+
if self.opts.sync or self.opts.callback:
497+
if self.inst.lock.locked():
498+
acquired = asyncio.create_task(self.inst.lock.acquire())
499+
cancelled = await self.wait_on(acquired, cancellable = True, for_callback = False)
500+
if cancelled:
501+
if acquired.done():
502+
self.inst.lock.release()
503+
else:
504+
acquired.cancel()
505+
return Cancelled.TRUE
506+
else:
507+
await self.inst.lock.acquire()
508+
if not self.inst.no_backpressure.is_set() or self.inst.num_backpressure_waiters > 0:
509+
while True:
510+
self.inst.num_backpressure_waiters += 1
511+
maybe_go = self.inst.no_backpressure.wait()
512+
cancelled = await self.wait_on(maybe_go, cancellable = True, for_callback = False)
513+
self.inst.num_backpressure_waiters -= 1
514+
if cancelled:
515+
return Cancelled.TRUE
516+
if self.inst.no_backpressure.is_set():
517+
break
518+
return Cancelled.FALSE
513519

514520
def trap_if_on_the_stack(self, inst):
515521
c = self.supertask
516522
while c is not None:
517523
trap_if(c.inst is inst)
518524
c = c.supertask
519525

520-
def may_enter(self, pending_task):
521-
return not self.inst.backpressure and \
522-
not self.inst.calling_sync_import and \
523-
not (self.inst.calling_sync_export and pending_task.opts.sync)
524-
525-
def maybe_start_pending_task(self):
526-
if self.inst.starting_pending_task:
527-
return
528-
for i,(pending_task,pending_future) in enumerate(self.inst.pending_tasks):
529-
if self.may_enter(pending_task):
530-
self.inst.pending_tasks.pop(i)
531-
self.inst.starting_pending_task = True
532-
pending_future.set_result(None)
533-
return
526+
async def wait_on(self, awaitable, cancellable = False, for_callback = False) -> Cancelled:
527+
f = asyncio.ensure_future(awaitable)
528+
if f.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
529+
return Cancelled.FALSE
534530

535-
async def wait_on(self, awaitable, sync, cancellable = False) -> bool:
536-
if sync:
537-
assert(not self.inst.calling_sync_import)
538-
self.inst.calling_sync_import = True
539-
else:
540-
self.maybe_start_pending_task()
531+
if for_callback:
532+
self.inst.lock.release()
541533

542-
awaitable = asyncio.ensure_future(awaitable)
543-
if awaitable.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
544-
cancelled = Cancelled.FALSE
545-
else:
546-
cancelled = await self.on_block(awaitable)
547-
if cancelled and not cancellable:
548-
assert(self.state == Task.State.INITIAL)
549-
self.state = Task.State.PENDING_CANCEL
550-
cancelled = await self.on_block(awaitable)
551-
assert(not cancelled)
534+
cancelled = await self.on_block(f)
535+
if cancelled and not cancellable:
536+
assert(await self.on_block(f) == Cancelled.FALSE)
552537

553-
if sync:
554-
self.inst.calling_sync_import = False
555-
self.inst.async_waiting_tasks.notify_all()
556-
else:
557-
while self.inst.calling_sync_import:
558-
await self.inst.async_waiting_tasks.wait()
538+
if for_callback:
539+
acquired = asyncio.create_task(self.inst.lock.acquire())
540+
cancelled |= await self.on_block(acquired)
541+
if cancelled:
542+
assert(self.on_block(acquired) == Cancelled.FALSE)
559543

560-
return cancelled
544+
if cancelled:
545+
assert(self.state == Task.State.INITIAL)
546+
if not cancellable:
547+
self.state = Task.State.PENDING_CANCEL
548+
return Cancelled.FALSE
549+
else:
550+
self.state = Task.State.CANCEL_DELIVERED
551+
return Cancelled.TRUE
552+
else:
553+
return Cancelled.FALSE
561554

562555
async def call_sync(self, callee, on_start, on_return):
563556
async def sync_on_block(awaitable):
@@ -567,42 +560,36 @@ async def sync_on_block(awaitable):
567560
assert(await self.on_block(awaitable) == Cancelled.FALSE)
568561
return Cancelled.FALSE
569562

570-
assert(not self.inst.calling_sync_import)
571-
self.inst.calling_sync_import = True
572563
await callee(self, on_start, on_return, sync_on_block)
573-
self.inst.calling_sync_import = False
574-
self.inst.async_waiting_tasks.notify_all()
575564

576-
async def wait_for_event(self, waitable_set, sync) -> EventTuple:
577-
if self.state == Task.State.PENDING_CANCEL:
565+
async def wait_for_event(self, waitable_set, cancellable, for_callback) -> EventTuple:
566+
if self.state == Task.State.PENDING_CANCEL and cancellable:
578567
self.state = Task.State.CANCEL_DELIVERED
579568
return (EventCode.TASK_CANCELLED, 0, 0)
580569
else:
581570
waitable_set.num_waiting += 1
582571
e = None
583572
while not e:
584573
maybe_event = waitable_set.maybe_has_pending_event.wait()
585-
if await self.wait_on(maybe_event, sync, cancellable = True):
586-
assert(self.state == Task.State.INITIAL)
587-
self.state = Task.State.CANCEL_DELIVERED
574+
if await self.wait_on(maybe_event, cancellable, for_callback) == Cancelled.TRUE:
588575
return (EventCode.TASK_CANCELLED, 0, 0)
589576
e = waitable_set.poll()
590577
waitable_set.num_waiting -= 1
591578
return e
592579

593-
async def yield_(self, sync) -> EventTuple:
594-
if self.state == Task.State.PENDING_CANCEL:
580+
async def yield_(self, cancellable, for_callback) -> EventTuple:
581+
if self.state == Task.State.PENDING_CANCEL and for_callback:
595582
self.state = Task.State.CANCEL_DELIVERED
596583
return (EventCode.TASK_CANCELLED, 0, 0)
597-
elif await self.wait_on(asyncio.sleep(0), sync, cancellable = True):
598-
assert(self.state == Task.State.INITIAL)
599-
self.state = Task.State.CANCEL_DELIVERED
584+
elif await self.wait_on(asyncio.sleep(0), cancellable, for_callback) == Cancelled.TRUE:
600585
return (EventCode.TASK_CANCELLED, 0, 0)
601586
else:
602587
return (EventCode.NONE, 0, 0)
603588

604-
async def poll_for_event(self, waitable_set, sync) -> Optional[EventTuple]:
605-
event_code,_,_ = e = await self.yield_(sync)
589+
async def poll_for_event(self, waitable_set, cancellable, for_callback) -> Optional[EventTuple]:
590+
waitable_set.num_waiting += 1
591+
event_code,_,_ = e = await self.yield_(cancellable, for_callback)
592+
waitable_set.num_waiting -= 1
606593
if event_code == EventCode.TASK_CANCELLED:
607594
return e
608595
elif (e := waitable_set.poll()):
@@ -624,13 +611,10 @@ def cancel(self):
624611
self.state = Task.State.RESOLVED
625612

626613
def exit(self):
627-
assert(scheduler.locked())
628614
trap_if(self.state != Task.State.RESOLVED)
629615
assert(self.num_borrows == 0)
630-
if self.opts.sync:
631-
assert(self.inst.calling_sync_export)
632-
self.inst.calling_sync_export = False
633-
self.maybe_start_pending_task()
616+
if self.opts.sync or self.opts.callback:
617+
self.inst.lock.release()
634618

635619
#### Subtask State
636620

@@ -1932,7 +1916,9 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
19321916

19331917
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_block):
19341918
task = Task(opts, inst, ft, caller, on_resolve, on_block)
1935-
if not await task.enter():
1919+
if await task.enter() == Cancelled.TRUE:
1920+
task.cancel()
1921+
task.exit()
19361922
return
19371923

19381924
cx = LiftLowerContext(opts, inst, task)
@@ -1967,15 +1953,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
19671953
task.exit()
19681954
return
19691955
case CallbackCode.YIELD:
1970-
e = await task.yield_(sync = False)
1956+
e = await task.yield_(cancellable = True, for_callback = True)
19711957
case CallbackCode.WAIT:
19721958
s = task.inst.table.get(si)
19731959
trap_if(not isinstance(s, WaitableSet))
1974-
e = await task.wait_for_event(s, sync = False)
1960+
e = await task.wait_for_event(s, cancellable = True, for_callback = True)
19751961
case CallbackCode.POLL:
19761962
s = task.inst.table.get(si)
19771963
trap_if(not isinstance(s, WaitableSet))
1978-
e = await task.poll_for_event(s, sync = False)
1964+
e = await task.poll_for_event(s, cancellable = True, for_callback = True)
19791965
event_code, p1, p2 = e
19801966
[packed] = await call_and_trap_on_throw(opts.callback, task, [event_code, p1, p2])
19811967

@@ -2114,8 +2100,11 @@ async def canon_context_set(t, i, task, v):
21142100
### 🔀 `canon backpressure.set`
21152101

21162102
async def canon_backpressure_set(task, flat_args):
2117-
trap_if(task.opts.sync)
2118-
task.inst.backpressure = bool(flat_args[0])
2103+
assert(len(flat_args) == 1)
2104+
if flat_args[0] == 0:
2105+
task.inst.no_backpressure.set()
2106+
else:
2107+
task.inst.no_backpressure.clear()
21192108
return []
21202109

21212110
### 🔀 `canon task.return`
@@ -2140,9 +2129,9 @@ async def canon_task_cancel(task):
21402129

21412130
### 🔀 `canon yield`
21422131

2143-
async def canon_yield(sync, task):
2132+
async def canon_yield(opts, task):
21442133
trap_if(not task.inst.may_leave)
2145-
event_code,_,_ = await task.yield_(sync)
2134+
event_code,_,_ = await task.yield_(opts.cancellable, for_callback = False)
21462135
match event_code:
21472136
case EventCode.NONE:
21482137
return [0]
@@ -2157,12 +2146,12 @@ async def canon_waitable_set_new(task):
21572146

21582147
### 🔀 `canon waitable-set.wait`
21592148

2160-
async def canon_waitable_set_wait(sync, mem, task, si, ptr):
2149+
async def canon_waitable_set_wait(opts, task, si, ptr):
21612150
trap_if(not task.inst.may_leave)
21622151
s = task.inst.table.get(si)
21632152
trap_if(not isinstance(s, WaitableSet))
2164-
e = await task.wait_for_event(s, sync)
2165-
return unpack_event(mem, task, ptr, e)
2153+
e = await task.wait_for_event(s, opts.cancellable, for_callback = False)
2154+
return unpack_event(opts.memory, task, ptr, e)
21662155

21672156
def unpack_event(mem, task, ptr, e: EventTuple):
21682157
event, p1, p2 = e
@@ -2173,12 +2162,12 @@ def unpack_event(mem, task, ptr, e: EventTuple):
21732162

21742163
### 🔀 `canon waitable-set.poll`
21752164

2176-
async def canon_waitable_set_poll(sync, mem, task, si, ptr):
2165+
async def canon_waitable_set_poll(opts, task, si, ptr):
21772166
trap_if(not task.inst.may_leave)
21782167
s = task.inst.table.get(si)
21792168
trap_if(not isinstance(s, WaitableSet))
2180-
e = await task.poll_for_event(s, sync)
2181-
return unpack_event(mem, task, ptr, e)
2169+
e = await task.poll_for_event(s, opts.cancellable, for_callback = False)
2170+
return unpack_event(opts.memory, task, ptr, e)
21822171

21832172
### 🔀 `canon waitable-set.drop`
21842173

@@ -2220,7 +2209,7 @@ async def canon_subtask_cancel(sync, task, i):
22202209
while not subtask.resolved():
22212210
if subtask.has_pending_event():
22222211
_ = subtask.get_event()
2223-
await task.wait_on(subtask.wait_for_pending_event(), sync = True)
2212+
await task.wait_on(subtask.wait_for_pending_event())
22242213
else:
22252214
if not subtask.resolved():
22262215
return [BLOCKED]
@@ -2296,7 +2285,7 @@ def on_copy_done(result):
22962285
e.copy(task.inst, buffer, on_copy, on_copy_done)
22972286

22982287
if opts.sync and not e.has_pending_event():
2299-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2288+
await task.wait_on(e.wait_for_pending_event())
23002289

23012290
if e.has_pending_event():
23022291
code,index,payload = e.get_event()
@@ -2342,7 +2331,7 @@ def on_copy_done(result):
23422331
e.copy(task.inst, buffer, on_copy_done)
23432332

23442333
if opts.sync and not e.has_pending_event():
2345-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2334+
await task.wait_on(e.wait_for_pending_event())
23462335

23472336
if e.has_pending_event():
23482337
code,index,payload = e.get_event()
@@ -2375,7 +2364,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
23752364
e.shared.cancel()
23762365
if not e.has_pending_event():
23772366
if sync:
2378-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2367+
await task.wait_on(e.wait_for_pending_event())
23792368
else:
23802369
return [BLOCKED]
23812370
code,index,payload = e.get_event()

0 commit comments

Comments
 (0)