Skip to content

Commit c35181c

Browse files
committed
Reverse-over-forward AD: FunctionAssignBlock. Minor test edits.
1 parent 29dfcf3 commit c35181c

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

firedrake/adjoint_utils/blocks/function.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,25 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
129129

130130
return dudm
131131

132+
def solve_tlm(self):
133+
x, = self.get_outputs()
134+
expr = self.expr
135+
136+
tlm_rhs = 0
137+
for block_variable in self.get_dependencies():
138+
dep = block_variable.output
139+
tlm_dep = block_variable.tlm_value
140+
if tlm_dep is not None:
141+
tlm_rhs = tlm_rhs + ufl.derivative(expr, dep, tlm_dep)
142+
143+
x.tlm_value = None
144+
if isinstance(tlm_rhs, int) and tlm_rhs == 0:
145+
return
146+
tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs)
147+
if isinstance(tlm_rhs, ufl.constantvalue.Zero):
148+
return
149+
x.tlm_value = firedrake.Function(x.output.function_space()).assign(tlm_rhs)
150+
132151
def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
133152
relevant_dependencies):
134153
return self.prepare_evaluate_adj(inputs, hessian_inputs,

tests/regression/test_adjoint_reverse_over_forward.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,17 @@ def test_assembly():
3434
space = FunctionSpace(mesh, "Lagrange", 1)
3535
test = TestFunction(space)
3636

37-
u = Function(space, name="u").interpolate(Constant(1.0))
37+
u = Function(space, name="u").interpolate(X[0])
3838
zeta = Function(space, name="tlm_u").interpolate(X[0])
3939
u.block_variable.tlm_value = zeta.copy(deepcopy=True)
4040

4141
with reverse_over_forward():
42-
J = assemble(u * u * dx)
42+
J = assemble(u * u.dx(0) * dx)
4343

4444
_ = compute_gradient(J.block_variable.tlm_value, Control(u))
4545
adj_value = u.block_variable.adj_value
4646
assert np.allclose(adj_value.dat.data_ro,
47-
assemble(2 * inner(zeta, test) * dx).dat.data_ro)
47+
assemble(inner(zeta.dx(0), test) * dx + inner(zeta, test.dx(0)) * dx).dat.data_ro)
4848

4949

5050
@pytest.mark.skipcomplex
@@ -57,7 +57,28 @@ def test_constant_assignment():
5757

5858
assert float(b.block_variable.tlm_value) == -2.0
5959

60-
# Minimal test that the TLM operations are on the tape
60+
# Minimal test that the TLM operation is on the tape
6161
_ = compute_gradient(b.block_variable.tlm_value, Control(a.block_variable.tlm_value))
6262
adj_value = a.block_variable.tlm_value.block_variable.adj_value
6363
assert float(adj_value) == 1.0
64+
65+
66+
@pytest.mark.skipcomplex
67+
def test_function_assignment():
68+
mesh = UnitIntervalMesh(10)
69+
X = SpatialCoordinate(mesh)
70+
space = FunctionSpace(mesh, "Lagrange", 1)
71+
test = TestFunction(space)
72+
73+
u = Function(space, name="u").interpolate(X[0] - 0.5)
74+
zeta = Function(space, name="tlm_u").interpolate(X[0])
75+
u.block_variable.tlm_value = zeta.copy(deepcopy=True)
76+
77+
with reverse_over_forward():
78+
v = Function(space, name="v").assign(-3 * u)
79+
J = assemble(v * v.dx(0) * dx)
80+
81+
_ = compute_gradient(J.block_variable.tlm_value, Control(u))
82+
adj_value = u.block_variable.adj_value
83+
assert np.allclose(adj_value.dat.data_ro,
84+
assemble(9 * inner(zeta.dx(0), test) * dx + 9 * inner(zeta, test.dx(0)) * dx).dat.data_ro)

0 commit comments

Comments
 (0)