1- from contextlib import contextmanager
2-
31import firedrake
42
53
@@ -10,65 +8,3 @@ def isconstant(expr):
108 if isinstance (expr , firedrake .Constant ):
119 raise ValueError ("Firedrake Constant requires a domain to work with pyadjoint" )
1210 return isinstance (expr , firedrake .Function ) and expr .ufl_element ().family () == "Real"
13-
14-
15- @contextmanager
16- def restored_outputs (* X , restore = None ):
17- """Construct a context manager which can be used to temporarily restore
18- block variable outputs to saved values.
19-
20- Parameters
21- ----------
22- X : tuple[BlockVariable]
23- Block variables to temporarily restore.
24- restore : callable
25- Can be used to exclude variables. Only inputs for which
26- `restore(x.output)` is true have their outputs temporarily restored.
27-
28- Returns
29- -------
30-
31- The context manager.
32-
33- Notes
34- -----
35-
36- A forward operation is allowed to modify the original variable, e.g. in
37-
38- .. code-block:: python3
39-
40- solve(inner(trial, test) * dx
41- == inner(x * x, test) * dx,
42- x)
43-
44- `x` has two versions: the input and the output. Reverse-over-forward AD
45- requires that we use the symbolic representation `x`, but with input value
46- `x.block_variable.saved_output`. A context manager can be used to
47- temporarily restore the value of `x` so that we can perform and annotate
48- a tangent-linear operation,
49-
50- .. code-block:: python3
51-
52- with restored_outputs(x):
53- # The value of x is now x.block_variable.saved_output
54- solve(inner(trial, test) * dx
55- == 2 * inner(x * x.block_variable.tlm_value, test) * dx,
56- x.block_variable.tlm_value)
57- # The value of x is again the output from the forward solve(...)
58- """
59-
60- if restore is None :
61- def restore (x ):
62- return True
63-
64- X = tuple (x for x in X if restore (x .output ))
65- X_old = tuple (x .output ._ad_copy (x ) for x in X )
66- for x in X :
67- # Ideally would use a generic _ad_assign here
68- x .output .assign (x .output .block_variable .saved_output )
69- try :
70- yield
71- finally :
72- for x , x_old in zip (X , X_old ):
73- # Ideally would use a generic _ad_assign here
74- x .output .assign (x_old )
0 commit comments