Skip to content

Commit aaabace

Browse files
committed
Fieldsplit: update MatNest Jacobian
1 parent 3b6c6a1 commit aaabace

File tree

5 files changed

+106
-42
lines changed

5 files changed

+106
-42
lines changed

firedrake/dmhooks.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -363,43 +363,36 @@ def create_subdm(dm, fields, *args, **kwargs):
363363
:arg fields: The fields in the new sub-DM.
364364
"""
365365
W = get_function_space(dm)
366-
ctx = get_appctx(dm)
367-
coarsen = get_ctx_coarsener(dm)
368-
parent = get_parent(dm)
369366
if len(fields) == 1:
370367
# Subspace is just a single FunctionSpace.
371368
idx, = fields
372-
subdm = W[idx].dm
369+
subspace = W[idx]
373370
iset = W._ises[idx]
374-
add_hook(parent, setup=partial(push_parent, subdm, parent), teardown=partial(pop_parent, subdm, parent),
375-
call_setup=True)
376-
377-
if ctx is not None:
378-
ctx, = ctx.split([(idx, )])
379-
add_hook(parent, setup=partial(push_appctx, subdm, ctx), teardown=partial(pop_appctx, subdm, ctx),
380-
call_setup=True)
381-
add_hook(parent, setup=partial(push_ctx_coarsener, subdm, coarsen), teardown=partial(pop_ctx_coarsener, subdm, coarsen),
382-
call_setup=True)
383-
return iset, subdm
384371
else:
385372
# Need to build an MFS for the subspace
386373
subspace = firedrake.MixedFunctionSpace([W[f] for f in fields])
387374

388-
add_hook(parent, setup=partial(push_parent, subspace.dm, parent), teardown=partial(pop_parent, subspace.dm, parent),
389-
call_setup=True)
390375
# Index set mapping from W into subspace.
391-
iset = PETSc.IS().createGeneral(numpy.concatenate([W._ises[f].indices
392-
for f in fields]),
376+
iset = PETSc.IS().createGeneral(numpy.concatenate([W.dof_dset.field_ises[f].indices for f in fields]),
393377
comm=W._comm)
394-
if ctx is not None:
395-
ctx, = ctx.split([fields])
396-
add_hook(parent, setup=partial(push_appctx, subspace.dm, ctx),
397-
teardown=partial(pop_appctx, subspace.dm, ctx),
398-
call_setup=True)
399-
add_hook(parent, setup=partial(push_ctx_coarsener, subspace.dm, coarsen),
400-
teardown=partial(pop_ctx_coarsener, subspace.dm, coarsen),
401-
call_setup=True)
402-
return iset, subspace.dm
378+
379+
subdm = subspace.dm
380+
parent = get_parent(dm)
381+
add_hook(parent, setup=partial(push_parent, subdm, parent),
382+
teardown=partial(pop_parent, subdm, parent),
383+
call_setup=True)
384+
385+
ctx = get_appctx(dm)
386+
coarsen = get_ctx_coarsener(dm)
387+
if ctx is not None:
388+
ctx, = ctx.split([fields])
389+
add_hook(parent, setup=partial(push_appctx, subdm, ctx),
390+
teardown=partial(pop_appctx, subdm, ctx),
391+
call_setup=True)
392+
add_hook(parent, setup=partial(push_ctx_coarsener, subdm, coarsen),
393+
teardown=partial(pop_ctx_coarsener, subdm, coarsen),
394+
call_setup=True)
395+
return iset, subdm
403396

404397

405398
@PETSc.Log.EventDecorator()

firedrake/formmanipulation.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import numpy
32
import collections
43

@@ -161,6 +160,7 @@ def cofunction(self, o):
161160
return Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
162161

163162
def matrix(self, o):
163+
from firedrake.bcs import DirichletBC, EquationBC
164164
ises = []
165165
args = []
166166
for a in o.arguments():
@@ -180,8 +180,22 @@ def matrix(self, o):
180180
args.append(asplit)
181181

182182
submat = o.petscmat.createSubMatrix(*ises)
183-
bcs = ()
184-
return AssembledMatrix(tuple(args), bcs, submat)
183+
bcs = []
184+
spaces = [a.function_space() for a in o.arguments()]
185+
for bc in o.bcs:
186+
W = bc.function_space()
187+
W = W.parent or W
188+
189+
number = spaces.index(W)
190+
V = args[number].function_space()
191+
field = self.blocks[number]
192+
if isinstance(bc, DirichletBC):
193+
sub_bc = bc.reconstruct(field=field, V=V, g=bc.function_arg)
194+
elif isinstance(bc, EquationBC):
195+
raise NotImplementedError("Please get in touch if you need this")
196+
if sub_bc is not None:
197+
bcs.append(sub_bc)
198+
return AssembledMatrix(tuple(args), tuple(bcs), submat)
185199

186200
def zero_base_form(self, o):
187201
return ZeroBaseForm(tuple(map(self, o.arguments())))

firedrake/matrix.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,7 @@ def __init__(self, a, bcs, mat_type, *args, **kwargs):
170170
self.mat_type = mat_type
171171

172172
def assemble(self):
173-
raise NotImplementedError("API compatibility to apply bcs after 'assemble(a)'\
174-
has been removed. Use 'assemble(a, bcs=bcs)', which\
175-
now returns an assembled matrix.")
173+
self.M.assemble()
176174

177175

178176
class ImplicitMatrix(MatrixBase):
@@ -250,3 +248,9 @@ def __init__(self, a, bcs, petscmat, *args, **kwargs):
250248

251249
def mat(self):
252250
return self.petscmat
251+
252+
def assemble(self):
253+
# Bump petsc matrix state by assembling it.
254+
# Ensures that if the matrix changed, the preconditioner is
255+
# updated if necessary.
256+
self.petscmat.assemble()

firedrake/solving_utils.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -407,23 +407,33 @@ def split(self, fields):
407407
F += problem.compute_bc_lifting(J, subu)
408408
else:
409409
F = replace(F, {problem.u_restrict: u})
410+
410411
if problem.Jp is not None:
411412
Jp = splitter.split(problem.Jp, argument_indices=(field, field))
412413
Jp = replace(Jp, {problem.u_restrict: u})
413414
else:
414415
Jp = None
415-
bcs = []
416-
for bc in problem.bcs:
417-
if isinstance(bc, DirichletBC):
418-
bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, sub_domain=bc.sub_domain)
419-
elif isinstance(bc, EquationBC):
420-
bc_temp = bc.reconstruct(V, subu, u, field, problem.is_linear)
421-
if bc_temp is not None:
422-
bcs.append(bc_temp)
416+
417+
if isinstance(J, MatrixBase) and J.has_bcs:
418+
bcs = None
419+
else:
420+
bcs = []
421+
for bc in problem.bcs:
422+
if isinstance(bc, DirichletBC):
423+
bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg)
424+
elif isinstance(bc, EquationBC):
425+
bc_temp = bc.reconstruct(V, subu, u, field, problem.is_linear)
426+
if bc_temp is not None:
427+
bcs.append(bc_temp)
428+
423429
new_problem = NLVP(F, subu, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear,
424430
form_compiler_parameters=problem.form_compiler_parameters)
425431
new_problem._constant_jacobian = problem._constant_jacobian
426-
splits.append(type(self)(new_problem, mat_type=self.mat_type, pmat_type=self.pmat_type,
432+
splits.append(type(self)(new_problem,
433+
mat_type=self.mat_type,
434+
pmat_type=self.pmat_type,
435+
sub_mat_type=self.sub_mat_type,
436+
sub_pmat_type=self.sub_pmat_type,
427437
appctx=self.appctx,
428438
transfer_manager=self.transfer_manager,
429439
pre_apply_bcs=self.pre_apply_bcs))
@@ -504,6 +514,15 @@ def form_jacobian(snes, X, J, P):
504514
ctx.set_nullspace(ctx._nullspace_T, ises, transpose=True, near=False)
505515
ctx.set_nullspace(ctx._near_nullspace, ises, transpose=False, near=True)
506516

517+
# Bump petsc matrix state of each split by assembling them.
518+
# Ensures that if the matrix changed, the preconditioner is
519+
# updated if necessary.
520+
for field, splits in ctx._splits.items():
521+
for subctx in splits:
522+
subctx._jac.assemble()
523+
if subctx.Jp is not None:
524+
subctx._pjac.assemble()
525+
507526
@staticmethod
508527
def compute_operators(ksp, J, P):
509528
r"""Form the Jacobian for this problem

tests/firedrake/regression/test_nested_fieldsplit_solves.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,40 @@ def test_nested_fieldsplit_solve_parallel(W, A, b, expect):
152152
assert norm(f) < 1e-11
153153

154154

155+
def test_nonlinear_fielsplit():
156+
mesh = UnitIntervalMesh(1)
157+
V = FunctionSpace(mesh, "DG", 0)
158+
Z = V * V * V
159+
160+
u = Function(Z)
161+
u0, u1, u2 = split(u)
162+
v0, v1, v2 = TestFunctions(Z)
163+
164+
F = inner(u0, v0) * dx
165+
F += inner(0.5*u1**2 + u1, v1) * dx
166+
F += inner(u2, v2) * dx
167+
u.subfunctions[1].assign(Constant(1))
168+
169+
sp = {
170+
"mat_type": "nest",
171+
"snes_max_it": 10,
172+
"ksp_type": "fgmres",
173+
"pc_type": "fieldsplit",
174+
"pc_fieldsplit_type": "additive",
175+
"pc_fieldsplit_0_fields": "0",
176+
"pc_fieldsplit_1_fields": "1,2",
177+
"fieldsplit_1_ksp_view_eigenvalues": None,
178+
"fieldsplit": {
179+
"ksp_type": "gmres",
180+
"pc_type": "jacobi",
181+
},
182+
}
183+
J = derivative(F, u)
184+
solver = NonlinearVariationalSolver(NonlinearVariationalProblem(F, u), solver_parameters=sp)
185+
solver.solve()
186+
assert np.allclose(solver.snes.ksp.pc.getFieldSplitSubKSP()[1].computeEigenvalues(), 1)
187+
188+
155189
def test_matrix_types(W):
156190
a = inner(TrialFunction(W), TestFunction(W))*dx
157191

0 commit comments

Comments
 (0)