Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/underworld3/function/pure_sympy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,19 @@ def _expr_hash(expr):
str
Hash string
"""
# sympy.Dummy carries a volatile, globally-unique dummy_index that
# srepr embeds. The evaluate() coordinate-substitution path mints a
# fresh Dummy per call, so an otherwise-identical expression would
# hash differently every call and never hit the cache (issue #194).
# Canonicalise dummies to name-stable Symbols before srepr. This only
# affects the cache *key*; the real sympy.lambdify() call still uses
# the original expr/symbols, so numerics are unchanged. The cache key
# also separately carries the symbol-name tuple, so name-keying here
# is safe and deterministic.
dummies = expr.atoms(sympy.Dummy)
if dummies:
expr = expr.xreplace({d: sympy.Symbol(d.name) for d in dummies})

# Use sympy's srepr for consistent string representation
expr_str = sympy.srepr(expr)
return hashlib.md5(expr_str.encode()).hexdigest()
Expand Down
120 changes: 81 additions & 39 deletions tests/test_0720_lambdify_optimization_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,51 +317,93 @@ def test_detection_uw3_variable(self, setup_mesh):


class TestPerformanceExpectations:
"""Test that performance is as expected (not exact timing, just sanity checks)."""

def test_lambdify_caching(self, setup_mesh, sample_points):
"""Cached lambdified evaluations should be fast."""
import time

x = setup_mesh.X[0]
expr = sympy.erf(5 * x - 2) / 2

# First call (compilation)
start = time.time()
result1 = uw.function.evaluate(expr, sample_points, rbf=True)
time1 = time.time() - start

# Cached call
start = time.time()
result2 = uw.function.evaluate(expr, sample_points, rbf=True)
time2 = time.time() - start

# Cached should be faster or at least not slower
# (For small evaluations, overhead dominates so speedup may be small)
assert time2 <= time1 * 2, f"Cached call slower than first: {time2} vs {time1}"

# Both should be reasonably fast (< 10ms for 3 points)
assert time2 < 0.01, f"Cached call too slow: {time2}s"

# Results should be identical
assert np.allclose(result1.flatten(), result2.flatten())

def test_rbf_false_not_slow(self, setup_mesh, sample_points):
"""rbf=False should not be dramatically slower for pure sympy."""
import time
"""Verify the lambdify cache *behaviour* (cache hit on repeat), not
wall-clock timing.

One-off wall-clock assertions are inherently flaky on shared CI
runners. The property we actually care about is behavioural: an
identical lambdify request must reuse the cached function object
rather than recompiling. These tests exercise the cache contract
(``get_cached_lambdified``) directly, so they are deterministic.
"""

def test_lambdify_cache_hit(self):
"""``get_cached_lambdified`` returns the *same* object on a repeat.

Validates the cache mechanism's contract directly:
- first request for an expression compiles and stores one entry;
- an identical request returns the very same function object and
adds no entry (a genuine cache hit);
- a structurally different expression is cached separately
(the key actually discriminates).
"""
from underworld3.function.pure_sympy_evaluator import (
_lambdify_cache,
clear_lambdify_cache,
get_cached_lambdified,
)

a, b = sympy.symbols("a b")
expr = sympy.erf(5 * a - 2) / 2 + sympy.sin(b)
symbols = (a, b)

clear_lambdify_cache()
assert len(_lambdify_cache) == 0

# Cache miss: compiles and stores exactly one entry.
f1 = get_cached_lambdified(expr, symbols)
assert len(_lambdify_cache) == 1, "first call did not populate the cache"

# Cache hit: identical request -> same object, no new entry.
f2 = get_cached_lambdified(expr, symbols)
assert f2 is f1, "identical request did not return the cached function"
assert len(_lambdify_cache) == 1, "cache hit must not add an entry"

# The cached function is correct.
val = f1(0.4, 0.7)
assert np.isclose(
val, float(sympy.erf(5 * 0.4 - 2) / 2 + sympy.sin(0.7))
)

# A different expression is a distinct cache entry (key discriminates).
other = sympy.erf(5 * a - 2) / 2 + sympy.cos(b)
g = get_cached_lambdified(other, symbols)
assert g is not f1
assert len(_lambdify_cache) == 2, "distinct expression must be cached separately"

def test_evaluate_cache_stable_across_calls(self, setup_mesh, sample_points):
"""Regression guard for #194 (aggregate cache behaviour, no timing).

Repeated identical ``uw.function.evaluate`` calls must reach a
bounded steady state in ``_lambdify_cache`` -- not grow one entry
per call. Before the #194 fix the cache grew [1,2,3,4,...] because
a fresh ``sympy.Dummy`` (volatile ``dummy_index`` in ``srepr``)
was minted every call, so the cache key never matched. Also
asserts results are identical across calls (no behavioural change
from the cache-key canonicalisation).
"""
from underworld3.function.pure_sympy_evaluator import (
_lambdify_cache,
clear_lambdify_cache,
)

x = setup_mesh.X[0]
expr = sympy.erf(5 * x - 2) / 2

# Warm up cache
_ = uw.function.evaluate(expr, sample_points, rbf=False)
clear_lambdify_cache()
ref = uw.function.evaluate(expr, sample_points, rbf=True)
size_after_first = len(_lambdify_cache)
assert size_after_first >= 1, "first call did not populate the cache"

# Cached call with rbf=False should be fast (< 10ms for 3 points)
start = time.time()
result = uw.function.evaluate(expr, sample_points, rbf=False)
elapsed = time.time() - start
for _ in range(8):
r = uw.function.evaluate(expr, sample_points, rbf=True)
assert np.allclose(r.flatten(), ref.flatten())

assert elapsed < 0.01, f"rbf=False too slow: {elapsed}s (should be < 10ms)"
# Steady state: no growth after the first compile (cache hits).
assert len(_lambdify_cache) == size_after_first, (
f"lambdify cache grew {size_after_first} -> {len(_lambdify_cache)} "
f"across identical evaluate() calls -- regression of #194"
)


if __name__ == "__main__":
Expand Down
Loading