@@ -34,17 +34,17 @@ def test_assembly():
3434 space = FunctionSpace (mesh , "Lagrange" , 1 )
3535 test = TestFunction (space )
3636
37- u = Function (space , name = "u" ).interpolate (Constant ( 1.0 ) )
37+ u = Function (space , name = "u" ).interpolate (X [ 0 ] )
3838 zeta = Function (space , name = "tlm_u" ).interpolate (X [0 ])
3939 u .block_variable .tlm_value = zeta .copy (deepcopy = True )
4040
4141 with reverse_over_forward ():
42- J = assemble (u * u * dx )
42+ J = assemble (u * u . dx ( 0 ) * dx )
4343
4444 _ = compute_gradient (J .block_variable .tlm_value , Control (u ))
4545 adj_value = u .block_variable .adj_value
4646 assert np .allclose (adj_value .dat .data_ro ,
47- assemble (2 * inner (zeta , test ) * dx ).dat .data_ro )
47+ assemble (inner ( zeta . dx ( 0 ), test ) * dx + inner (zeta , test . dx ( 0 ) ) * dx ).dat .data_ro )
4848
4949
5050@pytest .mark .skipcomplex
@@ -57,7 +57,28 @@ def test_constant_assignment():
5757
5858 assert float (b .block_variable .tlm_value ) == - 2.0
5959
60- # Minimal test that the TLM operations are on the tape
60+ # Minimal test that the TLM operation is on the tape
6161 _ = compute_gradient (b .block_variable .tlm_value , Control (a .block_variable .tlm_value ))
6262 adj_value = a .block_variable .tlm_value .block_variable .adj_value
6363 assert float (adj_value ) == 1.0
64+
65+
66+ @pytest .mark .skipcomplex
67+ def test_function_assignment ():
68+ mesh = UnitIntervalMesh (10 )
69+ X = SpatialCoordinate (mesh )
70+ space = FunctionSpace (mesh , "Lagrange" , 1 )
71+ test = TestFunction (space )
72+
73+ u = Function (space , name = "u" ).interpolate (X [0 ] - 0.5 )
74+ zeta = Function (space , name = "tlm_u" ).interpolate (X [0 ])
75+ u .block_variable .tlm_value = zeta .copy (deepcopy = True )
76+
77+ with reverse_over_forward ():
78+ v = Function (space , name = "v" ).assign (- 3 * u )
79+ J = assemble (v * v .dx (0 ) * dx )
80+
81+ _ = compute_gradient (J .block_variable .tlm_value , Control (u ))
82+ adj_value = u .block_variable .adj_value
83+ assert np .allclose (adj_value .dat .data_ro ,
84+ assemble (9 * inner (zeta .dx (0 ), test ) * dx + 9 * inner (zeta , test .dx (0 )) * dx ).dat .data_ro )
0 commit comments