Skip to content

Commit 7e13000

Browse files
committed
Rewrite assign
This commit: - Restricts `assign` to only work for weighted sums of coefficients (plus addition of constants). - Expunges codegen in favour of directly manipulating numpy arrays. - Introduces an `Assigner` class to speed up repeated `assign` calls.
1 parent 0331b5d commit 7e13000

File tree

9 files changed

+324
-801
lines changed

9 files changed

+324
-801
lines changed

firedrake/assemble.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import finat
1010
import firedrake
1111
import numpy
12+
from pyadjoint.tape import annotate_tape
1213
from tsfc import kernel_args
1314
from tsfc.finatinterface import create_element
1415
import ufl
15-
from firedrake import (assemble_expressions, extrusion_utils as eutils, matrix, parameters, solving,
16+
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
1617
tsfc_interface, utils)
1718
from firedrake.adjoint import annotate_assemble
1819
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
@@ -100,7 +101,7 @@ def assemble(expr, *args, **kwargs):
100101
if isinstance(expr, (ufl.form.Form, slate.TensorBase)):
101102
return _assemble_form(expr, *args, **kwargs)
102103
elif isinstance(expr, ufl.core.expr.Expr):
103-
return assemble_expressions.assemble_expression(expr)
104+
return _assemble_expr(expr)
104105
else:
105106
raise TypeError(f"Unable to assemble: {expr}")
106107

@@ -290,6 +291,20 @@ def _assemble_form(form, tensor=None, bcs=None, *,
290291
return assembler.assemble()
291292

292293

294+
def _assemble_expr(expr):
295+
"""Assemble a pointwise expression.
296+
297+
:arg expr: The :class:`ufl.core.expr.Expr` to be evaluated.
298+
:returns: A :class:`firedrake.Function` containing the result of this evaluation.
299+
"""
300+
try:
301+
coefficients = ufl.algorithms.extract_coefficients(expr)
302+
V, = set(c.function_space() for c in coefficients) - {None}
303+
except ValueError:
304+
raise ValueError("Cannot deduce correct target space from pointwise expression")
305+
return firedrake.Function(V).assign(expr)
306+
307+
293308
def _check_inputs(form, tensor, bcs, diagonal):
294309
# Ensure mesh is 'initialised' as we could have got here without building a
295310
# function space (e.g. if integrating a constant).

0 commit comments

Comments
 (0)