From af3b475b31f3ad5a2751ff3249c837700288d616 Mon Sep 17 00:00:00 2001 From: ZoeLeibowitz Date: Tue, 8 Apr 2025 17:14:46 +0100 Subject: [PATCH 1/7] compiler: Fix printer for unevaluation Mul for issue 2577 --- devito/ir/cgen/printer.py | 2 +- tests/test_symbolics.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index cf6eee1a7c..ee1e36fb1a 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -223,7 +223,7 @@ def _print_Mod(self, expr): return '%'.join(args) def _print_Mul(self, expr): - args = [a for a in expr.args if a != -1] + args = [a for a in expr.args if (a != -1. or a != -1)] neg = (len(expr.args) - len(args)) % 2 if len(args) > 1: diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 42728ec13d..a21d7486f8 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -15,6 +15,7 @@ INT, FieldFromComposite, IntDiv, Namespace, Rvalue, ReservedWord, ListInitializer, uxreplace, pow_to_mul, retrieve_derivatives, BaseCast) +from devito.symbolics.unevaluation import Mul from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, ComponentAccess, StencilDimension, Symbol as dSymbol) @@ -874,3 +875,12 @@ def test_assumptions(self, op, expr, assumptions, expected): assumptions = eval(assumptions) expected = eval(expected) assert evalrel(op, eqn, assumptions) == expected + + +def test_issue_2577(): + + u = TimeFunction(name='u', grid=Grid((2,))) + eq = Eq(u.forward, Mul(-1, -1., u)) + op = Operator(eq) + + assert '--' not in str(op.ccode) From b9499538cb58de5b2f8abc758c19cb711828dd2d Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 9 Apr 2025 09:54:04 +0100 Subject: [PATCH 2/7] tests: Add additional test --- tests/test_symbolics.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index a21d7486f8..56b38d478d 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -7,7 +7,7 @@ from sympy import Expr, Number, Symbol from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, - Min, Max) + Min, Max, SubDomain) from devito.finite_differences.differentiable import SafeInv, Weights from devito.ir import Expression, FindNodes, ccode from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa @@ -884,3 +884,23 @@ def test_issue_2577(): op = Operator(eq) assert '--' not in str(op.ccode) + + +def test_issue_2577a(): + class SD0(SubDomain): + name = 'sd0' + + def define(self, dimensions): + x, = dimensions + return {x: ('middle', 1, 1)} + + grid = Grid(shape=(11,)) + + sd0 = SD0(grid=grid) + + u = Function(name='u', grid=grid, space_order=2) + + eq_u = Eq(u, -(u*u).dxc, subdomain=sd0) + + op = Operator(eq_u) + assert '--' not in str(op.ccode) From be099e86b7e1774c08f546bf2e0b954a3c6b57ba Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 9 Apr 2025 09:56:17 +0100 Subject: [PATCH 3/7] tests: Replace Mul with UnevalMul --- tests/test_symbolics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 56b38d478d..f81f3fda53 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -15,7 +15,7 @@ INT, FieldFromComposite, IntDiv, Namespace, Rvalue, ReservedWord, ListInitializer, uxreplace, pow_to_mul, retrieve_derivatives, BaseCast) -from devito.symbolics.unevaluation import Mul +from devito.symbolics.unevaluation import Mul as UnevalMul from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, ComponentAccess, StencilDimension, Symbol as dSymbol) @@ -880,7 +880,7 @@ def test_assumptions(self, op, expr, assumptions, expected): def test_issue_2577(): u = TimeFunction(name='u', grid=Grid((2,))) - eq = Eq(u.forward, Mul(-1, -1., u)) + eq = Eq(u.forward, UnevalMul(-1, -1., u)) op = Operator(eq) assert '--' not in str(op.ccode) From 0d6259c40555ef07e1cce5838c5891f8e44f2796 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 9 Apr 2025 08:19:07 -0400 Subject: [PATCH 4/7] api: fix Mul arguments processing --- devito/finite_differences/differentiable.py | 24 ++++++++++++--------- tests/test_symbolics.py | 12 ++++++----- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 39458424d4..9a46f02573 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -548,20 +548,24 @@ def __new__(cls, *args, **kwargs): nested, others = split(args, lambda e: isinstance(e, Mul)) args = flatten(e.args for e in nested) + list(others) + # Gather all numbers and simplify + nums, others = split(args, lambda e: isinstance(e, (int, float, + sympy.Number, np.number))) + scalar = sympy.Mul(*nums) + try: + scalar = sympy.Integer(scalar) + except TypeError: + pass + # a*0 -> 0 - if any(i == 0 for i in args): + if scalar == 0: return sympy.S.Zero # a*1 -> a - args = [i for i in args if i != 1] - - # a*-1 -> a*-1 - # a*-1*-1 -> a - # a*-1*-1*-1 -> a*-1 - nminus = len([i for i in args if i == sympy.S.NegativeOne]) - args = [i for i in args if i != sympy.S.NegativeOne] - if nminus % 2 == 1: - args.append(sympy.S.NegativeOne) + if scalar == 1: + args = others + else: + args = [scalar] + others # Reorder for homogeneity with pure SymPy types _mulsort(args) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index f81f3fda53..b463c653a1 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -8,14 +8,13 @@ from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, Min, Max, SubDomain) -from devito.finite_differences.differentiable import SafeInv, Weights +from devito.finite_differences.differentiable import SafeInv, Weights, Mul from devito.ir import Expression, FindNodes, ccode from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa CallFromPointer, Cast, DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, Rvalue, ReservedWord, ListInitializer, uxreplace, pow_to_mul, retrieve_derivatives, BaseCast) -from devito.symbolics.unevaluation import Mul as UnevalMul from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, ComponentAccess, StencilDimension, Symbol as dSymbol) @@ -877,16 +876,19 @@ def test_assumptions(self, op, expr, assumptions, expected): assert evalrel(op, eqn, assumptions) == expected -def test_issue_2577(): +def test_issue_2577a(): u = TimeFunction(name='u', grid=Grid((2,))) - eq = Eq(u.forward, UnevalMul(-1, -1., u)) + x = u.grid.dimensions[0] + expr = Mul(-1, -1., x, u) + assert expr.args == (x, u) + eq = Eq(u.forward, expr) op = Operator(eq) assert '--' not in str(op.ccode) -def test_issue_2577a(): +def test_issue_2577b(): class SD0(SubDomain): name = 'sd0' From 17d68fe6ffafcf86138c1fd01942ba72b714b9ec Mon Sep 17 00:00:00 2001 From: ZoeLeibowitz Date: Wed, 9 Apr 2025 14:01:08 +0100 Subject: [PATCH 5/7] compiler: Revert printer hack --- devito/ir/cgen/printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index ee1e36fb1a..cf6eee1a7c 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -223,7 +223,7 @@ def _print_Mod(self, expr): return '%'.join(args) def _print_Mul(self, expr): - args = [a for a in expr.args if (a != -1. or a != -1)] + args = [a for a in expr.args if a != -1] neg = (len(expr.args) - len(args)) % 2 if len(args) > 1: From 24422fa2dd3b53bb369af4bc19a24b9be3056e0e Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 9 Apr 2025 09:06:23 -0400 Subject: [PATCH 6/7] api: fix non-integer Mul args --- devito/finite_differences/differentiable.py | 6 +----- examples/seismic/tutorials/05_staggered_acoustic.ipynb | 2 +- tests/test_symbolic_coefficients.py | 4 ++-- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 9a46f02573..28257fe6bc 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -552,17 +552,13 @@ def __new__(cls, *args, **kwargs): nums, others = split(args, lambda e: isinstance(e, (int, float, sympy.Number, np.number))) scalar = sympy.Mul(*nums) - try: - scalar = sympy.Integer(scalar) - except TypeError: - pass # a*0 -> 0 if scalar == 0: return sympy.S.Zero # a*1 -> a - if scalar == 1: + if scalar - 1 == 0: args = others else: args = [scalar] + others diff --git a/examples/seismic/tutorials/05_staggered_acoustic.ipynb b/examples/seismic/tutorials/05_staggered_acoustic.ipynb index 26250c955c..bfb211605a 100644 --- a/examples/seismic/tutorials/05_staggered_acoustic.ipynb +++ b/examples/seismic/tutorials/05_staggered_acoustic.ipynb @@ -130,7 +130,7 @@ "$\\displaystyle \\left[\\begin{matrix}v_x(t + dt, x + h_x/2, z)\\\\v_z(t + dt, x, z + h_z/2)\\end{matrix}\\right] = \\left[\\begin{matrix}dt \\left(1.0 \\frac{\\partial}{\\partial x} p(t, x, z) + \\frac{v_x(t, x + h_x/2, z)}{dt}\\right)\\\\dt \\left(1.0 \\frac{\\partial}{\\partial z} p(t, x, z) + \\frac{v_z(t, x, z + h_z/2)}{dt}\\right)\\end{matrix}\\right]$" ], "text/plain": [ - "Eq(Vector(v_x(t + dt, x + h_x/2, z), v_z(t + dt, x, z + h_z/2)), Vector(dt*(1.0*Derivative(p(t, x, z), x) + v_x(t, x + h_x/2, z)/dt), dt*(1.0*Derivative(p(t, x, z), z) + v_z(t, x, z + h_z/2)/dt)))" + "Eq(Vector(v_x(t + dt, x + h_x/2, z), v_z(t + dt, x, z + h_z/2)), Vector(dt*(Derivative(p(t, x, z), x) + v_x(t, x + h_x/2, z)/dt), dt*(Derivative(p(t, x, z), z) + v_z(t, x, z + h_z/2)/dt)))" ] }, "execution_count": 7, diff --git a/tests/test_symbolic_coefficients.py b/tests/test_symbolic_coefficients.py index 75bfa5e0ca..e3aa8b282f 100644 --- a/tests/test_symbolic_coefficients.py +++ b/tests/test_symbolic_coefficients.py @@ -202,8 +202,8 @@ def test_staggered_equation(self): eq_f = Eq(f, f.dx2(weights=weights)) - expected = 'Eq(f(x + h_x/2), 1.0*f(x - h_x/2)/h_x**2 - 2.0*f(x + h_x/2)/h_x**2 '\ - '+ 1.0*f(x + 3*h_x/2)/h_x**2)' + expected = 'Eq(f(x + h_x/2), f(x - h_x/2)/h_x**2 - 2.0*f(x + h_x/2)/h_x**2 '\ + '+ f(x + 3*h_x/2)/h_x**2)' assert(str(eq_f.evaluate) == expected) @pytest.mark.parametrize('stagger', [True, False]) From 624bea27b04f1772fb4ac8bac623eba4be15c065 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 16 Apr 2025 08:18:48 -0400 Subject: [PATCH 7/7] compiler: fix division printing --- .github/workflows/pytest-core-nompi.yml | 2 +- devito/finite_differences/differentiable.py | 5 ++--- devito/symbolics/extended_sympy.py | 4 ++++ devito/tools/utils.py | 9 ++++++++- tests/test_dse.py | 2 +- tests/test_symbolics.py | 10 ++++++++-- 6 files changed, 24 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pytest-core-nompi.yml b/.github/workflows/pytest-core-nompi.yml index 352ade878c..f91c04ee41 100644 --- a/.github/workflows/pytest-core-nompi.yml +++ b/.github/workflows/pytest-core-nompi.yml @@ -172,7 +172,7 @@ jobs: - name: Test with pytest run: | - ${{ env.RUN_CMD }} pytest -k "${{ matrix.test-set }}" -m "not parallel" --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }} + ${{ env.RUN_CMD }} pytest -k "${{ matrix.test-set }}" -m "not parallel" --cov --cov-config=.coveragerc --cov-report=xml tests/ - name: Upload coverage to Codecov if: "!contains(matrix.name, 'docker')" diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 28257fe6bc..1879b14d16 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -17,7 +17,7 @@ from devito.finite_differences.tools import make_shift_x0, coeff_priority from devito.logger import warning from devito.tools import (as_tuple, filter_ordered, flatten, frozendict, - infer_dtype, is_integer, split) + infer_dtype, is_integer, split, is_number) from devito.types import Array, DimensionTuple, Evaluable, StencilDimension from devito.types.basic import AbstractFunction @@ -549,8 +549,7 @@ def __new__(cls, *args, **kwargs): args = flatten(e.args for e in nested) + list(others) # Gather all numbers and simplify - nums, others = split(args, lambda e: isinstance(e, (int, float, - sympy.Number, np.number))) + nums, others = split(args, lambda e: is_number(e)) scalar = sympy.Mul(*nums) # a*0 -> 0 diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index ef91ef4d49..8509de6d1e 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -622,6 +622,10 @@ def __new__(cls, name, arguments=None, template=None, **kwargs): return obj + def _eval_is_commutative(self): + # DefFunction defaults to commutative + return True + @property def name(self): return self._name diff --git a/devito/tools/utils.py b/devito/tools/utils.py index b99eddcf6c..0a28de16a8 100644 --- a/devito/tools/utils.py +++ b/devito/tools/utils.py @@ -12,7 +12,7 @@ 'roundm', 'powerset', 'invert', 'flatten', 'single_or', 'filter_ordered', 'as_mapper', 'filter_sorted', 'pprint', 'sweep', 'all_equal', 'as_list', 'indices_to_slices', 'indices_to_sections', 'transitive_closure', - 'humanbytes', 'contains_val', 'sorted_priority', 'as_set'] + 'humanbytes', 'contains_val', 'sorted_priority', 'as_set', 'is_number'] def prod(iterable, initial=1): @@ -82,6 +82,13 @@ def is_integer(value): return isinstance(value, (int, np.integer, sympy.Integer)) +def is_number(value): + """ + A thorough instance comparison for all number types. + """ + return isinstance(value, (int, float, np.number, sympy.Number)) + + def contains_val(val, items): try: return val in items diff --git a/tests/test_dse.py b/tests/test_dse.py index 06e00ab182..4f37488a86 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -88,7 +88,7 @@ def test_pow_to_mul(expr, expected): @pytest.mark.parametrize('expr,expected', [ - ('s - SizeOf("int")*fa[x]', 's - fa[x]*sizeof(int)'), + ('s - SizeOf("int")*fa[x]', 's - sizeof(int)*fa[x]'), ('foo(4*fa[x] + 4*fb[x])', 'foo(4*(fa[x] + fb[x]))'), ('floor(0.1*a + 0.1*fa[x])', 'floor(0.1*(a + fa[x]))'), ('floor(0.1*(a + fa[x]))', 'floor(0.1*(a + fa[x]))'), diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index b463c653a1..7408f917f3 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -14,7 +14,7 @@ CallFromPointer, Cast, DefFunction, FieldFromPointer, INT, FieldFromComposite, IntDiv, Namespace, Rvalue, ReservedWord, ListInitializer, uxreplace, pow_to_mul, - retrieve_derivatives, BaseCast) + retrieve_derivatives, BaseCast, SizeOf) from devito.tools import as_tuple from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, ComponentAccess, StencilDimension, Symbol as dSymbol) @@ -877,7 +877,6 @@ def test_assumptions(self, op, expr, assumptions, expected): def test_issue_2577a(): - u = TimeFunction(name='u', grid=Grid((2,))) x = u.grid.dimensions[0] expr = Mul(-1, -1., x, u) @@ -906,3 +905,10 @@ def define(self, dimensions): op = Operator(eq_u) assert '--' not in str(op.ccode) + + +def test_print_div(): + a = SizeOf(np.int32) + b = SizeOf(np.int64) + cstr = ccode(a / b) + assert cstr == 'sizeof(int)/sizeof(long)'