Skip to content

Commit a0b7fc0

Browse files
committed
always interp whole expr for inject as doesn't make sense
1 parent a05f954 commit a0b7fc0

3 files changed

Lines changed: 10 additions & 57 deletions

File tree

devito/operations/interpolators.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -164,21 +164,19 @@ class Injection(UnevaluatedSparseOperation):
164164

165165
__rargs__ = ('field', 'expr', 'implicit_dims') + UnevaluatedSparseOperation.__rargs__
166166

167-
def __new__(cls, field, expr, implicit_dims, interpolator, interp_expr=False):
167+
def __new__(cls, field, expr, implicit_dims, interpolator):
168168
obj = super().__new__(cls, interpolator)
169169

170170
# TODO: unused now, but will be necessary to compute the adjoint
171171
obj.field = field
172172
obj.expr = expr
173173
obj.implicit_dims = implicit_dims
174-
obj.interp_expr = interp_expr
175174

176175
return obj
177176

178177
def operation(self, **kwargs):
179178
return self.interpolator._inject(expr=self.expr, field=self.field,
180-
implicit_dims=self.implicit_dims,
181-
interp_expr=self.interp_expr)
179+
implicit_dims=self.implicit_dims)
182180

183181
def __repr__(self):
184182
return f"Injection({repr(self.expr)} into {repr(self.field)})"
@@ -309,7 +307,7 @@ def _positions(self, implicit_dims):
309307
return [Eq(v, INT(floor(k)), implicit_dims=implicit_dims)
310308
for k, v in self.sfunction._position_map.items()]
311309

312-
def _interp_idx(self, variables, implicit_dims=None, pos_only=(), subdomain=None):
310+
def _interp_idx(self, variables, implicit_dims=None, subdomain=None):
313311
"""
314312
Generate interpolation indices for the DiscreteFunctions in ``variables``.
315313
"""
@@ -333,16 +331,6 @@ def _interp_idx(self, variables, implicit_dims=None, pos_only=(), subdomain=None
333331

334332
idx_subs = {v: v.subs(subs) for v in variables}
335333

336-
# Position only replacement, not radius dependent.
337-
# E.g src.inject(vp(x)*src) needs to use vp[posx] at all points
338-
# not vp[posx + rx]
339-
idx_subs.update({
340-
v: v.subs({
341-
k: p
342-
for (k, p) in zip(mapper, pos, strict=True)
343-
}) for v in pos_only
344-
})
345-
346334
return idx_subs, temps
347335

348336
@check_radius
@@ -368,7 +356,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
368356

369357
@check_radius
370358
@check_coords
371-
def inject(self, field, expr, implicit_dims=None, interp_expr=False):
359+
def inject(self, field, expr, implicit_dims=None):
372360
"""
373361
Generate equations injecting an arbitrary expression into a field.
374362
@@ -383,7 +371,7 @@ def inject(self, field, expr, implicit_dims=None, interp_expr=False):
383371
injection expression, but that should be honored when constructing
384372
the operator.
385373
"""
386-
return Injection(field, expr, implicit_dims, self, interp_expr=interp_expr)
374+
return Injection(field, expr, implicit_dims, self)
387375

388376
def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
389377
"""
@@ -435,7 +423,7 @@ def _interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None
435423

436424
return temps + summands + last
437425

438-
def _inject(self, field, expr, implicit_dims=None, interp_expr=False):
426+
def _inject(self, field, expr, implicit_dims=None):
439427
"""
440428
Generate equations injecting an arbitrary expression into a field.
441429
@@ -481,10 +469,9 @@ def _inject(self, field, expr, implicit_dims=None, interp_expr=False):
481469
self._rdim(subdomain=subdomain))
482470

483471
# List of indirection indices for all adjacent grid points
484-
finterp = fields + as_tuple(variables) if interp_expr else fields
485-
pos_only = () if interp_expr else variables
472+
finterp = fields + as_tuple(variables)
486473
idx_subs, temps = self._interp_idx(finterp, implicit_dims=implicit_dims,
487-
pos_only=pos_only, subdomain=subdomain)
474+
subdomain=subdomain)
488475

489476
# Substitute coordinate base symbols into the interpolation coefficients
490477
eqns = [Inc(_field.xreplace(idx_subs),

devito/types/sparse.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,8 +1089,7 @@ def interpolate(self, expr, u_t=None, p_t=None, increment=False, implicit_dims=N
10891089
return super().interpolate(expr, increment=increment, self_subs=subs,
10901090
implicit_dims=implicit_dims)
10911091

1092-
def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None,
1093-
interp_expr=False):
1092+
def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None):
10941093
"""
10951094
Generate equations injecting an arbitrary expression into a field.
10961095
@@ -1115,8 +1114,7 @@ def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None,
11151114
if p_t is not None:
11161115
expr = expr.subs({self.time_dim: p_t})
11171116

1118-
return super().inject(field, expr, implicit_dims=implicit_dims,
1119-
interp_expr=interp_expr)
1117+
return super().inject(field, expr, implicit_dims=implicit_dims)
11201118

11211119
@property
11221120
def forward(self):

tests/test_interpolation.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
SparseTimeFunction, SubDomain, TimeFunction, switchconfig
1212
)
1313
from devito.operations.interpolators import LinearInterpolator, SincInterpolator
14-
from devito.symbolics import retrieve_functions
1514
from examples.seismic import (
1615
AcquisitionGeometry, Receiver, RickerSource, TimeAxis, demo_model
1716
)
@@ -570,37 +569,6 @@ def test_inject_from_field(shape, coords, result, npoints=19):
570569
assert np.allclose(a.data[indices], result, rtol=1.e-5)
571570

572571

573-
def test_inject_interp_expr():
574-
"""
575-
Test that the Function coefficient gets interpolated too.
576-
"""
577-
coords = [(.05, .95), (.45, .45)]
578-
a = unit_box(shape=(11, 11))
579-
a.data[:] = 0.
580-
p = points(a.grid, ranges=coords, npoints=19)
581-
m = Function(name='m', grid=a.grid)
582-
m.data_with_halo[:] = 1.
583-
584-
expr = p.inject(a, m, interp_expr=True)
585-
op = Operator(expr)
586-
587-
op(a=a)
588-
589-
indices = [slice(4, 6, 1) for _ in coords]
590-
indices[0] = slice(1, -1, 1)
591-
assert np.allclose(a.data[indices], 1, rtol=1.e-5)
592-
593-
# Extract interp expr to check indices
594-
e_expr = expr.evaluate
595-
funcs = retrieve_functions(e_expr[-1])
596-
assert m in {f.function for f in funcs}
597-
# All funcs should have the same indices wit radius
598-
# includint the coefficient m
599-
indices = {f.indices for f in funcs}
600-
assert len(indices) == 1
601-
assert str(indices.pop()) == '(rp_pointsx + posx, rp_pointsy + posy)'
602-
603-
604572
@pytest.mark.parametrize('shape', [(50, 50, 50)])
605573
def test_position(shape):
606574
t0 = 0.0 # Start time

0 commit comments

Comments
 (0)