|
1 | 1 | from collections import ChainMap |
2 | | -from contextlib import suppress |
3 | 2 | from functools import cached_property, singledispatch |
4 | 3 | from itertools import product |
5 | 4 |
|
|
16 | 15 | # Moved in 1.13 |
17 | 16 | from sympy.core.basic import ordering_of_classes |
18 | 17 |
|
| 18 | +from devito.finite_differences.interpolation import interp_at, post_x0_indices |
19 | 19 | from devito.finite_differences.tools import coeff_priority, make_shift_x0 |
20 | 20 | from devito.logger import warning |
21 | 21 | from devito.tools import ( |
@@ -699,32 +699,31 @@ def _eval_at(self, func, interp_mode='direct', **kwargs): |
699 | 699 | if interp_mode != 'symmetric': |
700 | 700 | return super()._eval_at(func, **kwargs) |
701 | 701 |
|
702 | | - diff_args = [a for a in self.args if isinstance(a, Differentiable)] |
703 | | - other_args = [a for a in self.args if not isinstance(a, Differentiable)] |
| 702 | + diff, other = split(self.args, lambda a: isinstance(a, Differentiable)) |
704 | 703 |
|
705 | 704 | # Symmetric form requires every Differentiable factor to differ from |
706 | 705 | # func; otherwise direct evaluation is cleaner and equivalent. |
707 | | - if len(diff_args) < 2 or \ |
708 | | - any(a.staggered == func.staggered for a in diff_args): |
| 706 | + if len(diff) < 2 or \ |
| 707 | + any(a.staggered == func.staggered for a in diff): |
709 | 708 | return super()._eval_at(func, **kwargs) |
710 | 709 |
|
711 | 710 | block_indices = highest_priority(self).indices_ref |
712 | 711 |
|
713 | 712 | # Bring each factor to block's location (I^T where needed) |
714 | | - new_factors = list(other_args) |
715 | | - for a in diff_args: |
| 713 | + new_factors = list(other) |
| 714 | + for a in diff: |
716 | 715 | if isinstance(a, sympy.Derivative): |
717 | | - source = _post_x0_indices(a, func) |
| 716 | + source = post_x0_indices(a, func) |
718 | 717 | a = a._rebuild(x0={dim: func.indices_ref[dim] for dim in a.dims |
719 | 718 | if dim in func.indices_ref.getters}) |
720 | 719 | else: |
721 | 720 | source = a.indices_ref |
722 | | - new_factors.append(_interp_at(a, source, block_indices, |
723 | | - self.interp_order)) |
| 721 | + new_factors.append(interp_at(a, source, block_indices, |
| 722 | + self.interp_order)) |
724 | 723 |
|
725 | 724 | # Final I from block's location to func |
726 | | - return _interp_at(self.func(*new_factors), block_indices, |
727 | | - func.indices_ref, self.interp_order) |
| 725 | + return interp_at(self.func(*new_factors), block_indices, |
| 726 | + func.indices_ref, self.interp_order) |
728 | 727 |
|
729 | 728 |
|
730 | 729 | class Pow(DifferentiableOp, sympy.Pow): |
@@ -1251,63 +1250,6 @@ def _diff2sympy(obj): |
1251 | 1250 | evalf_table[Pow] = evalf_table[sympy.Pow] |
1252 | 1251 |
|
1253 | 1252 |
|
1254 | | -def _interp_mapper(source, target, dims): |
1255 | | - """ |
1256 | | - Build a `{dim: target_index}` mapper for dimensions in `dims` where |
1257 | | - `source[dim]` differs from `target[dim]`. |
1258 | | -
|
1259 | | - `source` and `target` are dict-like `{dim: index_expr}` (e.g. a plain |
1260 | | - dict or a `DimensionTuple`). Dimensions missing from either side are |
1261 | | - skipped silently. |
1262 | | - """ |
1263 | | - mapper = {} |
1264 | | - for d in dims: |
1265 | | - try: |
1266 | | - s = source[d] |
1267 | | - t = target[d] |
1268 | | - except (KeyError, IndexError): |
1269 | | - continue |
1270 | | - if s is not t: |
1271 | | - mapper[d] = t |
1272 | | - return mapper |
1273 | | - |
1274 | | - |
1275 | | -def _interp_at(expr, source, target, interp_order): |
1276 | | - """ |
1277 | | - Build a symbolic 0-order FD interpolation operator on `expr` that maps |
1278 | | - values from `source` indices to `target` indices, only on the |
1279 | | - dimensions where the two locations differ. |
1280 | | - """ |
1281 | | - if not isinstance(expr, Differentiable): |
1282 | | - return expr |
1283 | | - |
1284 | | - mapper = _interp_mapper(source, target, expr.dimensions) |
1285 | | - if not mapper: |
1286 | | - return expr |
1287 | | - |
1288 | | - return expr.diff(*mapper.keys(), |
1289 | | - deriv_order=(0,) * len(mapper), |
1290 | | - fd_order=(interp_order,) * len(mapper), |
1291 | | - x0=mapper) |
1292 | | - |
1293 | | - |
1294 | | -def _post_x0_indices(deriv, func): |
1295 | | - """ |
1296 | | - Conceptual indices of `deriv` after setting `x0` on its own derivative |
1297 | | - dimensions to `func`'s indices. Derivative dims take `func`'s indices; |
1298 | | - other dims keep the underlying expression's natural location (so that |
1299 | | - `interp_for_fd` does not introduce a spurious second shift). |
1300 | | - """ |
1301 | | - ref = {} |
1302 | | - for dim in deriv.dimensions: |
1303 | | - if dim in deriv.dims and dim in func.indices_ref.getters: |
1304 | | - ref[dim] = func.indices_ref[dim] |
1305 | | - else: |
1306 | | - with suppress(KeyError): |
1307 | | - ref[dim] = deriv.indices_ref[dim] |
1308 | | - return ref |
1309 | | - |
1310 | | - |
1311 | 1253 | # Interpolation for finite differences |
1312 | 1254 | @singledispatch |
1313 | 1255 | def interp_for_fd(expr, x0, **kwargs): |
|
0 commit comments