Skip to content

Commit 369db0b

Browse files
committed
Reverse-over-forward AD: AssembleBlock
1 parent 12e2082 commit 369db0b

File tree

4 files changed

+133
-2
lines changed

4 files changed

+133
-2
lines changed

firedrake/adjoint/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
from pyadjoint.tape import Tape, set_working_tape, get_working_tape, \
2121
pause_annotation, continue_annotation, \
22-
stop_annotating, annotate_tape # noqa F401
22+
stop_annotating, annotate_tape, \
23+
pause_reverse_over_forward, continue_reverse_over_forward, \
24+
stop_reverse_over_forward # noqa F401
2325
from pyadjoint.reduced_functional import ReducedFunctional # noqa F401
2426
from firedrake.adjoint_utils.checkpointing import \
2527
enable_disk_checkpointing, pause_disk_checkpointing, \

firedrake/adjoint_utils/blocks/assembly.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ufl.formatting.ufl2unicode import ufl2unicode
55
from pyadjoint import Block, AdjFloat, create_overloaded_object
66
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint
7-
from .block_utils import isconstant
7+
from .block_utils import isconstant, restored_outputs
88

99

1010
class AssembleBlock(Block):
@@ -145,6 +145,32 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
145145
dform = firedrake.assemble(dform)
146146
return dform
147147

148+
def solve_tlm(self):
149+
x, = self.get_outputs()
150+
form = self.form
151+
152+
tlm_rhs = 0
153+
for block_variable in self.get_dependencies():
154+
dep = block_variable.output
155+
tlm_dep = block_variable.tlm_value
156+
if tlm_dep is not None:
157+
if isinstance(dep, firedrake.MeshGeometry):
158+
dep = firedrake.SpatialCoordinate(dep)
159+
tlm_rhs = tlm_rhs + firedrake.derivative(
160+
form, dep, tlm_dep)
161+
else:
162+
tlm_rhs = tlm_rhs + firedrake.action(
163+
firedrake.derivative(form, dep), tlm_dep)
164+
165+
x.tlm_value = None
166+
if isinstance(tlm_rhs, int) and tlm_rhs == 0:
167+
return
168+
tau_rhs = ufl.algorithms.expand_derivatives(tlm_rhs)
169+
if tau_rhs.empty():
170+
return
171+
with restored_outputs(x, restore=lambda x: x in form.coefficients()):
172+
x.tlm_value = firedrake.assemble(tau_rhs)
173+
148174
def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
149175
relevant_dependencies):
150176
return self.prepare_evaluate_adj(inputs, adj_inputs,

firedrake/adjoint_utils/blocks/block_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from contextlib import contextmanager
2+
13
import firedrake
24

35

@@ -8,3 +10,65 @@ def isconstant(expr):
810
if isinstance(expr, firedrake.Constant):
911
raise ValueError("Firedrake Constant requires a domain to work with pyadjoint")
1012
return isinstance(expr, firedrake.Function) and expr.ufl_element().family() == "Real"
13+
14+
15+
@contextmanager
16+
def restored_outputs(*X, restore=None):
17+
"""Construct a context manager which can be used to temporarily restore
18+
block variable outputs to saved values.
19+
20+
Parameters
21+
----------
22+
X : tuple[BlockVariable]
23+
Block variables to temporarily restore.
24+
restore : callable
25+
Can be used to exclude variables. Only inputs for which
26+
`restore(x.output)` is true have their outputs temporarily restored.
27+
28+
Returns
29+
-------
30+
31+
The context manager.
32+
33+
Notes
34+
-----
35+
36+
A forward operation is allowed to modify the original variable, e.g. in
37+
38+
.. code-block:: python3
39+
40+
solve(inner(trial, test) * dx
41+
== inner(x * x, test) * dx,
42+
x)
43+
44+
`x` has two versions: the input and the output. Reverse-over-forward AD
45+
requires that we use the symbolic representation `x`, but with input value
46+
`x.block_variable.saved_output`. A context manager can be used to
47+
temporarily restore the value of `x` so that we can perform and annotate
48+
a tangent-linear operation,
49+
50+
.. code-block:: python3
51+
52+
with restored_outputs(x):
53+
# The value of x is now x.block_variable.saved_output
54+
solve(inner(trial, test) * dx
55+
== 2 * inner(x * x.block_variable.tlm_value, test) * dx,
56+
x.block_variable.tlm_value)
57+
# The value of x is again the output from the forward solve(...)
58+
"""
59+
60+
if restore is None:
61+
def restore(x):
62+
return True
63+
64+
X = tuple(x for x in X if restore(x.output))
65+
X_old = tuple(x.output._ad_copy(x) for x in X)
66+
for x in X:
67+
# Ideally would use a generic _ad_assign here
68+
x.output.assign(x.output.block_variable.saved_output)
69+
try:
70+
yield
71+
finally:
72+
for x, x_old in zip(X, X_old):
73+
# Ideally would use a generic _ad_assign here
74+
x.output.assign(x_old)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
import pytest
3+
4+
from firedrake import *
5+
from firedrake.adjoint import *
6+
from firedrake.__future__ import *
7+
8+
9+
@pytest.fixture(autouse=True, scope="module")
10+
def _():
11+
get_working_tape().clear_tape()
12+
pause_annotation()
13+
pause_reverse_over_forward()
14+
yield
15+
get_working_tape().clear_tape()
16+
pause_annotation()
17+
pause_reverse_over_forward()
18+
19+
20+
def test_assembly():
21+
mesh = UnitIntervalMesh(10)
22+
X = SpatialCoordinate(mesh)
23+
space = FunctionSpace(mesh, "Lagrange", 1)
24+
test = TestFunction(space)
25+
26+
u = Function(space, name="u").interpolate(Constant(1.0))
27+
zeta = Function(space, name="tlm_u").interpolate(X[0])
28+
u.block_variable.tlm_value = zeta.copy(deepcopy=True)
29+
30+
continue_annotation()
31+
continue_reverse_over_forward()
32+
J = assemble(u * u * dx)
33+
pause_annotation()
34+
pause_reverse_over_forward()
35+
36+
_ = compute_gradient(J.block_variable.tlm_value, Control(u))
37+
adj_value = u.block_variable.adj_value
38+
assert np.allclose(adj_value.dat.data_ro,
39+
assemble(2 * inner(zeta, test) * dx).dat.data_ro)

0 commit comments

Comments
 (0)