Skip to content

Commit 5e44a59

Browse files
committed
WIP TRIO912 unnecessary checkpoints
1 parent a29608c commit 5e44a59

File tree

4 files changed

+1010
-2
lines changed

4 files changed

+1010
-2
lines changed

flake8_trio/visitors/visitor91x.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class LoopState:
5656
nodes_needing_checkpoints: list[cst.Return | cst.Yield] = field(
5757
default_factory=list
5858
)
59+
possibly_redundant_lowlevel_checkpoints: list[cst.BaseExpression] = field(
60+
default_factory=list
61+
)
5962

6063
def copy(self):
6164
return LoopState(
@@ -66,6 +69,7 @@ def copy(self):
6669
uncheckpointed_before_break=self.uncheckpointed_before_break.copy(),
6770
artificial_errors=self.artificial_errors.copy(),
6871
nodes_needing_checkpoints=self.nodes_needing_checkpoints.copy(),
72+
possibly_redundant_lowlevel_checkpoints=self.possibly_redundant_lowlevel_checkpoints.copy(),
6973
)
7074

7175

@@ -214,6 +218,22 @@ def leave_Yield(
214218
leave_Return = leave_Yield # type: ignore
215219

216220

221+
# class RemoveLowlevelCheckpoints(cst.CSTTransformer):
222+
# def __init__(self, stmts_to_remove: set[cst.Await]):
223+
# self.stmts_to_remove = stmts_to_remove
224+
#
225+
# def leave_Await(self, original_node: cst.Await, updated_node: cst.Await) -> cst.Await:
226+
# # return original node to preserve identity
227+
# return original_node
228+
#
229+
# # for some reason you can't just return RemovalSentinel from Await, so we have to
230+
# # visit the possible wrappers and modify their bodies instead
231+
#
232+
# def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
233+
# new_body = [stmt for stmt in updated_node.body.body if stmt not in self.stmts_to_remove]
234+
# return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))
235+
236+
217237
@error_class_cst
218238
@disabled_by_default
219239
class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors):
@@ -226,16 +246,27 @@ class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors):
226246
"{0} from async iterable with no guaranteed checkpoint since {1.name} "
227247
"on line {1.lineno}."
228248
),
249+
"TRIO912": "Redundant checkpoint with no effect on program execution.",
229250
}
230251

231252
def __init__(self, *args: Any, **kwargs: Any):
232253
super().__init__(*args, **kwargs)
233254
self.has_yield = False
234255
self.safe_decorator = False
235256
self.async_function = False
236-
self.uncheckpointed_statements: set[Statement] = set()
237257
self.comp_unknown = False
238258

259+
self.uncheckpointed_statements: set[Statement] = set()
260+
self.checkpointed_by_lowlevel = False
261+
262+
# value == False, not redundant (or not determined to be redundant yet)
263+
# value == True, there were no uncheckpointed statements when we encountered it
264+
# value = expr/stmt, made redundant by the given expr/stmt
265+
self.lowlevel_checkpoints: dict[
266+
cst.Await, cst.BaseStatement | cst.BaseExpression | bool
267+
] = {}
268+
self.lowlevel_checkpoint_updated_nodes: dict[cst.Await, cst.Await] = {}
269+
239270
self.loop_state = LoopState()
240271
self.try_state = TryState()
241272

@@ -258,6 +289,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
258289
"safe_decorator",
259290
"async_function",
260291
"uncheckpointed_statements",
292+
"lowlevel_checkpoints",
261293
"loop_state",
262294
"try_state",
263295
copy=True,
@@ -299,8 +331,31 @@ def leave_FunctionDef(
299331
indentedblock = updated_node.body.with_changes(body=new_body)
300332
updated_node = updated_node.with_changes(body=indentedblock)
301333

334+
res: cst.FunctionDef = updated_node
335+
to_remove: set[cst.Await] = set()
336+
for expr, value in self.lowlevel_checkpoints.items():
337+
if value != False:
338+
self.error(expr, error_code="TRIO912")
339+
if self.should_autofix():
340+
to_remove.add(self.lowlevel_checkpoint_updated_nodes.pop(expr))
341+
342+
if to_remove:
343+
new_body = []
344+
for stmt in updated_node.body.body:
345+
if not m.matches(
346+
stmt,
347+
m.SimpleStatementLine(
348+
[m.Expr(m.MatchIfTrue(lambda x: x in to_remove))]
349+
),
350+
):
351+
new_body.append(stmt) # type: ignore
352+
assert new_body != updated_node.body.body
353+
res = updated_node.with_changes(
354+
body=updated_node.body.with_changes(body=new_body)
355+
)
356+
302357
self.restore_state(original_node)
303-
return updated_node # noqa: R504
358+
return res
304359

305360
# error if function exit/return/yields with uncheckpointed statements
306361
# returns a bool indicating if any real (i.e. not artificial) errors were raised
@@ -372,12 +427,48 @@ def error_91x(
372427
error_code="TRIO911" if self.has_yield else "TRIO910",
373428
)
374429

430+
def is_lowlevel_checkpoint(self, node: cst.BaseExpression) -> bool:
431+
# TODO: match against both libraries if both are imported
432+
return m.matches(
433+
node,
434+
m.Call(
435+
m.Attribute(
436+
m.Attribute(m.Name(self.library[0]), m.Name("lowlevel")),
437+
m.Name("checkpoint"),
438+
)
439+
),
440+
)
441+
442+
def visit_Await(self, node: cst.Await) -> None:
443+
# do a match against the awaited expr
444+
# if that is trio.lowlevel.checkpoint, and uncheckpointed statements
445+
# are empty, raise TRIO912.
446+
if self.is_lowlevel_checkpoint(node.expression):
447+
if not self.uncheckpointed_statements:
448+
self.lowlevel_checkpoints[node] = True
449+
elif self.uncheckpointed_statements == {ARTIFICIAL_STATEMENT}:
450+
self.loop_state.possibly_redundant_lowlevel_checkpoints.append(node)
451+
else:
452+
self.lowlevel_checkpoints[node] = False
453+
# if trio.lowlevel.checkpoint and *not* empty, take note of it in a special list.
454+
elif not self.uncheckpointed_statements:
455+
for expr, value in self.lowlevel_checkpoints.items():
456+
if value == False:
457+
self.lowlevel_checkpoints[expr] = node
458+
459+
# if this is not a trio.lowlevel.checkpoint, and there are no uncheckpointed statements, check if there is a lowlevel checkpoint in the special list. If so, raise a TRIO912 for it and remove it.
460+
375461
def leave_Await(
376462
self, original_node: cst.Await, updated_node: cst.Await
377463
) -> cst.Await:
378464
# the expression being awaited is not checkpointed
379465
# so only set checkpoint after the await node
380466

467+
# TODO: dirty hack to get identity right, the logic in visit should maybe be
468+
# moved/split into the leave
469+
if original_node in self.lowlevel_checkpoints:
470+
self.lowlevel_checkpoint_updated_nodes[original_node] = updated_node
471+
381472
# all nodes are now checkpointed
382473
self.uncheckpointed_statements = set()
383474
return updated_node
@@ -494,6 +585,10 @@ def leave_Try(self, original_node: cst.Try, updated_node: cst.Try) -> cst.Try:
494585
self.restore_state(original_node)
495586
return updated_node
496587

588+
# if a previous lowlevel checkpoint is marked as redundant after all bodies, then
589+
# it's redundant.
590+
# If any body marks it as necessary, then it's necessary.
591+
# Otherwise, it keeps it's state from before.
497592
def leave_If_test(self, node: cst.If | cst.IfExp) -> None:
498593
if not self.async_function:
499594
return
@@ -604,6 +699,11 @@ def leave_While_body(self, node: cst.For | cst.While):
604699
if not any_error:
605700
self.loop_state.nodes_needing_checkpoints = []
606701

702+
# but lowlevel checkpoints are redundant
703+
for expr in self.loop_state.possibly_redundant_lowlevel_checkpoints:
704+
self.error(expr, error_code="TRIO912")
705+
# self.possibly_redundant_lowlevel_checkpoints.clear()
706+
607707
# replace artificial statements in else with prebody uncheckpointed statements
608708
# non-artificial stmts before continue/break/at body end will already be in them
609709
for stmts in (
@@ -654,6 +754,12 @@ def leave_While_orelse(self, node: cst.For | cst.While):
654754
# reset break & continue in case of nested loops
655755
self.outer[node]["uncheckpointed_statements"] = self.uncheckpointed_statements
656756

757+
# TODO: if this loop always checkpoints
758+
# e.g. from being an async for, or being guaranteed to run once, or other stuff.
759+
# then we can warn about redundant checkpoints before the loop.
760+
# ... except if the reason we always checkpoint is due to redundant checkpoints
761+
# we're about to remove.... :thinking:
762+
657763
leave_For_orelse = leave_While_orelse
658764

659765
def leave_While(

0 commit comments

Comments
 (0)