Skip to content

Commit 2ec31f9

Browse files
authored
Merge pull request #2562 from firedrakeproject/connorjward/assign-weighted-sums-only
Only permit scalars and weighted sums for `assign`
2 parents 146397a + ad6d7c8 commit 2ec31f9

File tree

12 files changed

+347
-831
lines changed

12 files changed

+347
-831
lines changed

firedrake/adjoint/function.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,11 @@ def wrapper(self, other, *args, **kwargs):
149149
def _ad_annotate_iadd(__iadd__):
150150
@wraps(__iadd__)
151151
def wrapper(self, other, **kwargs):
152+
with stop_annotating():
153+
func = __iadd__(self, other, **kwargs)
154+
152155
ad_block_tag = kwargs.pop("ad_block_tag", None)
153156
annotate = annotate_tape(kwargs)
154-
func = __iadd__(self, other, **kwargs)
155-
156157
if annotate:
157158
block = FunctionAssignBlock(func, self + other, ad_block_tag=ad_block_tag)
158159
tape = get_working_tape()
@@ -167,10 +168,11 @@ def wrapper(self, other, **kwargs):
167168
def _ad_annotate_isub(__isub__):
168169
@wraps(__isub__)
169170
def wrapper(self, other, **kwargs):
171+
with stop_annotating():
172+
func = __isub__(self, other, **kwargs)
173+
170174
ad_block_tag = kwargs.pop("ad_block_tag", None)
171175
annotate = annotate_tape(kwargs)
172-
func = __isub__(self, other, **kwargs)
173-
174176
if annotate:
175177
block = FunctionAssignBlock(func, self - other, ad_block_tag=ad_block_tag)
176178
tape = get_working_tape()
@@ -185,10 +187,11 @@ def wrapper(self, other, **kwargs):
185187
def _ad_annotate_imul(__imul__):
186188
@wraps(__imul__)
187189
def wrapper(self, other, **kwargs):
190+
with stop_annotating():
191+
func = __imul__(self, other, **kwargs)
192+
188193
ad_block_tag = kwargs.pop("ad_block_tag", None)
189194
annotate = annotate_tape(kwargs)
190-
func = __imul__(self, other, **kwargs)
191-
192195
if annotate:
193196
block = FunctionAssignBlock(func, self*other, ad_block_tag=ad_block_tag)
194197
tape = get_working_tape()
@@ -203,10 +206,11 @@ def wrapper(self, other, **kwargs):
203206
def _ad_annotate_idiv(__idiv__):
204207
@wraps(__idiv__)
205208
def wrapper(self, other, **kwargs):
209+
with stop_annotating():
210+
func = __idiv__(self, other, **kwargs)
211+
206212
ad_block_tag = kwargs.pop("ad_block_tag", None)
207213
annotate = annotate_tape(kwargs)
208-
func = __idiv__(self, other, **kwargs)
209-
210214
if annotate:
211215
block = FunctionAssignBlock(func, self/other, ad_block_tag=ad_block_tag)
212216
tape = get_working_tape()

firedrake/assemble.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import finat
1010
import firedrake
1111
import numpy
12+
from pyadjoint.tape import annotate_tape
1213
from tsfc import kernel_args
1314
from tsfc.finatinterface import create_element
1415
import ufl
15-
from firedrake import (assemble_expressions, extrusion_utils as eutils, matrix, parameters, solving,
16+
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
1617
tsfc_interface, utils)
1718
from firedrake.adjoint import annotate_assemble
1819
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
@@ -100,7 +101,7 @@ def assemble(expr, *args, **kwargs):
100101
if isinstance(expr, (ufl.form.Form, slate.TensorBase)):
101102
return _assemble_form(expr, *args, **kwargs)
102103
elif isinstance(expr, ufl.core.expr.Expr):
103-
return assemble_expressions.assemble_expression(expr)
104+
return _assemble_expr(expr)
104105
else:
105106
raise TypeError(f"Unable to assemble: {expr}")
106107

@@ -290,6 +291,20 @@ def _assemble_form(form, tensor=None, bcs=None, *,
290291
return assembler.assemble()
291292

292293

294+
def _assemble_expr(expr):
295+
"""Assemble a pointwise expression.
296+
297+
:arg expr: The :class:`ufl.core.expr.Expr` to be evaluated.
298+
:returns: A :class:`firedrake.Function` containing the result of this evaluation.
299+
"""
300+
try:
301+
coefficients = ufl.algorithms.extract_coefficients(expr)
302+
V, = set(c.function_space() for c in coefficients) - {None}
303+
except ValueError:
304+
raise ValueError("Cannot deduce correct target space from pointwise expression")
305+
return firedrake.Function(V).assign(expr)
306+
307+
293308
def _check_inputs(form, tensor, bcs, diagonal):
294309
# Ensure mesh is 'initialised' as we could have got here without building a
295310
# function space (e.g. if integrating a constant).
@@ -392,6 +407,12 @@ def assemble(self):
392407
393408
:returns: The assembled object.
394409
"""
410+
if annotate_tape():
411+
raise NotImplementedError(
412+
"Taping with explicit FormAssembler objects is not supported yet. "
413+
"Use assemble instead."
414+
)
415+
395416
if self._needs_zeroing:
396417
self._as_pyop2_type(self._tensor).zero()
397418

0 commit comments

Comments
 (0)