diff --git a/devito/core/cpu.py b/devito/core/cpu.py index a1c70ccdfc..3a55a4a4c5 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -111,6 +111,7 @@ def _normalize_kwargs(cls, **kwargs): ) kwargs['options'].update(o) + kwargs['sym_options'] = cls._normalize_sym_kwargs(**kwargs) return kwargs diff --git a/devito/core/gpu.py b/devito/core/gpu.py index ca6c288c7d..bd13623441 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -121,6 +121,7 @@ def _normalize_kwargs(cls, **kwargs): ) kwargs['options'].update(o) + kwargs['sym_options'] = cls._normalize_sym_kwargs(**kwargs) return kwargs diff --git a/devito/core/operator.py b/devito/core/operator.py index 974753a9fb..8a52cbf98a 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -171,6 +171,30 @@ class BasicOperator(Operator): The target language constructor, to be specified by subclasses. """ + # ------------------------------------------------------------------ + # Symbolic-level option defaults (`sym_opt`). + # These steer mathematical choices made during expression lowering, + # *not* code generation or performance. They are kept separate from + # the `opt` options above to keep the two concerns distinct. + # ------------------------------------------------------------------ + + INTERP_MODE = 'direct' + """ + Default for the `sym_opt={'interp-mode': ...}` option. Controls how + a product of fields living at different staggered locations is mapped + onto a target location: + + * `'direct'` (default): each factor is interpolated to the target + independently. Cheapest stencil. + * `'symmetric'`: factors are first gathered at a common "block" + location, multiplied there, and the result is interpolated once to + the target. Preserves the `I A I^T` matrix structure, so the + discrete operator stays self-adjoint when the continuous one is + (e.g. the elastic stiffness `sigma = C eps`). + + See `examples/userapi/08_staggered_interp.ipynb` for a worked example. + """ + @classmethod def _normalize_kwargs(cls, **kwargs): # Will be populated with dummy values; this method is actually overridden @@ -188,12 +212,30 @@ def _normalize_kwargs(cls, **kwargs): ) kwargs['options'].update(o) + kwargs['sym_options'] = cls._normalize_sym_kwargs(**kwargs) return kwargs + @classmethod + def _normalize_sym_kwargs(cls, **kwargs): + """ + Fill in defaults and validate keys for the `sym_opt` dict passed to + the Operator. Returns the normalized `sym_options` dict. + """ + so = dict(kwargs.get('sym_options', {})) + out = {'interp-mode': so.pop('interp-mode', cls.INTERP_MODE)} + + if so: + raise InvalidOperator( + f'Unrecognized symbolic options: [{", ".join(list(so))}]' + ) + + return out + @classmethod def _check_kwargs(cls, **kwargs): oo = kwargs['options'] + so = kwargs['sym_options'] if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES: raise InvalidOperator(f"Unsupported MPI mode `{oo['mpi']}`") @@ -209,6 +251,9 @@ def _check_kwargs(cls, **kwargs): if oo['errctl'] not in (None, False, 'basic', 'max'): raise InvalidOperator("Illegal `errctl` value") + if so['interp-mode'] not in ('direct', 'symmetric'): + raise InvalidOperator("Illegal `interp-mode` value") + def _autotune(self, args, setup): if setup in [False, 'off']: return args diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 22984af6fb..ceb0b67a69 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -89,7 +89,14 @@ class Derivative(sympy.Derivative, Differentiable, Pickable): evaluation are `x0`, `fd_order` and `side`. """ - _fd_priority = 3 + @cached_property + def _fd_priority(self): + # A Derivative inherits the priority of its underlying expression, so + # that `highest_priority(C*v.dx)` and `highest_priority((C*v).dx)` + # agree on the gather location and the two gathering paths + # (`_gather_for_diff` and `Mul._eval_at(interp_mode='symmetric')`) + # produce consistent answers. + return getattr(self.expr, '_fd_priority', 0) __rargs__ = ('expr', '*dims') __rkwargs__ = ('side', 'deriv_order', 'fd_order', 'transpose', '_ppsubs', @@ -472,7 +479,7 @@ def T(self): return self._rebuild(transpose=adjoint) - def _eval_at(self, func): + def _eval_at(self, func, interp_mode='direct', **kwargs): """ Evaluates the derivative at the location of `func`. It is necessary for staggered setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx @@ -522,7 +529,8 @@ def _eval_at(self, func): return self._rebuild(self.expr, **rkw) args = [self.expr.func(*v) for v in mapper.values()] args.extend([a for a in self.expr.args if a not in self.expr._args_diff]) - args = [self._rebuild(a)._eval_at(func) for a in args] + args = [self._rebuild(a)._eval_at(func, interp_mode=interp_mode, **kwargs) + for a in args] return self.expr.func(*args) elif self.expr.is_Mul: # For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 503393b089..fd3be2ae1d 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -15,6 +15,7 @@ # Moved in 1.13 from sympy.core.basic import ordering_of_classes +from devito.finite_differences.interpolation import interp_at, post_x0_indices from devito.finite_differences.tools import coeff_priority, make_shift_x0 from devito.logger import warning from devito.tools import ( @@ -184,13 +185,11 @@ def coefficients(self): key = lambda x: coeff_priority.get(x, -1) return sorted(coefficients, key=key, reverse=True)[0] - def _eval_at(self, func): - if not func.is_Staggered: - # Cartesian grid, do no waste time - return self + def _eval_at(self, func, **kwargs): return self.func(*[ - getattr(a, '_eval_at', lambda x: a)(func) for a in self.args # noqa: B023 - ]) # false positive + getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs) # noqa: B023 + for a in self.args # false positive: lambda is invoked in-place + ]) def _subs(self, old, new, **hints): if old == self: @@ -669,6 +668,63 @@ def _gather_for_diff(self): other = self.func(*other)._eval_at(highest_priority(self)) return self.func(other, *derivs) + def _eval_at(self, func, interp_mode='direct', **kwargs): + """ + Evaluate a Mul at the location of `func`. + + Two modes: + + - `interp_mode='direct'` (default): per-arg evaluation; each factor is + independently evaluated at `func`'s location via + `Differentiable._eval_at`. + + - `interp_mode='symmetric'`: when every Differentiable factor has a + staggering different from `func`'s, apply the `I * (a * I^T * b)` + form: + + 1. Pick a `block` location -- the highest-priority factor's + staggering (NODE is the highest priority, so coefficient-like + NODE factors win, as in the `I * C * I^T` elastic stiffness + pattern). Each factor not at the block is brought there via + `I^T` (an explicit 0-order FD interpolation operator). + Derivatives additionally set `x0` on their own derivative + dimensions to `func`'s indices. + 2. The product is formed at `block`'s location. + 3. The whole product is interpolated to `func` via `I` (an + explicit 0-order FD operator). + + When the trigger does not hold (e.g. some factor already matches + `func`'s staggering), we fall back to `direct`. + """ + if interp_mode != 'symmetric': + return super()._eval_at(func, **kwargs) + + diff, other = split(self.args, lambda a: isinstance(a, Differentiable)) + + # Symmetric form requires every Differentiable factor to differ from + # func; otherwise direct evaluation is cleaner and equivalent. + if len(diff) < 2 or \ + any(a.staggered == func.staggered for a in diff): + return super()._eval_at(func, **kwargs) + + block_indices = highest_priority(self).indices_ref + + # Bring each factor to block's location (I^T where needed) + new_factors = list(other) + for a in diff: + if isinstance(a, sympy.Derivative): + source = post_x0_indices(a, func) + a = a._rebuild(x0={dim: func.indices_ref[dim] for dim in a.dims + if dim in func.indices_ref.getters}) + else: + source = a.indices_ref + new_factors.append(interp_at(a, source, block_indices, + self.interp_order)) + + # Final I from block's location to func + return interp_at(self.func(*new_factors), block_indices, + func.indices_ref, self.interp_order) + class Pow(DifferentiableOp, sympy.Pow): _fd_priority = 0 @@ -1020,7 +1076,7 @@ def _subs(self, old, new, **hints): class DiffDerivative(IndexDerivative, DifferentiableOp): - def _eval_at(self, func): + def _eval_at(self, func, **kwargs): # Like EvalDerivative, a DiffDerivative must have already been evaluated # at a valid x0 and should not be re-evaluated at a different location return self @@ -1074,7 +1130,7 @@ def _new_rawargs(self, *args, **kwargs): kwargs.pop('is_commutative', None) return self.func(*args, **kwargs) - def _eval_at(self, func): + def _eval_at(self, func, **kwargs): # An EvalDerivative must have already been evaluated at a valid x0 # and should not be re-evaluated at a different location return self @@ -1092,7 +1148,7 @@ class diffify: Notes ----- - The name "diffify" stems from SymPy's "simpify", which has an analogous task -- + The name "diffify" stems from SymPy's "simplify", which has an analogous task -- converting all arguments into SymPy core objects. """ diff --git a/devito/finite_differences/interpolation.py b/devito/finite_differences/interpolation.py new file mode 100644 index 0000000000..27ef222c81 --- /dev/null +++ b/devito/finite_differences/interpolation.py @@ -0,0 +1,62 @@ +from contextlib import suppress + +__all__ = ['interp_at', 'interp_mapper', 'post_x0_indices'] + + +def interp_mapper(source, target, dims): + """ + Build a `{dim: target_index}` mapper for dimensions in `dims` where + `source[dim]` differs from `target[dim]`. + + `source` and `target` are dict-like `{dim: index_expr}` (e.g. a plain + dict or a `DimensionTuple`). Dimensions missing from either side are + skipped silently. + """ + mapper = {} + for d in dims: + try: + s = source[d] + t = target[d] + except (KeyError, IndexError): + continue + if s is not t: + mapper[d] = t + return mapper + + +def interp_at(expr, source, target, interp_order): + """ + Build a symbolic 0-order FD interpolation operator on `expr` that maps + values from `source` indices to `target` indices, only on the + dimensions where the two locations differ. + """ + from devito.finite_differences.differentiable import Differentiable + + if not isinstance(expr, Differentiable): + return expr + + mapper = interp_mapper(source, target, expr.dimensions) + if not mapper: + return expr + + return expr.diff(*mapper.keys(), + deriv_order=(0,) * len(mapper), + fd_order=(interp_order,) * len(mapper), + x0=mapper) + + +def post_x0_indices(deriv, func): + """ + Conceptual indices of `deriv` after setting `x0` on its own derivative + dimensions to `func`'s indices. Derivative dims take `func`'s indices; + other dims keep the underlying expression's natural location (so that + `interp_for_fd` does not introduce a spurious second shift). + """ + ref = {} + for dim in deriv.dimensions: + if dim in deriv.dims and dim in func.indices_ref.getters: + ref[dim] = func.indices_ref[dim] + else: + with suppress(KeyError): + ref[dim] = deriv.indices_ref[dim] + return ref diff --git a/devito/operator/operator.py b/devito/operator/operator.py index c0bb6145a6..86899a6b7a 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -67,6 +67,16 @@ class Operator(Callable): Symbolic substitutions to be applied to ``expressions``. * opt : str The performance optimization level. Defaults to ``configuration['opt']``. + * sym_opt : dict + Symbolic-level options controlling mathematical choices made during + expression lowering (e.g. how staggered multi-factor products are + interpolated). Distinct from ``opt``, which controls code generation + and performance. Accepted keys: + + - ``'interp-mode'`` (``'direct'`` | ``'symmetric'``): selects the + interpolation strategy used by ``Mul._eval_at`` when projecting a + multi-factor expression onto a target staggered location. See the + tutorial at ``examples/userapi/08_staggered_interp.ipynb``. * language : str The target language for shared-memory parallelism. Defaults to ``configuration['language']``. @@ -234,6 +244,7 @@ def _build(cls, expressions, **kwargs): # Potentially required for lazily allocated Functions op._mode = kwargs['mode'] op._options = kwargs['options'] + op._sym_options = kwargs['sym_options'] op._allocator = kwargs['allocator'] op._platform = kwargs['platform'] @@ -341,6 +352,7 @@ def _lower_exprs(cls, expressions, **kwargs): * Shift indices for domain alignment. """ expand = kwargs['options'].get('expand', True) + interp_mode = kwargs.get('sym_options', {}).get('interp-mode', 'direct') # Specialization is performed on unevaluated expressions expressions = cls._specialize_dsl(expressions, **kwargs) @@ -351,7 +363,8 @@ def _lower_exprs(cls, expressions, **kwargs): # ModuloDimensions if not expand: expand = lambda d: d.is_Stepping - expressions = flatten([i._evaluate(expand=expand) for i in expressions]) + expressions = flatten([i._evaluate(expand=expand, interp_mode=interp_mode) + for i in expressions]) # Scalarize the tensor equations, if any expressions = [j for i in expressions for j in i._flatten] @@ -1641,6 +1654,12 @@ def parse_kwargs(**kwargs): mode = 'noop' kwargs['mode'] = mode + # `sym_opt` -- symbolic-level options (mathematical choices, not codegen) + sym_opt = kwargs.pop('sym_opt', None) or {} + if not isinstance(sym_opt, (dict, frozendict)): + raise InvalidOperator(f"Illegal `sym_opt={str(sym_opt)}`") + kwargs['sym_options'] = dict(sym_opt) + # `platform` platform = kwargs.get('platform') if platform is not None: diff --git a/devito/types/dense.py b/devito/types/dense.py index 30205a9712..6b62c91bf4 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -15,6 +15,7 @@ from devito.deprecations import deprecations from devito.exceptions import InvalidArgument from devito.finite_differences import Differentiable, generate_fd_shortcuts +from devito.finite_differences.interpolation import interp_mapper from devito.finite_differences.tools import fd_weights_registry from devito.logger import debug, warning from devito.mpi import MPI @@ -1116,22 +1117,24 @@ def __fd_setup__(self): @cached_property def _fd_priority(self): - return 1 if self.staggered.on_node else 2 + # NODE takes precedence: coefficients are conventionally stored at the + # cell centre, so when we gather a product onto a single location + # (either via _gather_for_diff or symmetric Mul._eval_at), NODE is the + # natural one to pick. + return 1.2 if self.staggered.on_node else 1.1 - def _eval_at(self, func): + def _eval_at(self, func, **kwargs): if self.staggered == func.staggered or self.interp_order == 0: return self - mapper = {} - for d in self.dimensions: - try: - if self.indices_ref[d] is not func.indices_ref[d]: - f_idx = func.indices_ref[d]._subs(func.dimensions[d], d) - mapper[self.indices_ref[d]] = f_idx - except KeyError: - pass + # Dims where self and func indices differ -> {dim: func_idx} + diff = interp_mapper(self.indices_ref, func.indices_ref, self.dimensions) + + # Translate into a subs mapper {self_idx: func_idx} aligned on self's dims + subs_map = {self.indices_ref[d]: t._subs(func.dimensions[d], d) + for d, t in diff.items()} - return self.subs(mapper) + return self.subs(subs_map) @classmethod def __staggered_setup__(cls, dimensions, staggered=None, **kwargs): @@ -1545,7 +1548,7 @@ def __shape_setup__(cls, **kwargs): @cached_property def _fd_priority(self): - return 2.1 if self.staggered.on_node else 2.2 + return 2.1 if self.staggered.on_node else 2 @property def time_order(self): diff --git a/devito/types/equation.py b/devito/types/equation.py index 5cbba42e89..5b709ee89e 100644 --- a/devito/types/equation.py +++ b/devito/types/equation.py @@ -110,7 +110,7 @@ def _evaluate(self, **kwargs): """ try: lhs = self.lhs._evaluate(**kwargs) - rhs = self.rhs._eval_at(self.lhs)._evaluate(**kwargs) + rhs = self.rhs._eval_at(self.lhs, **kwargs)._evaluate(**kwargs) except AttributeError: lhs, rhs = self._evaluate_args(**kwargs) eq = self.func(lhs, rhs, subdomain=self.subdomain, diff --git a/devito/types/sparse.py b/devito/types/sparse.py index aeb3c6c544..c66cc6da4a 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -695,7 +695,7 @@ def _dist_scatter(self, alias=None, data=None): mapper.update(self._dist_subfunc_scatter(sf)) return mapper - def _eval_at(self, func): + def _eval_at(self, func, **kwargs): return self def _halo_exchange(self): diff --git a/devito/types/tensor.py b/devito/types/tensor.py index f42d437c38..3e5586eb32 100644 --- a/devito/types/tensor.py +++ b/devito/types/tensor.py @@ -170,12 +170,13 @@ def __getattr__(self, name): f'{self.__class__!r} object has no attribute {name!r}' ) from e - def _eval_at(self, func): + def _eval_at(self, func, **kwargs): """ Evaluate tensor at func location """ def entries(i, j, func): - return getattr(self[i, j], '_eval_at', lambda x: self[i, j])(func[i, j]) + return getattr(self[i, j], '_eval_at', + lambda x: self[i, j])(func[i, j], **kwargs) entry = lambda i, j: entries(i, j, func) return self._new(self.rows, self.cols, entry) diff --git a/devito/types/utils.py b/devito/types/utils.py index 93d6b73b62..0593c78575 100644 --- a/devito/types/utils.py +++ b/devito/types/utils.py @@ -57,6 +57,16 @@ class Staggering(DimensionTuple): def on_node(self): return not self or all(s == 0 for s in self) + def __eq__(self, other): + # Two empty-or-all-zero Staggerings are equivalent regardless of arity + # (a Function declared with `staggered=NODE` and one declared without + # both live at the cell centre). + if isinstance(other, Staggering) and self.on_node and other.on_node: + return True + return tuple.__eq__(self, other) + + __hash__ = DimensionTuple.__hash__ + @property def _ref(self): if not self: diff --git a/examples/userapi/08_staggered_interpolation.ipynb b/examples/userapi/08_staggered_interpolation.ipynb new file mode 100644 index 0000000000..ec8795f487 --- /dev/null +++ b/examples/userapi/08_staggered_interpolation.ipynb @@ -0,0 +1,870 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c0", + "metadata": {}, + "source": [ + "# Staggered grids and the `interp-mode` option\n", + "\n", + "When fields in an equation live on different sub-grid offsets - a so-called\n", + "*staggered* grid - Devito automatically interpolates the right-hand side onto\n", + "the location of the left-hand side when it discretises the `Eq`. Most of the\n", + "time this is transparent: you write the equation as usual, build an `Operator`,\n", + "and Devito picks reasonable interpolations.\n", + "\n", + "For some applications (especially elastic wave propagation and other vector\n", + "PDEs) the *way* products of staggered fields are interpolated matters. Devito\n", + "exposes this choice as a *symbolic* option on the `Operator`, passed via\n", + "`sym_opt`:\n", + "\n", + "```python\n", + "Operator(eq, sym_opt={'interp-mode': 'direct'}) # default\n", + "Operator(eq, sym_opt={'interp-mode': 'symmetric'}) # alternative\n", + "```\n", + "\n", + "`sym_opt` is the place where *mathematical* choices made during expression\n", + "lowering live, separate from `opt`, which controls code generation and\n", + "performance. `interp-mode` accepts two values: `'direct'` (the default) and\n", + "`'symmetric'`. Both modes produce the same answer when at least one factor\n", + "already matches the target staggering - the choice only matters when *every*\n", + "factor sits elsewhere.\n", + "\n", + "This tutorial walks through:\n", + "\n", + "* declaring staggered `Function`s,\n", + "* what Devito does with an `Eq` mixing several staggerings,\n", + "* the difference between `'direct'` and `'symmetric'`,\n", + "* a worked example on the elastic stiffness matrix.\n", + "\n", + "You will not need to call any interpolation routine directly - we work\n", + "entirely with `Function`, `Eq`, and `Operator`. For a tour of how `_eval_at`\n", + "actually rewrites derivatives and functions under the hood, see\n", + "`09_fd_evaluation.ipynb`." + ] + }, + { + "cell_type": "markdown", + "id": "c1", + "metadata": {}, + "source": [ + "## Staggered locations\n", + "\n", + "A staggered grid stores different fields at different sub-grid offsets of a\n", + "single cell. In 1D with spacing $h$ the two canonical locations are NODE\n", + "($x_i$) and the half-shifted location $x_{i + 1/2} = x_i + h/2$. A first\n", + "derivative naturally lives between them: the centred two-point stencil\n", + "\n", + "$$\n", + "u'\\!\\left(x_{i + 1/2}\\right) \\;\\approx\\; \\frac{u(x_{i+1}) - u(x_i)}{h}\n", + "$$\n", + "\n", + "is second-order accurate at $x_{i+1/2}$ but only first-order at $x_i$. In\n", + "higher dimensions any subset of axes can be staggered, giving up to $2^n$\n", + "locations in $n$ dimensions.\n", + "\n", + "In Devito, the `staggered` keyword controls the offset. The usual 2D cases:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c2", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:29.927586Z", + "iopub.status.busy": "2026-05-12T14:40:29.927293Z", + "iopub.status.idle": "2026-05-12T14:40:31.299379Z", + "shell.execute_reply": "2026-05-12T14:40:31.298910Z" + } + }, + "outputs": [], + "source": [ + "from devito import Function, Grid, NODE\n", + "\n", + "grid = Grid(shape=(11, 11), extent=(1.0, 1.0))\n", + "x, y = grid.dimensions\n", + "\n", + "# cell centre, x-face, y-face, corner\n", + "fn = Function(name='fn', grid=grid, space_order=4, staggered=NODE)\n", + "fx = Function(name='fx', grid=grid, space_order=4, staggered=x)\n", + "fy = Function(name='fy', grid=grid, space_order=4, staggered=y)\n", + "fxy = Function(name='fxy', grid=grid, space_order=4, staggered=(x, y))" + ] + }, + { + "cell_type": "markdown", + "id": "c3", + "metadata": {}, + "source": [ + "Each `Function` knows its natural sample location, which we can read off\n", + "the `indices_ref` property:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c4", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.301235Z", + "iopub.status.busy": "2026-05-12T14:40:31.300970Z", + "iopub.status.idle": "2026-05-12T14:40:31.303818Z", + "shell.execute_reply": "2026-05-12T14:40:31.303603Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fn indices_ref = (x, y)\n", + "fx indices_ref = (x + h_x/2, y)\n", + "fy indices_ref = (x, y + h_y/2)\n", + "fxy indices_ref = (x + h_x/2, y + h_y/2)\n" + ] + } + ], + "source": [ + "for f in (fn, fx, fy, fxy):\n", + " print(f\"{f.name:<3s} indices_ref = {tuple(f.indices_ref)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c5", + "metadata": {}, + "source": [ + "## Mixing staggerings in an equation\n", + "\n", + "Suppose we want $f_\\mathrm{xy} = f_\\mathrm{n} \\cdot \\partial_x f_\\mathrm{x}$.\n", + "The three pieces live at three different locations: `fn` at the cell centre,\n", + "`fx.dx` on the $x$-face (a derivative inherits its operand's grid), and `fxy`\n", + "at the corner. We just write it down:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c6", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.305053Z", + "iopub.status.busy": "2026-05-12T14:40:31.304943Z", + "iopub.status.idle": "2026-05-12T14:40:31.310855Z", + "shell.execute_reply": "2026-05-12T14:40:31.310542Z" + } + }, + "outputs": [ + { + "data": { + "text/latex": [ + "$\\displaystyle fxy(x + h_x/2, y + h_y/2) = fn(x, y) \\frac{\\partial}{\\partial x} fx(x + h_x/2, y)$" + ], + "text/plain": [ + "Eq(fxy(x + h_x/2, y + h_y/2), fn(x, y)*Derivative(fx(x + h_x/2, y), x))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from devito import Eq\n", + "\n", + "eq = Eq(fxy, fn * fx.dx)\n", + "eq" + ] + }, + { + "cell_type": "markdown", + "id": "c7", + "metadata": {}, + "source": [ + "So far this is purely symbolic. The interpolation happens when we build an\n", + "`Operator`. To see the discretised update statement, we can walk the\n", + "Operator's IR and pick out its `Expression` nodes - the symbolic form is\n", + "much easier to read than the generated C code:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c8", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.312211Z", + "iopub.status.busy": "2026-05-12T14:40:31.312065Z", + "iopub.status.idle": "2026-05-12T14:40:31.405486Z", + "shell.execute_reply": "2026-05-12T14:40:31.405198Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[Eq(r0, 1/h_x),\n", + " Eq(fxy[x + 4, y + 4], (0.0416666667*(r0*(fx[x + 2, y + 4] + fx[x + 2, y + 5]) - r0*(fx[x + 6, y + 4] + fx[x + 6, y + 5])) + 0.333333333*(-r0*(fx[x + 3, y + 4] + fx[x + 3, y + 5]) + r0*(fx[x + 5, y + 4] + fx[x + 5, y + 5])))*(0.25*fn[x + 4, y + 4] + 0.25*fn[x + 5, y + 4] + 0.25*fn[x + 4, y + 5] + 0.25*fn[x + 5, y + 5]))]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from devito import Operator\n", + "from devito.ir.iet import Expression, FindNodes\n", + "\n", + "\n", + "def updates(op):\n", + " # Symbolic update equations carried by ``op``\n", + " return [n.expr for n in FindNodes(Expression).visit(op)]\n", + "\n", + "\n", + "op = Operator(eq)\n", + "updates(op)" + ] + }, + { + "cell_type": "markdown", + "id": "c9", + "metadata": {}, + "source": [ + "Reading the right-hand side: `fn` enters through a four-corner average (the\n", + "2D linear interpolation from cell centres to corners), while `fx.dx` is\n", + "computed by a centred fourth-order stencil already at the corner location,\n", + "with a 2-point $y$-shift baked in. Each factor has been moved to the corner\n", + "*independently* - this is what we mean by \"direct\" interpolation.\n", + "\n", + "(`r0` is a hoisted reciprocal of `h_x`, which is why it appears in the\n", + "stencil.)" + ] + }, + { + "cell_type": "markdown", + "id": "c10", + "metadata": {}, + "source": [ + "## The `interp-mode` option\n", + "\n", + "For a product whose factors all live at a staggering different from the\n", + "target, Devito offers two ways of building the discretisation. Because\n", + "this is a *mathematical* choice (it changes the discrete operator, not the\n", + "generated C code) it sits in the `sym_opt` dict on the `Operator`, not in\n", + "`opt`:\n", + "\n", + "```python\n", + "Operator(eq, sym_opt={'interp-mode': 'direct'}) # default\n", + "Operator(eq, sym_opt={'interp-mode': 'symmetric'}) # alternative\n", + "```\n", + "\n", + "`'direct'` is the default because it produces the smallest stencil. For\n", + "operators that need to remain self-adjoint under discretisation, however,\n", + "`'symmetric'` is the right choice.\n", + "\n", + "To see why, let us recall that each shift between two staggered grids is a\n", + "linear map and therefore has a transpose. For the 1D two-point average,\n", + "\n", + "$$\n", + "\\mathbf{I}\\,u\\,(x_{i+1/2}) = \\tfrac{1}{2}\\bigl[u(x_i) + u(x_{i+1})\\bigr],\n", + "$$\n", + "\n", + "$\\mathbf{I}$ maps cell centres to half-steps, and $\\mathbf{I}^{\\!\\top}$ goes\n", + "the other way. Many PDE operators decompose into block matrices of the form\n", + "$\\mathbf{I}\\,\\mathbf{A}\\,\\mathbf{I}^{\\!\\top}$. Preserving that structure in\n", + "the discretisation makes the discrete operator self-adjoint whenever\n", + "$\\mathbf{A}$ is. The two modes correspond to two ways of building such a\n", + "product:\n", + "\n", + "* `'direct'` interpolates every factor to the target independently. The\n", + " result is a product of averages, easy to read and cheap to evaluate.\n", + "* `'symmetric'` first gathers every factor at a common *block* location\n", + " (NODE if any factor sits there), forms the product there, then projects\n", + " the result to the target with a single $\\mathbf{I}$. This preserves the\n", + " $\\mathbf{I}\\,\\mathbf{A}\\,\\mathbf{I}^{\\!\\top}$ matrix structure.\n", + "\n", + "The two agree whenever at least one factor already matches the target - in\n", + "that case there is no \"block\" to gather to and no choice to make." + ] + }, + { + "cell_type": "markdown", + "id": "c11", + "metadata": {}, + "source": [ + "Here is the same equation built in both modes side by side:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c12", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.406996Z", + "iopub.status.busy": "2026-05-12T14:40:31.406859Z", + "iopub.status.idle": "2026-05-12T14:40:31.579916Z", + "shell.execute_reply": "2026-05-12T14:40:31.579652Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "direct :\n", + " Eq(r0, 1/h_x)\n", + " Eq(fxy[x + 4, y + 4], (0.0416666667*(r0*(fx[x + 2, y + 4] + fx[x + 2, y + 5]) - r0*(fx[x + 6, y + 4] + fx[x + 6, y + 5])) + 0.333333333*(-r0*(fx[x + 3, y + 4] + fx[x + 3, y + 5]) + r0*(fx[x + 5, y + 4] + fx[x + 5, y + 5])))*(0.25*fn[x + 4, y + 4] + 0.25*fn[x + 5, y + 4] + 0.25*fn[x + 4, y + 5] + 0.25*fn[x + 5, y + 5]))\n", + "\n", + "symmetric:\n", + " Eq(r0, 1/h_x)\n", + " Eq(r1[y + 1], r0*(0.5*((0.0416666667*(fx[x + 1, y + 5] - fx[x + 6, y + 5]) + 0.291666667*(-fx[x + 2, y + 5] + fx[x + 5, y + 5]) + 0.333333333*(-fx[x + 3, y + 5] + fx[x + 4, y + 5]))*fn[x + 4, y + 5] + (0.0416666667*(fx[x + 2, y + 5] - fx[x + 7, y + 5]) + 0.291666667*(-fx[x + 3, y + 5] + fx[x + 6, y + 5]) + 0.333333333*(-fx[x + 4, y + 5] + fx[x + 5, y + 5]))*fn[x + 5, y + 5])))\n", + " Eq(fxy[x + 4, y + 4], 0.5*(r1[y] + r1[y + 1]))\n" + ] + } + ], + "source": [ + "op_dir = Operator(eq, sym_opt={'interp-mode': 'direct'})\n", + "op_sym = Operator(eq, sym_opt={'interp-mode': 'symmetric'})\n", + "\n", + "print('direct :')\n", + "for u in updates(op_dir):\n", + " print(' ', u)\n", + "print()\n", + "print('symmetric:')\n", + "for u in updates(op_sym):\n", + " print(' ', u)" + ] + }, + { + "cell_type": "markdown", + "id": "c13", + "metadata": {}, + "source": [ + "## A worked example: the elastic stiffness\n", + "\n", + "The constitutive law of linear elasticity, in Voigt notation, reads\n", + "$\\boldsymbol{\\sigma} = \\mathbf{C}\\,\\boldsymbol{\\varepsilon}$ with\n", + "$\\boldsymbol{\\sigma}$ and $\\boldsymbol{\\varepsilon}$ six-component vectors and\n", + "$\\mathbf{C}$ a symmetric $6 \\times 6$ stiffness. On the standard staggered\n", + "grid for elastodynamics:\n", + "\n", + "| Voigt index | Field | Location | `staggered` |\n", + "|-------------|----------------|-----------|-------------|\n", + "| 1, 2, 3 | normal | cell centre | `NODE` |\n", + "| 4 | shear $yz$ | $yz$-edge | `(y, z)` |\n", + "| 5 | shear $xz$ | $xz$-edge | `(x, z)` |\n", + "| 6 | shear $xy$ | $xy$-edge | `(x, y)` |\n", + "\n", + "and the stiffness coefficients $C_{ij}$ live at the cell centre.\n", + "\n", + "`'symmetric'` reproduces the matrix factorisation\n", + "\n", + "$$\n", + "\\sigma_i \\;=\\; \\mathbf{I}\\Bigl(\\sum_j C_{ij}\\,\\mathbf{I}^{\\!\\top}\\!\\varepsilon_j\\Bigr),\n", + "$$\n", + "\n", + "with $\\mathbf{I}$ averaging from the cell centre up to $\\sigma_i$'s location\n", + "and $\\mathbf{I}^{\\!\\top}$ averaging $\\varepsilon_j$ back to the centre. Let's\n", + "build it in Devito." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c14", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.581219Z", + "iopub.status.busy": "2026-05-12T14:40:31.581128Z", + "iopub.status.idle": "2026-05-12T14:40:31.610846Z", + "shell.execute_reply": "2026-05-12T14:40:31.610567Z" + } + }, + "outputs": [], + "source": [ + "from devito import Function, Grid, NODE, Eq, Operator\n", + "\n", + "grid = Grid(shape=(11, 11, 11), extent=(1.0, 1.0, 1.0))\n", + "x, y, z = grid.dimensions\n", + "\n", + "locs = {1: NODE, 2: NODE, 3: NODE,\n", + " 4: (y, z), 5: (x, z), 6: (x, y)}\n", + "\n", + "\n", + "def F(name, stag):\n", + " return Function(name=name, grid=grid, space_order=4, staggered=stag)\n", + "\n", + "\n", + "sigma = {i: F(f's{i}', locs[i]) for i in range(1, 7)}\n", + "eps = {i: F(f'e{i}', locs[i]) for i in range(1, 7)}\n", + "C = {(i, j): F(f'C{i}{j}', NODE) for i in range(1, 7) for j in range(1, 7)}" + ] + }, + { + "cell_type": "markdown", + "id": "c15", + "metadata": {}, + "source": [ + "A small helper that compiles a single-equation operator in either mode and\n", + "returns its symbolic update statement, so we can compare modes side by\n", + "side:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c16", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.612403Z", + "iopub.status.busy": "2026-05-12T14:40:31.612312Z", + "iopub.status.idle": "2026-05-12T14:40:31.614405Z", + "shell.execute_reply": "2026-05-12T14:40:31.614157Z" + } + }, + "outputs": [], + "source": [ + "def show(eq, mode):\n", + " op = Operator(eq, sym_opt={'interp-mode': mode})\n", + " [u] = [n.expr for n in FindNodes(Expression).visit(op)\n", + " if n.expr.lhs.function is eq.lhs.function]\n", + " return u\n" + ] + }, + { + "cell_type": "markdown", + "id": "c17", + "metadata": {}, + "source": [ + "**Normal-normal block** $i, j \\in \\{1, 2, 3\\}$. Everything is at the cell\n", + "centre, so no interpolation is needed and the two modes agree:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c18", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.615775Z", + "iopub.status.busy": "2026-05-12T14:40:31.615692Z", + "iopub.status.idle": "2026-05-12T14:40:31.680333Z", + "shell.execute_reply": "2026-05-12T14:40:31.679899Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "direct : Eq(s1[x + 4, y + 4, z + 4], C11[x + 4, y + 4, z + 4]*e1[x + 4, y + 4, z + 4])\n", + "symmetric: Eq(s1[x + 4, y + 4, z + 4], C11[x + 4, y + 4, z + 4]*e1[x + 4, y + 4, z + 4])\n" + ] + } + ], + "source": [ + "eq = Eq(sigma[1], C[(1, 1)] * eps[1])\n", + "print('direct :', show(eq, 'direct'))\n", + "print('symmetric:', show(eq, 'symmetric'))" + ] + }, + { + "cell_type": "markdown", + "id": "c19", + "metadata": {}, + "source": [ + "**Normal $\\times$ shear** $i \\in \\{1, 2, 3\\}$, $j \\in \\{4, 5, 6\\}$. The\n", + "target $\\sigma_1$ already sits where $C_{14}$ does (the cell centre), so the\n", + "symmetric mode falls back to the direct one - only $\\varepsilon_4$ needs to\n", + "be averaged down:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c20", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.682438Z", + "iopub.status.busy": "2026-05-12T14:40:31.682217Z", + "iopub.status.idle": "2026-05-12T14:40:31.783991Z", + "shell.execute_reply": "2026-05-12T14:40:31.783719Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "direct : Eq(s1[x + 4, y + 4, z + 4], (0.25*e4[x + 4, y + 3, z + 3] + 0.25*e4[x + 4, y + 4, z + 3] + 0.25*e4[x + 4, y + 3, z + 4] + 0.25*e4[x + 4, y + 4, z + 4])*C14[x + 4, y + 4, z + 4])\n", + "symmetric: Eq(s1[x + 4, y + 4, z + 4], (0.25*e4[x + 4, y + 3, z + 3] + 0.25*e4[x + 4, y + 4, z + 3] + 0.25*e4[x + 4, y + 3, z + 4] + 0.25*e4[x + 4, y + 4, z + 4])*C14[x + 4, y + 4, z + 4])\n" + ] + } + ], + "source": [ + "eq = Eq(sigma[1], C[(1, 4)] * eps[4])\n", + "print('direct :', show(eq, 'direct'))\n", + "print('symmetric:', show(eq, 'symmetric'))" + ] + }, + { + "cell_type": "markdown", + "id": "c21", + "metadata": {}, + "source": [ + "**Shear $\\times$ normal** $i \\in \\{4, 5, 6\\}$, $j \\in \\{1, 2, 3\\}$. Now\n", + "$\\sigma_4$ sits at the $yz$-edge while $C_{41}$ and $\\varepsilon_1$ are both\n", + "at the cell centre. The two modes differ:\n", + "\n", + "* `'direct'` averages each factor to the edge independently - giving a\n", + " *product of averages*.\n", + "* `'symmetric'` forms the product at the cell centre first and averages the\n", + " result once - giving an *average of products*." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c22", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.785521Z", + "iopub.status.busy": "2026-05-12T14:40:31.785424Z", + "iopub.status.idle": "2026-05-12T14:40:31.904302Z", + "shell.execute_reply": "2026-05-12T14:40:31.904083Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "direct : Eq(s4[x + 4, y + 4, z + 4], (0.25*C41[x + 4, y + 4, z + 4] + 0.25*C41[x + 4, y + 5, z + 4] + 0.25*C41[x + 4, y + 4, z + 5] + 0.25*C41[x + 4, y + 5, z + 5])*(0.25*e1[x + 4, y + 4, z + 4] + 0.25*e1[x + 4, y + 5, z + 4] + 0.25*e1[x + 4, y + 4, z + 5] + 0.25*e1[x + 4, y + 5, z + 5]))\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "symmetric: Eq(s4[x + 4, y + 4, z + 4], 0.25*C41[x + 4, y + 4, z + 4]*e1[x + 4, y + 4, z + 4] + 0.25*C41[x + 4, y + 5, z + 4]*e1[x + 4, y + 5, z + 4] + 0.25*C41[x + 4, y + 4, z + 5]*e1[x + 4, y + 4, z + 5] + 0.25*C41[x + 4, y + 5, z + 5]*e1[x + 4, y + 5, z + 5])\n" + ] + } + ], + "source": [ + "eq = Eq(sigma[4], C[(4, 1)] * eps[1])\n", + "print('direct :', show(eq, 'direct'))\n", + "print()\n", + "print('symmetric:', show(eq, 'symmetric'))" + ] + }, + { + "cell_type": "markdown", + "id": "c23", + "metadata": {}, + "source": [ + "**Shear $\\times$ shear off-diagonal** $i \\neq j$, both in $\\{4, 5, 6\\}$.\n", + "Every factor sits at a different location. This is where `'symmetric'`\n", + "shines: the discrete operator keeps the\n", + "$\\mathbf{I}\\,(C_{ij}\\,\\mathbf{I}^{\\!\\top}\\varepsilon_j)$ structure that\n", + "preserves the self-adjointness of $\\mathbf{C} = \\mathbf{C}^{\\!\\top}$." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c24", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:31.905740Z", + "iopub.status.busy": "2026-05-12T14:40:31.905626Z", + "iopub.status.idle": "2026-05-12T14:40:32.058302Z", + "shell.execute_reply": "2026-05-12T14:40:32.058048Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "direct :" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Eq(s4[x + 4, y + 4, z + 4], (0.25*C45[x + 4, y + 4, z + 4] + 0.25*C45[x + 4, y + 5, z + 4] + 0.25*C45[x + 4, y + 4, z + 5] + 0.25*C45[x + 4, y + 5, z + 5])*(0.25*e5[x + 3, y + 4, z + 4] + 0.25*e5[x + 4, y + 4, z + 4] + 0.25*e5[x + 3, y + 5, z + 4] + 0.25*e5[x + 4, y + 5, z + 4]))\n", + "\n", + "symmetric: Eq(s4[x + 4, y + 4, z + 4], 0.5*(r0[z] + r0[z + 1]))\n" + ] + } + ], + "source": [ + "eq = Eq(sigma[4], C[(4, 5)] * eps[5])\n", + "print('direct :', show(eq, 'direct'))\n", + "print()\n", + "print('symmetric:', show(eq, 'symmetric'))" + ] + }, + { + "cell_type": "markdown", + "id": "c25", + "metadata": {}, + "source": [ + "**Shear $\\times$ shear diagonal** $i = j$. The target and $\\varepsilon_i$\n", + "already share their staggering; only $C_{ii}$ needs to be brought to the\n", + "edge, and that happens implicitly through the standard substitution. The two\n", + "modes agree:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c26", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:32.059688Z", + "iopub.status.busy": "2026-05-12T14:40:32.059602Z", + "iopub.status.idle": "2026-05-12T14:40:32.164584Z", + "shell.execute_reply": "2026-05-12T14:40:32.164357Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "direct : Eq(s4[x + 4, y + 4, z + 4], (0.25*C44[x + 4, y + 4, z + 4] + 0.25*C44[x + 4, y + 5, z + 4] + 0.25*C44[x + 4, y + 4, z + 5] + 0.25*C44[x + 4, y + 5, z + 5])*e4[x + 4, y + 4, z + 4])\n", + "symmetric: Eq(s4[x + 4, y + 4, z + 4], (0.25*C44[x + 4, y + 4, z + 4] + 0.25*C44[x + 4, y + 5, z + 4] + 0.25*C44[x + 4, y + 4, z + 5] + 0.25*C44[x + 4, y + 5, z + 5])*e4[x + 4, y + 4, z + 4])\n" + ] + } + ], + "source": [ + "eq = Eq(sigma[4], C[(4, 4)] * eps[4])\n", + "print('direct :', show(eq, 'direct'))\n", + "print('symmetric:', show(eq, 'symmetric'))" + ] + }, + { + "cell_type": "markdown", + "id": "c27", + "metadata": {}, + "source": [ + "## Self-adjointness of the discrete stiffness\n", + "\n", + "The point of `'symmetric'` is that the discrete operator inherits the\n", + "self-adjointness of its continuous counterpart. We can verify this with the\n", + "standard adjoint dot-product test: if the discrete $\\mathbf{C}$ is its own\n", + "transpose, then for any two strain-shaped fields $e_1$ and $t_2$,\n", + "\n", + "$$\n", + "\\langle e_1,\\,\\mathbf{C}\\,t_2 \\rangle \\;=\\; \\langle \\mathbf{C}\\,e_1,\\,t_2 \\rangle.\n", + "$$\n", + "\n", + "We assemble random symmetric $C_{ij}$, random $e_1$ and $t_2$, compute both\n", + "sides, and compare:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c28", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:32.165895Z", + "iopub.status.busy": "2026-05-12T14:40:32.165823Z", + "iopub.status.idle": "2026-05-12T14:40:32.207425Z", + "shell.execute_reply": "2026-05-12T14:40:32.207187Z" + } + }, + "outputs": [], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import numpy as np\n", + "\n", + "np.random.seed(1234)\n", + "\n", + "# Random symmetric stiffness\n", + "for i in range(1, 7):\n", + " for j in range(i, 7):\n", + " C[(i, j)].data[:] = np.random.rand(*C[(i, j)].shape)\n", + " if j != i:\n", + " C[(j, i)].data[:] = C[(i, j)].data\n", + "\n", + "# Two right-hand-side strain fields\n", + "e1 = {i: F(f'e1_{i}', locs[i]) for i in range(1, 7)}\n", + "t2 = {i: F(f't2_{i}', locs[i]) for i in range(1, 7)}\n", + "# ... and the outputs\n", + "t1 = {i: F(f't1_{i}', locs[i]) for i in range(1, 7)}\n", + "e2 = {i: F(f'e2_{i}', locs[i]) for i in range(1, 7)}\n", + "\n", + "for i in range(1, 7):\n", + " e1[i].data[:] = 2 * np.random.rand(*e1[i].shape) - 1\n", + " t2[i].data[:] = 2 * np.random.rand(*t2[i].shape) - 1\n", + "\n", + "eqns = []\n", + "for i in range(1, 7):\n", + " eqns.append(Eq(t1[i], sum(C[(i, j)] * e1[j] for j in range(1, 7))))\n", + " eqns.append(Eq(e2[i], sum(C[(i, j)] * t2[j] for j in range(1, 7))))\n", + "\n", + "\n", + "def run(mode):\n", + " Operator(eqns, sym_opt={'interp-mode': mode}).apply()\n", + " lhs = sum(float(np.dot(e1[i].data.flatten(), e2[i].data.flatten()))\n", + " for i in range(1, 7))\n", + " rhs = sum(float(np.dot(t1[i].data.flatten(), t2[i].data.flatten()))\n", + " for i in range(1, 7))\n", + " return lhs, rhs" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c29", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:32.208811Z", + "iopub.status.busy": "2026-05-12T14:40:32.208715Z", + "iopub.status.idle": "2026-05-12T14:40:35.866848Z", + "shell.execute_reply": "2026-05-12T14:40:35.866548Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "NUMA domain count autodetection failed, assuming 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Operator `Kernel` ran in 0.01 s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Operator `Kernel` ran in 0.01 s\n" + ] + } + ], + "source": [ + "sym_lhs, sym_rhs = run('symmetric')\n", + "sym_rel = abs(sym_lhs - sym_rhs) / max(abs(sym_lhs), abs(sym_rhs))\n", + "\n", + "dir_lhs, dir_rhs = run('direct')\n", + "dir_rel = abs(dir_lhs - dir_rhs) / max(abs(dir_lhs), abs(dir_rhs))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c30", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:35.868269Z", + "iopub.status.busy": "2026-05-12T14:40:35.868150Z", + "iopub.status.idle": "2026-05-12T14:40:35.870223Z", + "shell.execute_reply": "2026-05-12T14:40:35.870036Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "symmetric: = 29.060263, = 29.060250\n", + " relative difference = 4.35e-07\n", + "direct : = 27.701068, = 16.359724\n", + " relative difference = 4.09e-01\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "print(f\"symmetric: = {sym_lhs:.6f}, = {sym_rhs:.6f}\")\n", + "print(f\" relative difference = {sym_rel:.2e}\")\n", + "print(f\"direct : = {dir_lhs:.6f}, = {dir_rhs:.6f}\")\n", + "print(f\" relative difference = {dir_rel:.2e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c31", + "metadata": {}, + "source": [ + "The `'symmetric'` discretisation passes the dot-product test to float\n", + "precision; the `'direct'` one does not, because each factor was interpolated\n", + "independently and the resulting discrete operator is no longer the transpose\n", + "of itself. A quick assertion makes the contrast machine-checkable (and lets\n", + "CI catch regressions even though the printed floats above can drift in the\n", + "last digit between platforms):" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "c32", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T14:40:35.871434Z", + "iopub.status.busy": "2026-05-12T14:40:35.871343Z", + "iopub.status.idle": "2026-05-12T14:40:35.872864Z", + "shell.execute_reply": "2026-05-12T14:40:35.872647Z" + } + }, + "outputs": [], + "source": [ + "assert sym_rel < 1e-5, sym_rel # adjoint identity holds for 'symmetric'\n", + "assert dir_rel > 1e-2, dir_rel # 'direct' is not self-adjoint" + ] + }, + { + "cell_type": "markdown", + "id": "c33", + "metadata": {}, + "source": [ + "## When to use which\n", + "\n", + "| Situation | Mode |\n", + "|------------------------------------------------------------------------|---------------|\n", + "| Acoustic / scalar-wave equations | `'direct'` |\n", + "| Elastic stress-strain or any $\\mathbf{I}\\,\\mathbf{A}\\,\\mathbf{I}^{\\!\\top}$ operator | `'symmetric'` |\n", + "| Adjoint-state inversion needing exact discrete adjoint | `'symmetric'` |\n", + "| Any equation where one factor already matches the target staggering | either |\n", + "\n", + "`'direct'` is the default because it is the cheaper and smaller stencil; pick\n", + "`'symmetric'` deliberately when you need the adjoint structure preserved." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "devito", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/userapi/09_fd_evaluation.ipynb b/examples/userapi/09_fd_evaluation.ipynb new file mode 100644 index 0000000000..7d102542a2 --- /dev/null +++ b/examples/userapi/09_fd_evaluation.ipynb @@ -0,0 +1,642 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c0", + "metadata": {}, + "source": [ + "# Finite-difference evaluation: from `Eq` to stencil\n", + "\n", + "When you write an equation like `Eq(target, expr)` in Devito and pass it to\n", + "`Operator`, a fair bit happens before any C is generated. The right-hand\n", + "side is *evaluated at the location of the left-hand side*: derivatives get\n", + "their `x0` set to the LHS's sample point, and any factor that sits at a\n", + "different staggering is interpolated.\n", + "\n", + "This tutorial unpacks that step. The goal is to make the rules predictable:\n", + "\n", + "* What does Devito do automatically?\n", + "* Where does it stop, and how do you take over manually?\n", + "* What happens for products, sums, and derivatives mixed together?\n", + "\n", + "We work with `Function`, `Eq`, and `Operator` only - no internal calls. For\n", + "the deeper `'direct'` vs `'symmetric'` interpolation modes used on products\n", + "of staggered fields, see `08_staggered_interp.ipynb`." + ] + }, + { + "cell_type": "markdown", + "id": "c1", + "metadata": {}, + "source": [ + "## The lowering pipeline\n", + "\n", + "The compile-time hook is `Eq._evaluate`. In essence:\n", + "\n", + "```python\n", + "def _evaluate(self, **kw):\n", + " lhs = self.lhs._evaluate(**kw)\n", + " rhs = self.rhs._eval_at(self.lhs, **kw)._evaluate(**kw)\n", + " return Eq(lhs, rhs)\n", + "```\n", + "\n", + "Two passes happen on the RHS:\n", + "\n", + "1. **`_eval_at(lhs)`** moves the symbolic expression to the LHS's sample\n", + " location. For a `Derivative`, this sets the `x0` keyword. For a\n", + " `Function`, this returns either the same function (if it already lives\n", + " at the right place) or a 0-order FD interpolation stencil.\n", + "2. **`_evaluate()`** then expands every derivative into its concrete FD\n", + " stencil at whatever `x0` it now carries.\n", + "\n", + "So the rule of thumb is: when you write `Eq(target, expr)`, every\n", + "sub-expression of `expr` is asked \"please give me your value at\n", + "`target.indices_ref`\". Whether that triggers averaging depends on what the\n", + "sub-expression is." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c2", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:10.996548Z", + "iopub.status.busy": "2026-05-12T16:25:10.996047Z", + "iopub.status.idle": "2026-05-12T16:25:12.128349Z", + "shell.execute_reply": "2026-05-12T16:25:12.127991Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fn indices_ref = (x, y)\n", + "fx indices_ref = (x + h_x/2, y)\n", + "fy indices_ref = (x, y + h_y/2)\n", + "fxy indices_ref = (x + h_x/2, y + h_y/2)\n" + ] + } + ], + "source": [ + "from devito import Function, Grid, NODE, Eq, Operator\n", + "from devito.ir.iet import Expression, FindNodes\n", + "\n", + "\n", + "def updates(op):\n", + " \"\"\"Symbolic update equations carried by `op`.\"\"\"\n", + " return [n.expr for n in FindNodes(Expression).visit(op)]\n", + "\n", + "\n", + "grid = Grid(shape=(11, 11), extent=(1.0, 1.0))\n", + "x, y = grid.dimensions\n", + "\n", + "fn = Function(name='fn', grid=grid, space_order=4, staggered=NODE)\n", + "fx = Function(name='fx', grid=grid, space_order=4, staggered=x)\n", + "fy = Function(name='fy', grid=grid, space_order=4, staggered=y)\n", + "fxy = Function(name='fxy', grid=grid, space_order=4, staggered=(x, y))\n", + "\n", + "for f in (fn, fx, fy, fxy):\n", + " print(f.name, 'indices_ref =', tuple(f.indices_ref))" + ] + }, + { + "cell_type": "markdown", + "id": "c3", + "metadata": {}, + "source": [ + "## Where a derivative naturally lives\n", + "\n", + "A finite-difference derivative inherits the grid of its operand. Its\n", + "*natural sample point* depends on the stencil it carries. With Devito's\n", + "default centred stencil, the derivative of a `Function` `f` is built\n", + "around `f`'s indices: differentiating an `f` at `NODE` produces a\n", + "derivative at `NODE`; differentiating an `f` at `x + h_x/2` produces a\n", + "derivative at `x + h_x/2`. The stencil offsets are what shifts, not the\n", + "sample point.\n", + "\n", + "Before any equation context, `f.dx` carries no `x0` and Devito will use\n", + "its natural location:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c4", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.129928Z", + "iopub.status.busy": "2026-05-12T16:25:12.129699Z", + "iopub.status.idle": "2026-05-12T16:25:12.141394Z", + "shell.execute_reply": "2026-05-12T16:25:12.141126Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fn.dx symbolic : Derivative(fn(x, y), x)\n", + "\n", + "fn.dx.evaluate : 0.0833333333*fn(x - 2*h_x, y)/h_x - 0.666666667*fn(x - h_x, y)/h_x + 0.666666667*fn(x + h_x, y)/h_x - 0.0833333333*fn(x + 2*h_x, y)/h_x\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "print('fn.dx symbolic :', fn.dx)\n", + "print()\n", + "print('fn.dx.evaluate :', fn.dx.evaluate)" + ] + }, + { + "cell_type": "markdown", + "id": "c5", + "metadata": {}, + "source": [ + "## `_eval_at` on a Derivative: shifting `x0`\n", + "\n", + "When Devito sees `Eq(target, f.dx)`, it calls `f.dx._eval_at(target)`. This\n", + "sets the derivative's `x0` to `target.indices_ref` so that the FD stencil\n", + "is built around the *target's* sample point.\n", + "\n", + "Crucially, **the operand isn't interpolated** - only the stencil offsets\n", + "shift. That keeps the order of accuracy and the stencil shape clean. So\n", + "`fn.dx` at `x + h_x/2` is *not* an average of `fn.dx` values; it is a\n", + "stencil over `fn` samples reshuffled to land on `x + h_x/2`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c6", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.142800Z", + "iopub.status.busy": "2026-05-12T16:25:12.142708Z", + "iopub.status.idle": "2026-05-12T16:25:12.151364Z", + "shell.execute_reply": "2026-05-12T16:25:12.151141Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- fn.dx evaluated naturally (no Eq context)\n", + "0.0833333333*fn(x - 2*h_x, y)/h_x - 0.666666667*fn(x - h_x, y)/h_x + 0.666666667*fn(x + h_x, y)/h_x - 0.0833333333*fn(x + 2*h_x, y)/h_x\n", + "\n", + "--- fn.dx as it lowers inside Eq(fx, fn.dx) (target at x + h_x/2)\n", + "-1.125*fn(x, y)/h_x + 0.0416666667*fn(x - h_x, y)/h_x + 1.125*fn(x + h_x, y)/h_x - 0.0416666667*fn(x + 2*h_x, y)/h_x\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# Same operand, two different targets, two different x0s\n", + "print('--- fn.dx evaluated naturally (no Eq context)')\n", + "print(fn.dx.evaluate)\n", + "print()\n", + "print('--- fn.dx as it lowers inside Eq(fx, fn.dx) (target at x + h_x/2)')\n", + "print(fn.dx._eval_at(fx).evaluate)" + ] + }, + { + "cell_type": "markdown", + "id": "c7", + "metadata": {}, + "source": [ + "The same shift happens implicitly inside an `Operator`. Comparing two\n", + "equations whose only difference is the LHS:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c8", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.152601Z", + "iopub.status.busy": "2026-05-12T16:25:12.152524Z", + "iopub.status.idle": "2026-05-12T16:25:12.250763Z", + "shell.execute_reply": "2026-05-12T16:25:12.250485Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Eq(fn, fn.dx)\n", + " Eq(r0, 1/h_x)\n", + " Eq(fn[x + 4, y + 4], r0*(0.0833333333*(fn[x + 2, y + 4] - fn[x + 6, y + 4]) + 0.666666667*(-fn[x + 3, y + 4] + fn[x + 5, y + 4])))\n", + "\n", + "--- Eq(fx, fn.dx) (note the different fn-indices)\n", + " Eq(r0, 1/h_x)\n", + " Eq(fx[x + 4, y + 4], r0*(0.0416666667*(fn[x + 3, y + 4] - fn[x + 6, y + 4]) + 1.125*(-fn[x + 4, y + 4] + fn[x + 5, y + 4])))\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "op1 = Operator(Eq(fn, fn.dx)) # target at NODE\n", + "op2 = Operator(Eq(fx, fn.dx)) # target at x + h_x/2\n", + "\n", + "print('--- Eq(fn, fn.dx)')\n", + "for u in updates(op1):\n", + " print(' ', u)\n", + "print()\n", + "print('--- Eq(fx, fn.dx) (note the different fn-indices)')\n", + "for u in updates(op2):\n", + " print(' ', u)" + ] + }, + { + "cell_type": "markdown", + "id": "c9", + "metadata": {}, + "source": [ + "## `_eval_at` on a Function: interpolation\n", + "\n", + "A bare `Function` on the RHS is a different story. If its staggering does\n", + "not match the target's, Devito emits a 0-order FD interpolation operator\n", + "that averages it onto the target's sample points. In 1D this is the\n", + "two-point average; in higher dimensions it tensors out." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c10", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.252256Z", + "iopub.status.busy": "2026-05-12T16:25:12.252158Z", + "iopub.status.idle": "2026-05-12T16:25:12.289144Z", + "shell.execute_reply": "2026-05-12T16:25:12.288860Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Eq(fn[x + 4, y + 4], 0.5*(fx[x + 3, y + 4] + fx[x + 4, y + 4]))\n" + ] + } + ], + "source": [ + "# fx lives at x + h_x/2, fn at NODE - the assignment averages\n", + "op = Operator(Eq(fn, fx))\n", + "for u in updates(op):\n", + " print(' ', u)" + ] + }, + { + "cell_type": "markdown", + "id": "c11", + "metadata": {}, + "source": [ + "## When auto-interpolation doesn't happen\n", + "\n", + "Three things short-circuit the automatic interpolation. Knowing them is\n", + "the difference between \"magical\" stencils and predictable ones.\n", + "\n", + "### 1. The two sides already share a staggering\n", + "\n", + "If `lhs.indices_ref == rhs.indices_ref` there is nothing to do, so the\n", + "RHS comes through unchanged. This includes the common `staggered=None`\n", + "case: passing `None` is **not** a \"don't interpolate\" flag - it simply\n", + "means \"no staggering specified\", which Devito treats as `NODE`. So\n", + "`staggered=None` and `staggered=NODE` produce *identical* lowerings, and\n", + "both still trigger interpolation when the RHS lives elsewhere." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c12", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.290438Z", + "iopub.status.busy": "2026-05-12T16:25:12.290333Z", + "iopub.status.idle": "2026-05-12T16:25:12.359140Z", + "shell.execute_reply": "2026-05-12T16:25:12.358849Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eq(fx, fx2) : Eq(fx[x + 4, y + 4], fx2[x + 4, y + 4])\n", + "Eq(fnone, fx) : Eq(fnone[x + 4, y + 4], 0.5*(fx[x + 3, y + 4] + fx[x + 4, y + 4]))\n", + "-> identical to Eq(fn, fx) earlier\n" + ] + } + ], + "source": [ + "# Same staggering on both sides - no interpolation\n", + "fx2 = Function(name='fx2', grid=grid, space_order=4, staggered=x)\n", + "op = Operator(Eq(fx, fx2))\n", + "print('Eq(fx, fx2) :', updates(op)[0])\n", + "\n", + "# staggered=None really is NODE - interpolation still kicks in\n", + "fnone = Function(name='fnone', grid=grid, space_order=4, staggered=None)\n", + "op = Operator(Eq(fnone, fx))\n", + "print('Eq(fnone, fx) :', updates(op)[0])\n", + "print('-> identical to Eq(fn, fx) earlier')" + ] + }, + { + "cell_type": "markdown", + "id": "c13", + "metadata": {}, + "source": [ + "### 2. `interp_order=0` on the Function\n", + "\n", + "Passing `interp_order=0` when you build a `Function` opts it out of\n", + "averaging. The RHS is then sampled at the LHS's grid points without any\n", + "average. This is useful for piecewise-constant material parameters (a\n", + "velocity map, a mask) that should not be smoothed." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c14", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.360553Z", + "iopub.status.busy": "2026-05-12T16:25:12.360464Z", + "iopub.status.idle": "2026-05-12T16:25:12.393101Z", + "shell.execute_reply": "2026-05-12T16:25:12.392827Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eq(fn, fx0) : Eq(fn[x + 4, y + 4], fx0[x + 4, y + 4])\n", + "-> fx0 sampled directly, no average\n" + ] + } + ], + "source": [ + "fx0 = Function(name='fx0', grid=grid, space_order=4,\n", + " staggered=x, interp_order=0)\n", + "op = Operator(Eq(fn, fx0))\n", + "print('Eq(fn, fx0) :', updates(op)[0])\n", + "print('-> fx0 sampled directly, no average')" + ] + }, + { + "cell_type": "markdown", + "id": "c15", + "metadata": {}, + "source": [ + "### 3. An explicit `x0` on a Derivative\n", + "\n", + "If you pass `x0` to `f.dx(x0=...)` yourself, Devito will **not** overwrite\n", + "it. Use this when you want to control where the derivative is sampled\n", + "independently of the LHS - for example to keep a centred stencil even\n", + "when the target is staggered." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c16", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.394841Z", + "iopub.status.busy": "2026-05-12T16:25:12.394544Z", + "iopub.status.idle": "2026-05-12T16:25:12.478001Z", + "shell.execute_reply": "2026-05-12T16:25:12.477375Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "default : Eq(fx[x + 4, y + 4], r0*(0.0416666667*(fn[x + 3, y + 4] - fn[x + 6, y + 4]) + 1.125*(-fn[x + 4, y + 4] + fn[x + 5, y + 4])))\n", + "explicit : Eq(fx[x + 4, y + 4], r0*(0.0833333333*(fn[x + 2, y + 4] - fn[x + 6, y + 4]) + 0.666666667*(-fn[x + 3, y + 4] + fn[x + 5, y + 4])))\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# Default vs. explicit x0\n", + "default = Operator(Eq(fx, fn.dx)) # x0 = x + h_x/2\n", + "explicit = Operator(Eq(fx, fn.dx(x0={x: x}))) # x0 = x (centred)\n", + "print('default :', updates(default)[-1])\n", + "print('explicit :', updates(explicit)[-1])" + ] + }, + { + "cell_type": "markdown", + "id": "c17", + "metadata": {}, + "source": [ + "## Mixed equations\n", + "\n", + "Real equations combine all of the above. The lowering walks the RHS\n", + "recursively: each factor in a product is evaluated at the LHS's location;\n", + "each summand in an `Add` is evaluated independently. Derivatives keep\n", + "their `x0`-shift behaviour, Functions their averaging behaviour." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c18", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.479519Z", + "iopub.status.busy": "2026-05-12T16:25:12.479420Z", + "iopub.status.idle": "2026-05-12T16:25:12.561232Z", + "shell.execute_reply": "2026-05-12T16:25:12.560982Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Eq(r0, 1/h_x)\n", + " Eq(fxy[x + 4, y + 4], (0.0416666667*(r0*(fx[x + 2, y + 4] + fx[x + 2, y + 5]) - r0*(fx[x + 6, y + 4] + fx[x + 6, y + 5])) + 0.333333333*(-r0*(fx[x + 3, y + 4] + fx[x + 3, y + 5]) + r0*(fx[x + 5, y + 4] + fx[x + 5, y + 5])))*(0.25*fn[x + 4, y + 4] + 0.25*fn[x + 5, y + 4] + 0.25*fn[x + 4, y + 5] + 0.25*fn[x + 5, y + 5]))\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# A product of a Function and a derivative, target at the corner\n", + "eq = Eq(fxy, fn * fx.dx)\n", + "op = Operator(eq)\n", + "for u in updates(op):\n", + " print(' ', u)" + ] + }, + { + "cell_type": "markdown", + "id": "c19", + "metadata": {}, + "source": [ + "Reading the right-hand side: `fn` is averaged to the corner (four-point\n", + "average); `fx.dx` has its `x0` set to the corner so the FD stencil is\n", + "already there. The two interpolations happen *independently* - that is\n", + "the `'direct'` mode discussed in `08_staggered_interp.ipynb`. The\n", + "`'symmetric'` mode rewires this product into an `I (a I^T b)` form when\n", + "self-adjointness matters.\n", + "\n", + "A sum behaves the same way, summand by summand:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c20", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.562590Z", + "iopub.status.busy": "2026-05-12T16:25:12.562501Z", + "iopub.status.idle": "2026-05-12T16:25:12.631478Z", + "shell.execute_reply": "2026-05-12T16:25:12.631188Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Eq(r0, 1/h_y)\n", + " Eq(fx[x + 4, y + 4], 0.0208333333*(r0*(fy[x + 4, y + 2] + fy[x + 5, y + 2]) - r0*(fy[x + 4, y + 5] + fy[x + 5, y + 5])) + 0.5625*(-r0*(fy[x + 4, y + 3] + fy[x + 5, y + 3]) + r0*(fy[x + 4, y + 4] + fy[x + 5, y + 4])) + 0.5*(fn[x + 4, y + 4] + fn[x + 5, y + 4]))\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# fn lives at NODE; fy.dy is naturally at NODE too (a y-staggered\n", + "# field's y-derivative lands at NODE). Target fx is at x + h_x/2.\n", + "eq = Eq(fx, fn + fy.dy)\n", + "op = Operator(eq)\n", + "for u in updates(op):\n", + " print(' ', u)" + ] + }, + { + "cell_type": "markdown", + "id": "c21", + "metadata": {}, + "source": [ + "Both summands had to be brought to `x + h_x/2`: `fn` by averaging in\n", + "`x`, and the `fy.dy` stencil by an `x0` shift on its derivative dim plus\n", + "an interpolation in `y` (since `fy` lives at `y + h_y/2`)." + ] + }, + { + "cell_type": "markdown", + "id": "c22", + "metadata": {}, + "source": [ + "## Cheat sheet\n", + "\n", + "| Step | What Devito does |\n", + "|---------------------------------------------|-----------------------------------------------------------------------------------|\n", + "| `Eq._evaluate` | Calls `rhs._eval_at(lhs)` then `._evaluate()`. |\n", + "| `Derivative._eval_at(target)` | Sets `x0 = target.indices_ref` on the derivative. |\n", + "| `Function._eval_at(target)` | Emits a 0-order FD average to `target.indices_ref` (when staggerings differ). |\n", + "| LHS and RHS share staggering | No-op. |\n", + "| `staggered=None` | Treated as `NODE`. **Does not** opt out of interpolation. |\n", + "| `interp_order=0` | Function opts out of averaging - sampled directly at the target. |\n", + "| Explicit `x0` on a derivative | Devito leaves it alone. |\n", + "\n", + "For the `sym_opt={'interp-mode': ...}` choice that governs how products of\n", + "staggered fields are projected onto a target, see\n", + "`08_staggered_interp.ipynb`." + ] + }, + { + "cell_type": "markdown", + "id": "c23", + "metadata": {}, + "source": [ + "## CI regression guards\n", + "\n", + "The cells above print symbolic stencils for human reading; their textual\n", + "form can drift between SymPy versions and is not load-bearing. The\n", + "following cell pins the *behavioural* invariants of this tutorial so that\n", + "a regression in the lowering will fail CI even when output comparison is\n", + "disabled." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c24", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-12T16:25:12.632948Z", + "iopub.status.busy": "2026-05-12T16:25:12.632859Z", + "iopub.status.idle": "2026-05-12T16:25:12.850865Z", + "shell.execute_reply": "2026-05-12T16:25:12.850567Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "all invariants hold\n" + ] + } + ], + "source": [ + "# 1. `staggered=None` lowers identically to `staggered=NODE`.\n", + "rhs_none = str(updates(Operator(Eq(fnone, fx)))[0].rhs)\n", + "rhs_node = str(updates(Operator(Eq(fn, fx)))[0].rhs)\n", + "assert rhs_none == rhs_node, (rhs_none, rhs_node)\n", + "\n", + "# 2. `interp_order=0` opts out of averaging - the RHS is sampled directly.\n", + "rhs_io0 = str(updates(Operator(Eq(fn, fx0)))[0].rhs)\n", + "assert rhs_io0 == 'fx0[x + 4, y + 4]', rhs_io0\n", + "\n", + "# 3. Same staggering on both sides is a pure pass-through.\n", + "rhs_same = str(updates(Operator(Eq(fx, fx2)))[0].rhs)\n", + "assert rhs_same == 'fx2[x + 4, y + 4]', rhs_same\n", + "\n", + "# 4. An explicit `x0` on a derivative is not overwritten by `_eval_at` -\n", + "# the two stencils below must differ.\n", + "default_rhs = str(updates(Operator(Eq(fx, fn.dx)))[-1].rhs)\n", + "explicit_rhs = str(updates(Operator(Eq(fx, fn.dx(x0={x: x}))))[-1].rhs)\n", + "assert default_rhs != explicit_rhs\n", + "\n", + "print('all invariants hold')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index ab8d724b60..9708f7c59e 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -4,7 +4,7 @@ import pytest import sympy -from devito import NODE, Differentiable, Function, Grid +from devito import NODE, Differentiable, Eq, Function, Grid, Operator from devito.finite_differences.differentiable import ( Add, Mul, Pow, SafeInv, diffify, interp_for_fd ) @@ -165,3 +165,372 @@ def test_no_interp(): a_shift = a._subs(x, x + x.spacing / 2) # Should just do nearest grid point, so shift back to original assert a_shift.evaluate == a + + +class TestMulEvalAt: + """ + Verify `Mul._eval_at` in both modes: + + - `interp_mode="direct"`: default per-arg evaluation. + - `interp_mode="symmetric"`: symmetric `I * a * I^T * b.dx` interpolation. + """ + + @staticmethod + def _all_funcs(grid): + x, y = grid.dimensions + return { + 'node': Function(name='fn', grid=grid, space_order=4, staggered=NODE), + 'x': Function(name='fx', grid=grid, space_order=4, staggered=x), + 'y': Function(name='fy', grid=grid, space_order=4, staggered=y), + 'xy': Function(name='fxy', grid=grid, space_order=4, staggered=(x, y)), + } + + @pytest.mark.parametrize('interp_mode', ['direct', 'symmetric']) + @pytest.mark.parametrize('targets', [ + ('node', 'x', 'xy'), + ('node', 'y', 'xy'), + ('x', 'y', 'xy'), + ('node', 'x', 'y'), + ('node', 'xy', 'x'), + ]) + def test_mul_two_funcs(self, interp_mode, targets): + """`a * b` evaluated at `L` references both factors.""" + grid = Grid((11, 11)) + funcs = self._all_funcs(grid) + a_key, b_key, l_key = targets + a, b, L = funcs[a_key], funcs[b_key], funcs[l_key] + + result = (a * b)._eval_at(L, interp_mode=interp_mode) + evaluated_str = str(result.evaluate) + assert a.name in evaluated_str + assert b.name in evaluated_str + + @pytest.mark.parametrize('interp_mode', ['direct', 'symmetric']) + @pytest.mark.parametrize('targets', [ + ('node', 'x', 'y', 'xy'), + ('node', 'x', 'xy', 'y'), + ('x', 'y', 'node', 'xy'), + ]) + def test_mul_three_funcs(self, interp_mode, targets): + """`c * a * b` evaluated at `L` references all three factors.""" + grid = Grid((11, 11)) + funcs = self._all_funcs(grid) + c_key, a_key, b_key, l_key = targets + c, a, b, L = funcs[c_key], funcs[a_key], funcs[b_key], funcs[l_key] + + result = (c * a * b)._eval_at(L, interp_mode=interp_mode) + evaluated_str = str(result.evaluate) + assert a.name in evaluated_str + assert b.name in evaluated_str + assert c.name in evaluated_str + + @pytest.mark.parametrize('interp_mode', ['direct', 'symmetric']) + @pytest.mark.parametrize('targets', [ + ('node', 'x', 'xy'), + ('x', 'node', 'xy'), + ('node', 'xy', 'y'), + ('xy', 'node', 'x'), + ]) + def test_mul_func_deriv(self, interp_mode, targets): + """`a * b.dx` evaluated at `L`: symmetric `I*a*I^T*b.dx` form.""" + grid = Grid((11, 11)) + funcs = self._all_funcs(grid) + a_key, b_key, l_key = targets + a, b, L = funcs[a_key], funcs[b_key], funcs[l_key] + + result = (a * b.dx)._eval_at(L, interp_mode=interp_mode) + evaluated_str = str(result.evaluate) + assert a.name in evaluated_str + assert b.name in evaluated_str + + @pytest.mark.parametrize('interp_mode', ['direct', 'symmetric']) + def test_mul_eval_at_no_op(self, interp_mode): + """`a * b` evaluated at its own location is a no-op.""" + grid = Grid((11, 11)) + a = Function(name='a', grid=grid, space_order=4, staggered=NODE) + b = Function(name='b', grid=grid, space_order=4, staggered=NODE) + + result = (a * b)._eval_at(a, interp_mode=interp_mode) + assert sympy.simplify(result.evaluate - (a * b).evaluate) == 0 + + def test_interp_mode_skips_when_deriv_at_func(self): + """When `b.dx`'s natural staggering matches `c`'s, the symmetric + mode falls back to the default per-arg evaluation (no I^T).""" + grid = Grid((11,)) + x = grid.dimensions[0] + a = Function(name='a', grid=grid, space_order=4, staggered=NODE) + b = Function(name='b', grid=grid, space_order=4, staggered=x) + c = Function(name='c', grid=grid, space_order=4, staggered=x) + + default = (a * b.dx)._eval_at(c).evaluate + symmetric = (a * b.dx)._eval_at(c, interp_mode="symmetric").evaluate + assert sympy.simplify(default - symmetric) == 0 + + def test_interp_mode_applies_symmetric(self): + """When both `a` and `b` differ from `c` in staggering, the + symmetric mode wraps the product in the `I * a * I^T * b.dx` form, + producing two distinct 0-order FD interpolations (`I^T` and `I`).""" + grid = Grid((11, 11)) + x, y = grid.dimensions + a = Function(name='a', grid=grid, space_order=4, staggered=NODE) + b = Function(name='b', grid=grid, space_order=4, staggered=x) + c = Function(name='c', grid=grid, space_order=4, staggered=(x, y)) + + result = (a * b.dx)._eval_at(c, interp_mode="symmetric") + zero_order = [d for d in result.find(sympy.Derivative) + if all(o == 0 for o in d.deriv_order)] + # Two 0-order Derivatives: the outer I (a -> c) and the inner I^T (c -> a) + assert len(zero_order) == 2 + + def test_interp_mode_with_function_factor(self): + """The symmetric mode applies to any Differentiable factor at a + non-matching staggering, not only Derivatives. E.g. `a * bdx` where + `bdx` is a stored Function (e.g. holding a pre-computed derivative).""" + grid = Grid((11, 11)) + x, y = grid.dimensions + a = Function(name='a', grid=grid, space_order=4, staggered=NODE) + bdx = Function(name='bdx', grid=grid, space_order=4, staggered=x) + c = Function(name='c', grid=grid, space_order=4, staggered=(x, y)) + + result = (a * bdx)._eval_at(c, interp_mode="symmetric") + # Symmetric form: outer I (a -> c) and inner I^T (bdx's loc -> a) + zero_order = [d for d in result.find(sympy.Derivative) + if all(o == 0 for o in d.deriv_order)] + assert len(zero_order) == 2 + + def test_interp_mode_same_loc_block_interp(self): + """When all factors share a single location that differs from `func`, + the symmetric mode interpolates the whole product as one block (as + required by the elastic stiffness `I*(C_{ij}*b_j) -> a_i` form), + not per-arg.""" + grid = Grid((11, 11)) + x, y = grid.dimensions + a = Function(name='a', grid=grid, space_order=4, staggered=NODE) + b = Function(name='b', grid=grid, space_order=4, staggered=NODE) + c = Function(name='c', grid=grid, space_order=4, staggered=(x, y)) + + result = (a * b)._eval_at(c, interp_mode="symmetric") + # Single 0-order Derivative wrapping the product: I(a*b) + zero_order = [d for d in result.find(sympy.Derivative) + if all(o == 0 for o in d.deriv_order)] + assert len(zero_order) == 1 + block_str = str(zero_order[0].expr) + assert a.name in block_str + assert b.name in block_str + + def test_interp_mode_factor_at_func_falls_back(self): + """When at least one factor already matches `func`'s staggering, the + symmetric mode is unnecessary and we fall back to the default.""" + grid = Grid((11, 11)) + x, y = grid.dimensions + a = Function(name='a', grid=grid, space_order=4, staggered=NODE) + b = Function(name='b', grid=grid, space_order=4, staggered=NODE) + c = Function(name='c', grid=grid, space_order=4, staggered=NODE) + + default = (a * b)._eval_at(c).evaluate + symmetric = (a * b)._eval_at(c, interp_mode="symmetric").evaluate + assert sympy.simplify(default - symmetric) == 0 + + +class TestElasticStiffness: + """ + Verify `Mul._eval_at(interp_mode="symmetric")` produces the symmetric `a = C*b` + elastic stiffness pattern in 3D Voigt notation, with `C_{ij}` at NODE + and `b`, `a` components at the standard staggered locations: + + - 1, 2, 3 (normal): NODE + - 4 (yz): (y, z)-staggered ("0--") + - 5 (xz): (x, z)-staggered ("-0-") + - 6 (xy): (x, y)-staggered ("--0") + + Expected discrete C matrix (only the structure of interp operators + matters, not the specific weights): + + .. code-block:: text + + [ C11 C12 C13 | C14 I0-- C15 I-0- C16 I--0 + C12 C22 C23 | C24 I0-- C25 I-0- C26 I--0 + C13 C23 C33 | C34 I0-- C35 I-0- C36 I--0 + ----------------------------+----------------------------- + I0++ C14 I0++ C24 I0++ C34 | C44 I0++ C45 I-0- I0++ C46 I--0 + I+0+ C15 I+0+ C25 I+0+ C35 | I+0+ C45 I0-- C55 I+0+ C56 I--0 + I++0 C16 I++0 C26 I++0 C36 | I++0 C46 I0-- I++0 C56 I-0- C66 ] + """ + + @staticmethod + def _setup(): + grid = Grid((11, 11, 11)) + x, y, z = grid.dimensions + + def F(name, stag): + return Function(name=name, grid=grid, space_order=4, staggered=stag) + + # a, b components at Voigt locations + a = {1: F('a1', NODE), 2: F('a2', NODE), 3: F('a3', NODE), + 4: F('a4', (y, z)), 5: F('a5', (x, z)), 6: F('a6', (x, y))} + b = {1: F('b1', NODE), 2: F('b2', NODE), 3: F('b3', NODE), + 4: F('b4', (y, z)), 5: F('b5', (x, z)), 6: F('b6', (x, y))} + # Stiffness C_{ij} all at NODE + C = {(i, j): F(f'C{i}{j}', NODE) for i in range(1, 7) for j in range(1, 7)} + return a, b, C + + @staticmethod + def _zero_order_derivs(expr): + return [d for d in expr.find(sympy.Derivative) + if all(o == 0 for o in d.deriv_order)] + + @pytest.mark.parametrize('i, j', [(1, 1), (1, 2), (2, 3)]) + def test_normal_normal(self, i, j): + """`i, j in {1, 2, 3}`: a, b, C all at NODE -> direct product.""" + a, b, C = self._setup() + result = (C[(i, j)] * b[j])._eval_at(a[i], interp_mode="symmetric") + # No interpolation needed + assert self._zero_order_derivs(result) == [] + + @pytest.mark.parametrize('i, j', [(1, 4), (2, 5), (3, 6)]) + def test_normal_shear_row(self, i, j): + """`i in {1,2,3}, j in {4,5,6}`: bring `b_j` from its stag location + to NODE; `a_i` is at NODE so no outer I. `C_{ij} * I_{j-stag}(b_j)`. + + The default-mode subs path handles this (one factor matches func), + producing the correct `C_{ij} * b_j(NODE indices)` form.""" + a, b, C = self._setup() + result = (C[(i, j)] * b[j])._eval_at(a[i], interp_mode="symmetric") + # b_j gets subs'd to NODE indices; no explicit 0-order Derivative + assert self._zero_order_derivs(result) == [] + # b_j now indexed at a_i's NODE indices, ready for interp_for_fd + evaluated_str = str(result.evaluate) + assert b[j].name in evaluated_str + assert C[(i, j)].name in evaluated_str + + @pytest.mark.parametrize('i, j', [(4, 1), (5, 2), (6, 3)]) + def test_shear_normal_row(self, i, j): + """`i in {4,5,6}, j in {1,2,3}`: `C_{ij}` and `b_j` both at NODE, + target at stag -> single block interp `I(C_{ij} * b_j) -> a_i`.""" + a, b, C = self._setup() + result = (C[(i, j)] * b[j])._eval_at(a[i], interp_mode="symmetric") + zo = self._zero_order_derivs(result) + assert len(zo) == 1 + block_str = str(zo[0].expr) + assert C[(i, j)].name in block_str + assert b[j].name in block_str + + @pytest.mark.parametrize('i', [4, 5, 6]) + def test_shear_diagonal(self, i): + """`i == j in {4,5,6}`: `C_{ii}` at NODE; both `a_i` and `b_i` + at the same stag -> `b_i` matches func, default path applies + (C is implicitly interp'd to `a_i`'s location via subs).""" + a, b, C = self._setup() + result = (C[(i, i)] * b[i])._eval_at(a[i], interp_mode="symmetric") + # b_i matches func -> default path -> no explicit 0-order Derivative + assert self._zero_order_derivs(result) == [] + + @pytest.mark.parametrize('i, j', [(4, 5), (4, 6), (5, 6), (5, 4)]) + def test_shear_shear_offdiag(self, i, j): + """`i, j in {4,5,6}, i != j`: `C_{ij}` at NODE, `b_j` and `a_i` + at *different* stag -> full symmetric `I_{a_i++}(C_{ij} * I_{b_j--}(b_j))`. + Produces two 0-order interp operators (one I^T, one I).""" + a, b, C = self._setup() + result = (C[(i, j)] * b[j])._eval_at(a[i], interp_mode="symmetric") + assert len(self._zero_order_derivs(result)) == 2 + + def test_full_row_shear(self): + """Build the full row 4 of `a = C * b` and verify the structure: + sum over j of `C_{4j} * b_j` evaluated at `a_4`. Every term should + appear and each non-diagonal term contributes a symmetric structure.""" + a, b, C = self._setup() + terms = [(C[(4, j)] * b[j])._eval_at(a[4], interp_mode="symmetric") + for j in range(1, 7)] + # Sanity: every component is referenced + full_str = ''.join(str(t.evaluate) for t in terms) + for j in range(1, 7): + assert b[j].name in full_str + assert C[(4, j)].name in full_str + + +class TestSymmetricAdjoint: + """ + Numerical adjoint-identity check for `interp-mode='symmetric'`. + + For a symmetric stiffness `C` the continuous operator `sigma = C * eps` + is self-adjoint: ` = `. A discretization that + preserves this identity (to numerical precision) preserves + energy / yields the correct adjoint state. The `'symmetric'` interp + mode does exactly that — the `I * A * I^T` factorization makes the + discrete operator self-adjoint when `A` is. + + The companion `'direct'` mode does *not* preserve the identity (each + factor is interpolated independently, so the discrete operator is not + the transpose of itself). + """ + + @staticmethod + def _setup(interp_mode, so=4): + np.random.seed(1234) + + nx, ny, nz = 11, 11, 11 + grid = Grid(shape=(nx, ny, nz), extent=(1.0, 1.0, 1.0)) + x, y, z = grid.dimensions + + # Standard Voigt staggerings + locs = {1: NODE, 2: NODE, 3: NODE, + 4: (y, z), 5: (x, z), 6: (x, y)} + + # Random symmetric 6x6 stiffness, all components at NODE + C = {} + for i in range(1, 7): + for j in range(i, 7): + f = Function(name=f'C{i}{j}', grid=grid, space_order=so, + staggered=NODE) + f.data[:] = np.random.rand(nx, ny, nz) + C[(i, j)] = f + C[(j, i)] = f + + def six(prefix): + return {i: Function(name=f'{prefix}{i}', grid=grid, + space_order=so, staggered=locs[i]) + for i in range(1, 7)} + + e1, e2, t1, t2 = six('e1_'), six('e2_'), six('t1_'), six('t2_') + for i in range(1, 7): + e1[i].data[:] = 2 * np.random.rand(nx, ny, nz) - 1 + t2[i].data[:] = 2 * np.random.rand(nx, ny, nz) - 1 + + # t1 = C * e1 and e2 = C * t2 -- two applications of the same operator + eqns = [] + for i in range(1, 7): + eqns.append(Eq(t1[i], sum(C[(i, j)] * e1[j] for j in range(1, 7)))) + eqns.append(Eq(e2[i], sum(C[(i, j)] * t2[j] for j in range(1, 7)))) + + Operator(eqns, sym_opt={'interp-mode': interp_mode}).apply() + + inner_e = sum(float(np.dot(e1[i].data.flatten(), + e2[i].data.flatten())) + for i in range(1, 7)) + inner_t = sum(float(np.dot(t1[i].data.flatten(), + t2[i].data.flatten())) + for i in range(1, 7)) + return inner_e, inner_t + + def test_symmetric_preserves_adjoint(self): + """` == ` to numerical precision under + `interp-mode='symmetric'`.""" + inner_e, inner_t = self._setup('symmetric') + rel = abs(inner_e - inner_t) / max(abs(inner_e), abs(inner_t)) + assert rel < 1e-5, ( + f' = {inner_e!r} vs = {inner_t!r} ' + f'(rel diff {rel:.3e})' + ) + + def test_direct_breaks_adjoint(self): + """`interp-mode='direct'` interpolates factors independently and so + does *not* preserve the discrete adjoint identity. Recorded as an + explicit large-discrepancy check so a regression that accidentally + makes `'direct'` adjoint-correct also shows up.""" + inner_e, inner_t = self._setup('direct') + rel = abs(inner_e - inner_t) / max(abs(inner_e), abs(inner_t)) + assert rel > 1e-2, ( + f"'direct' mode unexpectedly preserved adjoint identity: " + f' = {inner_e!r}, = {inner_t!r} ' + f'(rel diff {rel:.3e})' + ) diff --git a/tests/test_staggered_utils.py b/tests/test_staggered_utils.py index fab3742638..74ff3a7426 100644 --- a/tests/test_staggered_utils.py +++ b/tests/test_staggered_utils.py @@ -79,8 +79,12 @@ def test_is_param(ndim): @pytest.mark.parametrize('expr, expected', [ - ('(a*b)._gather_for_diff', 'a.subs({x: x0}) * b'), - ('(d*b)._gather_for_diff', 'd.subs({x: x0}) * b'), + # NODE has higher _fd_priority than staggered, so a product of factors at + # different staggerings is gathered at the cell centre (where coefficients + # naturally live). A Derivative inherits the priority of its underlying + # function (so `d.dx` represents `d` for the comparison). + ('(a*b)._gather_for_diff', 'a * b.subs({x0: x})'), + ('(d*b)._gather_for_diff', 'd * b.subs({x0: x})'), ('(d.dx*b)._gather_for_diff', 'd.dx * b.subs({x0: x})'), ('(b*c)._gather_for_diff', 'b * c.subs({x: x0, y0: y})')]) def test_gather_for_diff(expr, expected): @@ -99,9 +103,11 @@ def test_gather_for_diff(expr, expected): @pytest.mark.parametrize('expr, expected', [ ('((a + b).dx._eval_at(a)).is_Add', 'True'), ('(a + b).dx._eval_at(a)', 'a.dx + b.dx._eval_at(a)'), - ('(a*b).dx._eval_at(a).expr', 'a.subs({x: x0}) * b'), + # NODE has higher _fd_priority than staggered, so products are gathered at + # the cell centre (with the staggered factor shifted *down* to NODE). + ('(a*b).dx._eval_at(a).expr', 'a * b.subs({x0: x})'), ('(a * b.dx).dx._eval_at(b).expr._eval_deriv ', - 'a.subs({x: x0}) * b.dx.evaluate')]) + 'a * b.dx.evaluate')]) def test_stagg_fd_composite(expr, expected): grid = Grid((10, 10)) x, y = grid.dimensions