1+ from contextlib import contextmanager
12import numpy as np
23import pytest
34
67from firedrake .__future__ import *
78
89
9- @pytest .fixture (autouse = True , scope = "module" )
10+ @pytest .fixture (autouse = True )
1011def _ ():
1112 get_working_tape ().clear_tape ()
1213 pause_annotation ()
@@ -17,6 +18,16 @@ def _():
1718 pause_reverse_over_forward ()
1819
1920
21+ @contextmanager
22+ def reverse_over_forward ():
23+ continue_annotation ()
24+ continue_reverse_over_forward ()
25+ yield
26+ pause_annotation ()
27+ pause_reverse_over_forward ()
28+
29+
30+ @pytest .mark .skipcomplex
2031def test_assembly ():
2132 mesh = UnitIntervalMesh (10 )
2233 X = SpatialCoordinate (mesh )
@@ -27,13 +38,26 @@ def test_assembly():
2738 zeta = Function (space , name = "tlm_u" ).interpolate (X [0 ])
2839 u .block_variable .tlm_value = zeta .copy (deepcopy = True )
2940
30- continue_annotation ()
31- continue_reverse_over_forward ()
32- J = assemble (u * u * dx )
33- pause_annotation ()
34- pause_reverse_over_forward ()
41+ with reverse_over_forward ():
42+ J = assemble (u * u * dx )
3543
3644 _ = compute_gradient (J .block_variable .tlm_value , Control (u ))
3745 adj_value = u .block_variable .adj_value
3846 assert np .allclose (adj_value .dat .data_ro ,
3947 assemble (2 * inner (zeta , test ) * dx ).dat .data_ro )
48+
49+
50+ @pytest .mark .skipcomplex
51+ def test_constant_assignment ():
52+ a = Constant (2.5 )
53+ a .block_variable .tlm_value = Constant (- 2.0 )
54+
55+ with reverse_over_forward ():
56+ b = Constant (0.0 ).assign (a )
57+
58+ assert float (b .block_variable .tlm_value ) == - 2.0
59+
60+ # Minimal test that the TLM operations are on the tape
61+ _ = compute_gradient (b .block_variable .tlm_value , Control (a .block_variable .tlm_value ))
62+ adj_value = a .block_variable .tlm_value .block_variable .adj_value
63+ assert float (adj_value ) == 1.0
0 commit comments