@@ -55,8 +55,6 @@ def prepare_interpolation_data(
5555 array_evaluated = compiled_expr .eval (mesh , np .arange (num_cells , dtype = np .int32 ))
5656 assert np .prod (Q .value_shape ) == np .prod (expr .ufl_shape )
5757
58- im = Q .element .basix_element .interpolation_matrix
59-
6058 # Get data as (num_cells*num_points,1, expr_shape, num_test_basis_functions*test_block_size)
6159 expr_size = int (np .prod (expr .ufl_shape ))
6260 array_evaluated = array_evaluated .reshape (
@@ -73,19 +71,29 @@ def prepare_interpolation_data(
7371 num_cells * num_points , mesh .geometry .dim , mesh .topology .dim
7472 )
7573
76- Q_vs = Q .element .basix_element .value_size
74+ try :
75+ Q_vs = Q .element .basix_element .value_size
76+ except RuntimeError :
77+ Q_vs = 1 # If we do not have a basix element, assume value size is 1
78+
7779 new_array = np .zeros (
7880 (num_cells * num_points , Q .dofmap .bs * Q_vs , V .dofmap .bs * V .dofmap .dof_layout .num_dofs ),
7981 dtype = np .float64 ,
8082 )
81- for i in range (V .dofmap .bs * V .dofmap .dof_layout .num_dofs ):
82- for q in range (Q .dofmap .bs ):
83- new_array [:, q * Q_vs : (q + 1 ) * Q_vs , i ] = Q .element .basix_element .pull_back (
84- array_evaluated [:, :, q * Q_vs : (q + 1 ) * Q_vs , i ], jacs , detJs , Ks
85- ).reshape (num_cells * num_points , Q_vs )
86- new_array = new_array .reshape (
87- num_cells , num_points , Q .dofmap .bs * Q_vs , V .dofmap .bs * V .dofmap .dof_layout .num_dofs
88- )
83+ if not isinstance (Q .ufl_element ().pullback , ufl .pullback .IdentityPullback ):
84+ for i in range (V .dofmap .bs * V .dofmap .dof_layout .num_dofs ):
85+ for q in range (Q .dofmap .bs ):
86+ new_array [:, q * Q_vs : (q + 1 ) * Q_vs , i ] = Q .element .basix_element .pull_back (
87+ array_evaluated [:, :, q * Q_vs : (q + 1 ) * Q_vs , i ], jacs , detJs , Ks
88+ ).reshape (num_cells * num_points , Q_vs )
89+ new_array = new_array .reshape (
90+ num_cells , num_points , Q .dofmap .bs * Q_vs , V .dofmap .bs * V .dofmap .dof_layout .num_dofs
91+ )
92+ else :
93+ new_array = array_evaluated .reshape (
94+ num_cells , num_points , Q .dofmap .bs * Q_vs , V .dofmap .bs * V .dofmap .dof_layout .num_dofs
95+ )
96+
8997 interpolated_matrix = np .zeros (
9098 (
9199 num_cells ,
@@ -95,6 +103,11 @@ def prepare_interpolation_data(
95103 dtype = np .float64 ,
96104 )
97105
106+ if Q .element .interpolation_ident :
107+ im = np .eye (Q .element .interpolation_points .shape [0 ])
108+ else :
109+ im = Q .element .basix_element .interpolation_matrix
110+
98111 for c in range (num_cells ):
99112 for i in range (V .dofmap .bs * V .dofmap .dof_layout .num_dofs ):
100113 tmp_array = np .zeros ((int (num_points ), Q .dofmap .bs * Q_vs ), dtype = np .float64 )
0 commit comments