@@ -211,6 +211,7 @@ class CanonicalOptions(LiftLowerOptions):
211
211
post_return : Optional [Callable ] = None
212
212
sync : bool = True # = !canonopt.async
213
213
callback : Optional [Callable ] = None
214
+ cancellable : bool = False
214
215
215
216
### Runtime State
216
217
@@ -221,22 +222,17 @@ class CanonicalOptions(LiftLowerOptions):
221
222
class ComponentInstance :
222
223
table : Table
223
224
may_leave : bool
224
- backpressure : bool
225
- calling_sync_export : bool
226
- calling_sync_import : bool
227
- pending_tasks : list [tuple [Task , asyncio .Future ]]
228
- starting_pending_task : bool
229
- async_waiting_tasks : asyncio .Condition
225
+ no_backpressure : asyncio .Event
226
+ num_backpressure_waiters : int
227
+ lock : asyncio .Lock
230
228
231
229
def __init__ (self ):
232
230
self .table = Table ()
233
231
self .may_leave = True
234
- self .backpressure = False
235
- self .calling_sync_export = False
236
- self .calling_sync_import = False
237
- self .pending_tasks = []
238
- self .starting_pending_task = False
239
- self .async_waiting_tasks = asyncio .Condition (scheduler )
232
+ self .no_backpressure = asyncio .Event ()
233
+ self .no_backpressure .set ()
234
+ self .num_backpressure_waiters = 0
235
+ self .lock = asyncio .Lock ()
240
236
241
237
#### Table State
242
238
@@ -497,67 +493,64 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
497
493
async def enter (self ):
498
494
assert (scheduler .locked ())
499
495
self .trap_if_on_the_stack (self .inst )
500
- if not self .may_enter (self ) or self .inst .pending_tasks :
501
- f = asyncio .Future ()
502
- self .inst .pending_tasks .append ((self , f ))
503
- if await self .on_block (f ) == Cancelled .TRUE :
504
- [i ] = [i for i ,(t ,_ ) in enumerate (self .inst .pending_tasks ) if t == self ]
505
- self .inst .pending_tasks .pop (i )
506
- self .on_resolve (None )
507
- return Cancelled .FALSE
508
- assert (self .may_enter (self ) and self .inst .starting_pending_task )
509
- self .inst .starting_pending_task = False
510
- if self .opts .sync :
511
- self .inst .calling_sync_export = True
512
- return True
496
+ if self .opts .sync or self .opts .callback :
497
+ if self .inst .lock .locked ():
498
+ acquired = asyncio .create_task (self .inst .lock .acquire ())
499
+ cancelled = await self .wait_on (acquired , cancellable = True , for_callback = False )
500
+ if cancelled :
501
+ if acquired .done ():
502
+ self .inst .lock .release ()
503
+ else :
504
+ acquired .cancel ()
505
+ return Cancelled .TRUE
506
+ else :
507
+ await self .inst .lock .acquire ()
508
+ if not self .inst .no_backpressure .is_set () or self .inst .num_backpressure_waiters > 0 :
509
+ while True :
510
+ self .inst .num_backpressure_waiters += 1
511
+ maybe_go = self .inst .no_backpressure .wait ()
512
+ cancelled = await self .wait_on (maybe_go , cancellable = True , for_callback = False )
513
+ self .inst .num_backpressure_waiters -= 1
514
+ if cancelled :
515
+ return Cancelled .TRUE
516
+ if self .inst .no_backpressure .is_set ():
517
+ break
518
+ return Cancelled .FALSE
513
519
514
520
def trap_if_on_the_stack (self , inst ):
515
521
c = self .supertask
516
522
while c is not None :
517
523
trap_if (c .inst is inst )
518
524
c = c .supertask
519
525
520
- def may_enter (self , pending_task ):
521
- return not self .inst .backpressure and \
522
- not self .inst .calling_sync_import and \
523
- not (self .inst .calling_sync_export and pending_task .opts .sync )
524
-
525
- def maybe_start_pending_task (self ):
526
- if self .inst .starting_pending_task :
527
- return
528
- for i ,(pending_task ,pending_future ) in enumerate (self .inst .pending_tasks ):
529
- if self .may_enter (pending_task ):
530
- self .inst .pending_tasks .pop (i )
531
- self .inst .starting_pending_task = True
532
- pending_future .set_result (None )
533
- return
526
+ async def wait_on (self , awaitable , cancellable = False , for_callback = False ) -> Cancelled :
527
+ f = asyncio .ensure_future (awaitable )
528
+ if f .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
529
+ return Cancelled .FALSE
534
530
535
- async def wait_on (self , awaitable , sync , cancellable = False ) -> bool :
536
- if sync :
537
- assert (not self .inst .calling_sync_import )
538
- self .inst .calling_sync_import = True
539
- else :
540
- self .maybe_start_pending_task ()
531
+ if for_callback :
532
+ self .inst .lock .release ()
541
533
542
- awaitable = asyncio .ensure_future (awaitable )
543
- if awaitable .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
544
- cancelled = Cancelled .FALSE
545
- else :
546
- cancelled = await self .on_block (awaitable )
547
- if cancelled and not cancellable :
548
- assert (self .state == Task .State .INITIAL )
549
- self .state = Task .State .PENDING_CANCEL
550
- cancelled = await self .on_block (awaitable )
551
- assert (not cancelled )
534
+ cancelled = await self .on_block (f )
535
+ if cancelled and not cancellable :
536
+ assert (await self .on_block (f ) == Cancelled .FALSE )
552
537
553
- if sync :
554
- self .inst .calling_sync_import = False
555
- self .inst .async_waiting_tasks .notify_all ()
556
- else :
557
- while self .inst .calling_sync_import :
558
- await self .inst .async_waiting_tasks .wait ()
538
+ if for_callback :
539
+ acquired = asyncio .create_task (self .inst .lock .acquire ())
540
+ cancelled |= await self .on_block (acquired )
541
+ if cancelled :
542
+ assert (self .on_block (acquired ) == Cancelled .FALSE )
559
543
560
- return cancelled
544
+ if cancelled :
545
+ assert (self .state == Task .State .INITIAL )
546
+ if not cancellable :
547
+ self .state = Task .State .PENDING_CANCEL
548
+ return Cancelled .FALSE
549
+ else :
550
+ self .state = Task .State .CANCEL_DELIVERED
551
+ return Cancelled .TRUE
552
+ else :
553
+ return Cancelled .FALSE
561
554
562
555
async def call_sync (self , callee , on_start , on_return ):
563
556
async def sync_on_block (awaitable ):
@@ -567,42 +560,36 @@ async def sync_on_block(awaitable):
567
560
assert (await self .on_block (awaitable ) == Cancelled .FALSE )
568
561
return Cancelled .FALSE
569
562
570
- assert (not self .inst .calling_sync_import )
571
- self .inst .calling_sync_import = True
572
563
await callee (self , on_start , on_return , sync_on_block )
573
- self .inst .calling_sync_import = False
574
- self .inst .async_waiting_tasks .notify_all ()
575
564
576
- async def wait_for_event (self , waitable_set , sync ) -> EventTuple :
577
- if self .state == Task .State .PENDING_CANCEL :
565
+ async def wait_for_event (self , waitable_set , cancellable , for_callback ) -> EventTuple :
566
+ if self .state == Task .State .PENDING_CANCEL and cancellable :
578
567
self .state = Task .State .CANCEL_DELIVERED
579
568
return (EventCode .TASK_CANCELLED , 0 , 0 )
580
569
else :
581
570
waitable_set .num_waiting += 1
582
571
e = None
583
572
while not e :
584
573
maybe_event = waitable_set .maybe_has_pending_event .wait ()
585
- if await self .wait_on (maybe_event , sync , cancellable = True ):
586
- assert (self .state == Task .State .INITIAL )
587
- self .state = Task .State .CANCEL_DELIVERED
574
+ if await self .wait_on (maybe_event , cancellable , for_callback ) == Cancelled .TRUE :
588
575
return (EventCode .TASK_CANCELLED , 0 , 0 )
589
576
e = waitable_set .poll ()
590
577
waitable_set .num_waiting -= 1
591
578
return e
592
579
593
- async def yield_ (self , sync ) -> EventTuple :
594
- if self .state == Task .State .PENDING_CANCEL :
580
+ async def yield_ (self , cancellable , for_callback ) -> EventTuple :
581
+ if self .state == Task .State .PENDING_CANCEL and for_callback :
595
582
self .state = Task .State .CANCEL_DELIVERED
596
583
return (EventCode .TASK_CANCELLED , 0 , 0 )
597
- elif await self .wait_on (asyncio .sleep (0 ), sync , cancellable = True ):
598
- assert (self .state == Task .State .INITIAL )
599
- self .state = Task .State .CANCEL_DELIVERED
584
+ elif await self .wait_on (asyncio .sleep (0 ), cancellable , for_callback ) == Cancelled .TRUE :
600
585
return (EventCode .TASK_CANCELLED , 0 , 0 )
601
586
else :
602
587
return (EventCode .NONE , 0 , 0 )
603
588
604
- async def poll_for_event (self , waitable_set , sync ) -> Optional [EventTuple ]:
605
- event_code ,_ ,_ = e = await self .yield_ (sync )
589
+ async def poll_for_event (self , waitable_set , cancellable , for_callback ) -> Optional [EventTuple ]:
590
+ waitable_set .num_waiting += 1
591
+ event_code ,_ ,_ = e = await self .yield_ (cancellable , for_callback )
592
+ waitable_set .num_waiting -= 1
606
593
if event_code == EventCode .TASK_CANCELLED :
607
594
return e
608
595
elif (e := waitable_set .poll ()):
@@ -624,13 +611,10 @@ def cancel(self):
624
611
self .state = Task .State .RESOLVED
625
612
626
613
def exit (self ):
627
- assert (scheduler .locked ())
628
614
trap_if (self .state != Task .State .RESOLVED )
629
615
assert (self .num_borrows == 0 )
630
- if self .opts .sync :
631
- assert (self .inst .calling_sync_export )
632
- self .inst .calling_sync_export = False
633
- self .maybe_start_pending_task ()
616
+ if self .opts .sync or self .opts .callback :
617
+ self .inst .lock .release ()
634
618
635
619
#### Subtask State
636
620
@@ -1932,7 +1916,9 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
1932
1916
1933
1917
async def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve , on_block ):
1934
1918
task = Task (opts , inst , ft , caller , on_resolve , on_block )
1935
- if not await task .enter ():
1919
+ if await task .enter () == Cancelled .TRUE :
1920
+ task .cancel ()
1921
+ task .exit ()
1936
1922
return
1937
1923
1938
1924
cx = LiftLowerContext (opts , inst , task )
@@ -1967,15 +1953,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
1967
1953
task .exit ()
1968
1954
return
1969
1955
case CallbackCode .YIELD :
1970
- e = await task .yield_ (sync = False )
1956
+ e = await task .yield_ (cancellable = True , for_callback = True )
1971
1957
case CallbackCode .WAIT :
1972
1958
s = task .inst .table .get (si )
1973
1959
trap_if (not isinstance (s , WaitableSet ))
1974
- e = await task .wait_for_event (s , sync = False )
1960
+ e = await task .wait_for_event (s , cancellable = True , for_callback = True )
1975
1961
case CallbackCode .POLL :
1976
1962
s = task .inst .table .get (si )
1977
1963
trap_if (not isinstance (s , WaitableSet ))
1978
- e = await task .poll_for_event (s , sync = False )
1964
+ e = await task .poll_for_event (s , cancellable = True , for_callback = True )
1979
1965
event_code , p1 , p2 = e
1980
1966
[packed ] = await call_and_trap_on_throw (opts .callback , task , [event_code , p1 , p2 ])
1981
1967
@@ -2114,8 +2100,11 @@ async def canon_context_set(t, i, task, v):
2114
2100
### 🔀 `canon backpressure.set`
2115
2101
2116
2102
async def canon_backpressure_set (task , flat_args ):
2117
- trap_if (task .opts .sync )
2118
- task .inst .backpressure = bool (flat_args [0 ])
2103
+ assert (len (flat_args ) == 1 )
2104
+ if flat_args [0 ] == 0 :
2105
+ task .inst .no_backpressure .set ()
2106
+ else :
2107
+ task .inst .no_backpressure .clear ()
2119
2108
return []
2120
2109
2121
2110
### 🔀 `canon task.return`
@@ -2140,9 +2129,9 @@ async def canon_task_cancel(task):
2140
2129
2141
2130
### 🔀 `canon yield`
2142
2131
2143
- async def canon_yield (sync , task ):
2132
+ async def canon_yield (opts , task ):
2144
2133
trap_if (not task .inst .may_leave )
2145
- event_code ,_ ,_ = await task .yield_ (sync )
2134
+ event_code ,_ ,_ = await task .yield_ (opts . cancellable , for_callback = False )
2146
2135
match event_code :
2147
2136
case EventCode .NONE :
2148
2137
return [0 ]
@@ -2157,12 +2146,12 @@ async def canon_waitable_set_new(task):
2157
2146
2158
2147
### 🔀 `canon waitable-set.wait`
2159
2148
2160
- async def canon_waitable_set_wait (sync , mem , task , si , ptr ):
2149
+ async def canon_waitable_set_wait (opts , task , si , ptr ):
2161
2150
trap_if (not task .inst .may_leave )
2162
2151
s = task .inst .table .get (si )
2163
2152
trap_if (not isinstance (s , WaitableSet ))
2164
- e = await task .wait_for_event (s , sync )
2165
- return unpack_event (mem , task , ptr , e )
2153
+ e = await task .wait_for_event (s , opts . cancellable , for_callback = False )
2154
+ return unpack_event (opts . memory , task , ptr , e )
2166
2155
2167
2156
def unpack_event (mem , task , ptr , e : EventTuple ):
2168
2157
event , p1 , p2 = e
@@ -2173,12 +2162,12 @@ def unpack_event(mem, task, ptr, e: EventTuple):
2173
2162
2174
2163
### 🔀 `canon waitable-set.poll`
2175
2164
2176
- async def canon_waitable_set_poll (sync , mem , task , si , ptr ):
2165
+ async def canon_waitable_set_poll (opts , task , si , ptr ):
2177
2166
trap_if (not task .inst .may_leave )
2178
2167
s = task .inst .table .get (si )
2179
2168
trap_if (not isinstance (s , WaitableSet ))
2180
- e = await task .poll_for_event (s , sync )
2181
- return unpack_event (mem , task , ptr , e )
2169
+ e = await task .poll_for_event (s , opts . cancellable , for_callback = False )
2170
+ return unpack_event (opts . memory , task , ptr , e )
2182
2171
2183
2172
### 🔀 `canon waitable-set.drop`
2184
2173
@@ -2220,7 +2209,7 @@ async def canon_subtask_cancel(sync, task, i):
2220
2209
while not subtask .resolved ():
2221
2210
if subtask .has_pending_event ():
2222
2211
_ = subtask .get_event ()
2223
- await task .wait_on (subtask .wait_for_pending_event (), sync = True )
2212
+ await task .wait_on (subtask .wait_for_pending_event ())
2224
2213
else :
2225
2214
if not subtask .resolved ():
2226
2215
return [BLOCKED ]
@@ -2296,7 +2285,7 @@ def on_copy_done(result):
2296
2285
e .copy (task .inst , buffer , on_copy , on_copy_done )
2297
2286
2298
2287
if opts .sync and not e .has_pending_event ():
2299
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2288
+ await task .wait_on (e .wait_for_pending_event ())
2300
2289
2301
2290
if e .has_pending_event ():
2302
2291
code ,index ,payload = e .get_event ()
@@ -2342,7 +2331,7 @@ def on_copy_done(result):
2342
2331
e .copy (task .inst , buffer , on_copy_done )
2343
2332
2344
2333
if opts .sync and not e .has_pending_event ():
2345
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2334
+ await task .wait_on (e .wait_for_pending_event ())
2346
2335
2347
2336
if e .has_pending_event ():
2348
2337
code ,index ,payload = e .get_event ()
@@ -2375,7 +2364,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
2375
2364
e .shared .cancel ()
2376
2365
if not e .has_pending_event ():
2377
2366
if sync :
2378
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2367
+ await task .wait_on (e .wait_for_pending_event ())
2379
2368
else :
2380
2369
return [BLOCKED ]
2381
2370
code ,index ,payload = e .get_event ()
0 commit comments