Skip to content

Commit 8a830ba

Browse files
fix Function._ad_assign_numpy for real space in parallel (#4714)
Co-authored-by: Connor Ward <[email protected]>
1 parent 1c2a50f commit 8a830ba

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

firedrake/adjoint_utils/function.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,12 @@ def _ad_dot(self, other, options=None):
278278
def _ad_assign_numpy(dst, src, offset):
279279
range_begin, range_end = dst.dat.dataset.layout_vec.getOwnershipRange()
280280
m_a_local = src[offset + range_begin:offset + range_end]
281-
dst.dat.data_wo[...] = m_a_local.reshape(dst.dat.data_wo.shape)
281+
if dst.function_space().ufl_element().family() == "Real":
282+
# Real space keeps a redundant copy of the data on every rank
283+
comm = dst.function_space().mesh()._comm
284+
dst.dat.data_wo[...] = comm.bcast(m_a_local, root=0)
285+
else:
286+
dst.dat.data_wo[...] = m_a_local.reshape(dst.dat.data_wo.shape)
282287
offset += dst.dat.dataset.layout_vec.size
283288
return dst, offset
284289

tests/firedrake/adjoint/test_reduced_functional.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from firedrake import *
44
from firedrake.adjoint import *
5+
from pytest_mpi.parallel_assert import parallel_assert
56

67
from numpy.random import rand
78

@@ -251,3 +252,34 @@ def test_interpolate_mixed():
251252
h.subfunctions[0].dat.data[:] = 5
252253
h.subfunctions[1].dat.data[:] = 6
253254
assert taylor_test(Jhat, f, h) > 1.9
255+
256+
257+
@pytest.mark.skipcomplex
258+
@pytest.mark.parallel(2)
259+
def test_real_space_assign_numpy():
260+
"""Check that Function._ad_assign_numpy correctly handles
261+
zero length arrays on some ranks for Real space in parallel.
262+
"""
263+
mesh = UnitSquareMesh(1, 1)
264+
R = FunctionSpace(mesh, "R", 0)
265+
dst = Function(R)
266+
src = dst.dat.dataset.layout_vec.array_r.copy()
267+
data = 1 + np.arange(src.shape[0])
268+
src[:] = data
269+
dst._ad_assign_numpy(dst, src, offset=0)
270+
parallel_assert(np.allclose(dst.dat.data_ro, data))
271+
272+
273+
@pytest.mark.skipcomplex
274+
@pytest.mark.parallel(2)
275+
def test_real_space_parallel():
276+
"""Check that scipy.optimize works for Real space in parallel
277+
despite dat.data array having zero length on some ranks.
278+
"""
279+
mesh = UnitSquareMesh(1, 1)
280+
R = FunctionSpace(mesh, "R", 0)
281+
m = Function(R)
282+
J = assemble((m-1)**2*dx)
283+
Jhat = ReducedFunctional(J, Control(m))
284+
opt = minimize(Jhat)
285+
parallel_assert(np.allclose(opt.dat.data_ro, 1))

0 commit comments

Comments
 (0)