Skip to content

Commit 76b5a4c

Browse files
committed
Reverse-over-forward AD: GenericSolveBlock
1 parent c35181c commit 76b5a4c

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,50 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
317317
dFdm, dudm, bcs
318318
)
319319

320+
def solve_tlm(self):
321+
x, = self.get_outputs()
322+
if self.linear:
323+
form = firedrake.action(self.lhs, x.output) - self.rhs
324+
else:
325+
form = self.lhs
326+
327+
tlm_rhs = 0
328+
tlm_bcs = []
329+
for block_variable in self.get_dependencies():
330+
dep = block_variable.output
331+
if dep == x.output:
332+
continue
333+
tlm_dep = block_variable.tlm_value
334+
if isinstance(dep, firedrake.DirichletBC):
335+
if tlm_value is None:
336+
tlm_bcs.append(dep.reconstruct(g=0))
337+
else:
338+
tlm_bcs.append(tlm_value)
339+
elif tlm_dep is not None:
340+
if isinstance(dep, firedrake.MeshGeometry):
341+
dep = firedrake.SpatialCoordinate(dep)
342+
tlm_rhs = tlm_rhs - firedrake.derivative(
343+
form, dep, tlm_dep)
344+
else:
345+
tlm_rhs = tlm_rhs - firedrake.action(
346+
firedrake.derivative(form, dep), tlm_dep)
347+
348+
x.tlm_value = None
349+
if isinstance(tlm_rhs, int) and tlm_rhs == 0:
350+
return
351+
tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs)
352+
if tlm_rhs.empty():
353+
return
354+
355+
if self.linear:
356+
J = self.lhs
357+
else:
358+
J = firedrake.derivative(form, x.output, firedrake.TrialFunction(x.output.function_space()))
359+
360+
x.tlm_value = firedrake.Function(x.output.function_space())
361+
firedrake.solve(J == tlm_rhs, x.tlm_value, tlm_bcs, *self.forward_args,
362+
**self.forward_kwargs)
363+
320364
def _assemble_and_solve_tlm_eq(self, dFdu, dFdm, dudm, bcs):
321365
return self._assembled_solve(dFdu, dFdm, dudm, bcs)
322366

tests/regression/test_adjoint_reverse_over_forward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,26 @@ def test_function_assignment():
8282
adj_value = u.block_variable.adj_value
8383
assert np.allclose(adj_value.dat.data_ro,
8484
assemble(9 * inner(zeta.dx(0), test) * dx + 9 * inner(zeta, test.dx(0)) * dx).dat.data_ro)
85+
86+
87+
@pytest.mark.skipcomplex
88+
def test_project():
89+
mesh = UnitIntervalMesh(10)
90+
X = SpatialCoordinate(mesh)
91+
space = FunctionSpace(mesh, "Lagrange", 1)
92+
test = TestFunction(space)
93+
94+
u = Function(space, name="u").interpolate(X[0] - 0.5)
95+
zeta = Function(space, name="tlm_u").interpolate(X[0])
96+
u.block_variable.tlm_value = zeta.copy(deepcopy=True)
97+
98+
space_0 = FunctionSpace(mesh, "Discontinuous Lagrange", 0)
99+
100+
with reverse_over_forward():
101+
v = Function(space_0, name="v").project(u)
102+
J = assemble(v * v * dx)
103+
104+
_ = compute_gradient(J.block_variable.tlm_value, Control(u))
105+
adj_value = u.block_variable.adj_value
106+
assert np.allclose(adj_value.dat.data_ro,
107+
assemble(2 * inner(Function(space_0).project(zeta), test) * dx).dat.data_ro)

0 commit comments

Comments
 (0)