@@ -211,6 +211,7 @@ class CanonicalOptions(LiftLowerOptions):
211
211
post_return : Optional [Callable ] = None
212
212
sync : bool = True # = !canonopt.async
213
213
callback : Optional [Callable ] = None
214
+ cancellable : bool = False
214
215
215
216
### Runtime State
216
217
@@ -221,22 +222,17 @@ class CanonicalOptions(LiftLowerOptions):
221
222
class ComponentInstance :
222
223
table : Table
223
224
may_leave : bool
224
- backpressure : bool
225
- calling_sync_export : bool
226
- calling_sync_import : bool
227
- pending_tasks : list [tuple [Task , asyncio .Future ]]
228
- starting_pending_task : bool
229
- async_waiting_tasks : asyncio .Condition
225
+ no_backpressure : asyncio .Event
226
+ num_backpressure_waiters : int
227
+ lock : asyncio .Lock
230
228
231
229
def __init__ (self ):
232
230
self .table = Table ()
233
231
self .may_leave = True
234
- self .backpressure = False
235
- self .calling_sync_export = False
236
- self .calling_sync_import = False
237
- self .pending_tasks = []
238
- self .starting_pending_task = False
239
- self .async_waiting_tasks = asyncio .Condition (scheduler )
232
+ self .no_backpressure = asyncio .Event ()
233
+ self .no_backpressure .set ()
234
+ self .num_backpressure_waiters = 0
235
+ self .lock = asyncio .Lock ()
240
236
241
237
#### Table State
242
238
@@ -497,67 +493,64 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
497
493
async def enter (self ):
498
494
assert (scheduler .locked ())
499
495
self .trap_if_on_the_stack (self .inst )
500
- if not self .may_enter (self ) or self .inst .pending_tasks :
501
- f = asyncio .Future ()
502
- self .inst .pending_tasks .append ((self , f ))
503
- if await self .on_block (f ) == Cancelled .TRUE :
504
- [i ] = [i for i ,(t ,_ ) in enumerate (self .inst .pending_tasks ) if t == self ]
505
- self .inst .pending_tasks .pop (i )
496
+ if self .opts .sync or self .opts .callback :
497
+ if await self .acquire_instance_lock () == Cancelled .TRUE :
506
498
self .on_resolve (None )
507
- return Cancelled .FALSE
508
- assert (self .may_enter (self ) and self .inst .starting_pending_task )
509
- self .inst .starting_pending_task = False
510
- if self .opts .sync :
511
- self .inst .calling_sync_export = True
512
- return True
499
+ return Cancelled .TRUE
500
+ if not self .inst .no_backpressure .is_set () or self .inst .num_backpressure_waiters > 0 :
501
+ while True :
502
+ self .inst .num_backpressure_waiters += 1
503
+ cancelled = await self .wait_on (self .inst .no_backpressure .wait (),
504
+ cancellable = True , for_callback = False )
505
+ self .inst .num_backpressure_waiters -= 1
506
+ if cancelled :
507
+ self .on_resolve (None )
508
+ return Cancelled .TRUE
509
+ if self .inst .no_backpressure .is_set ():
510
+ break
511
+ return Cancelled .FALSE
513
512
514
513
def trap_if_on_the_stack (self , inst ):
515
514
c = self .supertask
516
515
while c is not None :
517
516
trap_if (c .inst is inst )
518
517
c = c .supertask
519
518
520
- def may_enter (self , pending_task ):
521
- return not self .inst .backpressure and \
522
- not self .inst .calling_sync_import and \
523
- not (self .inst .calling_sync_export and pending_task .opts .sync )
524
-
525
- def maybe_start_pending_task (self ):
526
- if self .inst .starting_pending_task :
527
- return
528
- for i ,(pending_task ,pending_future ) in enumerate (self .inst .pending_tasks ):
529
- if self .may_enter (pending_task ):
530
- self .inst .pending_tasks .pop (i )
531
- self .inst .starting_pending_task = True
532
- pending_future .set_result (None )
533
- return
534
-
535
- async def wait_on (self , awaitable , sync , cancellable = False ) -> bool :
536
- if sync :
537
- assert (not self .inst .calling_sync_import )
538
- self .inst .calling_sync_import = True
519
+ async def acquire_instance_lock (self ) -> Cancelled :
520
+ if not self .inst .lock .locked ():
521
+ await self .inst .lock .acquire ()
539
522
else :
540
- self .maybe_start_pending_task ()
523
+ acquired = asyncio .create_task (self .inst .lock .acquire ())
524
+ if await self .on_block (acquired ) == Cancelled .TRUE :
525
+ assert (self .on_block (acquired ) == Cancelled .FALSE )
526
+ return Cancelled .TRUE
527
+ return Cancelled .FALSE
528
+
529
+ async def wait_on (self , awaitable , cancellable = False , for_callback = False ) -> Cancelled :
530
+ f = asyncio .ensure_future (awaitable )
531
+ if f .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
532
+ return Cancelled .FALSE
541
533
542
- awaitable = asyncio .ensure_future (awaitable )
543
- if awaitable .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
544
- cancelled = Cancelled .FALSE
545
- else :
546
- cancelled = await self .on_block (awaitable )
547
- if cancelled and not cancellable :
548
- assert (self .state == Task .State .INITIAL )
549
- self .state = Task .State .PENDING_CANCEL
550
- cancelled = await self .on_block (awaitable )
551
- assert (not cancelled )
534
+ if for_callback :
535
+ self .inst .lock .release ()
552
536
553
- if sync :
554
- self .inst .calling_sync_import = False
555
- self .inst .async_waiting_tasks .notify_all ()
556
- else :
557
- while self .inst .calling_sync_import :
558
- await self .inst .async_waiting_tasks .wait ()
537
+ cancelled = await self .on_block (f )
538
+ if cancelled and not cancellable :
539
+ assert (await self .on_block (f ) == Cancelled .FALSE )
559
540
560
- return cancelled
541
+ if for_callback :
542
+ cancelled |= await self .acquire_instance_lock ()
543
+
544
+ if cancelled :
545
+ assert (self .state == Task .State .INITIAL )
546
+ if not cancellable :
547
+ self .state = Task .State .PENDING_CANCEL
548
+ return Cancelled .FALSE
549
+ else :
550
+ self .state = Task .State .CANCEL_DELIVERED
551
+ return Cancelled .TRUE
552
+ else :
553
+ return Cancelled .FALSE
561
554
562
555
async def call_sync (self , callee , on_start , on_return ):
563
556
async def sync_on_block (awaitable ):
@@ -567,42 +560,36 @@ async def sync_on_block(awaitable):
567
560
assert (await self .on_block (awaitable ) == Cancelled .FALSE )
568
561
return Cancelled .FALSE
569
562
570
- assert (not self .inst .calling_sync_import )
571
- self .inst .calling_sync_import = True
572
563
await callee (self , on_start , on_return , sync_on_block )
573
- self .inst .calling_sync_import = False
574
- self .inst .async_waiting_tasks .notify_all ()
575
564
576
- async def wait_for_event (self , waitable_set , sync ) -> EventTuple :
577
- if self .state == Task .State .PENDING_CANCEL :
565
+ async def wait_for_event (self , waitable_set , cancellable , for_callback ) -> EventTuple :
566
+ if self .state == Task .State .PENDING_CANCEL and cancellable :
578
567
self .state = Task .State .CANCEL_DELIVERED
579
568
return (EventCode .TASK_CANCELLED , 0 , 0 )
580
569
else :
581
570
waitable_set .num_waiting += 1
582
571
e = None
583
572
while not e :
584
573
maybe_event = waitable_set .maybe_has_pending_event .wait ()
585
- if await self .wait_on (maybe_event , sync , cancellable = True ):
586
- assert (self .state == Task .State .INITIAL )
587
- self .state = Task .State .CANCEL_DELIVERED
574
+ if await self .wait_on (maybe_event , cancellable , for_callback ) == Cancelled .TRUE :
588
575
return (EventCode .TASK_CANCELLED , 0 , 0 )
589
576
e = waitable_set .poll ()
590
577
waitable_set .num_waiting -= 1
591
578
return e
592
579
593
- async def yield_ (self , sync ) -> EventTuple :
594
- if self .state == Task .State .PENDING_CANCEL :
580
+ async def yield_ (self , cancellable , for_callback ) -> EventTuple :
581
+ if self .state == Task .State .PENDING_CANCEL and for_callback :
595
582
self .state = Task .State .CANCEL_DELIVERED
596
583
return (EventCode .TASK_CANCELLED , 0 , 0 )
597
- elif await self .wait_on (asyncio .sleep (0 ), sync , cancellable = True ):
598
- assert (self .state == Task .State .INITIAL )
599
- self .state = Task .State .CANCEL_DELIVERED
584
+ elif await self .wait_on (asyncio .sleep (0 ), cancellable , for_callback ) == Cancelled .TRUE :
600
585
return (EventCode .TASK_CANCELLED , 0 , 0 )
601
586
else :
602
587
return (EventCode .NONE , 0 , 0 )
603
588
604
- async def poll_for_event (self , waitable_set , sync ) -> Optional [EventTuple ]:
605
- event_code ,_ ,_ = e = await self .yield_ (sync )
589
+ async def poll_for_event (self , waitable_set , cancellable , for_callback ) -> Optional [EventTuple ]:
590
+ waitable_set .num_waiting += 1
591
+ event_code ,_ ,_ = e = await self .yield_ (cancellable , for_callback )
592
+ waitable_set .num_waiting -= 1
606
593
if event_code == EventCode .TASK_CANCELLED :
607
594
return e
608
595
elif (e := waitable_set .poll ()):
@@ -624,13 +611,10 @@ def cancel(self):
624
611
self .state = Task .State .RESOLVED
625
612
626
613
def exit (self ):
627
- assert (scheduler .locked ())
628
614
trap_if (self .state != Task .State .RESOLVED )
629
615
assert (self .num_borrows == 0 )
630
- if self .opts .sync :
631
- assert (self .inst .calling_sync_export )
632
- self .inst .calling_sync_export = False
633
- self .maybe_start_pending_task ()
616
+ if self .opts .sync or self .opts .callback :
617
+ self .inst .lock .release ()
634
618
635
619
#### Subtask State
636
620
@@ -1932,7 +1916,7 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
1932
1916
1933
1917
async def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve , on_block ):
1934
1918
task = Task (opts , inst , ft , caller , on_resolve , on_block )
1935
- if not await task .enter ():
1919
+ if await task .enter () == Cancelled . TRUE :
1936
1920
return
1937
1921
1938
1922
cx = LiftLowerContext (opts , inst , task )
@@ -1967,15 +1951,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
1967
1951
task .exit ()
1968
1952
return
1969
1953
case CallbackCode .YIELD :
1970
- e = await task .yield_ (sync = False )
1954
+ e = await task .yield_ (cancellable = True , for_callback = True )
1971
1955
case CallbackCode .WAIT :
1972
1956
s = task .inst .table .get (si )
1973
1957
trap_if (not isinstance (s , WaitableSet ))
1974
- e = await task .wait_for_event (s , sync = False )
1958
+ e = await task .wait_for_event (s , cancellable = True , for_callback = True )
1975
1959
case CallbackCode .POLL :
1976
1960
s = task .inst .table .get (si )
1977
1961
trap_if (not isinstance (s , WaitableSet ))
1978
- e = await task .poll_for_event (s , sync = False )
1962
+ e = await task .poll_for_event (s , cancellable = True , for_callback = True )
1979
1963
event_code , p1 , p2 = e
1980
1964
[packed ] = await call_and_trap_on_throw (opts .callback , task , [event_code , p1 , p2 ])
1981
1965
@@ -2114,8 +2098,11 @@ async def canon_context_set(t, i, task, v):
2114
2098
### 🔀 `canon backpressure.set`
2115
2099
2116
2100
async def canon_backpressure_set (task , flat_args ):
2117
- trap_if (task .opts .sync )
2118
- task .inst .backpressure = bool (flat_args [0 ])
2101
+ assert (len (flat_args ) == 1 )
2102
+ if flat_args [0 ] == 0 :
2103
+ task .inst .no_backpressure .set ()
2104
+ else :
2105
+ task .inst .no_backpressure .clear ()
2119
2106
return []
2120
2107
2121
2108
### 🔀 `canon task.return`
@@ -2140,9 +2127,9 @@ async def canon_task_cancel(task):
2140
2127
2141
2128
### 🔀 `canon yield`
2142
2129
2143
- async def canon_yield (sync , task ):
2130
+ async def canon_yield (opts , task ):
2144
2131
trap_if (not task .inst .may_leave )
2145
- event_code ,_ ,_ = await task .yield_ (sync )
2132
+ event_code ,_ ,_ = await task .yield_ (opts . cancellable , for_callback = False )
2146
2133
match event_code :
2147
2134
case EventCode .NONE :
2148
2135
return [0 ]
@@ -2157,12 +2144,12 @@ async def canon_waitable_set_new(task):
2157
2144
2158
2145
### 🔀 `canon waitable-set.wait`
2159
2146
2160
- async def canon_waitable_set_wait (sync , mem , task , si , ptr ):
2147
+ async def canon_waitable_set_wait (opts , task , si , ptr ):
2161
2148
trap_if (not task .inst .may_leave )
2162
2149
s = task .inst .table .get (si )
2163
2150
trap_if (not isinstance (s , WaitableSet ))
2164
- e = await task .wait_for_event (s , sync )
2165
- return unpack_event (mem , task , ptr , e )
2151
+ e = await task .wait_for_event (s , opts . cancellable , for_callback = False )
2152
+ return unpack_event (opts . memory , task , ptr , e )
2166
2153
2167
2154
def unpack_event (mem , task , ptr , e : EventTuple ):
2168
2155
event , p1 , p2 = e
@@ -2173,12 +2160,12 @@ def unpack_event(mem, task, ptr, e: EventTuple):
2173
2160
2174
2161
### 🔀 `canon waitable-set.poll`
2175
2162
2176
- async def canon_waitable_set_poll (sync , mem , task , si , ptr ):
2163
+ async def canon_waitable_set_poll (opts , task , si , ptr ):
2177
2164
trap_if (not task .inst .may_leave )
2178
2165
s = task .inst .table .get (si )
2179
2166
trap_if (not isinstance (s , WaitableSet ))
2180
- e = await task .poll_for_event (s , sync )
2181
- return unpack_event (mem , task , ptr , e )
2167
+ e = await task .poll_for_event (s , opts . cancellable , for_callback = False )
2168
+ return unpack_event (opts . memory , task , ptr , e )
2182
2169
2183
2170
### 🔀 `canon waitable-set.drop`
2184
2171
@@ -2220,7 +2207,7 @@ async def canon_subtask_cancel(sync, task, i):
2220
2207
while not subtask .resolved ():
2221
2208
if subtask .has_pending_event ():
2222
2209
_ = subtask .get_event ()
2223
- await task .wait_on (subtask .wait_for_pending_event (), sync = True )
2210
+ await task .wait_on (subtask .wait_for_pending_event ())
2224
2211
else :
2225
2212
if not subtask .resolved ():
2226
2213
return [BLOCKED ]
@@ -2296,7 +2283,7 @@ def on_copy_done(result):
2296
2283
e .copy (task .inst , buffer , on_copy , on_copy_done )
2297
2284
2298
2285
if opts .sync and not e .has_pending_event ():
2299
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2286
+ await task .wait_on (e .wait_for_pending_event ())
2300
2287
2301
2288
if e .has_pending_event ():
2302
2289
code ,index ,payload = e .get_event ()
@@ -2342,7 +2329,7 @@ def on_copy_done(result):
2342
2329
e .copy (task .inst , buffer , on_copy_done )
2343
2330
2344
2331
if opts .sync and not e .has_pending_event ():
2345
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2332
+ await task .wait_on (e .wait_for_pending_event ())
2346
2333
2347
2334
if e .has_pending_event ():
2348
2335
code ,index ,payload = e .get_event ()
@@ -2375,7 +2362,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
2375
2362
e .shared .cancel ()
2376
2363
if not e .has_pending_event ():
2377
2364
if sync :
2378
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2365
+ await task .wait_on (e .wait_for_pending_event ())
2379
2366
else :
2380
2367
return [BLOCKED ]
2381
2368
code ,index ,payload = e .get_event ()
0 commit comments