Skip to content

Commit 08a6d32

Browse files
committed
Test improvements
1 parent 692699c commit 08a6d32

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

tests/regression/test_adjoint_reverse_over_forward.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_subfunction(idx):
121121

122122
with reverse_over_forward():
123123
u = Function(space, name="u")
124-
u.sub(idx).interpolate(-2 * X[0])
124+
u.sub(idx).interpolate(X[0] - 0.5)
125125
u_ref = u.copy(deepcopy=True)
126126
zeta = Function(space, name="zeta")
127127
zeta.sub(idx).interpolate(X[0])
@@ -140,7 +140,7 @@ def test_subfunction(idx):
140140

141141
@pytest.mark.skipcomplex
142142
def test_interpolate():
143-
mesh = UnitSquareMesh(10, 10)
143+
mesh = UnitIntervalMesh(10)
144144
X = SpatialCoordinate(mesh)
145145
space_a = FunctionSpace(mesh, "Lagrange", 1)
146146
space_b = FunctionSpace(mesh, "Lagrange", 2)
@@ -162,9 +162,33 @@ def test_interpolate():
162162
assemble(6 * inner(u_ref * zeta, test_a) * dx).dat.data_ro)
163163

164164

165+
@pytest.mark.skipcomplex
166+
def test_interpolate_expr():
167+
mesh = UnitIntervalMesh(10)
168+
X = SpatialCoordinate(mesh)
169+
space_a = FunctionSpace(mesh, "Lagrange", 1)
170+
space_b = FunctionSpace(mesh, "Lagrange", 2)
171+
test_a = TestFunction(space_a)
172+
173+
with reverse_over_forward():
174+
u = Function(space_a, name="u").interpolate(X[0] - 0.5)
175+
u_ref = u.copy(deepcopy=True)
176+
zeta = Function(space_a, name="zeta").interpolate(X[0])
177+
u.block_variable.tlm_value = zeta.copy(deepcopy=True)
178+
179+
v = Function(space_b, name="v").interpolate(-3 * u)
180+
J = assemble(v ** 3 * dx)
181+
182+
_ = compute_gradient(J.block_variable.tlm_value, Control(u))
183+
adj_value = u.block_variable.adj_value
184+
assert np.allclose(
185+
adj_value.dat.data_ro,
186+
assemble(-162 * inner(u_ref * zeta, test_a) * dx).dat.data_ro)
187+
188+
165189
@pytest.mark.skipcomplex
166190
def test_project():
167-
mesh = UnitSquareMesh(10, 10)
191+
mesh = UnitIntervalMesh(10)
168192
X = SpatialCoordinate(mesh)
169193
space_a = FunctionSpace(mesh, "Lagrange", 1)
170194
space_b = FunctionSpace(mesh, "Discontinuous Lagrange", 0)

0 commit comments

Comments
 (0)