Skip to content

Commit 883d31e

Browse files
committed
api: move symbolic config to its own options
1 parent e0d0d5d commit 883d31e

10 files changed

Lines changed: 880 additions & 182 deletions

File tree

devito/core/cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def _normalize_kwargs(cls, **kwargs):
8787

8888
# Code generation options for derivatives
8989
o['expand'] = oo.pop('expand', cls.EXPAND)
90-
o['interp-mode'] = oo.pop('interp-mode', cls.INTERP_MODE)
9190
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
9291
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
9392
o['deriv-unroll'] = oo.pop('deriv-unroll', False)
@@ -112,6 +111,7 @@ def _normalize_kwargs(cls, **kwargs):
112111
)
113112

114113
kwargs['options'].update(o)
114+
kwargs['sym_options'] = cls._normalize_sym_kwargs(**kwargs)
115115

116116
return kwargs
117117

devito/core/gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def _normalize_kwargs(cls, **kwargs):
102102

103103
# Code generation options for derivatives
104104
o['expand'] = oo.pop('expand', cls.EXPAND)
105-
o['interp-mode'] = oo.pop('interp-mode', cls.INTERP_MODE)
106105
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
107106
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
108107
o['deriv-unroll'] = oo.pop('deriv-unroll', False)
@@ -122,6 +121,7 @@ def _normalize_kwargs(cls, **kwargs):
122121
)
123122

124123
kwargs['options'].update(o)
124+
kwargs['sym_options'] = cls._normalize_sym_kwargs(**kwargs)
125125

126126
return kwargs
127127

devito/core/operator.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -125,26 +125,6 @@ class BasicOperator(Operator):
125125
finite-difference derivatives.
126126
"""
127127

128-
INTERP_MODE = 'direct'
129-
"""
130-
Interpolation mode used by `Mul._eval_at` when projecting a multi-factor
131-
expression onto a target staggered location:
132-
133-
* `'direct'` (default): each factor is shifted to `func`'s location
134-
independently (`Function._eval_at` per arg). Cheapest stencil; the
135-
mode to pick unless you need an explicitly self-adjoint discretization.
136-
137-
* `'symmetric'`: when every factor lives at a staggering different from
138-
`func`'s, the symmetric form `I * (a * I^T * b)` is built -- all
139-
factors are gathered at the highest-priority "block" location via
140-
`I^T`, multiplied there, and the product is interpolated to `func`
141-
via `I`. Use this for operators whose continuous form decomposes as
142-
`I * A * I^T` (e.g. the elastic stiffness `σ = C ε`).
143-
144-
See `examples/userapi/08_staggered_interp.ipynb` for the maths and a
145-
worked elastic-stiffness example.
146-
"""
147-
148128
DERIV_COLLECT = True
149129
"""
150130
Factorize finite-difference derivatives exploiting the linearity of the FD
@@ -191,6 +171,30 @@ class BasicOperator(Operator):
191171
The target language constructor, to be specified by subclasses.
192172
"""
193173

174+
# ------------------------------------------------------------------
175+
# Symbolic-level option defaults (`sym_opt`).
176+
# These steer mathematical choices made during expression lowering,
177+
# *not* code generation or performance. They are kept separate from
178+
# the `opt` options above to keep the two concerns distinct.
179+
# ------------------------------------------------------------------
180+
181+
INTERP_MODE = 'direct'
182+
"""
183+
Default for the `sym_opt={'interp-mode': ...}` option. Controls how
184+
a product of fields living at different staggered locations is mapped
185+
onto a target location:
186+
187+
* `'direct'` (default): each factor is interpolated to the target
188+
independently. Cheapest stencil.
189+
* `'symmetric'`: factors are first gathered at a common "block"
190+
location, multiplied there, and the result is interpolated once to
191+
the target. Preserves the `I A I^T` matrix structure, so the
192+
discrete operator stays self-adjoint when the continuous one is
193+
(e.g. the elastic stiffness `sigma = C eps`).
194+
195+
See `examples/userapi/08_staggered_interp.ipynb` for a worked example.
196+
"""
197+
194198
@classmethod
195199
def _normalize_kwargs(cls, **kwargs):
196200
# Will be populated with dummy values; this method is actually overridden
@@ -208,12 +212,30 @@ def _normalize_kwargs(cls, **kwargs):
208212
)
209213

210214
kwargs['options'].update(o)
215+
kwargs['sym_options'] = cls._normalize_sym_kwargs(**kwargs)
211216

212217
return kwargs
213218

219+
@classmethod
220+
def _normalize_sym_kwargs(cls, **kwargs):
221+
"""
222+
Fill in defaults and validate keys for the `sym_opt` dict passed to
223+
the Operator. Returns the normalized `sym_options` dict.
224+
"""
225+
so = dict(kwargs.get('sym_options', {}))
226+
out = {'interp-mode': so.pop('interp-mode', cls.INTERP_MODE)}
227+
228+
if so:
229+
raise InvalidOperator(
230+
f'Unrecognized symbolic options: [{", ".join(list(so))}]'
231+
)
232+
233+
return out
234+
214235
@classmethod
215236
def _check_kwargs(cls, **kwargs):
216237
oo = kwargs['options']
238+
so = kwargs['sym_options']
217239

218240
if oo['mpi'] and oo['mpi'] not in cls.MPI_MODES:
219241
raise InvalidOperator(f"Unsupported MPI mode `{oo['mpi']}`")
@@ -229,6 +251,9 @@ def _check_kwargs(cls, **kwargs):
229251
if oo['errctl'] not in (None, False, 'basic', 'max'):
230252
raise InvalidOperator("Illegal `errctl` value")
231253

254+
if so['interp-mode'] not in ('direct', 'symmetric'):
255+
raise InvalidOperator("Illegal `interp-mode` value")
256+
232257
def _autotune(self, args, setup):
233258
if setup in [False, 'off']:
234259
return args

devito/finite_differences/differentiable.py

Lines changed: 11 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections import ChainMap
2-
from contextlib import suppress
32
from functools import cached_property, singledispatch
43
from itertools import product
54

@@ -16,6 +15,7 @@
1615
# Moved in 1.13
1716
from sympy.core.basic import ordering_of_classes
1817

18+
from devito.finite_differences.interpolation import interp_at, post_x0_indices
1919
from devito.finite_differences.tools import coeff_priority, make_shift_x0
2020
from devito.logger import warning
2121
from devito.tools import (
@@ -699,32 +699,31 @@ def _eval_at(self, func, interp_mode='direct', **kwargs):
699699
if interp_mode != 'symmetric':
700700
return super()._eval_at(func, **kwargs)
701701

702-
diff_args = [a for a in self.args if isinstance(a, Differentiable)]
703-
other_args = [a for a in self.args if not isinstance(a, Differentiable)]
702+
diff, other = split(self.args, lambda a: isinstance(a, Differentiable))
704703

705704
# Symmetric form requires every Differentiable factor to differ from
706705
# func; otherwise direct evaluation is cleaner and equivalent.
707-
if len(diff_args) < 2 or \
708-
any(a.staggered == func.staggered for a in diff_args):
706+
if len(diff) < 2 or \
707+
any(a.staggered == func.staggered for a in diff):
709708
return super()._eval_at(func, **kwargs)
710709

711710
block_indices = highest_priority(self).indices_ref
712711

713712
# Bring each factor to block's location (I^T where needed)
714-
new_factors = list(other_args)
715-
for a in diff_args:
713+
new_factors = list(other)
714+
for a in diff:
716715
if isinstance(a, sympy.Derivative):
717-
source = _post_x0_indices(a, func)
716+
source = post_x0_indices(a, func)
718717
a = a._rebuild(x0={dim: func.indices_ref[dim] for dim in a.dims
719718
if dim in func.indices_ref.getters})
720719
else:
721720
source = a.indices_ref
722-
new_factors.append(_interp_at(a, source, block_indices,
723-
self.interp_order))
721+
new_factors.append(interp_at(a, source, block_indices,
722+
self.interp_order))
724723

725724
# Final I from block's location to func
726-
return _interp_at(self.func(*new_factors), block_indices,
727-
func.indices_ref, self.interp_order)
725+
return interp_at(self.func(*new_factors), block_indices,
726+
func.indices_ref, self.interp_order)
728727

729728

730729
class Pow(DifferentiableOp, sympy.Pow):
@@ -1251,63 +1250,6 @@ def _diff2sympy(obj):
12511250
evalf_table[Pow] = evalf_table[sympy.Pow]
12521251

12531252

1254-
def _interp_mapper(source, target, dims):
1255-
"""
1256-
Build a `{dim: target_index}` mapper for dimensions in `dims` where
1257-
`source[dim]` differs from `target[dim]`.
1258-
1259-
`source` and `target` are dict-like `{dim: index_expr}` (e.g. a plain
1260-
dict or a `DimensionTuple`). Dimensions missing from either side are
1261-
skipped silently.
1262-
"""
1263-
mapper = {}
1264-
for d in dims:
1265-
try:
1266-
s = source[d]
1267-
t = target[d]
1268-
except (KeyError, IndexError):
1269-
continue
1270-
if s is not t:
1271-
mapper[d] = t
1272-
return mapper
1273-
1274-
1275-
def _interp_at(expr, source, target, interp_order):
1276-
"""
1277-
Build a symbolic 0-order FD interpolation operator on `expr` that maps
1278-
values from `source` indices to `target` indices, only on the
1279-
dimensions where the two locations differ.
1280-
"""
1281-
if not isinstance(expr, Differentiable):
1282-
return expr
1283-
1284-
mapper = _interp_mapper(source, target, expr.dimensions)
1285-
if not mapper:
1286-
return expr
1287-
1288-
return expr.diff(*mapper.keys(),
1289-
deriv_order=(0,) * len(mapper),
1290-
fd_order=(interp_order,) * len(mapper),
1291-
x0=mapper)
1292-
1293-
1294-
def _post_x0_indices(deriv, func):
1295-
"""
1296-
Conceptual indices of `deriv` after setting `x0` on its own derivative
1297-
dimensions to `func`'s indices. Derivative dims take `func`'s indices;
1298-
other dims keep the underlying expression's natural location (so that
1299-
`interp_for_fd` does not introduce a spurious second shift).
1300-
"""
1301-
ref = {}
1302-
for dim in deriv.dimensions:
1303-
if dim in deriv.dims and dim in func.indices_ref.getters:
1304-
ref[dim] = func.indices_ref[dim]
1305-
else:
1306-
with suppress(KeyError):
1307-
ref[dim] = deriv.indices_ref[dim]
1308-
return ref
1309-
1310-
13111253
# Interpolation for finite differences
13121254
@singledispatch
13131255
def interp_for_fd(expr, x0, **kwargs):
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from contextlib import suppress
2+
3+
__all__ = ['interp_at', 'interp_mapper', 'post_x0_indices']
4+
5+
6+
def interp_mapper(source, target, dims):
7+
"""
8+
Build a `{dim: target_index}` mapper for dimensions in `dims` where
9+
`source[dim]` differs from `target[dim]`.
10+
11+
`source` and `target` are dict-like `{dim: index_expr}` (e.g. a plain
12+
dict or a `DimensionTuple`). Dimensions missing from either side are
13+
skipped silently.
14+
"""
15+
mapper = {}
16+
for d in dims:
17+
try:
18+
s = source[d]
19+
t = target[d]
20+
except (KeyError, IndexError):
21+
continue
22+
if s is not t:
23+
mapper[d] = t
24+
return mapper
25+
26+
27+
def interp_at(expr, source, target, interp_order):
28+
"""
29+
Build a symbolic 0-order FD interpolation operator on `expr` that maps
30+
values from `source` indices to `target` indices, only on the
31+
dimensions where the two locations differ.
32+
"""
33+
from devito.finite_differences.differentiable import Differentiable
34+
35+
if not isinstance(expr, Differentiable):
36+
return expr
37+
38+
mapper = interp_mapper(source, target, expr.dimensions)
39+
if not mapper:
40+
return expr
41+
42+
return expr.diff(*mapper.keys(),
43+
deriv_order=(0,) * len(mapper),
44+
fd_order=(interp_order,) * len(mapper),
45+
x0=mapper)
46+
47+
48+
def post_x0_indices(deriv, func):
49+
"""
50+
Conceptual indices of `deriv` after setting `x0` on its own derivative
51+
dimensions to `func`'s indices. Derivative dims take `func`'s indices;
52+
other dims keep the underlying expression's natural location (so that
53+
`interp_for_fd` does not introduce a spurious second shift).
54+
"""
55+
ref = {}
56+
for dim in deriv.dimensions:
57+
if dim in deriv.dims and dim in func.indices_ref.getters:
58+
ref[dim] = func.indices_ref[dim]
59+
else:
60+
with suppress(KeyError):
61+
ref[dim] = deriv.indices_ref[dim]
62+
return ref

devito/operator/operator.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@ class Operator(Callable):
6767
Symbolic substitutions to be applied to ``expressions``.
6868
* opt : str
6969
The performance optimization level. Defaults to ``configuration['opt']``.
70+
* sym_opt : dict
71+
Symbolic-level options controlling mathematical choices made during
72+
expression lowering (e.g. how staggered multi-factor products are
73+
interpolated). Distinct from ``opt``, which controls code generation
74+
and performance. Accepted keys:
75+
76+
- ``'interp-mode'`` (``'direct'`` | ``'symmetric'``): selects the
77+
interpolation strategy used by ``Mul._eval_at`` when projecting a
78+
multi-factor expression onto a target staggered location. See the
79+
tutorial at ``examples/userapi/08_staggered_interp.ipynb``.
7080
* language : str
7181
The target language for shared-memory parallelism. Defaults to
7282
``configuration['language']``.
@@ -234,6 +244,7 @@ def _build(cls, expressions, **kwargs):
234244
# Potentially required for lazily allocated Functions
235245
op._mode = kwargs['mode']
236246
op._options = kwargs['options']
247+
op._sym_options = kwargs['sym_options']
237248
op._allocator = kwargs['allocator']
238249
op._platform = kwargs['platform']
239250

@@ -341,7 +352,7 @@ def _lower_exprs(cls, expressions, **kwargs):
341352
* Shift indices for domain alignment.
342353
"""
343354
expand = kwargs['options'].get('expand', True)
344-
interp_mode = kwargs['options'].get('interp-mode', 'direct')
355+
interp_mode = kwargs.get('sym_options', {}).get('interp-mode', 'direct')
345356

346357
# Specialization is performed on unevaluated expressions
347358
expressions = cls._specialize_dsl(expressions, **kwargs)
@@ -1643,6 +1654,12 @@ def parse_kwargs(**kwargs):
16431654
mode = 'noop'
16441655
kwargs['mode'] = mode
16451656

1657+
# `sym_opt` -- symbolic-level options (mathematical choices, not codegen)
1658+
sym_opt = kwargs.pop('sym_opt', None) or {}
1659+
if not isinstance(sym_opt, (dict, frozendict)):
1660+
raise InvalidOperator(f"Illegal `sym_opt={str(sym_opt)}`")
1661+
kwargs['sym_options'] = dict(sym_opt)
1662+
16461663
# `platform`
16471664
platform = kwargs.get('platform')
16481665
if platform is not None:

devito/types/dense.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from devito.deprecations import deprecations
1616
from devito.exceptions import InvalidArgument
1717
from devito.finite_differences import Differentiable, generate_fd_shortcuts
18-
from devito.finite_differences.differentiable import _interp_mapper
18+
from devito.finite_differences.interpolation import interp_mapper
1919
from devito.finite_differences.tools import fd_weights_registry
2020
from devito.logger import debug, warning
2121
from devito.mpi import MPI
@@ -1128,7 +1128,7 @@ def _eval_at(self, func, **kwargs):
11281128
return self
11291129

11301130
# Dims where self and func indices differ -> {dim: func_idx}
1131-
diff = _interp_mapper(self.indices_ref, func.indices_ref, self.dimensions)
1131+
diff = interp_mapper(self.indices_ref, func.indices_ref, self.dimensions)
11321132

11331133
# Translate into a subs mapper {self_idx: func_idx} aligned on self's dims
11341134
subs_map = {self.indices_ref[d]: t._subs(func.dimensions[d], d)

0 commit comments

Comments
 (0)