Skip to content

Commit a33a610

Browse files
authored
Fix point evaluation API change (#4675)
1 parent 8a830ba commit a33a610

File tree

3 files changed

+103
-16
lines changed

3 files changed

+103
-16
lines changed

firedrake/function.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,12 @@ def evaluate(self, coord, mapping, component, index_values):
549549
# Called by UFL when evaluating expressions at coordinates
550550
if component or index_values:
551551
raise NotImplementedError("Unsupported arguments when attempting to evaluate Function.")
552+
coord = np.asarray(coord, dtype=utils.ScalarType)
552553
evaluator = PointEvaluator(self.function_space().mesh(), coord)
553-
return evaluator.evaluate(self)
554+
result = evaluator.evaluate(self)
555+
if len(coord.shape) == 1:
556+
result = result.squeeze(axis=0)
557+
return result
554558

555559
def at(self, arg, *args, **kwargs):
556560
warnings.warn(
@@ -809,7 +813,7 @@ def evaluate(self, function: Function) -> np.ndarray | Tuple[np.ndarray, ...]:
809813
f_at_points = assemble(interpolate(function, P0DG))
810814
f_at_points_io = Function(P0DG_io).assign(np.nan)
811815
f_at_points_io.interpolate(f_at_points)
812-
result = f_at_points_io.dat.data_ro
816+
result = f_at_points_io.dat.data_ro.copy()
813817

814818
# If redundant, all points are now on rank 0, so we broadcast the result
815819
if self.redundant and self.mesh.comm.size > 1:

tests/firedrake/regression/test_point_eval_api.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from os.path import abspath, dirname
2+
from numbers import Number
23
import numpy as np
34
import pytest
45

@@ -184,23 +185,33 @@ def test_point_evaluator_scalar(mesh_and_points):
184185
# Test standard scalar function evaluation at points
185186
f_at_points = evaluator.evaluate(f)
186187
assert np.allclose(f_at_points, [0.2, 0.4, 0.6])
188+
assert isinstance(f_at_points, np.ndarray)
189+
assert f_at_points.shape == (len(evaluator.points),) + f.ufl_shape
190+
assert isinstance(f_at_points[0], Number)
187191

188192
# Test standard scalar function with missing points
189193
eval_missing = PointEvaluator(mesh, np.append(points, [[1.5, 1.5]], axis=0), missing_points_behaviour="ignore")
190194
f_at_points_missing = eval_missing.evaluate(f)
191195
assert np.isnan(f_at_points_missing[-1])
192196

197+
# Can modify result
198+
f_at_points *= 2.0
199+
assert np.allclose(f_at_points, [0.4, 0.8, 1.2])
200+
193201

194202
@pytest.mark.parallel([1, 3])
195203
def test_point_evaluator_vector_tensor_mixed(mesh_and_points):
196204
mesh, evaluator = mesh_and_points
197205
V_vec = VectorFunctionSpace(mesh, "CG", 1)
198-
f_vec = Function(V_vec)
199206
x, y = SpatialCoordinate(mesh)
200-
f_vec.interpolate(as_vector([x, y]))
207+
f_vec = Function(V_vec).interpolate(as_vector([x, y]))
201208
f_vec_at_points = evaluator.evaluate(f_vec)
202209
vec_expected = np.array([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]])
203210
assert np.allclose(f_vec_at_points, vec_expected)
211+
assert isinstance(f_vec_at_points, np.ndarray)
212+
assert f_vec_at_points.shape == (len(evaluator.points),) + f_vec.ufl_shape
213+
assert isinstance(f_vec_at_points[0, 0], Number)
214+
assert isinstance(f_vec_at_points[0, :], np.ndarray)
204215

205216
V_tensor = TensorFunctionSpace(mesh, "CG", 1, shape=(2, 3))
206217
f_tensor = Function(V_tensor)
@@ -210,15 +221,25 @@ def test_point_evaluator_vector_tensor_mixed(mesh_and_points):
210221
[[0.2, 0.2, 0.04], [0.2, 0.2, 0.04]],
211222
[[0.3, 0.3, 0.09], [0.3, 0.3, 0.09]]])
212223
assert np.allclose(f_tensor_at_points, tensor_expected)
224+
assert f_tensor_at_points.shape == (len(evaluator.points),) + f_tensor.ufl_shape
225+
assert isinstance(f_tensor_at_points[0, 0, 0], Number)
226+
assert isinstance(f_tensor_at_points[0, 0, :], np.ndarray)
227+
assert isinstance(f_tensor_at_points[0, :, :], np.ndarray)
213228

214229
V_mixed = V_vec * V_tensor
215230
f_mixed = Function(V_mixed)
216231
f_vec, f_tensor = f_mixed.subfunctions
217232
f_vec.interpolate(as_vector([x, y]))
218233
f_tensor.interpolate(as_matrix([[x, y, x*y], [y, x, x*y]]))
219234
f_mixed_at_points = evaluator.evaluate(f_mixed)
235+
assert isinstance(f_mixed_at_points, tuple)
236+
assert len(f_mixed_at_points) == 2
220237
assert np.allclose(f_mixed_at_points[0], vec_expected)
221238
assert np.allclose(f_mixed_at_points[1], tensor_expected)
239+
assert isinstance(f_mixed_at_points[0], np.ndarray)
240+
assert isinstance(f_mixed_at_points[1], np.ndarray)
241+
assert f_mixed_at_points[0].shape == (len(evaluator.points),) + f_vec.ufl_shape
242+
assert f_mixed_at_points[1].shape == (len(evaluator.points),) + f_tensor.ufl_shape
222243

223244

224245
@pytest.mark.parallel(3)
@@ -286,58 +307,98 @@ def test_point_evaluator_tolerance():
286307

287308
def test_point_evaluator_inputs_1d():
288309
mesh = UnitIntervalMesh(1)
289-
f = mesh.coordinates
310+
f = mesh.coordinates # ufl_shape (1,)
290311

291312
# one point
292313
for input in [0.2, (0.2,), [0.2], np.array([0.2])]:
293314
e = PointEvaluator(mesh, input)
294-
assert np.allclose(0.2, e.evaluate(f))
315+
res = e.evaluate(f)
316+
assert np.allclose(0.2, res)
317+
assert isinstance(res, np.ndarray)
318+
assert res.shape == (1,)
319+
assert isinstance(res[0], Number)
295320

296321
# multiple points as tuples/list
297322
for input in [
298323
(0.2, 0.3), ((0.2,), (0.3,)), ([0.2], [0.3]),
299324
(np.array(0.2), np.array(0.3)), (np.array([0.2]), np.array([0.3]))
300325
]:
301326
e2 = PointEvaluator(mesh, input)
302-
assert np.allclose([[0.2, 0.3]], e2.evaluate(f))
327+
res = e2.evaluate(f)
328+
assert np.allclose([[0.2, 0.3]], res)
329+
assert isinstance(res, np.ndarray)
330+
assert res.shape == (len(input),)
331+
assert isinstance(res[0], Number)
332+
303333
e3 = PointEvaluator(mesh, list(input))
304-
assert np.allclose([[0.2, 0.3]], e3.evaluate(f))
334+
res2 = e3.evaluate(f)
335+
assert np.allclose([[0.2, 0.3]], res2)
336+
assert isinstance(res2, np.ndarray)
337+
assert res2.shape == (len(input),)
338+
assert isinstance(res2[0], Number)
305339

306340
# multiple points as numpy array
307341
for input in [np.array([0.2, 0.3]), np.array([[0.2], [0.3]])]:
308342
e = PointEvaluator(mesh, input)
309-
assert np.allclose([[0.2, 0.3]], e.evaluate(f))
343+
res = e.evaluate(f)
344+
assert np.allclose([[0.2, 0.3]], res)
345+
assert isinstance(res, np.ndarray)
346+
assert res.shape == (len(input),)
347+
assert isinstance(res[0], Number)
310348

311349
# test incorrect inputs
312350
for input in [[[0.2, 0.3]], ([0.2, 0.3], [0.4, 0.5]), np.array([[0.2, 0.3]])]:
313-
with pytest.raises(ValueError):
351+
with pytest.raises(ValueError, match=r"Point dimension \(2\) does not match geometric dimension \(1\)."):
314352
PointEvaluator(mesh, input)
315353

316354

317355
def test_point_evaluator_inputs_2d():
318356
mesh = UnitSquareMesh(1, 1)
319-
f = mesh.coordinates
357+
f = mesh.coordinates # ufl_shape (2,)
320358

321359
# one point
322360
for input in [(0.2, 0.4), [0.2, 0.4], [[0.2, 0.4]], np.array([0.2, 0.4])]:
323361
e = PointEvaluator(mesh, input)
324-
assert np.allclose([0.2, 0.4], e.evaluate(f))
362+
res = e.evaluate(f)
363+
assert np.allclose([0.2, 0.4], res)
364+
assert isinstance(res, np.ndarray)
365+
assert res.shape == (1,) + f.ufl_shape
366+
assert isinstance(res[0], np.ndarray)
367+
assert isinstance(res[0, 0], Number)
325368

326369
# multiple points as tuple
327370
for input in [
328371
((0.2, 0.4), (0.3, 0.5)), ([0.2, 0.4], [0.3, 0.5]),
329372
(np.array([0.2, 0.4]), np.array([0.3, 0.5]))
330373
]:
331374
e1 = PointEvaluator(mesh, input)
332-
assert np.allclose([[0.2, 0.4], [0.3, 0.5]], e1.evaluate(f))
375+
res1 = e1.evaluate(f)
376+
assert np.allclose([[0.2, 0.4], [0.3, 0.5]], res1)
377+
assert isinstance(res1, np.ndarray)
378+
assert res1.shape == (len(input),) + f.ufl_shape
379+
assert isinstance(res1[0], np.ndarray)
380+
assert isinstance(res1[0, 0], Number)
381+
333382
e2 = PointEvaluator(mesh, list(input))
334383
assert np.allclose([[0.2, 0.4], [0.3, 0.5]], e2.evaluate(f))
335384

336385
# multiple points as numpy array
337-
e = PointEvaluator(mesh, np.array([[0.2, 0.4], [0.3, 0.5]]))
338-
assert np.allclose([[0.2, 0.4], [0.3, 0.5]], e.evaluate(f))
386+
points = np.array([[0.2, 0.4], [0.3, 0.5]])
387+
e = PointEvaluator(mesh, points)
388+
res = e.evaluate(f)
389+
assert np.allclose([[0.2, 0.4], [0.3, 0.5]], res)
390+
assert isinstance(res, np.ndarray)
391+
assert res.shape == (len(points),) + f.ufl_shape
392+
assert isinstance(res[0], np.ndarray)
393+
assert isinstance(res[0, 0], Number)
394+
395+
res2 = e.evaluate(f)
396+
res3 = res + res2
397+
assert np.allclose([[0.4, 0.8], [0.6, 1.0]], res3)
398+
assert isinstance(res3, np.ndarray)
399+
assert res3.shape == (len(points),) + f.ufl_shape
339400

340401
# test incorrect inputs
341402
for input in [0.2, [0.2]]:
342-
with pytest.raises(ValueError):
403+
with pytest.raises(ValueError, match=r"Point dimension \(1\) does not match geometric dimension \(2\)."):
343404
PointEvaluator(mesh, input)

tests/firedrake/regression/test_point_eval_fs.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from os.path import abspath, dirname
2+
from numbers import Number
23
import numpy as np
34
import pytest
45

@@ -104,6 +105,11 @@ def test_triangle_tensor(mesh_triangle, family, degree):
104105

105106
assert np.allclose([[0.4, 0.8], [0.48, 0.08]], f([0.6, 0.4]))
106107
assert np.allclose([[0.9, 0.2], [0.00, 0.18]], f([0.0, 0.9]))
108+
res = f([0.1, 0.2])
109+
assert isinstance(res, np.ndarray)
110+
assert res.shape == (2, 2)
111+
assert isinstance(res[0, :], np.ndarray)
112+
assert isinstance(res[0, 0], Number)
107113

108114

109115
def test_triangle_mixed(mesh_triangle):
@@ -143,6 +149,10 @@ def test_quadrilateral(mesh_quadrilateral, family, degree):
143149
f = Function(V).interpolate((x[0] - 0.5)*(x[1] - 0.2))
144150
assert np.allclose(+0.02, f([0.6, 0.4]))
145151
assert np.allclose(-0.35, f([0.0, 0.9]))
152+
res = f([0.1, 0.2])
153+
assert isinstance(res, np.ndarray)
154+
assert len(res.shape) == 0
155+
assert isinstance(res.item(), Number)
146156

147157

148158
@pytest.mark.parametrize(('family', 'degree'),
@@ -161,6 +171,10 @@ def test_quadrilateral_vector(mesh_quadrilateral, family, degree):
161171

162172
assert np.allclose([0.6, 0.56], f([0.6, 0.4]))
163173
assert np.allclose([1.1, 0.18], f([0.0, 0.9]))
174+
res = f([0.1, 0.2])
175+
assert isinstance(res, np.ndarray)
176+
assert len(res.shape) == 1
177+
assert isinstance(res[0], Number)
164178

165179

166180
@pytest.mark.parametrize(('family', 'degree'),
@@ -172,6 +186,10 @@ def test_tetrahedron(mesh_tetrahedron, family, degree):
172186
f = Function(V).interpolate((x[0] - 0.5)*(x[1] - x[2]))
173187
assert np.allclose(+0.01, f([0.6, 0.4, 0.3]))
174188
assert np.allclose(-0.06, f([0.4, 0.7, 0.1]))
189+
res = f([0.2, 0.3, 0.4])
190+
assert isinstance(res, np.ndarray)
191+
assert len(res.shape) == 0
192+
assert isinstance(res.item(), Number)
175193

176194

177195
@pytest.mark.parametrize(('family', 'degree'),
@@ -192,6 +210,10 @@ def test_tetrahedron_vector(mesh_tetrahedron, family, degree):
192210

193211
assert np.allclose([0.6, 0.54, 0.4], f([0.6, 0.4, 0.3]))
194212
assert np.allclose([0.9, 0.34, 0.7], f([0.4, 0.7, 0.1]))
213+
res = f([0.2, 0.3, 0.4])
214+
assert isinstance(res, np.ndarray)
215+
assert len(res.shape) == 1
216+
assert isinstance(res[0], Number)
195217

196218

197219
def test_point_eval_forces_writes():

0 commit comments

Comments
 (0)