Skip to content

Commit 80b1da7

Browse files
committed
Add fixes for quadrature elements
1 parent dcf329b commit 80b1da7

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

src/scifem/interpolation.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ def prepare_interpolation_data(
5555
array_evaluated = compiled_expr.eval(mesh, np.arange(num_cells, dtype=np.int32))
5656
assert np.prod(Q.value_shape) == np.prod(expr.ufl_shape)
5757

58-
im = Q.element.basix_element.interpolation_matrix
59-
6058
# Get data as (num_cells*num_points,1, expr_shape, num_test_basis_functions*test_block_size)
6159
expr_size = int(np.prod(expr.ufl_shape))
6260
array_evaluated = array_evaluated.reshape(
@@ -73,19 +71,29 @@ def prepare_interpolation_data(
7371
num_cells * num_points, mesh.geometry.dim, mesh.topology.dim
7472
)
7573

76-
Q_vs = Q.element.basix_element.value_size
74+
try:
75+
Q_vs = Q.element.basix_element.value_size
76+
except RuntimeError:
77+
Q_vs = 1 # If we do not have a basix element, assume value size is 1
78+
7779
new_array = np.zeros(
7880
(num_cells * num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs),
7981
dtype=np.float64,
8082
)
81-
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
82-
for q in range(Q.dofmap.bs):
83-
new_array[:, q * Q_vs : (q + 1) * Q_vs, i] = Q.element.basix_element.pull_back(
84-
array_evaluated[:, :, q * Q_vs : (q + 1) * Q_vs, i], jacs, detJs, Ks
85-
).reshape(num_cells * num_points, Q_vs)
86-
new_array = new_array.reshape(
87-
num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
88-
)
83+
if not isinstance(Q.ufl_element().pullback, ufl.pullback.IdentityPullback):
84+
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
85+
for q in range(Q.dofmap.bs):
86+
new_array[:, q * Q_vs : (q + 1) * Q_vs, i] = Q.element.basix_element.pull_back(
87+
array_evaluated[:, :, q * Q_vs : (q + 1) * Q_vs, i], jacs, detJs, Ks
88+
).reshape(num_cells * num_points, Q_vs)
89+
new_array = new_array.reshape(
90+
num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
91+
)
92+
else:
93+
new_array = array_evaluated.reshape(
94+
num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
95+
)
96+
8997
interpolated_matrix = np.zeros(
9098
(
9199
num_cells,
@@ -95,6 +103,11 @@ def prepare_interpolation_data(
95103
dtype=np.float64,
96104
)
97105

106+
if Q.element.interpolation_ident:
107+
im = np.eye(Q.element.interpolation_points.shape[0])
108+
else:
109+
im = Q.element.basix_element.interpolation_matrix
110+
98111
for c in range(num_cells):
99112
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
100113
tmp_array = np.zeros((int(num_points), Q.dofmap.bs * Q_vs), dtype=np.float64)

0 commit comments

Comments
 (0)