@@ -221,22 +221,17 @@ class CanonicalOptions(LiftLowerOptions):
221221class ComponentInstance :
222222 table : Table
223223 may_leave : bool
224- backpressure : bool
225- calling_sync_export : bool
226- calling_sync_import : bool
227- pending_tasks : list [tuple [Task , asyncio .Future ]]
228- starting_pending_task : bool
229- async_waiting_tasks : asyncio .Condition
224+ no_backpressure : asyncio .Event
225+ num_backpressure_waiters : int
226+ lock : asyncio .Lock
230227
231228 def __init__ (self ):
232229 self .table = Table ()
233230 self .may_leave = True
234- self .backpressure = False
235- self .calling_sync_export = False
236- self .calling_sync_import = False
237- self .pending_tasks = []
238- self .starting_pending_task = False
239- self .async_waiting_tasks = asyncio .Condition (scheduler )
231+ self .no_backpressure = asyncio .Event ()
232+ self .no_backpressure .set ()
233+ self .num_backpressure_waiters = 0
234+ self .lock = asyncio .Lock ()
240235
241236#### Table State
242237
@@ -464,7 +459,7 @@ class Cancelled(IntEnum):
464459
465460OnStart = Callable [[], list [any ]]
466461OnResolve = Callable [[Optional [list [any ]]], None ]
467- OnBlock = Callable [[Awaitable ], Awaitable [Cancelled ]]
462+ OnBlock = Callable [[asyncio . Future ], Awaitable [Cancelled ]]
468463
469464class Task :
470465 class State (Enum ):
@@ -494,70 +489,65 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
494489 self .num_borrows = 0
495490 self .context = ContextLocalStorage ()
496491
497- async def enter (self ):
498- assert (scheduler .locked ())
499- self .trap_if_on_the_stack (self .inst )
500- if not self .may_enter (self ) or self .inst .pending_tasks :
501- f = asyncio .Future ()
502- self .inst .pending_tasks .append ((self , f ))
503- if await self .on_block (f ) == Cancelled .TRUE :
504- [i ] = [i for i ,(t ,_ ) in enumerate (self .inst .pending_tasks ) if t == self ]
505- self .inst .pending_tasks .pop (i )
506- self .on_resolve (None )
507- return Cancelled .FALSE
508- assert (self .may_enter (self ) and self .inst .starting_pending_task )
509- self .inst .starting_pending_task = False
510- if self .opts .sync :
511- self .inst .calling_sync_export = True
512- return True
513-
514492 def trap_if_on_the_stack (self , inst ):
515493 c = self .supertask
516494 while c is not None :
517495 trap_if (c .inst is inst )
518496 c = c .supertask
519497
520- def may_enter (self , pending_task ):
521- return not self .inst .backpressure and \
522- not self .inst .calling_sync_import and \
523- not (self .inst .calling_sync_export and pending_task .opts .sync )
524-
525- def maybe_start_pending_task (self ):
526- if self .inst .starting_pending_task :
527- return
528- for i ,(pending_task ,pending_future ) in enumerate (self .inst .pending_tasks ):
529- if self .may_enter (pending_task ):
530- self .inst .pending_tasks .pop (i )
531- self .inst .starting_pending_task = True
532- pending_future .set_result (None )
533- return
498+ async def enter (self ):
499+ if self .opts .sync or self .opts .callback :
500+ if self .inst .lock .locked ():
501+ acquired = asyncio .create_task (self .inst .lock .acquire ())
502+ cancelled = await self .block_on (acquired , cancellable = True , for_callback = False )
503+ if cancelled :
504+ if acquired .done ():
505+ self .inst .lock .release ()
506+ else :
507+ acquired .cancel ()
508+ return Cancelled .TRUE
509+ else :
510+ await self .inst .lock .acquire ()
511+ if not self .inst .no_backpressure .is_set () or self .inst .num_backpressure_waiters > 0 :
512+ while True :
513+ self .inst .num_backpressure_waiters += 1
514+ maybe_go = self .inst .no_backpressure .wait ()
515+ cancelled = await self .block_on (maybe_go , cancellable = True , for_callback = False )
516+ self .inst .num_backpressure_waiters -= 1
517+ if cancelled :
518+ return Cancelled .TRUE
519+ if self .inst .no_backpressure .is_set ():
520+ break
521+ return Cancelled .FALSE
534522
535- async def wait_on (self , awaitable , sync , cancellable = False ) -> bool :
536- if sync :
537- assert (not self .inst .calling_sync_import )
538- self .inst .calling_sync_import = True
539- else :
540- self .maybe_start_pending_task ()
523+ async def block_on (self , awaitable , cancellable = False , for_callback = False ) -> Cancelled :
524+ f = asyncio .ensure_future (awaitable )
525+ if f .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
526+ return Cancelled .FALSE
541527
542- awaitable = asyncio .ensure_future (awaitable )
543- if awaitable .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
544- cancelled = Cancelled .FALSE
545- else :
546- cancelled = await self .on_block (awaitable )
547- if cancelled and not cancellable :
548- assert (self .state == Task .State .INITIAL )
549- self .state = Task .State .PENDING_CANCEL
550- cancelled = await self .on_block (awaitable )
551- assert (not cancelled )
528+ if for_callback :
529+ self .inst .lock .release ()
552530
553- if sync :
554- self .inst .calling_sync_import = False
555- self .inst .async_waiting_tasks .notify_all ()
556- else :
557- while self .inst .calling_sync_import :
558- await self .inst .async_waiting_tasks .wait ()
531+ cancelled = await self .on_block (f )
532+ if cancelled and not cancellable :
533+ assert (await self .on_block (f ) == Cancelled .FALSE )
559534
560- return cancelled
535+ if for_callback :
536+ acquired = asyncio .create_task (self .inst .lock .acquire ())
537+ cancelled |= await self .on_block (acquired )
538+ if cancelled :
539+ assert (self .on_block (acquired ) == Cancelled .FALSE )
540+
541+ if cancelled :
542+ assert (self .state == Task .State .INITIAL )
543+ if not cancellable :
544+ self .state = Task .State .PENDING_CANCEL
545+ return Cancelled .FALSE
546+ else :
547+ self .state = Task .State .CANCEL_DELIVERED
548+ return Cancelled .TRUE
549+ else :
550+ return Cancelled .FALSE
561551
562552 async def call_sync (self , callee , on_start , on_return ):
563553 async def sync_on_block (awaitable ):
@@ -567,42 +557,36 @@ async def sync_on_block(awaitable):
567557 assert (await self .on_block (awaitable ) == Cancelled .FALSE )
568558 return Cancelled .FALSE
569559
570- assert (not self .inst .calling_sync_import )
571- self .inst .calling_sync_import = True
572560 await callee (self , on_start , on_return , sync_on_block )
573- self .inst .calling_sync_import = False
574- self .inst .async_waiting_tasks .notify_all ()
575561
576- async def wait_for_event (self , waitable_set , sync ) -> EventTuple :
577- if self .state == Task .State .PENDING_CANCEL :
562+ async def wait_for_event (self , waitable_set , cancellable , for_callback ) -> EventTuple :
563+ if self .state == Task .State .PENDING_CANCEL and cancellable :
578564 self .state = Task .State .CANCEL_DELIVERED
579565 return (EventCode .TASK_CANCELLED , 0 , 0 )
580566 else :
581567 waitable_set .num_waiting += 1
582568 e = None
583569 while not e :
584570 maybe_event = waitable_set .maybe_has_pending_event .wait ()
585- if await self .wait_on (maybe_event , sync , cancellable = True ):
586- assert (self .state == Task .State .INITIAL )
587- self .state = Task .State .CANCEL_DELIVERED
571+ if await self .block_on (maybe_event , cancellable , for_callback ) == Cancelled .TRUE :
588572 return (EventCode .TASK_CANCELLED , 0 , 0 )
589573 e = waitable_set .poll ()
590574 waitable_set .num_waiting -= 1
591575 return e
592576
593- async def yield_ (self , sync ) -> EventTuple :
594- if self .state == Task .State .PENDING_CANCEL :
577+ async def yield_ (self , cancellable , for_callback ) -> EventTuple :
578+ if self .state == Task .State .PENDING_CANCEL and cancellable :
595579 self .state = Task .State .CANCEL_DELIVERED
596580 return (EventCode .TASK_CANCELLED , 0 , 0 )
597- elif await self .wait_on (asyncio .sleep (0 ), sync , cancellable = True ):
598- assert (self .state == Task .State .INITIAL )
599- self .state = Task .State .CANCEL_DELIVERED
581+ elif await self .block_on (asyncio .sleep (0 ), cancellable , for_callback ) == Cancelled .TRUE :
600582 return (EventCode .TASK_CANCELLED , 0 , 0 )
601583 else :
602584 return (EventCode .NONE , 0 , 0 )
603585
604- async def poll_for_event (self , waitable_set , sync ) -> Optional [EventTuple ]:
605- event_code ,_ ,_ = e = await self .yield_ (sync )
586+ async def poll_for_event (self , waitable_set , cancellable , for_callback ) -> Optional [EventTuple ]:
587+ waitable_set .num_waiting += 1
588+ event_code ,_ ,_ = e = await self .yield_ (cancellable , for_callback )
589+ waitable_set .num_waiting -= 1
606590 if event_code == EventCode .TASK_CANCELLED :
607591 return e
608592 elif (e := waitable_set .poll ()):
@@ -624,13 +608,10 @@ def cancel(self):
624608 self .state = Task .State .RESOLVED
625609
626610 def exit (self ):
627- assert (scheduler .locked ())
628611 trap_if (self .state != Task .State .RESOLVED )
629612 assert (self .num_borrows == 0 )
630- if self .opts .sync :
631- assert (self .inst .calling_sync_export )
632- self .inst .calling_sync_export = False
633- self .maybe_start_pending_task ()
613+ if self .opts .sync or self .opts .callback :
614+ self .inst .lock .release ()
634615
635616#### Subtask State
636617
@@ -1932,7 +1913,10 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
19321913
19331914async def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve , on_block ):
19341915 task = Task (opts , inst , ft , caller , on_resolve , on_block )
1935- if not await task .enter ():
1916+ task .trap_if_on_the_stack (inst )
1917+ if await task .enter () == Cancelled .TRUE :
1918+ task .cancel ()
1919+ task .exit ()
19361920 return
19371921
19381922 cx = LiftLowerContext (opts , inst , task )
@@ -1967,15 +1951,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
19671951 task .exit ()
19681952 return
19691953 case CallbackCode .YIELD :
1970- e = await task .yield_ (sync = False )
1954+ e = await task .yield_ (cancellable = True , for_callback = True )
19711955 case CallbackCode .WAIT :
19721956 s = task .inst .table .get (si )
19731957 trap_if (not isinstance (s , WaitableSet ))
1974- e = await task .wait_for_event (s , sync = False )
1958+ e = await task .wait_for_event (s , cancellable = True , for_callback = True )
19751959 case CallbackCode .POLL :
19761960 s = task .inst .table .get (si )
19771961 trap_if (not isinstance (s , WaitableSet ))
1978- e = await task .poll_for_event (s , sync = False )
1962+ e = await task .poll_for_event (s , cancellable = True , for_callback = True )
19791963 event_code , p1 , p2 = e
19801964 [packed ] = await call_and_trap_on_throw (opts .callback , task , [event_code , p1 , p2 ])
19811965
@@ -2115,7 +2099,11 @@ async def canon_context_set(t, i, task, v):
21152099
21162100async def canon_backpressure_set (task , flat_args ):
21172101 trap_if (task .opts .sync )
2118- task .inst .backpressure = bool (flat_args [0 ])
2102+ assert (len (flat_args ) == 1 )
2103+ if flat_args [0 ] == 0 :
2104+ task .inst .no_backpressure .set ()
2105+ else :
2106+ task .inst .no_backpressure .clear ()
21192107 return []
21202108
21212109### 🔀 `canon task.return`
@@ -2140,9 +2128,9 @@ async def canon_task_cancel(task):
21402128
21412129### 🔀 `canon yield`
21422130
2143- async def canon_yield (sync , task ):
2131+ async def canon_yield (cancellable , task ):
21442132 trap_if (not task .inst .may_leave )
2145- event_code ,_ ,_ = await task .yield_ (sync )
2133+ event_code ,_ ,_ = await task .yield_ (cancellable , for_callback = False )
21462134 match event_code :
21472135 case EventCode .NONE :
21482136 return [0 ]
@@ -2157,11 +2145,11 @@ async def canon_waitable_set_new(task):
21572145
21582146### 🔀 `canon waitable-set.wait`
21592147
2160- async def canon_waitable_set_wait (sync , mem , task , si , ptr ):
2148+ async def canon_waitable_set_wait (cancellable , mem , task , si , ptr ):
21612149 trap_if (not task .inst .may_leave )
21622150 s = task .inst .table .get (si )
21632151 trap_if (not isinstance (s , WaitableSet ))
2164- e = await task .wait_for_event (s , sync )
2152+ e = await task .wait_for_event (s , cancellable , for_callback = False )
21652153 return unpack_event (mem , task , ptr , e )
21662154
21672155def unpack_event (mem , task , ptr , e : EventTuple ):
@@ -2173,11 +2161,11 @@ def unpack_event(mem, task, ptr, e: EventTuple):
21732161
21742162### 🔀 `canon waitable-set.poll`
21752163
2176- async def canon_waitable_set_poll (sync , mem , task , si , ptr ):
2164+ async def canon_waitable_set_poll (cancellable , mem , task , si , ptr ):
21772165 trap_if (not task .inst .may_leave )
21782166 s = task .inst .table .get (si )
21792167 trap_if (not isinstance (s , WaitableSet ))
2180- e = await task .poll_for_event (s , sync )
2168+ e = await task .poll_for_event (s , cancellable , for_callback = False )
21812169 return unpack_event (mem , task , ptr , e )
21822170
21832171### 🔀 `canon waitable-set.drop`
@@ -2220,7 +2208,7 @@ async def canon_subtask_cancel(sync, task, i):
22202208 while not subtask .resolved ():
22212209 if subtask .has_pending_event ():
22222210 _ = subtask .get_event ()
2223- await task .wait_on (subtask .wait_for_pending_event (), sync = True )
2211+ await task .block_on (subtask .wait_for_pending_event ())
22242212 else :
22252213 if not subtask .resolved ():
22262214 return [BLOCKED ]
@@ -2296,7 +2284,7 @@ def on_copy_done(result):
22962284 e .copy (task .inst , buffer , on_copy , on_copy_done )
22972285
22982286 if opts .sync and not e .has_pending_event ():
2299- await task .wait_on (e .wait_for_pending_event (), sync = True )
2287+ await task .block_on (e .wait_for_pending_event ())
23002288
23012289 if e .has_pending_event ():
23022290 code ,index ,payload = e .get_event ()
@@ -2342,7 +2330,7 @@ def on_copy_done(result):
23422330 e .copy (task .inst , buffer , on_copy_done )
23432331
23442332 if opts .sync and not e .has_pending_event ():
2345- await task .wait_on (e .wait_for_pending_event (), sync = True )
2333+ await task .block_on (e .wait_for_pending_event ())
23462334
23472335 if e .has_pending_event ():
23482336 code ,index ,payload = e .get_event ()
@@ -2375,7 +2363,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
23752363 e .shared .cancel ()
23762364 if not e .has_pending_event ():
23772365 if sync :
2378- await task .wait_on (e .wait_for_pending_event (), sync = True )
2366+ await task .block_on (e .wait_for_pending_event ())
23792367 else :
23802368 return [BLOCKED ]
23812369 code ,index ,payload = e .get_event ()
0 commit comments