Skip to content

Commit 287e019

Browse files
committed
Reverse-over-forward AD: SupermeshProjectBlock
1 parent 76b5a4c commit 287e019

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs):
842842
mesh = target_space.mesh()
843843
self.source_space = source.function_space()
844844
self.target_space = target_space
845+
self._kwargs = dict(kwargs)
845846
self.projector = firedrake.Projector(source, target_space, **kwargs)
846847

847848
# Assemble mixed mass matrix
@@ -922,6 +923,15 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
922923
prepared)
923924
return dJdm
924925

926+
def solve_tlm(self):
927+
x, = self.get_outputs()
928+
dep, = self.get_dependencies()
929+
if dep.tlm_value is None:
930+
x.tlm_value = None
931+
else:
932+
x.tlm_value = firedrake.Function(x.output.function_space())
933+
firedrake.project(dep.tlm_value, x.tlm_value, **self._kwargs)
934+
925935
def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
926936
block_variable, idx,
927937
relevant_dependencies, prepared=None):

tests/regression/test_adjoint_reverse_over_forward.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,22 +86,44 @@ def test_function_assignment():
8686

8787
@pytest.mark.skipcomplex
8888
def test_project():
89-
mesh = UnitIntervalMesh(10)
89+
mesh = UnitSquareMesh(10, 10)
9090
X = SpatialCoordinate(mesh)
91-
space = FunctionSpace(mesh, "Lagrange", 1)
92-
test = TestFunction(space)
91+
space_a = FunctionSpace(mesh, "Lagrange", 1)
92+
space_b = FunctionSpace(mesh, "Discontinuous Lagrange", 0)
93+
test_a = TestFunction(space_a)
9394

94-
u = Function(space, name="u").interpolate(X[0] - 0.5)
95-
zeta = Function(space, name="tlm_u").interpolate(X[0])
95+
u = Function(space_a, name="u").interpolate(X[0] - 0.5)
96+
zeta = Function(space_a, name="tlm_u").interpolate(X[0])
9697
u.block_variable.tlm_value = zeta.copy(deepcopy=True)
9798

98-
space_0 = FunctionSpace(mesh, "Discontinuous Lagrange", 0)
99+
with reverse_over_forward():
100+
v = Function(space_b, name="v").project(u)
101+
J = assemble(v * v * dx)
102+
103+
_ = compute_gradient(J.block_variable.tlm_value, Control(u))
104+
adj_value = u.block_variable.adj_value
105+
assert np.allclose(adj_value.dat.data_ro,
106+
assemble(2 * inner(Function(space_b).project(zeta), test_a) * dx).dat.data_ro)
107+
108+
109+
@pytest.mark.skipcomplex
110+
def test_supermesh_project():
111+
mesh_a = UnitSquareMesh(10, 10)
112+
mesh_b = UnitSquareMesh(5, 20)
113+
X_a = SpatialCoordinate(mesh_a)
114+
space_a = FunctionSpace(mesh_a, "Lagrange", 1)
115+
space_b = FunctionSpace(mesh_b, "Discontinuous Lagrange", 0)
116+
test_a = TestFunction(space_a)
117+
118+
u = Function(space_a, name="u").interpolate(X_a[0] - 0.5)
119+
zeta = Function(space_a, name="tlm_u").interpolate(X_a[0])
120+
u.block_variable.tlm_value = zeta.copy(deepcopy=True)
99121

100122
with reverse_over_forward():
101-
v = Function(space_0, name="v").project(u)
123+
v = Function(space_b, name="v").project(u)
102124
J = assemble(v * v * dx)
103125

104126
_ = compute_gradient(J.block_variable.tlm_value, Control(u))
105127
adj_value = u.block_variable.adj_value
106128
assert np.allclose(adj_value.dat.data_ro,
107-
assemble(2 * inner(Function(space_0).project(zeta), test) * dx).dat.data_ro)
129+
assemble(2 * inner(Function(space_a).project(Function(space_b).project(zeta)), test_a) * dx).dat.data_ro)

0 commit comments

Comments
 (0)