@@ -56,6 +56,9 @@ class LoopState:
56
56
nodes_needing_checkpoints : list [cst .Return | cst .Yield ] = field (
57
57
default_factory = list
58
58
)
59
+ possibly_redundant_lowlevel_checkpoints : list [cst .BaseExpression ] = field (
60
+ default_factory = list
61
+ )
59
62
60
63
def copy (self ):
61
64
return LoopState (
@@ -66,6 +69,7 @@ def copy(self):
66
69
uncheckpointed_before_break = self .uncheckpointed_before_break .copy (),
67
70
artificial_errors = self .artificial_errors .copy (),
68
71
nodes_needing_checkpoints = self .nodes_needing_checkpoints .copy (),
72
+ possibly_redundant_lowlevel_checkpoints = self .possibly_redundant_lowlevel_checkpoints .copy (),
69
73
)
70
74
71
75
@@ -214,6 +218,22 @@ def leave_Yield(
214
218
leave_Return = leave_Yield # type: ignore
215
219
216
220
221
+ # class RemoveLowlevelCheckpoints(cst.CSTTransformer):
222
+ # def __init__(self, stmts_to_remove: set[cst.Await]):
223
+ # self.stmts_to_remove = stmts_to_remove
224
+ #
225
+ # def leave_Await(self, original_node: cst.Await, updated_node: cst.Await) -> cst.Await:
226
+ # # return original node to preserve identity
227
+ # return original_node
228
+ #
229
+ # # for some reason you can't just return RemovalSentinel from Await, so we have to
230
+ # # visit the possible wrappers and modify their bodies instead
231
+ #
232
+ # def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
233
+ # new_body = [stmt for stmt in updated_node.body.body if stmt not in self.stmts_to_remove]
234
+ # return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))
235
+
236
+
217
237
@error_class_cst
218
238
@disabled_by_default
219
239
class Visitor91X (Flake8TrioVisitor_cst , CommonVisitors ):
@@ -226,16 +246,27 @@ class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors):
226
246
"{0} from async iterable with no guaranteed checkpoint since {1.name} "
227
247
"on line {1.lineno}."
228
248
),
249
+ "TRIO912" : "Redundant checkpoint with no effect on program execution." ,
229
250
}
230
251
231
252
def __init__ (self , * args : Any , ** kwargs : Any ):
232
253
super ().__init__ (* args , ** kwargs )
233
254
self .has_yield = False
234
255
self .safe_decorator = False
235
256
self .async_function = False
236
- self .uncheckpointed_statements : set [Statement ] = set ()
237
257
self .comp_unknown = False
238
258
259
+ self .uncheckpointed_statements : set [Statement ] = set ()
260
+ self .checkpointed_by_lowlevel = False
261
+
262
+ # value == False, not redundant (or not determined to be redundant yet)
263
+ # value == True, there were no uncheckpointed statements when we encountered it
264
+ # value = expr/stmt, made redundant by the given expr/stmt
265
+ self .lowlevel_checkpoints : dict [
266
+ cst .Await , cst .BaseStatement | cst .BaseExpression | bool
267
+ ] = {}
268
+ self .lowlevel_checkpoint_updated_nodes : dict [cst .Await , cst .Await ] = {}
269
+
239
270
self .loop_state = LoopState ()
240
271
self .try_state = TryState ()
241
272
@@ -258,6 +289,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
258
289
"safe_decorator" ,
259
290
"async_function" ,
260
291
"uncheckpointed_statements" ,
292
+ "lowlevel_checkpoints" ,
261
293
"loop_state" ,
262
294
"try_state" ,
263
295
copy = True ,
@@ -299,8 +331,31 @@ def leave_FunctionDef(
299
331
indentedblock = updated_node .body .with_changes (body = new_body )
300
332
updated_node = updated_node .with_changes (body = indentedblock )
301
333
334
+ res : cst .FunctionDef = updated_node
335
+ to_remove : set [cst .Await ] = set ()
336
+ for expr , value in self .lowlevel_checkpoints .items ():
337
+ if value != False :
338
+ self .error (expr , error_code = "TRIO912" )
339
+ if self .should_autofix ():
340
+ to_remove .add (self .lowlevel_checkpoint_updated_nodes .pop (expr ))
341
+
342
+ if to_remove :
343
+ new_body = []
344
+ for stmt in updated_node .body .body :
345
+ if not m .matches (
346
+ stmt ,
347
+ m .SimpleStatementLine (
348
+ [m .Expr (m .MatchIfTrue (lambda x : x in to_remove ))]
349
+ ),
350
+ ):
351
+ new_body .append (stmt ) # type: ignore
352
+ assert new_body != updated_node .body .body
353
+ res = updated_node .with_changes (
354
+ body = updated_node .body .with_changes (body = new_body )
355
+ )
356
+
302
357
self .restore_state (original_node )
303
- return updated_node # noqa: R504
358
+ return res
304
359
305
360
# error if function exit/return/yields with uncheckpointed statements
306
361
# returns a bool indicating if any real (i.e. not artificial) errors were raised
@@ -372,12 +427,48 @@ def error_91x(
372
427
error_code = "TRIO911" if self .has_yield else "TRIO910" ,
373
428
)
374
429
430
+ def is_lowlevel_checkpoint (self , node : cst .BaseExpression ) -> bool :
431
+ # TODO: match against both libraries if both are imported
432
+ return m .matches (
433
+ node ,
434
+ m .Call (
435
+ m .Attribute (
436
+ m .Attribute (m .Name (self .library [0 ]), m .Name ("lowlevel" )),
437
+ m .Name ("checkpoint" ),
438
+ )
439
+ ),
440
+ )
441
+
442
+ def visit_Await (self , node : cst .Await ) -> None :
443
+ # do a match against the awaited expr
444
+ # if that is trio.lowlevel.checkpoint, and uncheckpointed statements
445
+ # are empty, raise TRIO912.
446
+ if self .is_lowlevel_checkpoint (node .expression ):
447
+ if not self .uncheckpointed_statements :
448
+ self .lowlevel_checkpoints [node ] = True
449
+ elif self .uncheckpointed_statements == {ARTIFICIAL_STATEMENT }:
450
+ self .loop_state .possibly_redundant_lowlevel_checkpoints .append (node )
451
+ else :
452
+ self .lowlevel_checkpoints [node ] = False
453
+ # if trio.lowlevel.checkpoint and *not* empty, take note of it in a special list.
454
+ elif not self .uncheckpointed_statements :
455
+ for expr , value in self .lowlevel_checkpoints .items ():
456
+ if value == False :
457
+ self .lowlevel_checkpoints [expr ] = node
458
+
459
+ # if this is not a trio.lowlevel.checkpoint, and there are no uncheckpointed statements, check if there is a lowlevel checkpoint in the special list. If so, raise a TRIO912 for it and remove it.
460
+
375
461
def leave_Await (
376
462
self , original_node : cst .Await , updated_node : cst .Await
377
463
) -> cst .Await :
378
464
# the expression being awaited is not checkpointed
379
465
# so only set checkpoint after the await node
380
466
467
+ # TODO: dirty hack to get identity right, the logic in visit should maybe be
468
+ # moved/split into the leave
469
+ if original_node in self .lowlevel_checkpoints :
470
+ self .lowlevel_checkpoint_updated_nodes [original_node ] = updated_node
471
+
381
472
# all nodes are now checkpointed
382
473
self .uncheckpointed_statements = set ()
383
474
return updated_node
@@ -494,6 +585,10 @@ def leave_Try(self, original_node: cst.Try, updated_node: cst.Try) -> cst.Try:
494
585
self .restore_state (original_node )
495
586
return updated_node
496
587
588
+ # if a previous lowlevel checkpoint is marked as redundant after all bodies, then
589
+ # it's redundant.
590
+ # If any body marks it as necessary, then it's necessary.
591
+ # Otherwise, it keeps it's state from before.
497
592
def leave_If_test (self , node : cst .If | cst .IfExp ) -> None :
498
593
if not self .async_function :
499
594
return
@@ -604,6 +699,11 @@ def leave_While_body(self, node: cst.For | cst.While):
604
699
if not any_error :
605
700
self .loop_state .nodes_needing_checkpoints = []
606
701
702
+ # but lowlevel checkpoints are redundant
703
+ for expr in self .loop_state .possibly_redundant_lowlevel_checkpoints :
704
+ self .error (expr , error_code = "TRIO912" )
705
+ # self.possibly_redundant_lowlevel_checkpoints.clear()
706
+
607
707
# replace artificial statements in else with prebody uncheckpointed statements
608
708
# non-artificial stmts before continue/break/at body end will already be in them
609
709
for stmts in (
@@ -654,6 +754,12 @@ def leave_While_orelse(self, node: cst.For | cst.While):
654
754
# reset break & continue in case of nested loops
655
755
self .outer [node ]["uncheckpointed_statements" ] = self .uncheckpointed_statements
656
756
757
+ # TODO: if this loop always checkpoints
758
+ # e.g. from being an async for, or being guaranteed to run once, or other stuff.
759
+ # then we can warn about redundant checkpoints before the loop.
760
+ # ... except if the reason we always checkpoint is due to redundant checkpoints
761
+ # we're about to remove.... :thinking:
762
+
657
763
leave_For_orelse = leave_While_orelse
658
764
659
765
def leave_While (
0 commit comments