Skip to content

Commit b9af69f

Browse files
authored
Fix action interpolate simplification (#447)
* add comment * fix with test
1 parent ae9f816 commit b9af69f

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

test/test_interpolate.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636
from ufl.algorithms.expand_indices import expand_indices
3737
from ufl.core.interpolate import Interpolate
38-
from ufl.form import FormSum
38+
from ufl.form import Form, FormSum
3939
from ufl.pullback import identity_pullback
4040
from ufl.sobolevspace import H1
4141

@@ -110,7 +110,7 @@ def test_action_adjoint(V1, V2):
110110
Iu = Interpolate(u, vstar)
111111

112112
v1 = TrialFunction(V1)
113-
Iv = Interpolate(v1, vstar)
113+
Iv = Interpolate(v1, vstar) # V1 -> V2
114114

115115
assert Iv.argument_slots() == (vstar, v1)
116116
assert Iv.arguments() == (vstar, v1)
@@ -126,6 +126,18 @@ def test_action_adjoint(V1, V2):
126126
# -- Adjoint -- #
127127
adjoint(Iv) == Adjoint(Iv)
128128

129+
# action of one-form on interpolation operator
130+
one_form = Argument(V2, 0) * dx
131+
action_one_form = action(one_form, Iv) # adjoint interpolation V2^* -> V1^*
132+
assert isinstance(action_one_form, Interpolate)
133+
assert action_one_form.arguments() == (Argument(V1, 0),)
134+
assert action_one_form.ufl_function_space() == V1.dual()
135+
136+
# zero-form case
137+
action_zero_form = action(one_form, Iu) # a number
138+
assert isinstance(action_zero_form, Form)
139+
assert action_zero_form.arguments() == ()
140+
129141

130142
def test_differentiation(V1, V2):
131143
u = Coefficient(V1)
@@ -180,6 +192,12 @@ def test_differentiation(V1, V2):
180192
# Need to expand indices to be able to match equal (different MultiIndex used for both).
181193
assert expand_indices(dFdIu) == expand_indices(dFdw)
182194

195+
# Derivative of form I(u, V2) wrt coefficient u
196+
J = Iu * dx
197+
dJdu = expand_derivatives(derivative(J, u))
198+
assert isinstance(dJdu, Interpolate)
199+
assert dJdu.arguments() == (Argument(V1, 0),)
200+
183201

184202
def test_extract_base_form_operators(V1, V2):
185203
u = Coefficient(V1)

ufl/action.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class Action(BaseForm):
5151

5252
def __new__(cls, *args, **kw):
5353
"""Create a new Action."""
54+
from ufl.algorithms.analysis import extract_arguments
55+
from ufl.algorithms.replace import replace
56+
5457
left, right = args
5558

5659
# Check trivial case
@@ -98,6 +101,12 @@ def __new__(cls, *args, **kw):
98101
and len(left.arguments()) == 1
99102
):
100103
v, operand = right.argument_slots()
104+
# If the operand has an argument, replace it with number 0
105+
operand_args = extract_arguments(operand)
106+
if operand_args:
107+
(old_arg,) = operand_args
108+
new_arg = type(old_arg)(old_arg.ufl_function_space(), 0, old_arg.part())
109+
operand = replace(operand, {old_arg: new_arg})
101110
if v == right.arguments()[0]:
102111
return right._ufl_expr_reconstruct_(operand, v=left)
103112

0 commit comments

Comments
 (0)