Skip to content

Commit 223479a

Browse files
committed
Remove restored_output, now handled within pyadjoint
1 parent c9794be commit 223479a

File tree

2 files changed

+2
-67
lines changed

2 files changed

+2
-67
lines changed

firedrake/adjoint_utils/blocks/assembly.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ufl.formatting.ufl2unicode import ufl2unicode
55
from pyadjoint import Block, AdjFloat, create_overloaded_object
66
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint
7-
from .block_utils import isconstant, restored_outputs
7+
from .block_utils import isconstant
88

99

1010
class AssembleBlock(Block):
@@ -168,8 +168,7 @@ def solve_tlm(self):
168168
tau_rhs = ufl.algorithms.expand_derivatives(tlm_rhs)
169169
if tau_rhs.empty():
170170
return
171-
with restored_outputs(x, restore=lambda x: x in form.coefficients()):
172-
x.tlm_value = firedrake.assemble(tau_rhs)
171+
x.tlm_value = firedrake.assemble(tau_rhs)
173172

174173
def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
175174
relevant_dependencies):
Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from contextlib import contextmanager
2-
31
import firedrake
42

53

@@ -10,65 +8,3 @@ def isconstant(expr):
108
if isinstance(expr, firedrake.Constant):
119
raise ValueError("Firedrake Constant requires a domain to work with pyadjoint")
1210
return isinstance(expr, firedrake.Function) and expr.ufl_element().family() == "Real"
13-
14-
15-
@contextmanager
16-
def restored_outputs(*X, restore=None):
17-
"""Construct a context manager which can be used to temporarily restore
18-
block variable outputs to saved values.
19-
20-
Parameters
21-
----------
22-
X : tuple[BlockVariable]
23-
Block variables to temporarily restore.
24-
restore : callable
25-
Can be used to exclude variables. Only inputs for which
26-
`restore(x.output)` is true have their outputs temporarily restored.
27-
28-
Returns
29-
-------
30-
31-
The context manager.
32-
33-
Notes
34-
-----
35-
36-
A forward operation is allowed to modify the original variable, e.g. in
37-
38-
.. code-block:: python3
39-
40-
solve(inner(trial, test) * dx
41-
== inner(x * x, test) * dx,
42-
x)
43-
44-
`x` has two versions: the input and the output. Reverse-over-forward AD
45-
requires that we use the symbolic representation `x`, but with input value
46-
`x.block_variable.saved_output`. A context manager can be used to
47-
temporarily restore the value of `x` so that we can perform and annotate
48-
a tangent-linear operation,
49-
50-
.. code-block:: python3
51-
52-
with restored_outputs(x):
53-
# The value of x is now x.block_variable.saved_output
54-
solve(inner(trial, test) * dx
55-
== 2 * inner(x * x.block_variable.tlm_value, test) * dx,
56-
x.block_variable.tlm_value)
57-
# The value of x is again the output from the forward solve(...)
58-
"""
59-
60-
if restore is None:
61-
def restore(x):
62-
return True
63-
64-
X = tuple(x for x in X if restore(x.output))
65-
X_old = tuple(x.output._ad_copy(x) for x in X)
66-
for x in X:
67-
# Ideally would use a generic _ad_assign here
68-
x.output.assign(x.output.block_variable.saved_output)
69-
try:
70-
yield
71-
finally:
72-
for x, x_old in zip(X, X_old):
73-
# Ideally would use a generic _ad_assign here
74-
x.output.assign(x_old)

0 commit comments

Comments
 (0)