diff --git a/firedrake/ensemble/__init__.py b/firedrake/ensemble/__init__.py index 6662a06f72..1ddf574b3b 100644 --- a/firedrake/ensemble/__init__.py +++ b/firedrake/ensemble/__init__.py @@ -1,3 +1,5 @@ from firedrake.ensemble.ensemble import * # noqa: F401 from firedrake.ensemble.ensemble_function import * # noqa: F401 from firedrake.ensemble.ensemble_functionspace import * # noqa: F401 +from firedrake.ensemble.ensemble_mat import * # noqa: F401 +from firedrake.ensemble.ensemble_pc import * # noqa: F401 diff --git a/firedrake/ensemble/ensemble_functionspace.py b/firedrake/ensemble/ensemble_functionspace.py index b846949258..47523a44fb 100644 --- a/firedrake/ensemble/ensemble_functionspace.py +++ b/firedrake/ensemble/ensemble_functionspace.py @@ -87,7 +87,8 @@ class EnsembleFunctionSpaceBase: on each ensemble member using a :class:`~firedrake.functionspaceimpl.FunctionSpace` from `EnsembleFunctionSpace.local_spaces`. - See also: + See Also + -------- - Primal ensemble objects: :class:`EnsembleFunctionSpace` and :class:`~firedrake.ensemble.ensemble_function.EnsembleFunction`. - Dual ensemble objects: :class:`EnsembleDualSpace` and :class:`~firedrake.ensemble.ensemble_function.EnsembleCofunction`. """ @@ -156,7 +157,7 @@ def dual(self): @cached_property def nlocal_spaces(self): - """The total number of subspaces across all ensemble ranks. + """The number of subspaces on this ensemble rank. """ return len(self.local_spaces) @@ -184,6 +185,12 @@ def nglobal_dofs(self): """ return self.ensemble_comm.allreduce(self.nlocal_comm_dofs) + @cached_property + def global_spaces_offset(self): + """Index of the first local subspace in the global mixed space. + """ + return self.ensemble.ensemble_comm.exscan(self.nlocal_spaces) or 0 + def _component_indices(self, i): """ Return the indices into the local mixed function storage @@ -199,10 +206,14 @@ def create_vec(self): in this function space. """ vec = PETSc.Vec().create(comm=self.global_comm) - vec.setSizes((self.nlocal_dofs, self.nglobal_dofs)) + vec.setSizes((self.nlocal_rank_dofs, self.nglobal_dofs)) vec.setUp() return vec + @cached_property + def layout_vec(self): + return self.create_vec() + def __eq__(self, other): if not isinstance(other, type(self)): local_eq = False diff --git a/firedrake/ensemble/ensemble_mat.py b/firedrake/ensemble/ensemble_mat.py new file mode 100644 index 0000000000..145d7a4ce4 --- /dev/null +++ b/firedrake/ensemble/ensemble_mat.py @@ -0,0 +1,234 @@ +from typing import Iterable +from firedrake.petsc import PETSc +from firedrake.ensemble.ensemble_function import EnsembleFunction, EnsembleFunctionBase +from firedrake.ensemble.ensemble_functionspace import EnsembleFunctionSpaceBase + +__all__ = ( + "EnsembleBlockDiagonalMat", + "EnsembleBlockDiagonalMatrix", +) + + +class EnsembleMatBase: + """ + Base class for python type Mats defined over an :class:`~.ensemble.Ensemble`. + + Parameters + ---------- + row_space : + The function space that the matrix acts on. + Must have the same number of subspaces on each ensemble rank as col_space. + col_space : + The function space for the result of the matrix action. + Must have the same number of subspaces on each ensemble rank as row_space. + + Notes + ----- + The main use of this base class is to enable users to implement the matrix + action as acting on and resulting in an :class:`~.ensemble_function.EnsembleFunction`. + This is done by implementing the ``mult_impl`` method. + + See Also + -------- + .ensemble_pc.EnsemblePCBase + """ + def __init__(self, row_space: EnsembleFunctionSpaceBase, + col_space: EnsembleFunctionSpaceBase): + name = type(self).__name__ + if not isinstance(row_space, EnsembleFunctionSpaceBase): + raise ValueError( + f"{name} row_space must be EnsembleFunctionSpace not {type(row_space).__name__}") + if not isinstance(col_space, EnsembleFunctionSpaceBase): + raise ValueError( + f"{name} col_space must be EnsembleFunctionSpace not {type(col_space).__name__}") + + if row_space.ensemble != col_space.ensemble: + raise ValueError( + f"{name} row and column spaces must have the same Ensemble") + + self.ensemble = row_space.ensemble + self.row_space = row_space + self.col_space = col_space + + # input/output Vecs will be copied in/out of these + # so that base classes can implement mult only in + # terms of Ensemble objects not Vecs. + self.x = EnsembleFunction(self.row_space) + self.y = EnsembleFunction(self.col_space) + + def mult(self, A: PETSc.Mat, x: PETSc.Vec, y: PETSc.Vec): + """Apply the action of the matrix to x, putting the result in y. + + This method will be called by PETSc with x and y as Vecs, and acts + as a wrapper around the ``mult_impl`` method which has x and y as + EnsembleFunction for convenience. + y is not guaranteed to be zero on entry. + """ + with self.x.vec_wo() as xvec: + x.copy(result=xvec) + + self.mult_impl(A, self.x, self.y) + + with self.y.vec_ro() as yvec: + yvec.copy(result=y) + + def mult_impl(self, A: PETSc.Mat, x: EnsembleFunctionBase, y: EnsembleFunctionBase): + """Apply the action of the matrix to x, putting the result in y. + + y is not guaranteed to be zero on entry. + """ + raise NotImplementedError + + +class EnsembleBlockDiagonalMat(EnsembleMatBase): + """ + A python Mat context for a block diagonal matrix defined over an :class:`~.ensemble.Ensemble`. + Each block acts on a single subspace of an :class:`~.ensemble_functionspace.EnsembleFunctionSpace`. + + Parameters + ---------- + block_mats : Iterable[PETSc.Mat] + The PETSc Mats for each block. On each ensemble rank there must be as many + Mats as there are local subspaces of ``row_space`` and ``col_space``, and + the Mat sizes must match the sizes of the corresponding subspaces. + row_space : + The function space that the matrix acts on. + Must have the same number of subspaces on each ensemble rank as col_space. + col_space : + The function space for the result of the matrix action. + Must have the same number of subspaces on each ensemble rank as row_space. + + Notes + ----- + This is a python context, not an actual PETSc.Mat. To create the corresponding + PETSc.Mat users should call :func:`~.EnsembleBlockDiagonalMatrix`. + + See Also + -------- + EnsembleBlockDiagonalMatrix + ~.ensemble_pc.EnsembleBJacobiPC + """ + def __init__(self, block_mats: Iterable, + row_space: EnsembleFunctionSpaceBase, + col_space: EnsembleFunctionSpaceBase): + super().__init__(row_space, col_space) + self.block_mats = block_mats + + if self.row_space.nlocal_spaces != self.col_space.nlocal_spaces: + raise ValueError( + "EnsembleBlockDiagonalMat row and col spaces must be the same length," + f" not {row_space.nlocal_spaces=} and {col_space.nlocal_spaces=}") + + if len(self.block_mats) != self.row_space.nlocal_spaces: + raise ValueError( + f"EnsembleBlockDiagonalMat requires one submatrix for each of the" + f" {self.row_space.nlocal_spaces} local subfunctions of the EnsembleFunctionSpace," + f" but only {len(self.block_mats)} provided.") + + for i, (Vrow, Vcol, block) in enumerate(zip(self.row_space.local_spaces, + self.col_space.local_spaces, + self.block_mats)): + # number of columns is row length, and vice-versa + vc_sizes = Vrow.dof_dset.layout_vec.sizes + vr_sizes = Vcol.dof_dset.layout_vec.sizes + mr_sizes, mc_sizes = block.sizes + if (vr_sizes[0] != mr_sizes[0]) or (vr_sizes[1] != mr_sizes[1]): + raise ValueError( + f"Row sizes {mr_sizes} of block {i} and {vr_sizes} of row_space {i} of EnsembleBlockDiagonalMat must match.") + if (vc_sizes[0] != mc_sizes[0]) or (vc_sizes[1] != mc_sizes[1]): + raise ValueError( + f"Col sizes of block {i} and col_space {i} of EnsembleBlockDiagonalMat must match.") + + def mult_impl(self, A, x, y): + for block, xsub, ysub in zip(self.block_mats, + x.subfunctions, + y.subfunctions): + with xsub.dat.vec_ro as xvec, ysub.dat.vec_wo as yvec: + block.mult(xvec, yvec) + + def setUp(self, mat): + for bmat in self.block_mats: + bmat.setUp() + + def view(self, mat, viewer=None): + if viewer is None: + return + if viewer.getType() != PETSc.Viewer.Type.ASCII: + return + viewer.printfASCII(f" firedrake block diagonal Ensemble matrix: {type(self).__name__}\n") + viewer.printfASCII(f" Number of blocks = {self.col_space.nglobal_spaces}, Number of ensemble ranks = {self.ensemble.ensemble_size}\n") + + if viewer.getFormat() != PETSc.Viewer.Format.ASCII_INFO_DETAIL: + viewer.printfASCII(" Local information for first block is in the following Mat objects on rank 0:\n") + prefix = mat.getOptionsPrefix() or "" + viewer.printfASCII(f" Use -{prefix}ksp_view ::ascii_info_detail to display information for all blocks\n") + subviewer = viewer.getSubViewer(self.ensemble.comm) + if self.ensemble.ensemble_rank == 0: + subviewer.pushASCIITab() + self.block_mats[0].view(subviewer) + subviewer.popASCIITab() + viewer.restoreSubViewer(subviewer) + # Comment taken from PCView_BJacobi in https://petsc.org/release/src/ksp/pc/impls/bjacobi/bjacobi.c.html#PCBJACOBI + # extra call needed because of the two calls to PetscViewerASCIIPushSynchronized() in PetscViewerGetSubViewer() + viewer.popASCIISynchronized() + + else: + viewer.pushASCIISynchronized() + viewer.printfASCII(" Local information for each block is in the following Mat objects:\n") + viewer.pushASCIITab() + subviewer = viewer.getSubViewer(self.ensemble.comm) + r = self.ensemble.ensemble_rank + offset = self.col_space.global_spaces_offset + subviewer.printfASCII(f"[{r}] number of local blocks = {self.col_space.nlocal_spaces}, first local block number = {offset}\n") + for i, submat in enumerate(self.block_mats): + subviewer.printfASCII(f"[{r}] local block number {i}, global block number {offset + i}\n") + submat.view(subviewer) + subviewer.printfASCII("- - - - - - - - - - - - - - - - - -\n") + viewer.restoreSubViewer(subviewer) + viewer.popASCIITab() + viewer.popASCIISynchronized() + + +def EnsembleBlockDiagonalMatrix(block_mats: Iterable, + row_space: EnsembleFunctionSpaceBase, + col_space: EnsembleFunctionSpaceBase): + """ + A Mat for a block diagonal matrix defined over an :class:`~.ensemble.Ensemble`. + Each block acts on a single subspace of an :class:`~.ensemble_functionspace.EnsembleFunctionSpace`. + This is a convenience function to create a PETSc.Mat with a :class:`.EnsembleBlockDiagonalMat` Python context. + + Parameters + ---------- + block_mats : Iterable[PETSc.Mat] + The PETSc Mats for each block. On each ensemble rank there must be as many + Mats as there are local subspaces of ``row_space`` and ``col_space``, and + the Mat sizes must match the sizes of the corresponding subspaces. + row_space : + The function space that the matrix acts on. + Must have the same number of subspaces on each ensemble rank as col_space. + col_space : + The function space for the result of the matrix action. + Must have the same number of subspaces on each ensemble rank as row_space. + + Returns + ------- + PETSc.Mat : + The PETSc.Mat with an :class:`.EnsembleBlockDiagonalMat` Python context. + + See Also + -------- + EnsembleBlockDiagonalMat + ~.ensemble_pc.EnsembleBJacobiPC + """ + ctx = EnsembleBlockDiagonalMat(block_mats, row_space, col_space) + + # number of columns is row length, and vice-versa + ncols = ctx.col_space.layout_vec.getSizes() + nrows = ctx.row_space.layout_vec.getSizes() + + mat = PETSc.Mat().createPython( + (ncols, nrows), ctx, + comm=ctx.ensemble.global_comm) + mat.setUp() + mat.assemble() + return mat diff --git a/firedrake/ensemble/ensemble_pc.py b/firedrake/ensemble/ensemble_pc.py new file mode 100644 index 0000000000..bb1a8c6838 --- /dev/null +++ b/firedrake/ensemble/ensemble_pc.py @@ -0,0 +1,210 @@ +import petsctools +from firedrake.petsc import PETSc +from firedrake.ensemble.ensemble_function import EnsembleFunction +from firedrake.ensemble.ensemble_mat import EnsembleMatBase, EnsembleBlockDiagonalMat + +__all__ = ( + "EnsembleBJacobiPC", +) + + +def get_default_options(default_prefix, custom_prefix_endings, options=PETSc.Options()): + # TODO: replace with petsctools.get_default_options from https://github.com/firedrakeproject/petsctools/pull/24 + + # build all non-default prefixes + custom_prefixes = [default_prefix + str(ending) + for ending in custom_prefix_endings] + for prefix in custom_prefixes: + if not prefix.endswith("_"): + prefix += "_" + + default_options = { + k.removeprefix(default_prefix): v + for k, v in options.getAll().items() + if (k.startswith(default_prefix) + and not any(k.startswith(prefix) for prefix in custom_prefixes)) + } + assert not any(k.startswith(str(end)) + for k in default_options.keys() + for end in custom_prefix_endings) + return default_options + + +def obj_name(obj): + return f"{type(obj).__module__}.{type(obj).__name__}" + + +class EnsemblePCBase(petsctools.PCBase): + """ + Base class for python type PCs defined over an :class:`~.ensemble.Ensemble`. + + The pc operators must be python Mats with :class:`~.ensemble_mat.EnsembleMatBase`. + + Notes + ----- + The main use of this base class is to enable users to implement the preconditioner + action as acting on and resulting in an :class:`~.ensemble_function.EnsembleFunction`. + This is done by implementing the ``apply_impl`` method. + + See Also + -------- + ~.ensemble_mat.EnsembleMatBase + """ + needs_python_pmat = True + + def initialize(self, pc): + super().initialize(pc) + + if not isinstance(self.pmat, EnsembleMatBase): + pcname = obj_name(self) + pmatname = obj_name(self.pmat) + raise TypeError( + f"PC {pcname} needs an EnsembleMatBase pmat, but it is a {pmatname}") + + self.ensemble = self.pmat.ensemble + + self.row_space = self.pmat.row_space.dual() + self.col_space = self.pmat.col_space.dual() + + self.x = EnsembleFunction(self.row_space) + self.y = EnsembleFunction(self.col_space) + + def apply(self, pc, x, y): + with self.x.vec_wo() as v: + x.copy(result=v) + + self.apply_impl(pc, self.x, self.y) + + with self.y.vec_ro() as v: + v.copy(result=y) + + def apply_impl(self, pc, x, y): + raise NotImplementedError + + +class EnsembleBJacobiPC(EnsemblePCBase): + """ + A python PC context for a block Jacobi method defined over an :class:`~.ensemble.Ensemble`. + Each block acts on a single subspace of an :class:`~.ensemble_functionspace.EnsembleFunctionSpace` + and is (approximately) solved with its own KSP, which defaults to -ksp_type preonly. + + Available options: + + * ``-pc_use_amat`` - use Amat to apply block of operator in inner Krylov method + * ``-sub_%d`` - set options for the ``%d``'th block, numbered from ensemble rank 0. + * ``-sub_`` - set default options for all blocks. + + Notes + ----- + Currently this is only implemented for :class:`~.ensemble_mat.EnsembleBlockDiagonalMat` matrices. + + See Also + -------- + ~.ensemble_mat.EnsembleBlockDiagonalMatrix + ~.ensemble_mat.EnsembleBlockDiagonalMat + """ + prefix = "ebjacobi_" + + def initialize(self, pc): + super().initialize(pc) + + use_amat_prefix = self.parent_prefix + "pc_use_amat" + self.use_amat = PETSc.Options().getBool(use_amat_prefix, False) + + if not isinstance(self.pmat, EnsembleBlockDiagonalMat): + pcname = obj_name(self) + matname = obj_name(self.pmat) + raise TypeError( + f"PC {pcname} needs an EnsembleBlockDiagonalMat pmat, but it is a {matname}") + + if self.use_amat: + if not isinstance(self.amat, EnsembleBlockDiagonalMat): + pcname = obj_name(self) + matname = obj_name(self.amat) + raise TypeError( + f"PC {pcname} needs an EnsembleBlockDiagonalMat amat, but it is a {matname}") + + # default to behaving like a PC + default_options = {'ksp_type': 'preonly'} + + default_sub_prefix = self.parent_prefix + "sub_" + default_sub_options = get_default_options( + default_sub_prefix, range(self.col_space.nglobal_spaces)) + default_options.update(default_sub_options) + + block_offset = self.col_space.global_spaces_offset + + sub_ksps = [] + for i in range(len(self.pmat.block_mats)): + sub_ksp = PETSc.KSP().create( + comm=self.ensemble.comm) + + if self.use_amat: + sub_amat = self.amat.block_mats[i] + else: + sub_amat = self.pmat.block_mats[i] + + sub_pmat = self.pmat.block_mats[i] + + sub_ksp.setOperators(sub_amat, sub_pmat) + + sub_prefix = default_sub_prefix + str(block_offset + i) + + petsctools.set_from_options( + sub_ksp, parameters=default_options, + options_prefix=sub_prefix) + + sub_ksp.incrementTabLevel(1, parent=pc) + sub_ksp.pc.incrementTabLevel(1, parent=pc) + + sub_ksps.append(sub_ksp) + + self.sub_ksps = tuple(sub_ksps) + + def apply_impl(self, pc, x, y): + sub_vecs = zip(self.x.subfunctions, self.y.subfunctions) + for sub_ksp, (subx, suby) in zip(self.sub_ksps, sub_vecs): + with subx.dat.vec_ro as rhs, suby.dat.vec_wo as sol: + with petsctools.inserted_options(sub_ksp): + sub_ksp.solve(rhs, sol) + + def update(self, pc): + for sub_ksp in self.sub_ksps: + sub_ksp.setUp() + + def view(self, pc, viewer=None): + super().view(pc, viewer=viewer) + viewer.printfASCII(" firedrake block Jacobi preconditioner for ensemble Mats\n") + if self.use_amat: + viewer.printfASCII(" using Amat local matrix\n") + viewer.printfASCII(f" Number of blocks = {self.col_space.nglobal_spaces}, Number of ensemble ranks = {self.ensemble.ensemble_size}\n") + + if viewer.getFormat() != PETSc.Viewer.Format.ASCII_INFO_DETAIL: + viewer.printfASCII(" Local solver information for first block is in the following KSP and PC objects on rank 0:\n") + prefix = self.parent_prefix + viewer.printfASCII(f" Use -{prefix}ksp_view ::ascii_info_detail to display information for all blocks\n") + subviewer = viewer.getSubViewer(self.ensemble.comm) + if self.ensemble.ensemble_rank == 0: + subviewer.pushASCIITab() + self.sub_ksps[0].view(subviewer) + subviewer.popASCIITab() + viewer.restoreSubViewer(subviewer) + # Comment taken from PCView_BJacobi in https://petsc.org/release/src/ksp/pc/impls/bjacobi/bjacobi.c.html#PCBJACOBI + # extra call needed because of the two calls to PetscViewerASCIIPushSynchronized() in PetscViewerGetSubViewer() + viewer.popASCIISynchronized() + + else: + viewer.pushASCIISynchronized() + viewer.printfASCII(" Local solver information for each block is in the following KSP and PC objects:\n") + viewer.pushASCIITab() + subviewer = viewer.getSubViewer(self.ensemble.comm) + r = self.ensemble.ensemble_rank + offset = self.col_space.global_spaces_offset + subviewer.printfASCII(f"[{r}] number of local blocks = {self.col_space.nlocal_spaces}, first local block number = {offset}\n") + for i, subksp in enumerate(self.sub_ksps): + subviewer.printfASCII(f"[{r}] local block number {i}, global block number {offset + i}\n") + subksp.view(subviewer) + subviewer.printfASCII("- - - - - - - - - - - - - - - - - -\n") + viewer.restoreSubViewer(subviewer) + viewer.popASCIITab() + viewer.popASCIISynchronized() diff --git a/tests/firedrake/ensemble/test_ensemble_mat.py b/tests/firedrake/ensemble/test_ensemble_mat.py new file mode 100644 index 0000000000..321a4d8c2a --- /dev/null +++ b/tests/firedrake/ensemble/test_ensemble_mat.py @@ -0,0 +1,150 @@ +import pytest +from pytest_mpi.parallel_assert import parallel_assert +from firedrake import * +from firedrake.ensemble.ensemble_mat import EnsembleBlockDiagonalMatrix + + +@pytest.mark.parallel([1, 2, 3, 4]) +def test_ensemble_mat(): + # create ensemble + global_ranks = COMM_WORLD.size + nspatial_ranks = 2 if (global_ranks % 2 == 0) else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + ensemble_rank = ensemble.ensemble_rank + + # create mesh + mesh = UnitIntervalMesh(10, comm=ensemble.comm) + + # create function spaces + CG = FunctionSpace(mesh, "CG", 1) + DG = FunctionSpace(mesh, "DG", 1+ensemble_rank) + + # create ensemble function spaces / functions + row_space = EnsembleFunctionSpace([CG, CG], ensemble) + col_space = EnsembleFunctionSpace([CG, DG], ensemble) + + # build forms + u, v = TrialFunction(CG), TestFunction(CG) + nu = Constant(ensemble_rank+1) + a0 = inner(u, v)*dx + nu*inner(grad(u), grad(v))*dx + + u, v = TrialFunction(CG), TestFunction(DG) + a1 = (1/nu)*inner(u, v)*dx + + # assemble mats + A0mat = assemble(a0).petscmat + A1mat = assemble(a1).petscmat + mats = [A0mat, A1mat] + + # create ensemble mat + emat = EnsembleBlockDiagonalMatrix(mats, row_space, col_space) + + # build ensemble function lhs and rhs for Ax=y + x = EnsembleFunction(row_space) + y = EnsembleCofunction(col_space.dual()) + ycheck = EnsembleCofunction(col_space.dual()) + + for i, xi in enumerate(x.subfunctions): + xi.assign(ensemble_rank + i + 1) + + # assemble reference matmult + for A, xi, yi in zip(mats, x.subfunctions, ycheck.subfunctions): + with xi.dat.vec_ro as xv, yi.dat.vec_wo as yv: + A.mult(xv, yv) + + # assemble matmult + with x.vec_ro() as xv, y.vec_wo() as yv: + emat.mult(xv, yv) + + checks = [ + np.allclose(yi.dat.data_ro, yci.dat.data_ro) + for yi, yci in zip(y.subfunctions, ycheck.subfunctions) + ] + + # check results + parallel_assert( + all(checks), + msg=("Action of EnsembleBlockDiagonalMatrix does not match" + f" actions of local matrices: {checks}") + ) + + +@pytest.mark.parallel([1, 2, 3, 4]) +@pytest.mark.parametrize("default_options", [True, False], + ids=["default_options", "blockwise_options"]) +def test_ensemble_pc(default_options): + # create ensemble + global_ranks = COMM_WORLD.size + nspatial_ranks = 2 if (global_ranks % 2 == 0) else 1 + ensemble = Ensemble(COMM_WORLD, nspatial_ranks) + ensemble_rank = ensemble.ensemble_rank + + # Default PETSc pc is ILU so need a 2D mesh + # because for 1D ILU is an exact solver. + mesh = UnitSquareMesh(8, 8, comm=ensemble.comm) + + # create function spaces + CG = FunctionSpace(mesh, "CG", 2) + DG = FunctionSpace(mesh, "DG", 2+ensemble_rank) + + # create ensemble function spaces / functions + row_space = EnsembleFunctionSpace([CG, DG], ensemble) + col_space = EnsembleFunctionSpace([CG, DG], ensemble) + offset = col_space.global_spaces_offset + + # build forms + u, v = TrialFunction(CG), TestFunction(CG) + nu = Constant(offset + 1) + a0 = inner(u, v)*dx + nu*inner(grad(u), grad(v))*dx + + u, v = TrialFunction(DG), TestFunction(DG) + a1 = (1/nu)*inner(u, v)*dx + + # assemble mats + A0mat = assemble(a0, mat_type='aij').petscmat + A1mat = assemble(a1, mat_type='aij').petscmat + mats = [A0mat, A1mat] + + # create ensemble mat + emat = EnsembleBlockDiagonalMatrix(mats, row_space, col_space) + + # parameters: direct solve on blocks + parameters = { + 'ksp_rtol': 1e-14, + 'ksp_type': 'richardson', + 'pc_type': 'python', + 'pc_python_type': 'firedrake.EnsembleBJacobiPC', + } + if default_options: + parameters['sub_pc_type'] = 'lu' + else: + for i in range(col_space.nglobal_spaces): + parameters[f'sub_{i}_pc_type'] = 'lu' + + # create ensemble ksp + ksp = PETSc.KSP().create(comm=ensemble.global_comm) + ksp.setOperators(emat, emat) + petsctools.set_from_options( + ksp, parameters=parameters, + options_prefix="ensemble") + + x = EnsembleFunction(row_space) + b = EnsembleFunction(col_space.dual()) + + for i, bi in enumerate(b.subfunctions): + bi.assign(offset + i + 1) + + with petsctools.inserted_options(ksp): + with x.vec_wo() as xv, b.vec_ro() as bv: + ksp.solve(bv, xv) + + # 1 richardson iteration should be a direct solve + parallel_assert( + ksp.its == 1, + msg=("EnsembleBJacobiPC took more than one iteration to" + f" solve an EnsembleBlockDiagonalMatrix: {ksp.its=}") + ) + + +if __name__ == "__main__": + test_ensemble_pc(default_options=True)