Skip to content

Commit 3e5b8cc

Browse files
authored
added memoization of evaluatable objects
1 parent adf5fce commit 3e5b8cc

File tree

2 files changed

+69
-20
lines changed

2 files changed

+69
-20
lines changed

check50/assertions/rewrite.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def visit_Assert(self, node):
5454
self.generic_visit(node)
5555
cond_type = self._identify_comparison_type(node.test)
5656

57+
# Begin adding a named parameter that determines the type of condition
5758
keywords = [ast.keyword(arg="cond_type", value=ast.Constant(value=cond_type))]
5859

5960
# Extract variable names and build context={"var": var, ...}
@@ -75,20 +76,26 @@ def visit_Assert(self, node):
7576
left_str = ast.unparse(left_node)
7677
right_str = ast.unparse(right_node)
7778

79+
# Only add to context if not literal constants
80+
if not isinstance(left_node, ast.Constant):
81+
context_dict.keys.append(ast.Constant(value=left_str))
82+
context_dict.values.append(ast.Constant(value=None))
83+
if not isinstance(right_node, ast.Constant):
84+
context_dict.keys.append(ast.Constant(value=right_str))
85+
context_dict.values.append(ast.Constant(value=None))
86+
87+
7888
keywords.extend([
7989
ast.keyword(arg="left", value=ast.Constant(value=left_str)),
8090
ast.keyword(arg="right", value=ast.Constant(value=right_str))
8191
])
8292

83-
8493
return ast.Expr(
8594
value=ast.Call(
8695
# Create a function called check50_assert
8796
func=ast.Name(id="check50_assert", ctx=ast.Load()),
8897
# Give it these postional arguments:
8998
args=[
90-
# The condition
91-
node.test,
9299
# The string form of the condition
93100
ast.Constant(value=ast.unparse(node.test)),
94101
# The additional msg or exception that the user provided
@@ -150,7 +157,7 @@ def visit_Call(self, node):
150157
# As we travel down the function's subtree, denote this flag as True
151158
self._in_func_chain = True
152159
self.visit(node.func)
153-
self._in_func_chain = already_in_chain # Restore state
160+
self._in_func_chain = already_in_chain # Restore previous state
154161

155162
# Now visit the arguments of this function
156163
for arg in node.args:
@@ -159,8 +166,9 @@ def visit_Call(self, node):
159166
self.visit(kw)
160167

161168
def visit_Name(self, node):
162-
if not self._in_func_chain: # ignore Names of modules
169+
if not self._in_func_chain: # ignore Names of modules/libraries
163170
self.names.add(node.id)
171+
# self.names.add(node.id)
164172

165173
def _get_full_func_name(self, node):
166174
"""

check50/assertions/runtime.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from check50 import Failure, Missing, Mismatch
22

3-
def check50_assert(cond, src, msg_or_exc=None, cond_type="unknown", left=None, right=None, context=None):
3+
def check50_assert(src, msg_or_exc=None, cond_type="unknown", left=None, right=None, context=None):
44
"""
55
Asserts a conditional statement. If the condition evaluates to True,
66
nothing happens. Otherwise, it will look for a message or exception that
@@ -29,8 +29,6 @@ def check50_assert(cond, src, msg_or_exc=None, cond_type="unknown", left=None, r
2929
check50_assert(x in y, "x in y", None, "in", x, y)
3030
```
3131
32-
:param cond: The evaluated conditional statement.
33-
:type cond: bool
3432
:param src: The source code string of the conditional expression \
3533
(e.g., 'x in y'), extracted from the AST.
3634
:type src: str
@@ -52,39 +50,82 @@ def check50_assert(cond, src, msg_or_exc=None, cond_type="unknown", left=None, r
5250
:raises check50.Failure: If msg_or_exc is a string, or if cond_type is \
5351
unrecognized.
5452
"""
55-
if cond:
56-
return
57-
53+
# Evaluate all variables and functions within the context dict and generate
54+
# a string of these values
5855
context_str = None
5956
if context or (left and right):
60-
# Add `left` and `right` to `context` so that they can be evaluated in
61-
# the same pass as the other variables
62-
if left and right:
63-
context[left] = None
64-
context[right] = None
65-
# Evaluate context
6657
import inspect
6758
for expr_str in context:
6859
try:
60+
# Grab the global and local variables as of now
6961
caller_frame = inspect.currentframe().f_back
7062
context[expr_str] = eval(expr_str, caller_frame.f_globals, caller_frame.f_locals)
7163
except Exception as e:
7264
context[expr_str] = f"[error evaluating: {e}]"
7365

66+
# produces a string like "var1 = ..., var2 = ..., foo() = ..."
7467
context_str = ", ".join(f"{k} = {repr(v)}" for k, v in (context or {}).items())
7568

69+
# Since we've memoized the functions and variables once, now try and
70+
# evaluate the conditional by substituting the function calls/vars with
71+
# their results
72+
eval_src, eval_context = substitute_expressions(src, context)
73+
cond = eval(eval_src, {}, eval_context)
74+
75+
# Finally, quit if the condition evaluated to True.
76+
if cond:
77+
return
78+
79+
# If `right` or `left` were evaluatable objects, their actual value will be stored in `context`.
80+
# Otherwise, they're still just literals.
81+
right = context.get(right) or right
82+
left = context.get(left) or left
83+
84+
# Since the condition didn't evaluate to True, now, we can raise special
85+
# exceptions.
7686
if isinstance(msg_or_exc, str):
7787
raise Failure(msg_or_exc)
7888
elif isinstance(msg_or_exc, BaseException):
7989
raise msg_or_exc
8090
elif cond_type == 'eq' and left and right:
8191
help_msg = f"checked: {src}"
8292
help_msg += f"\n where {context_str}" if context_str else ""
83-
raise Mismatch(context[right], context[left], help=help_msg)
93+
raise Mismatch(right, left, help=help_msg)
8494
elif cond_type == 'in' and left and right:
8595
help_msg = f"checked: {src}"
8696
help_msg += f"\n where {context_str}" if context_str else ""
87-
raise Missing(context[left], context[right], help=help_msg)
97+
raise Missing(left, right, help=help_msg)
8898
else:
8999
help_msg = f"\n where {context_str}" if context_str else ""
90-
raise Failure(f"check did not pass: {src}" + help_msg)
100+
raise Failure(f"check did not pass: {src} {context}" + help_msg)
101+
102+
def substitute_expressions(src: str, context: dict) -> tuple[str, dict]:
103+
"""
104+
Rewrites `src` by replacing each key in `context` with a placeholder variable name,
105+
and builds a new context dict where those names map to pre-evaluated values.
106+
107+
For instance, given a `src`:
108+
```
109+
check50.run('pwd').stdout() == actual
110+
```
111+
it will create a new `eval_src` as
112+
```
113+
__expr0 == __expr1
114+
```
115+
and use the given context to define these variables:
116+
```
117+
eval_context = {
118+
'__expr0': context['check50.run('pwd').stdout()'],
119+
'__expr1': context['actual']
120+
}
121+
```
122+
"""
123+
new_src = src
124+
new_context = {}
125+
126+
for i, expr in enumerate(sorted(context.keys(), key=len, reverse=True)):
127+
placeholder = f"__expr{i}"
128+
new_src = new_src.replace(expr, placeholder)
129+
new_context[placeholder] = context[expr]
130+
131+
return new_src, new_context

0 commit comments

Comments
 (0)