Skip to content

Commit 28e3ea7

Browse files
committed
Reverse-over-forward AD: ConstantAssignBlock
1 parent 223479a commit 28e3ea7

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

firedrake/adjoint_utils/blocks/constant.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from pyadjoint import Block, OverloadedType
1+
import firedrake
22
import numpy
3-
3+
from pyadjoint import Block, OverloadedType
44
from pyadjoint.reduced_functional_numpy import gather
55
from .block_utils import isconstant
66

@@ -70,6 +70,21 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
7070
values = values.values()
7171
return constant_from_values(block_variable.output, values)
7272

73+
def solve_tlm(self):
74+
x, = self.get_outputs()
75+
if len(x.output.ufl_shape) == 0:
76+
x.tlm_value = firedrake.Constant(0.0)
77+
else:
78+
x.tlm_value = firedrake.Constant(
79+
numpy.reshape(numpy.zeros_like(x.output.values()), x.output.ufl_shape))
80+
if self.assigned_list:
81+
# Not reachable?
82+
raise NotImplementedError
83+
else:
84+
dep, = self.get_dependencies()
85+
if dep.tlm_value is not None:
86+
x.tlm_value.assign(dep.tlm_value)
87+
7388
def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
7489
relevant_dependencies):
7590
return self.prepare_evaluate_adj(inputs, hessian_inputs,

tests/regression/test_adjoint_reverse_over_forward.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import contextmanager
12
import numpy as np
23
import pytest
34

@@ -6,7 +7,7 @@
67
from firedrake.__future__ import *
78

89

9-
@pytest.fixture(autouse=True, scope="module")
10+
@pytest.fixture(autouse=True)
1011
def _():
1112
get_working_tape().clear_tape()
1213
pause_annotation()
@@ -17,6 +18,16 @@ def _():
1718
pause_reverse_over_forward()
1819

1920

21+
@contextmanager
22+
def reverse_over_forward():
23+
continue_annotation()
24+
continue_reverse_over_forward()
25+
yield
26+
pause_annotation()
27+
pause_reverse_over_forward()
28+
29+
30+
@pytest.mark.skipcomplex
2031
def test_assembly():
2132
mesh = UnitIntervalMesh(10)
2233
X = SpatialCoordinate(mesh)
@@ -27,13 +38,26 @@ def test_assembly():
2738
zeta = Function(space, name="tlm_u").interpolate(X[0])
2839
u.block_variable.tlm_value = zeta.copy(deepcopy=True)
2940

30-
continue_annotation()
31-
continue_reverse_over_forward()
32-
J = assemble(u * u * dx)
33-
pause_annotation()
34-
pause_reverse_over_forward()
41+
with reverse_over_forward():
42+
J = assemble(u * u * dx)
3543

3644
_ = compute_gradient(J.block_variable.tlm_value, Control(u))
3745
adj_value = u.block_variable.adj_value
3846
assert np.allclose(adj_value.dat.data_ro,
3947
assemble(2 * inner(zeta, test) * dx).dat.data_ro)
48+
49+
50+
@pytest.mark.skipcomplex
51+
def test_constant_assignment():
52+
a = Constant(2.5)
53+
a.block_variable.tlm_value = Constant(-2.0)
54+
55+
with reverse_over_forward():
56+
b = Constant(0.0).assign(a)
57+
58+
assert float(b.block_variable.tlm_value) == -2.0
59+
60+
# Minimal test that the TLM operations are on the tape
61+
_ = compute_gradient(b.block_variable.tlm_value, Control(a.block_variable.tlm_value))
62+
adj_value = a.block_variable.tlm_value.block_variable.adj_value
63+
assert float(adj_value) == 1.0

0 commit comments

Comments
 (0)