From bf9d69ecedff3daa096a7eb0a3beacd19cbd5c86 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Fri, 7 Mar 2025 18:01:10 +0000 Subject: [PATCH 01/43] IDC seems to work, RIDIC implementation started, but imcomplete --- gusto/time_discretisation/sdc.py | 531 ++++++++++++++++++++++++++++++- 1 file changed, 530 insertions(+), 1 deletion(-) diff --git a/gusto/time_discretisation/sdc.py b/gusto/time_discretisation/sdc.py index 3e634f17a..54c52beb6 100644 --- a/gusto/time_discretisation/sdc.py +++ b/gusto/time_discretisation/sdc.py @@ -57,7 +57,7 @@ from qmat import genQCoeffs, genQDeltaCoeffs -__all__ = ["SDC"] +__all__ = ["SDC", "IDC"] class SDC(object, metaclass=ABCMeta): @@ -509,3 +509,532 @@ def apply(self, x_out, x_in): x_out.assign(self.Unodes[-1]) else: x_out.assign(self.Unodes[-1]) + +class IDC(SDC): + """Class for Integral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, maxk, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, + limiter=None, options=None, initial_guess="base"): + """ + Initialise IDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of quadrature nodes to compute spectral integration over + maxk (int): Max number of correction interations + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + final_update (bool, optional): Whether to compute final update, or just take last + quadrature value. Defaults to True + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + options (:class:`AdvectionOptions`, optional): an object containing + + initial_guess (str, optional): Initial guess to be base timestepper, or copy + """ + super(IDC, self).__init__(base_scheme, domain, M, maxk, 'GAUSS', 'EQUID', 'BE', 'FE', + 'N2N', field_name, linear_solver_parameters, nonlinear_solver_parameters, + final_update, limiter, options, initial_guess) + def setup(self, equation, apply_bcs=True, *active_labels): + """ + Set up the IDC time discretisation based on the equation.n + + Args: + equation (:class:`PrognosticEquation`): the model's equation. + apply_bcs (bool, optional): whether to apply the equation's boundary + conditions. Defaults to True. + *active_labels (:class:`Label`): labels indicating which terms of + the equation to include. + """ + # Inherit from base time discretisation + super(IDC, self).setup(equation, apply_bcs, *active_labels) + self.source_Ukp1_m = Function(self.W) + self.source_Uk_m = Function(self.W) + self.Uk_mp1 = Function(self.W) + self.Uk_m = Function(self.W) + self.Ukp1_m = Function(self.W) + self.dt = Constant(0.0) + + def compute_quad(self): + """ + Computes integration of F(y) on quadrature nodes + """ + for j in range(self.M): + self.quad[j].assign(0.) + for k in range(self.M): + self.quad[j] += float(self.Q[j, k])*self.fUnodes[k] + + @property + def res_rhs(self): + """Set up the residual for the calculation of F(y).""" + a = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Urhs, old_idx=self.idx), + drop) + # F(y) + L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), + drop, + replace_subject(self.Uin, old_idx=self.idx)) + L_source = self.residual.label_map(lambda t: t.has_label(source_label), + replace_subject(self.source_in, old_idx=self.idx), + drop) + residual_rhs = a - (L + L_source) + return residual_rhs.form + + @property + def res(self): + """Set up the discretisation's residual for a given node m.""" + # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation + # and y^(k)_m for N2N formulation + mass_form = self.residual.label_map( + lambda t: t.has_label(time_derivative), + map_if_false=drop) + residual = mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) + residual -= mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_start, old_idx=self.idx)) + + # Calculate source terms + r_source_kp1 = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_source_kp1 = r_source_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_source_kp1 + + r_source_k = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), + map_if_false=drop) + r_source_k = r_source_k.label_map( + all_terms, + map_if_true=lambda t: Constant(self.dt)*t) + residual -= r_source_k + + # Add on final implicit terms + # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + r_imp_kp1 = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), + map_if_false=drop) + r_imp_kp1 = r_imp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_imp_kp1 + r_imp_k = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), + map_if_false=drop) + r_imp_k = r_imp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_imp_k + + r_exp_kp1 = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_exp_kp1 = r_exp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + + residual += r_exp_kp1 + r_exp_k = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), + map_if_false=drop) + r_exp_k = r_exp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_exp_k + + + # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j + # and s1j = q1j. + Q = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Q_, old_idx=self.idx), + drop) + residual += Q + return residual.form + + @cached_property + def solver(self): + """Set up a list of solvers for each problem at a node m.""" + # setup solver using residual defined in derived class + problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__ + solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) + return solver + + @cached_property + def solver_rhs(self): + """Set up the problem and the solver for mass matrix inversion.""" + # setup linear solver using rhs residual defined in derived class + prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__+"_rhs" + return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, + options_prefix=solver_name) + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + if (self.base_flag): + for m in range(self.M): + self.base.dt = float(self.dtau[m]) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + else: + for m in range(self.M): + self.Unodes[m+1].assign(self.Un) + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + k = 0 + while k < self.maxk: + k += 1 + + # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + for m in range(1, self.M+1): + self.Uin.assign(self.Unodes[m]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m-1].assign(self.Urhs) + self.compute_quad() + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + for m in range(1, self.M+1): + # Set S matrix + self.Q_.assign(self.quad[m-1]) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m-1]) + self.Ukp1_m.assign(self.Unodes1[m-1]) + self.Uk_mp1.assign(self.Unodes[m]) + self.Uk_m.assign(self.Unodes[m-1]) + self.source_Ukp1_m.assign(self.source_Ukp1[m-1]) + self.source_Uk_m.assign(self.source_Uk[m-1]) + self.dt.assign(float(self.dtau[m-1])) + self.U_SDC.assign(self.Unodes[m]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1[m]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m]) + for m in range(1, self.M+1): + self.Unodes[m].assign(self.Unodes1[m]) + self.source_Uk[m].assign(self.source_Ukp1[m]) + + x_out.assign(self.Unodes[-1]) + + +class RIDC(SDC): + """Class for Revisionist Integral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, maxk, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, + limiter=None, options=None, initial_guess="base"): + """ + Initialise IDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of quadrature nodes to compute spectral integration over + maxk (int): Max number of correction interations + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + final_update (bool, optional): Whether to compute final update, or just take last + quadrature value. Defaults to True + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + options (:class:`AdvectionOptions`, optional): an object containing + + initial_guess (str, optional): Initial guess to be base timestepper, or copy + """ + super(IDC, self).__init__(base_scheme, domain, M, maxk, 'GAUSS', 'EQUID', 'BE', 'FE', + 'N2N', field_name, linear_solver_parameters, nonlinear_solver_parameters, + final_update, limiter, options, initial_guess) + def setup(self, equation, apply_bcs=True, *active_labels): + """ + Set up the IDC time discretisation based on the equation.n + + Args: + equation (:class:`PrognosticEquation`): the model's equation. + apply_bcs (bool, optional): whether to apply the equation's boundary + conditions. Defaults to True. + *active_labels (:class:`Label`): labels indicating which terms of + the equation to include. + """ + # Inherit from base time discretisation + super(IDC, self).setup(equation, apply_bcs, *active_labels) + self.source_Ukp1_m = Function(self.W) + self.source_Uk_m = Function(self.W) + self.Uk_mp1 = Function(self.W) + self.Uk_m = Function(self.W) + self.Ukp1_m = Function(self.W) + self.dt = Constant(0.0) + + def compute_quad(self, Q, fUnodes, M_val): + """ + Computes integration of F(y) on quadrature nodes + """ + quad.assign(0.) + for k in range(0, self.M): + quad += float(Q[M_val, k])*fUnodes[k] + return quad + + def compute_quad_final(self, Q, fUnodes, M_val): + """ + Computes final integration of F(y) on quadrature nodes + """ + quad.assign(0.) + for k in range(0, self.M): + quad += float(Q[self.M-1, k])*fUnodes[M_val - self.M + k] + return quad + + @property + def res_rhs(self): + """Set up the residual for the calculation of F(y).""" + a = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Urhs, old_idx=self.idx), + drop) + # F(y) + L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), + drop, + replace_subject(self.Uin, old_idx=self.idx)) + L_source = self.residual.label_map(lambda t: t.has_label(source_label), + replace_subject(self.source_in, old_idx=self.idx), + drop) + residual_rhs = a - (L + L_source) + return residual_rhs.form + + @property + def res(self): + """Set up the discretisation's residual for a given node m.""" + # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation + # and y^(k)_m for N2N formulation + mass_form = self.residual.label_map( + lambda t: t.has_label(time_derivative), + map_if_false=drop) + residual = mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) + residual -= mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_start, old_idx=self.idx)) + + # Calculate source terms + r_source_kp1 = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_source_kp1 = r_source_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_source_kp1 + + r_source_k = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), + map_if_false=drop) + r_source_k = r_source_k.label_map( + all_terms, + map_if_true=lambda t: Constant(self.dt)*t) + residual -= r_source_k + + # Add on final implicit terms + # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + r_imp_kp1 = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), + map_if_false=drop) + r_imp_kp1 = r_imp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_imp_kp1 + r_imp_k = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), + map_if_false=drop) + r_imp_k = r_imp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_imp_k + + r_exp_kp1 = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_exp_kp1 = r_exp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + + residual += r_exp_kp1 + r_exp_k = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), + map_if_false=drop) + r_exp_k = r_exp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_exp_k + + + # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j + # and s1j = q1j. + Q = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Q_, old_idx=self.idx), + drop) + residual += Q + return residual.form + + @cached_property + def solver(self): + """Set up a list of solvers for each problem at a node m.""" + # setup solver using residual defined in derived class + problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__ + solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) + return solver + + @cached_property + def solver_rhs(self): + """Set up the problem and the solver for mass matrix inversion.""" + # setup linear solver using rhs residual defined in derived class + prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__+"_rhs" + return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, + options_prefix=solver_name) + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + if (self.base_flag): + for m in range(self.M): + self.base.dt = float(self.dtau[m]) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + else: + for m in range(self.M): + self.Unodes[m+1].assign(self.Un) + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + k = 0 + while k < self.maxk: + k += 1 + + # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + for m in range(1, self.M+1): + self.Uin.assign(self.Unodes[m]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m-1].assign(self.Urhs) + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + for m in range(1, self.maxk+1): + # Set S matrix + self.Q_.assign(self.compute_quad(self, self.Q, self.fUnodes, m-1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m-1]) + self.Ukp1_m.assign(self.Unodes1[m-1]) + self.Uk_mp1.assign(self.Unodes[m]) + self.Uk_m.assign(self.Unodes[m-1]) + self.source_Ukp1_m.assign(self.source_Ukp1[m-1]) + self.source_Uk_m.assign(self.source_Uk[m-1]) + self.dt.assign(float(self.dtau[m-1])) + self.U_SDC.assign(self.Unodes[m]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1[m]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m]) + for m in range(self.maxk, self.M+1): + # Set S matrix + self.Q_.assign(self.compute_quad_final(self, self.Q, self.fUnodes, m-1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m-1]) + self.Ukp1_m.assign(self.Unodes1[m-1]) + self.Uk_mp1.assign(self.Unodes[m]) + self.Uk_m.assign(self.Unodes[m-1]) + self.source_Ukp1_m.assign(self.source_Ukp1[m-1]) + self.source_Uk_m.assign(self.source_Uk[m-1]) + self.dt.assign(float(self.dtau[m-1])) + self.U_SDC.assign(self.Unodes[m]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1[m]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m]) + + for m in range(1, self.M+1): + self.Unodes[m].assign(self.Unodes1[m]) + self.source_Uk[m].assign(self.source_Ukp1[m]) + + x_out.assign(self.Unodes[-1]) \ No newline at end of file From abbd4848423ea9cfa0186d05f81f22c068d356ab Mon Sep 17 00:00:00 2001 From: atb1995 Date: Wed, 12 Mar 2025 17:35:12 +0000 Subject: [PATCH 02/43] RIDC and IDC running, RIDC not getting as high convergence at IDC --- .../explicit_runge_kutta.py | 9 +- gusto/time_discretisation/sdc.py | 482 +++++++++--------- .../time_discretisation.py | 2 + gusto/timestepping/timestepper.py | 2 + 4 files changed, 256 insertions(+), 239 deletions(-) diff --git a/gusto/time_discretisation/explicit_runge_kutta.py b/gusto/time_discretisation/explicit_runge_kutta.py index 187339df8..77cff79d4 100644 --- a/gusto/time_discretisation/explicit_runge_kutta.py +++ b/gusto/time_discretisation/explicit_runge_kutta.py @@ -321,7 +321,7 @@ def solve_stage(self, x0, stage): if self.rk_formulation == RungeKuttaFormulation.increment: self.x1.assign(x0) - + print("stage", stage) for i in range(stage): self.x1.assign(self.x1 + self.dt*self.butcher_matrix[stage-1, i]*self.k[i]) for evaluate in self.evaluate_source: @@ -340,7 +340,8 @@ def solve_stage(self, x0, stage): self.x1.assign(x0) for i in range(self.nStages): self.x1.assign(self.x1 + self.dt*self.butcher_matrix[stage, i]*self.k[i]) - self.x1.assign(self.x1) + #self.x1.assign(self.x1) + print("Hello") if self.limiter is not None: self.limiter.apply(self.x1) @@ -451,6 +452,10 @@ def apply_cycle(self, x_out, x_in): for i in range(self.nStages): self.solve_stage(x_in, i) x_out.assign(self.x1) + self.xdiff = Function(self.fs) + self.xdiff.assign(x_out-x_in) + print("xdiff", np.max(self.xdiff.dat.data), np.min(self.xdiff.dat.data)) + #breakpoint() class ForwardEuler(ExplicitRungeKutta): diff --git a/gusto/time_discretisation/sdc.py b/gusto/time_discretisation/sdc.py index 54c52beb6..dbe888918 100644 --- a/gusto/time_discretisation/sdc.py +++ b/gusto/time_discretisation/sdc.py @@ -57,7 +57,7 @@ from qmat import genQCoeffs, genQDeltaCoeffs -__all__ = ["SDC", "IDC"] +__all__ = ["SDC", "IDC", "RIDC"] class SDC(object, metaclass=ABCMeta): @@ -108,6 +108,7 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im # Initialise parameters self.base = base_scheme + self.base.dt = domain.dt self.field_name = field_name self.domain = domain self.dt_coarse = domain.dt @@ -414,13 +415,18 @@ def apply(self, x_out, x_in): if (self.base_flag): for m in range(self.M): self.base.dt = float(self.dtau[m]) - self.base.apply(self.Unodes[m+1], self.Unodes[m]) + self.base.apply_cycle(self.Unodes[m+1], self.Unodes[m]) + print(self.base.dt) else: for m in range(self.M): self.Unodes[m+1].assign(self.Un) for m in range(self.M+1): for evaluate in self.evaluate_source: evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + self.Udiff = Function(self.W) + self.Udiff.assign(self.Unodes[0] - self.Unodes[-1]) + print(np.max(self.Udiff.dat.data[:]), np.min(self.Udiff.dat.data[:])) + #breakpoint() # Iterate through correction sweeps k = 0 @@ -510,12 +516,12 @@ def apply(self, x_out, x_in): else: x_out.assign(self.Unodes[-1]) -class IDC(SDC): +class IDC(object, metaclass=ABCMeta): """Class for Integral Deferred Correction schemes.""" - def __init__(self, base_scheme, domain, M, maxk, field_name=None, - linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, - limiter=None, options=None, initial_guess="base"): + def __init__(self, base_scheme, domain, M, K, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, + limiter=None, options=None): """ Initialise IDC object Args: @@ -524,7 +530,7 @@ def __init__(self, base_scheme, domain, M, maxk, field_name=None, domain (:class:`Domain`): the model's domain object, containing the mesh and the compatible function spaces. M (int): Number of quadrature nodes to compute spectral integration over - maxk (int): Max number of correction interations + K (int): Max number of correction interations field_name (str, optional): name of the field to be evolved. Defaults to None. linear_solver_parameters (dict, optional): dictionary of parameters to @@ -539,12 +545,41 @@ def __init__(self, base_scheme, domain, M, maxk, field_name=None, initial_guess (str, optional): Initial guess to be base timestepper, or copy """ - super(IDC, self).__init__(base_scheme, domain, M, maxk, 'GAUSS', 'EQUID', 'BE', 'FE', - 'N2N', field_name, linear_solver_parameters, nonlinear_solver_parameters, - final_update, limiter, options, initial_guess) + self.base = base_scheme + self.field_name = field_name + self.domain = domain + self.dt_coarse = domain.dt + self.limiter = limiter + self.augmentation = self.base.augmentation + self.wrapper = self.base.wrapper + self.K = K + self.M = M + self.dt = Constant(float(self.dt_coarse)/(self.M - 1)) + self.nodes = np.arange(0, M*float(self.dt), float(self.dt)) + integration_matrix = self.lagrange_integration_matrix(self.M) + self.Q = 0.5 * (self.M-1) * float(self.dt) * integration_matrix + + # Set default linear and nonlinear solver options if none passed in + if linear_solver_parameters is None: + self.linear_solver_parameters = {'snes_type': 'ksponly', + 'ksp_type': 'cg', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.linear_solver_parameters = linear_solver_parameters + + if nonlinear_solver_parameters is None: + self.nonlinear_solver_parameters = {'snes_type': 'newtonls', + 'ksp_type': 'gmres', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.nonlinear_solver_parameters = nonlinear_solver_parameters + + def setup(self, equation, apply_bcs=True, *active_labels): """ - Set up the IDC time discretisation based on the equation.n + Set up the SDC time discretisation based on the equation.n Args: equation (:class:`PrognosticEquation`): the model's equation. @@ -554,22 +589,130 @@ def setup(self, equation, apply_bcs=True, *active_labels): the equation to include. """ # Inherit from base time discretisation - super(IDC, self).setup(equation, apply_bcs, *active_labels) - self.source_Ukp1_m = Function(self.W) - self.source_Uk_m = Function(self.W) - self.Uk_mp1 = Function(self.W) - self.Uk_m = Function(self.W) - self.Ukp1_m = Function(self.W) - self.dt = Constant(0.0) + self.base.setup(equation, apply_bcs, *active_labels) + self.equation = self.base.equation + self.residual = self.base.residual + self.evaluate_source = self.base.evaluate_source - def compute_quad(self): + for t in self.residual: + # Check all terms are labeled implicit or explicit + if ((not t.has_label(implicit)) and (not t.has_label(explicit)) + and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): + raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") + + # Set up bcs + self.bcs = self.base.bcs + + # Set up SDC variables + if self.field_name is not None and hasattr(equation, "field_names"): + self.idx = equation.field_names.index(self.field_name) + W = equation.spaces[self.idx] + else: + self.field_name = equation.field_name + W = equation.function_space + self.idx = None + self.W = W + self.Unodes = [Function(W) for _ in range(self.M)] + self.Unodes1 = [Function(W) for _ in range(self.M)] + self.fUnodes = [Function(W) for _ in range(self.M)] + self.quad = [Function(W) for _ in range(self.M)] + self.source_Uk = [Function(W) for _ in range(self.M)] + self.source_Ukp1 = [Function(W) for _ in range(self.M)] + self.U_SDC = Function(W) + self.U_start = Function(W) + self.Un = Function(W) + self.Q_ = Function(W) + self.quad_final = Function(W) + self.U_fin = Function(W) + self.Urhs = Function(W) + self.Uin = Function(W) + self.source_in = Function(W) + self.source_Ukp1_m = Function(W) + self.source_Uk_m = Function(W) + self.Uk_mp1 = Function(W) + self.Uk_m = Function(W) + self.Ukp1_m = Function(W) + + @property + def nlevels(self): + return 1 + + def equidistant_nodes(self ,M): + # This returns a grid of M equispaced nodes from -1 to 1 + grid = np.linspace(-1., 1., M) + return grid + + def lagrange_polynomial(self, index, nodes): + # This returns the coefficients of the Lagrange polynomial l_m with m=index + + M = len(nodes) + + # c is the denominator + c = 1. + for k in range(M): + if k != index: + c *= (nodes[index] - nodes[k]) + + coeffs = np.zeros(M) + coeffs[0] = 1. + m = 0 + + for k in range(M): + if k != index: + m += 1 + d1 = np.zeros(M) + d2 = np.zeros(M) + + d1 = (-1.)*nodes[k] * coeffs + d2[1:m+1] = coeffs[0:m] + + coeffs = d1+d2 + return coeffs / c + + def integrate_polynomial(self, p): + # given a list of coefficients of a polynomial p, this returns those of the integral of p + integral_coeffs = np.zeros(len(p)+1) + + for n, pn in enumerate(p): + integral_coeffs[n+1] = 1/(n+1) * pn + + return integral_coeffs + + def evaluate(self, p, a, b): + # given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) + value = 0. + for n, pn in enumerate(p): + value += pn * (b**n - a**n) + + return value + + def lagrange_integration_matrix(self, M): + # using the functions defined above, this returns the MxM integration matrix + + # set up equidistant nodes and initialise matrix to zero + nodes = self.equidistant_nodes(M) + L = len(nodes) + int_matrix = np.zeros((L, L)) + + # fill in matrix values + for index in range(L): + coeff_p = self.lagrange_polynomial(index, nodes) + int_coeff = self.integrate_polynomial(coeff_p) + + for n in range(L-1): + int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) + + return int_matrix + + def compute_quad(self, Q, fUnodes, m): """ Computes integration of F(y) on quadrature nodes """ - for j in range(self.M): - self.quad[j].assign(0.) - for k in range(self.M): - self.quad[j] += float(self.Q[j, k])*self.fUnodes[k] + quad = Function(self.W) + quad.assign(0.) + for k in range(0, np.shape(Q)[0]): + quad += float(Q[m, k])*fUnodes[k] + return quad @property def res_rhs(self): @@ -645,7 +788,6 @@ def res(self): r_exp_kp1 = r_exp_kp1.label_map( all_terms, lambda t: Constant(self.dt)*t) - residual += r_exp_kp1 r_exp_k = self.residual.label_map( lambda t: t.has_label(explicit), @@ -690,49 +832,44 @@ def apply(self, x_out, x_in): # Compute initial guess on quadrature nodes with low-order # base timestepper self.Unodes[0].assign(self.Un) - if (self.base_flag): - for m in range(self.M): - self.base.dt = float(self.dtau[m]) - self.base.apply(self.Unodes[m+1], self.Unodes[m]) - else: - for m in range(self.M): - self.Unodes[m+1].assign(self.Un) - for m in range(self.M+1): + + for m in range(self.M-1): + self.base.dt = float(self.dt) + print(self.base.dt) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + #breakpoint() + for m in range(self.M): for evaluate in self.evaluate_source: evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) # Iterate through correction sweeps - k = 0 - while k < self.maxk: - k += 1 - + for k in range(1, self.K+1): + print("Correction sweep", k) # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) - for m in range(1, self.M+1): + for m in range(self.M): self.Uin.assign(self.Unodes[m]) # Include source terms for evaluate in self.evaluate_source: evaluate(self.Uin, self.base.dt, x_out=self.source_in) self.solver_rhs.solve() - self.fUnodes[m-1].assign(self.Urhs) - self.compute_quad() + self.fUnodes[m].assign(self.Urhs) # Loop through quadrature nodes and solve self.Unodes1[0].assign(self.Unodes[0]) for evaluate in self.evaluate_source: evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - for m in range(1, self.M+1): + for m in range(0, self.M-1): # Set S matrix - self.Q_.assign(self.quad[m-1]) + self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m-1]) - self.Ukp1_m.assign(self.Unodes1[m-1]) - self.Uk_mp1.assign(self.Unodes[m]) - self.Uk_m.assign(self.Unodes[m-1]) - self.source_Ukp1_m.assign(self.source_Ukp1[m-1]) - self.source_Uk_m.assign(self.source_Uk[m-1]) - self.dt.assign(float(self.dtau[m-1])) - self.U_SDC.assign(self.Unodes[m]) + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_SDC.assign(self.Unodes[m+1]) # Compute # for N2N: @@ -740,28 +877,28 @@ def apply(self, x_out, x_in): # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) # + sum(j=1,M) s_mj*(F+S)(y^k) self.solver.solve() - self.Unodes1[m].assign(self.U_SDC) + self.Unodes1[m+1].assign(self.U_SDC) # Evaluate source terms for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1[m]) + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) # Apply limiter if required if self.limiter is not None: - self.limiter.apply(self.Unodes1[m]) - for m in range(1, self.M+1): + self.limiter.apply(self.Unodes1[m+1]) + + for m in range(self.M): self.Unodes[m].assign(self.Unodes1[m]) self.source_Uk[m].assign(self.source_Ukp1[m]) x_out.assign(self.Unodes[-1]) - -class RIDC(SDC): +class RIDC(IDC): """Class for Revisionist Integral Deferred Correction schemes.""" - def __init__(self, base_scheme, domain, M, maxk, field_name=None, - linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, - limiter=None, options=None, initial_guess="base"): + def __init__(self, base_scheme, domain, M, K, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, + limiter=None, options=None): """ Initialise IDC object Args: @@ -770,7 +907,7 @@ def __init__(self, base_scheme, domain, M, maxk, field_name=None, domain (:class:`Domain`): the model's domain object, containing the mesh and the compatible function spaces. M (int): Number of quadrature nodes to compute spectral integration over - maxk (int): Max number of correction interations + K (int): Max number of correction interations field_name (str, optional): name of the field to be evolved. Defaults to None. linear_solver_parameters (dict, optional): dictionary of parameters to @@ -785,12 +922,17 @@ def __init__(self, base_scheme, domain, M, maxk, field_name=None, initial_guess (str, optional): Initial guess to be base timestepper, or copy """ - super(IDC, self).__init__(base_scheme, domain, M, maxk, 'GAUSS', 'EQUID', 'BE', 'FE', - 'N2N', field_name, linear_solver_parameters, nonlinear_solver_parameters, - final_update, limiter, options, initial_guess) + super(RIDC, self).__init__(base_scheme, domain, M, K, + field_name, linear_solver_parameters, nonlinear_solver_parameters, + limiter, options) + self.dt = Constant(float(self.dt_coarse)/(self.M - 1)) + self.nodes = np.arange(0, M*float(self.dt), float(self.dt)) + integration_matrix = self.lagrange_integration_matrix(self.K) + self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix + def setup(self, equation, apply_bcs=True, *active_labels): """ - Set up the IDC time discretisation based on the equation.n + Set up the SDC time discretisation based on the equation.n Args: equation (:class:`PrognosticEquation`): the model's equation. @@ -799,145 +941,18 @@ def setup(self, equation, apply_bcs=True, *active_labels): *active_labels (:class:`Label`): labels indicating which terms of the equation to include. """ - # Inherit from base time discretisation - super(IDC, self).setup(equation, apply_bcs, *active_labels) - self.source_Ukp1_m = Function(self.W) - self.source_Uk_m = Function(self.W) - self.Uk_mp1 = Function(self.W) - self.Uk_m = Function(self.W) - self.Ukp1_m = Function(self.W) - self.dt = Constant(0.0) - - def compute_quad(self, Q, fUnodes, M_val): - """ - Computes integration of F(y) on quadrature nodes - """ - quad.assign(0.) - for k in range(0, self.M): - quad += float(Q[M_val, k])*fUnodes[k] - return quad + super(RIDC, self).setup(equation, apply_bcs, *active_labels) - def compute_quad_final(self, Q, fUnodes, M_val): + def compute_quad_final(self, Q, fUnodes, m): """ Computes final integration of F(y) on quadrature nodes """ + quad = Function(self.W) quad.assign(0.) - for k in range(0, self.M): - quad += float(Q[self.M-1, k])*fUnodes[M_val - self.M + k] + for k in range(0, np.shape(Q)[0]): + quad += float(Q[-1, k])*fUnodes[m + 1 - self.K + k] return quad - @property - def res_rhs(self): - """Set up the residual for the calculation of F(y).""" - a = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Urhs, old_idx=self.idx), - drop) - # F(y) - L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), - drop, - replace_subject(self.Uin, old_idx=self.idx)) - L_source = self.residual.label_map(lambda t: t.has_label(source_label), - replace_subject(self.source_in, old_idx=self.idx), - drop) - residual_rhs = a - (L + L_source) - return residual_rhs.form - - @property - def res(self): - """Set up the discretisation's residual for a given node m.""" - # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation - # and y^(k)_m for N2N formulation - mass_form = self.residual.label_map( - lambda t: t.has_label(time_derivative), - map_if_false=drop) - residual = mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) - residual -= mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_start, old_idx=self.idx)) - - # Calculate source terms - r_source_kp1 = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_source_kp1 = r_source_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_source_kp1 - - r_source_k = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), - map_if_false=drop) - r_source_k = r_source_k.label_map( - all_terms, - map_if_true=lambda t: Constant(self.dt)*t) - residual -= r_source_k - - # Add on final implicit terms - # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) - r_imp_kp1 = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), - map_if_false=drop) - r_imp_kp1 = r_imp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_imp_kp1 - r_imp_k = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), - map_if_false=drop) - r_imp_k = r_imp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_imp_k - - r_exp_kp1 = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_exp_kp1 = r_exp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - - residual += r_exp_kp1 - r_exp_k = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), - map_if_false=drop) - r_exp_k = r_exp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_exp_k - - - # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j - # and s1j = q1j. - Q = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Q_, old_idx=self.idx), - drop) - residual += Q - return residual.form - - @cached_property - def solver(self): - """Set up a list of solvers for each problem at a node m.""" - # setup solver using residual defined in derived class - problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__ - solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) - return solver - - @cached_property - def solver_rhs(self): - """Set up the problem and the solver for mass matrix inversion.""" - # setup linear solver using rhs residual defined in derived class - prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__+"_rhs" - return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, - options_prefix=solver_name) - @wrapper_apply def apply(self, x_out, x_in): self.Un.assign(x_in) @@ -945,48 +960,42 @@ def apply(self, x_out, x_in): # Compute initial guess on quadrature nodes with low-order # base timestepper self.Unodes[0].assign(self.Un) - if (self.base_flag): - for m in range(self.M): - self.base.dt = float(self.dtau[m]) - self.base.apply(self.Unodes[m+1], self.Unodes[m]) - else: - for m in range(self.M): - self.Unodes[m+1].assign(self.Un) - for m in range(self.M+1): + for m in range(self.M-1): + self.base.dt = float(self.dt) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + for m in range(self.M): for evaluate in self.evaluate_source: evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) # Iterate through correction sweeps - k = 0 - while k < self.maxk: - k += 1 - + for k in range(1, self.K+1): + print("Correction sweep", k) # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) - for m in range(1, self.M+1): + for m in range(self.M): self.Uin.assign(self.Unodes[m]) # Include source terms for evaluate in self.evaluate_source: evaluate(self.Uin, self.base.dt, x_out=self.source_in) self.solver_rhs.solve() - self.fUnodes[m-1].assign(self.Urhs) - + self.fUnodes[m].assign(self.Urhs) + #self.compute_quad() # Loop through quadrature nodes and solve self.Unodes1[0].assign(self.Unodes[0]) for evaluate in self.evaluate_source: evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - for m in range(1, self.maxk+1): + for m in range(0, self.K-1): # Set S matrix - self.Q_.assign(self.compute_quad(self, self.Q, self.fUnodes, m-1)) + self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) + #self.Q_.assign(self.quad[m-1]) # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m-1]) - self.Ukp1_m.assign(self.Unodes1[m-1]) - self.Uk_mp1.assign(self.Unodes[m]) - self.Uk_m.assign(self.Unodes[m-1]) - self.source_Ukp1_m.assign(self.source_Ukp1[m-1]) - self.source_Uk_m.assign(self.source_Uk[m-1]) - self.dt.assign(float(self.dtau[m-1])) - self.U_SDC.assign(self.Unodes[m]) + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_SDC.assign(self.Unodes[m+1]) # Compute # for N2N: @@ -994,28 +1003,27 @@ def apply(self, x_out, x_in): # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) # + sum(j=1,M) s_mj*(F+S)(y^k) self.solver.solve() - self.Unodes1[m].assign(self.U_SDC) + self.Unodes1[m+1].assign(self.U_SDC) # Evaluate source terms for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1[m]) + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) # Apply limiter if required if self.limiter is not None: - self.limiter.apply(self.Unodes1[m]) - for m in range(self.maxk, self.M+1): + self.limiter.apply(self.Unodes1[m+1]) + for m in range(self.K-1, self.M-1): # Set S matrix - self.Q_.assign(self.compute_quad_final(self, self.Q, self.fUnodes, m-1)) + self.Q_.assign(self.compute_quad_final(self.Q, self.fUnodes, m+1)) # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m-1]) - self.Ukp1_m.assign(self.Unodes1[m-1]) - self.Uk_mp1.assign(self.Unodes[m]) - self.Uk_m.assign(self.Unodes[m-1]) - self.source_Ukp1_m.assign(self.source_Ukp1[m-1]) - self.source_Uk_m.assign(self.source_Uk[m-1]) - self.dt.assign(float(self.dtau[m-1])) - self.U_SDC.assign(self.Unodes[m]) + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_SDC.assign(self.Unodes[m+1]) # Compute # for N2N: @@ -1023,17 +1031,17 @@ def apply(self, x_out, x_in): # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) # + sum(j=1,M) s_mj*(F+S)(y^k) self.solver.solve() - self.Unodes1[m].assign(self.U_SDC) + self.Unodes1[m+1].assign(self.U_SDC) # Evaluate source terms for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1[m]) + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) # Apply limiter if required if self.limiter is not None: - self.limiter.apply(self.Unodes1[m]) + self.limiter.apply(self.Unodes1[m+1]) - for m in range(1, self.M+1): + for m in range(self.M): self.Unodes[m].assign(self.Unodes1[m]) self.source_Uk[m].assign(self.source_Ukp1[m]) diff --git a/gusto/time_discretisation/time_discretisation.py b/gusto/time_discretisation/time_discretisation.py index 7b5eee9a6..21a60fade 100644 --- a/gusto/time_discretisation/time_discretisation.py +++ b/gusto/time_discretisation/time_discretisation.py @@ -22,6 +22,8 @@ from gusto.core.logging import logger, DEBUG, logging_ksp_monitor_true_residual from gusto.time_discretisation.wrappers import * from gusto.solvers import mass_parameters +import numpy as np + __all__ = ["TimeDiscretisation", "ExplicitTimeDiscretisation", "BackwardEuler", "ThetaMethod", "TrapeziumRule", "TR_BDF2"] diff --git a/gusto/timestepping/timestepper.py b/gusto/timestepping/timestepper.py index f4f5bba0a..de99fbe99 100644 --- a/gusto/timestepping/timestepper.py +++ b/gusto/timestepping/timestepper.py @@ -351,6 +351,8 @@ def setup_scheme(self): self.setup_equation(self.equation) self.scheme.setup(self.equation) self.setup_transporting_velocity(self.scheme) + if hasattr(self.scheme, 'base'): + self.setup_transporting_velocity(self.scheme.base) if self.io.output.log_courant: self.scheme.courant_max = self.io.courant_max From e5b48d0ee424559e381e926091a00c8f2d84c8be Mon Sep 17 00:00:00 2001 From: atb1995 Date: Thu, 27 Mar 2025 11:18:43 +0000 Subject: [PATCH 03/43] RIDC and RIDC_Reduced working --- gusto/time_discretisation/sdc.py | 1108 ++++++++++++++++++++++++++---- 1 file changed, 990 insertions(+), 118 deletions(-) diff --git a/gusto/time_discretisation/sdc.py b/gusto/time_discretisation/sdc.py index dbe888918..5ee2b1bb7 100644 --- a/gusto/time_discretisation/sdc.py +++ b/gusto/time_discretisation/sdc.py @@ -57,7 +57,7 @@ from qmat import genQCoeffs, genQDeltaCoeffs -__all__ = ["SDC", "IDC", "RIDC"] +__all__ = ["SDC", "IDC", "RIDC", "RIDC_Reduced"] class SDC(object, metaclass=ABCMeta): @@ -529,7 +529,7 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, quadrature nodes. domain (:class:`Domain`): the model's domain object, containing the mesh and the compatible function spaces. - M (int): Number of quadrature nodes to compute spectral integration over + M (int): Number of subintervals K (int): Max number of correction interations field_name (str, optional): name of the field to be evolved. Defaults to None. @@ -554,10 +554,12 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, self.wrapper = self.base.wrapper self.K = K self.M = M - self.dt = Constant(float(self.dt_coarse)/(self.M - 1)) - self.nodes = np.arange(0, M*float(self.dt), float(self.dt)) - integration_matrix = self.lagrange_integration_matrix(self.M) - self.Q = 0.5 * (self.M-1) * float(self.dt) * integration_matrix + self.dt = Constant(float(self.dt_coarse)/(self.M)) + self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) + integration_matrix = self.lagrange_integration_matrix(self.M+1) + + # Rescale nodes from [0, 1] to [0, dt] + self.Q = float(self.dt_coarse) * integration_matrix # Set default linear and nonlinear solver options if none passed in if linear_solver_parameters is None: @@ -612,12 +614,12 @@ def setup(self, equation, apply_bcs=True, *active_labels): W = equation.function_space self.idx = None self.W = W - self.Unodes = [Function(W) for _ in range(self.M)] - self.Unodes1 = [Function(W) for _ in range(self.M)] - self.fUnodes = [Function(W) for _ in range(self.M)] - self.quad = [Function(W) for _ in range(self.M)] - self.source_Uk = [Function(W) for _ in range(self.M)] - self.source_Ukp1 = [Function(W) for _ in range(self.M)] + self.Unodes = [Function(W) for _ in range(self.M+1)] + self.Unodes1 = [Function(W) for _ in range(self.M+1)] + self.fUnodes = [Function(W) for _ in range(self.M+1)] + self.quad = [Function(W) for _ in range(self.M+1)] + self.source_Uk = [Function(W) for _ in range(self.M+1)] + self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] self.U_SDC = Function(W) self.U_start = Function(W) self.Un = Function(W) @@ -691,6 +693,7 @@ def lagrange_integration_matrix(self, M): # set up equidistant nodes and initialise matrix to zero nodes = self.equidistant_nodes(M) + nodes = 0.5 * (nodes + 1) L = len(nodes) int_matrix = np.zeros((L, L)) @@ -833,12 +836,12 @@ def apply(self, x_out, x_in): # base timestepper self.Unodes[0].assign(self.Un) - for m in range(self.M-1): + for m in range(self.M): self.base.dt = float(self.dt) print(self.base.dt) self.base.apply(self.Unodes[m+1], self.Unodes[m]) #breakpoint() - for m in range(self.M): + for m in range(self.M+1): for evaluate in self.evaluate_source: evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) @@ -846,7 +849,7 @@ def apply(self, x_out, x_in): for k in range(1, self.K+1): print("Correction sweep", k) # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) - for m in range(self.M): + for m in range(self.M+1): self.Uin.assign(self.Unodes[m]) # Include source terms for evaluate in self.evaluate_source: @@ -858,7 +861,7 @@ def apply(self, x_out, x_in): self.Unodes1[0].assign(self.Unodes[0]) for evaluate in self.evaluate_source: evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - for m in range(0, self.M-1): + for m in range(0, self.M): # Set S matrix self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) @@ -887,14 +890,14 @@ def apply(self, x_out, x_in): if self.limiter is not None: self.limiter.apply(self.Unodes1[m+1]) - for m in range(self.M): + for m in range(self.M+1): self.Unodes[m].assign(self.Unodes1[m]) self.source_Uk[m].assign(self.source_Ukp1[m]) x_out.assign(self.Unodes[-1]) -class RIDC(IDC): - """Class for Revisionist Integral Deferred Correction schemes.""" +class RIDC(object, metaclass=ABCMeta): + """Class for Integral Deferred Correction schemes.""" def __init__(self, base_scheme, domain, M, K, field_name=None, linear_solver_parameters=None, nonlinear_solver_parameters=None, @@ -906,7 +909,7 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, quadrature nodes. domain (:class:`Domain`): the model's domain object, containing the mesh and the compatible function spaces. - M (int): Number of quadrature nodes to compute spectral integration over + M (int): Number of subintervals K (int): Max number of correction interations field_name (str, optional): name of the field to be evolved. Defaults to None. @@ -922,13 +925,45 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, initial_guess (str, optional): Initial guess to be base timestepper, or copy """ - super(RIDC, self).__init__(base_scheme, domain, M, K, - field_name, linear_solver_parameters, nonlinear_solver_parameters, - limiter, options) - self.dt = Constant(float(self.dt_coarse)/(self.M - 1)) - self.nodes = np.arange(0, M*float(self.dt), float(self.dt)) - integration_matrix = self.lagrange_integration_matrix(self.K) - self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix + self.base = base_scheme + self.field_name = field_name + self.domain = domain + self.dt_coarse = domain.dt + self.limiter = limiter + self.augmentation = self.base.augmentation + self.wrapper = self.base.wrapper + self.K = K + self.M = M + self.dt = Constant(float(self.dt_coarse)/(self.M)) + self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) + integration_matrix = self.lagrange_integration_matrix(self.K+1) + # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] + self.Q = 0.5 * (self.K) * float(self.dt) * integration_matrix + + # integration_matrix = self.lagrange_integration_matrix(self.K) + # # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] + # self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix + + + #self.Q = float(self.dt_coarse) * integration_matrix + + # Set default linear and nonlinear solver options if none passed in + if linear_solver_parameters is None: + self.linear_solver_parameters = {'snes_type': 'ksponly', + 'ksp_type': 'cg', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.linear_solver_parameters = linear_solver_parameters + + if nonlinear_solver_parameters is None: + self.nonlinear_solver_parameters = {'snes_type': 'newtonls', + 'ksp_type': 'gmres', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.nonlinear_solver_parameters = nonlinear_solver_parameters + def setup(self, equation, apply_bcs=True, *active_labels): """ @@ -941,108 +976,945 @@ def setup(self, equation, apply_bcs=True, *active_labels): *active_labels (:class:`Label`): labels indicating which terms of the equation to include. """ - super(RIDC, self).setup(equation, apply_bcs, *active_labels) + # Inherit from base time discretisation + self.base.setup(equation, apply_bcs, *active_labels) + self.equation = self.base.equation + self.residual = self.base.residual + self.evaluate_source = self.base.evaluate_source - def compute_quad_final(self, Q, fUnodes, m): - """ - Computes final integration of F(y) on quadrature nodes - """ - quad = Function(self.W) - quad.assign(0.) - for k in range(0, np.shape(Q)[0]): - quad += float(Q[-1, k])*fUnodes[m + 1 - self.K + k] - return quad + for t in self.residual: + # Check all terms are labeled implicit or explicit + if ((not t.has_label(implicit)) and (not t.has_label(explicit)) + and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): + raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") - @wrapper_apply - def apply(self, x_out, x_in): - self.Un.assign(x_in) + # Set up bcs + self.bcs = self.base.bcs - # Compute initial guess on quadrature nodes with low-order - # base timestepper - self.Unodes[0].assign(self.Un) - for m in range(self.M-1): - self.base.dt = float(self.dt) - self.base.apply(self.Unodes[m+1], self.Unodes[m]) - for m in range(self.M): - for evaluate in self.evaluate_source: - evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + # Set up SDC variables + if self.field_name is not None and hasattr(equation, "field_names"): + self.idx = equation.field_names.index(self.field_name) + W = equation.spaces[self.idx] + else: + self.field_name = equation.field_name + W = equation.function_space + self.idx = None + self.W = W + self.Unodes = [Function(W) for _ in range(self.M+1)] + self.Unodes1 = [Function(W) for _ in range(self.M+1)] + self.fUnodes = [Function(W) for _ in range(self.M+1)] + self.quad = [Function(W) for _ in range(self.M+1)] + self.source_Uk = [Function(W) for _ in range(self.M+1)] + self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] + self.U_SDC = Function(W) + self.U_start = Function(W) + self.Un = Function(W) + self.Q_ = Function(W) + self.quad_final = Function(W) + self.U_fin = Function(W) + self.Urhs = Function(W) + self.Uin = Function(W) + self.source_in = Function(W) + self.source_Ukp1_m = Function(W) + self.source_Uk_m = Function(W) + self.Uk_mp1 = Function(W) + self.Uk_m = Function(W) + self.Ukp1_m = Function(W) - # Iterate through correction sweeps - for k in range(1, self.K+1): - print("Correction sweep", k) - # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) - for m in range(self.M): - self.Uin.assign(self.Unodes[m]) - # Include source terms - for evaluate in self.evaluate_source: - evaluate(self.Uin, self.base.dt, x_out=self.source_in) - self.solver_rhs.solve() - self.fUnodes[m].assign(self.Urhs) - #self.compute_quad() - # Loop through quadrature nodes and solve - self.Unodes1[0].assign(self.Unodes[0]) - for evaluate in self.evaluate_source: - evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - for m in range(0, self.K-1): - # Set S matrix - self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) - #self.Q_.assign(self.quad[m-1]) + @property + def nlevels(self): + return 1 - # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m]) - self.Ukp1_m.assign(self.Unodes1[m]) - self.Uk_mp1.assign(self.Unodes[m+1]) - self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) - self.U_SDC.assign(self.Unodes[m+1]) + def equidistant_nodes(self ,M): + # This returns a grid of M equispaced nodes from -1 to 1 + grid = np.linspace(-1., 1., M) + #grid = 0.5 * (grid + 1) + return grid - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - self.solver.solve() - self.Unodes1[m+1].assign(self.U_SDC) + def lagrange_polynomial(self, index, nodes): + # This returns the coefficients of the Lagrange polynomial l_m with m=index - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + M = len(nodes) - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m+1]) - for m in range(self.K-1, self.M-1): - # Set S matrix - self.Q_.assign(self.compute_quad_final(self.Q, self.fUnodes, m+1)) + # c is the denominator + c = 1. + for k in range(M): + if k != index: + c *= (nodes[index] - nodes[k]) - # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m]) - self.Ukp1_m.assign(self.Unodes1[m]) - self.Uk_mp1.assign(self.Unodes[m+1]) - self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) - self.U_SDC.assign(self.Unodes[m+1]) + coeffs = np.zeros(M) + coeffs[0] = 1. + m = 0 - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - self.solver.solve() - self.Unodes1[m+1].assign(self.U_SDC) + for k in range(M): + if k != index: + m += 1 + d1 = np.zeros(M) + d2 = np.zeros(M) - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + d1 = (-1.)*nodes[k] * coeffs + d2[1:m+1] = coeffs[0:m] - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m+1]) + coeffs = d1+d2 + return coeffs / c - for m in range(self.M): - self.Unodes[m].assign(self.Unodes1[m]) - self.source_Uk[m].assign(self.source_Ukp1[m]) + def integrate_polynomial(self, p): + # given a list of coefficients of a polynomial p, this returns those of the integral of p + integral_coeffs = np.zeros(len(p)+1) + + for n, pn in enumerate(p): + integral_coeffs[n+1] = 1/(n+1) * pn + + return integral_coeffs + + def evaluate(self, p, a, b): + # given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) + value = 0. + for n, pn in enumerate(p): + value += pn * (b**n - a**n) + + return value + + def lagrange_integration_matrix(self, M): + # using the functions defined above, this returns the MxM integration matrix + + # set up equidistant nodes and initialise matrix to zero + nodes = self.equidistant_nodes(M) + L = len(nodes) + int_matrix = np.zeros((L, L)) + + # fill in matrix values + for index in range(L): + coeff_p = self.lagrange_polynomial(index, nodes) + int_coeff = self.integrate_polynomial(coeff_p) + + for n in range(L-1): + int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) + + return int_matrix + + def compute_quad(self, Q, fUnodes, m): + """ + Computes integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + for k in range(0, self.K+1): + quad += float(Q[m, k])*fUnodes[k] + return quad + + def compute_quad_final(self, Q, fUnodes, m): + """ + Computes final integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + for k in range(0, self.K+1): + print(k, self.K) + quad += float(Q[-1, k])*fUnodes[m - self.K + k] + return quad + + @property + def res_rhs(self): + """Set up the residual for the calculation of F(y).""" + a = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Urhs, old_idx=self.idx), + drop) + # F(y) + L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), + drop, + replace_subject(self.Uin, old_idx=self.idx)) + L_source = self.residual.label_map(lambda t: t.has_label(source_label), + replace_subject(self.source_in, old_idx=self.idx), + drop) + residual_rhs = a - (L + L_source) + return residual_rhs.form + + @property + def res(self): + """Set up the discretisation's residual for a given node m.""" + # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation + # and y^(k)_m for N2N formulation + mass_form = self.residual.label_map( + lambda t: t.has_label(time_derivative), + map_if_false=drop) + residual = mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) + residual -= mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_start, old_idx=self.idx)) + + # Calculate source terms + r_source_kp1 = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_source_kp1 = r_source_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_source_kp1 + + r_source_k = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), + map_if_false=drop) + r_source_k = r_source_k.label_map( + all_terms, + map_if_true=lambda t: Constant(self.dt)*t) + residual -= r_source_k + + # Add on final implicit terms + # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + r_imp_kp1 = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), + map_if_false=drop) + r_imp_kp1 = r_imp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_imp_kp1 + r_imp_k = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), + map_if_false=drop) + r_imp_k = r_imp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_imp_k + + r_exp_kp1 = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_exp_kp1 = r_exp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_exp_kp1 + r_exp_k = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), + map_if_false=drop) + r_exp_k = r_exp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_exp_k + + + # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j + # and s1j = q1j. + Q = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Q_, old_idx=self.idx), + drop) + residual += Q + return residual.form + + @cached_property + def solver(self): + """Set up a list of solvers for each problem at a node m.""" + # setup solver using residual defined in derived class + problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__ + solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) + return solver + + @cached_property + def solver_rhs(self): + """Set up the problem and the solver for mass matrix inversion.""" + # setup linear solver using rhs residual defined in derived class + prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__+"_rhs" + return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, + options_prefix=solver_name) + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + + for m in range(self.M): + self.base.dt = float(self.dt) + print(self.base.dt) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + #breakpoint() + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + for k in range(1, self.K+1): + print("Correction sweep", k) + # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + for m in range(self.M+1): + self.Uin.assign(self.Unodes[m]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m].assign(self.Urhs) + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + for m in range(0, self.K): + # Set S matrix + self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_SDC.assign(self.Unodes[m+1]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + for m in range(self.K, self.M): + # Set S matrix + self.Q_.assign(self.compute_quad_final(self.Q, self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_SDC.assign(self.Unodes[m+1]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + + for m in range(self.M+1): + self.Unodes[m].assign(self.Unodes1[m]) + self.source_Uk[m].assign(self.source_Ukp1[m]) + + x_out.assign(self.Unodes[-1]) + +class RIDC_Reduced(object, metaclass=ABCMeta): + """Class for Integral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, K, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, + limiter=None, options=None): + """ + Initialise IDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of subintervals + K (int): Max number of correction interations + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + final_update (bool, optional): Whether to compute final update, or just take last + quadrature value. Defaults to True + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + options (:class:`AdvectionOptions`, optional): an object containing + + initial_guess (str, optional): Initial guess to be base timestepper, or copy + """ + self.base = base_scheme + self.field_name = field_name + self.domain = domain + self.dt_coarse = domain.dt + self.limiter = limiter + self.augmentation = self.base.augmentation + self.wrapper = self.base.wrapper + self.K = K + self.M = M + self.dt = Constant(float(self.dt_coarse)/(self.M)) + self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) + self.Q = [] + for l in range(1, self.K+1): + integration_matrix = self.lagrange_integration_matrix(l+1) + integration_matrix = 0.5 * (l) * float(self.dt) * integration_matrix + self.Q.append(integration_matrix) + print(integration_matrix) + print(l) + print(self.K) + # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] + # self.Q = 0.5 * (self.K) * float(self.dt) * integration_matrix + + # integration_matrix = self.lagrange_integration_matrix(self.K) + # # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] + # self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix + + + #self.Q = float(self.dt_coarse) * integration_matrix + + # Set default linear and nonlinear solver options if none passed in + if linear_solver_parameters is None: + self.linear_solver_parameters = {'snes_type': 'ksponly', + 'ksp_type': 'cg', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.linear_solver_parameters = linear_solver_parameters + + if nonlinear_solver_parameters is None: + self.nonlinear_solver_parameters = {'snes_type': 'newtonls', + 'ksp_type': 'gmres', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.nonlinear_solver_parameters = nonlinear_solver_parameters + + + def setup(self, equation, apply_bcs=True, *active_labels): + """ + Set up the SDC time discretisation based on the equation.n + + Args: + equation (:class:`PrognosticEquation`): the model's equation. + apply_bcs (bool, optional): whether to apply the equation's boundary + conditions. Defaults to True. + *active_labels (:class:`Label`): labels indicating which terms of + the equation to include. + """ + # Inherit from base time discretisation + self.base.setup(equation, apply_bcs, *active_labels) + self.equation = self.base.equation + self.residual = self.base.residual + self.evaluate_source = self.base.evaluate_source + + for t in self.residual: + # Check all terms are labeled implicit or explicit + if ((not t.has_label(implicit)) and (not t.has_label(explicit)) + and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): + raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") + + # Set up bcs + self.bcs = self.base.bcs + + # Set up SDC variables + if self.field_name is not None and hasattr(equation, "field_names"): + self.idx = equation.field_names.index(self.field_name) + W = equation.spaces[self.idx] + else: + self.field_name = equation.field_name + W = equation.function_space + self.idx = None + self.W = W + self.Unodes = [Function(W) for _ in range(self.M+1)] + self.Unodes1 = [Function(W) for _ in range(self.M+1)] + self.fUnodes = [Function(W) for _ in range(self.M+1)] + self.quad = [Function(W) for _ in range(self.M+1)] + self.source_Uk = [Function(W) for _ in range(self.M+1)] + self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] + self.U_SDC = Function(W) + self.U_start = Function(W) + self.Un = Function(W) + self.Q_ = Function(W) + self.quad_final = Function(W) + self.U_fin = Function(W) + self.Urhs = Function(W) + self.Uin = Function(W) + self.source_in = Function(W) + self.source_Ukp1_m = Function(W) + self.source_Uk_m = Function(W) + self.Uk_mp1 = Function(W) + self.Uk_m = Function(W) + self.Ukp1_m = Function(W) + + @property + def nlevels(self): + return 1 + + def equidistant_nodes(self ,M): + # This returns a grid of M equispaced nodes from -1 to 1 + grid = np.linspace(-1., 1., M) + #grid = 0.5 * (grid + 1) + return grid + + def lagrange_polynomial(self, index, nodes): + # This returns the coefficients of the Lagrange polynomial l_m with m=index + + M = len(nodes) + + # c is the denominator + c = 1. + for k in range(M): + if k != index: + c *= (nodes[index] - nodes[k]) + + coeffs = np.zeros(M) + coeffs[0] = 1. + m = 0 + + for k in range(M): + if k != index: + m += 1 + d1 = np.zeros(M) + d2 = np.zeros(M) + + d1 = (-1.)*nodes[k] * coeffs + d2[1:m+1] = coeffs[0:m] + + coeffs = d1+d2 + return coeffs / c + + def integrate_polynomial(self, p): + # given a list of coefficients of a polynomial p, this returns those of the integral of p + integral_coeffs = np.zeros(len(p)+1) + + for n, pn in enumerate(p): + integral_coeffs[n+1] = 1/(n+1) * pn + + return integral_coeffs + + def evaluate(self, p, a, b): + # given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) + value = 0. + for n, pn in enumerate(p): + value += pn * (b**n - a**n) + + return value + + def lagrange_integration_matrix(self, M): + # using the functions defined above, this returns the MxM integration matrix + + # set up equidistant nodes and initialise matrix to zero + nodes = self.equidistant_nodes(M) + L = len(nodes) + int_matrix = np.zeros((L, L)) + + # fill in matrix values + for index in range(L): + coeff_p = self.lagrange_polynomial(index, nodes) + int_coeff = self.integrate_polynomial(coeff_p) + + for n in range(L-1): + int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) + + return int_matrix + + def compute_quad(self, Q, fUnodes, m): + """ + Computes integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + for k in range(0, np.shape(Q)[1]): + quad += float(Q[m, k])*fUnodes[k] + return quad + + def compute_quad_final(self, Q, fUnodes, m): + """ + Computes final integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + l = np.shape(Q)[0] - 1 + for k in range(0, l+1): + print(k, self.K) + quad += float(Q[-1, k])*fUnodes[m - l + k] + return quad + + @property + def res_rhs(self): + """Set up the residual for the calculation of F(y).""" + a = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Urhs, old_idx=self.idx), + drop) + # F(y) + L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), + drop, + replace_subject(self.Uin, old_idx=self.idx)) + L_source = self.residual.label_map(lambda t: t.has_label(source_label), + replace_subject(self.source_in, old_idx=self.idx), + drop) + residual_rhs = a - (L + L_source) + return residual_rhs.form + + @property + def res(self): + """Set up the discretisation's residual for a given node m.""" + # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation + # and y^(k)_m for N2N formulation + mass_form = self.residual.label_map( + lambda t: t.has_label(time_derivative), + map_if_false=drop) + residual = mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) + residual -= mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_start, old_idx=self.idx)) + + # Calculate source terms + r_source_kp1 = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_source_kp1 = r_source_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_source_kp1 + + r_source_k = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), + map_if_false=drop) + r_source_k = r_source_k.label_map( + all_terms, + map_if_true=lambda t: Constant(self.dt)*t) + residual -= r_source_k + + # Add on final implicit terms + # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + r_imp_kp1 = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), + map_if_false=drop) + r_imp_kp1 = r_imp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_imp_kp1 + r_imp_k = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), + map_if_false=drop) + r_imp_k = r_imp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_imp_k + + r_exp_kp1 = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_exp_kp1 = r_exp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_exp_kp1 + r_exp_k = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), + map_if_false=drop) + r_exp_k = r_exp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_exp_k + + + # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j + # and s1j = q1j. + Q = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Q_, old_idx=self.idx), + drop) + residual += Q + return residual.form + + @cached_property + def solver(self): + """Set up a list of solvers for each problem at a node m.""" + # setup solver using residual defined in derived class + problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__ + solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) + return solver + + @cached_property + def solver_rhs(self): + """Set up the problem and the solver for mass matrix inversion.""" + # setup linear solver using rhs residual defined in derived class + prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__+"_rhs" + return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, + options_prefix=solver_name) + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + + for m in range(self.M): + self.base.dt = float(self.dt) + print(self.base.dt) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + #breakpoint() + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + for k in range(1, self.K+1): + print("Correction sweep", k) + # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + for m in range(self.M+1): + self.Uin.assign(self.Unodes[m]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m].assign(self.Urhs) + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + for m in range(0, k): + # Set S matrix + self.Q_.assign(self.compute_quad(self.Q[k-1], self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_SDC.assign(self.Unodes[m+1]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + for m in range(k, self.M): + # Set S matrix + self.Q_.assign(self.compute_quad_final(self.Q[k-1], self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_SDC.assign(self.Unodes[m+1]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + + for m in range(self.M+1): + self.Unodes[m].assign(self.Unodes1[m]) + self.source_Uk[m].assign(self.source_Ukp1[m]) + + x_out.assign(self.Unodes[-1]) - x_out.assign(self.Unodes[-1]) \ No newline at end of file +# class RIDC2(IDC): +# """Class for Revisionist Integral Deferred Correction schemes.""" + +# def __init__(self, base_scheme, domain, M, K, field_name=None, +# linear_solver_parameters=None, nonlinear_solver_parameters=None, +# limiter=None, options=None): +# """ +# Initialise IDC object +# Args: +# base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on +# quadrature nodes. +# domain (:class:`Domain`): the model's domain object, containing the +# mesh and the compatible function spaces. +# M (int): Number of quadrature nodes to compute spectral integration over +# K (int): Max number of correction interations +# field_name (str, optional): name of the field to be evolved. +# Defaults to None. +# linear_solver_parameters (dict, optional): dictionary of parameters to +# pass to the underlying linear solver. Defaults to None. +# nonlinear_solver_parameters (dict, optional): dictionary of parameters to +# pass to the underlying nonlinear solver. Defaults to None. +# final_update (bool, optional): Whether to compute final update, or just take last +# quadrature value. Defaults to True +# limiter (:class:`Limiter` object, optional): a limiter to apply to +# the evolving field to enforce monotonicity. Defaults to None. +# options (:class:`AdvectionOptions`, optional): an object containing + +# initial_guess (str, optional): Initial guess to be base timestepper, or copy +# """ +# super(RIDC, self).__init__(base_scheme, domain, M, K, +# field_name, linear_solver_parameters, nonlinear_solver_parameters, +# limiter, options) +# self.dt = Constant(float(self.dt_coarse)/(self.M - 1)) +# self.nodes = np.arange(0, M*float(self.dt), float(self.dt)) +# #integration_matrix = self.lagrange_integration_matrix(self.K) +# #self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix + +# def setup(self, equation, apply_bcs=True, *active_labels): +# """ +# Set up the SDC time discretisation based on the equation.n + +# Args: +# equation (:class:`PrognosticEquation`): the model's equation. +# apply_bcs (bool, optional): whether to apply the equation's boundary +# conditions. Defaults to True. +# *active_labels (:class:`Label`): labels indicating which terms of +# the equation to include. +# """ +# super(RIDC, self).setup(equation, apply_bcs, *active_labels) + +# def compute_quad(self, Q, fUnodes, m): +# """ +# Computes integration of F(y) on quadrature nodes +# """ +# quad = Function(self.W) +# quad.assign(0.) +# for k in range(0, np.shape(Q)[0]): +# quad += float(Q[m, k])*fUnodes[k] +# return quad + +# def compute_quad_final(self, Q, fUnodes, m): +# """ +# Computes final integration of F(y) on quadrature nodes +# """ +# quad = Function(self.W) +# quad.assign(0.) +# for k in range(0, self.K): +# quad += float(Q[-1, k])*fUnodes[m + 1 - self.K + k] +# return quad + +# @wrapper_apply +# def apply(self, x_out, x_in): +# self.Un.assign(x_in) + +# # Compute initial guess on quadrature nodes with low-order +# # base timestepper +# self.Unodes[0].assign(self.Un) + +# for m in range(self.M-1): +# self.base.dt = float(self.dt) +# self.base.apply(self.Unodes[m+1], self.Unodes[m]) +# for m in range(self.M): +# for evaluate in self.evaluate_source: +# evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + +# # Iterate through correction sweeps +# for k in range(1, self.K): +# print("Correction sweep", k) +# # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) +# for m in range(self.M): +# self.Uin.assign(self.Unodes[m]) +# # Include source terms +# for evaluate in self.evaluate_source: +# evaluate(self.Uin, self.base.dt, x_out=self.source_in) +# self.solver_rhs.solve() +# self.fUnodes[m].assign(self.Urhs) +# #self.compute_quad() +# # Loop through quadrature nodes and solve +# self.Unodes1[0].assign(self.Unodes[0]) +# for evaluate in self.evaluate_source: +# evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) +# for m in range(0, self.M-1): +# # Set S matrix +# self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) +# #self.Q_.assign(self.quad[m-1]) + +# # Set initial guess for solver, and pick correct solver +# self.U_start.assign(self.Unodes1[m]) +# self.Ukp1_m.assign(self.Unodes1[m]) +# self.Uk_mp1.assign(self.Unodes[m+1]) +# self.Uk_m.assign(self.Unodes[m]) +# self.source_Ukp1_m.assign(self.source_Ukp1[m]) +# self.source_Uk_m.assign(self.source_Uk[m]) +# self.U_SDC.assign(self.Unodes[m+1]) + +# # Compute +# # for N2N: +# # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) +# # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) +# # + sum(j=1,M) s_mj*(F+S)(y^k) +# self.solver.solve() +# self.Unodes1[m+1].assign(self.U_SDC) + +# # Evaluate source terms +# for evaluate in self.evaluate_source: +# evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + +# # Apply limiter if required +# if self.limiter is not None: +# self.limiter.apply(self.Unodes1[m+1]) +# # for m in range(self.K-1, self.M-1): +# # # Set S matrix +# # self.Q_.assign(self.compute_quad_final(self.Q, self.fUnodes, m+1)) + +# # # Set initial guess for solver, and pick correct solver +# # self.U_start.assign(self.Unodes1[m]) +# # self.Ukp1_m.assign(self.Unodes1[m]) +# # self.Uk_mp1.assign(self.Unodes[m+1]) +# # self.Uk_m.assign(self.Unodes[m]) +# # self.source_Ukp1_m.assign(self.source_Ukp1[m]) +# # self.source_Uk_m.assign(self.source_Uk[m]) +# # self.U_SDC.assign(self.Unodes[m+1]) + +# # # Compute +# # # for N2N: +# # # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) +# # # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) +# # # + sum(j=1,M) s_mj*(F+S)(y^k) +# # self.solver.solve() +# # self.Unodes1[m+1].assign(self.U_SDC) + +# # # Evaluate source terms +# # for evaluate in self.evaluate_source: +# # evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + +# # # Apply limiter if required +# # if self.limiter is not None: +# # self.limiter.apply(self.Unodes1[m+1]) + +# for m in range(self.M): +# self.Unodes[m].assign(self.Unodes1[m]) +# self.source_Uk[m].assign(self.source_Ukp1[m]) + +# x_out.assign(self.Unodes[-1]) \ No newline at end of file From 801bb063ce46bce730f965df422244199b721870 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Wed, 2 Apr 2025 11:07:07 +0100 Subject: [PATCH 04/43] Parallel RIDC is working :) --- .../explicit_runge_kutta.py | 9 - gusto/time_discretisation/sdc.py | 509 +++++++++++++++++- 2 files changed, 495 insertions(+), 23 deletions(-) diff --git a/gusto/time_discretisation/explicit_runge_kutta.py b/gusto/time_discretisation/explicit_runge_kutta.py index 77cff79d4..34c47b977 100644 --- a/gusto/time_discretisation/explicit_runge_kutta.py +++ b/gusto/time_discretisation/explicit_runge_kutta.py @@ -321,7 +321,6 @@ def solve_stage(self, x0, stage): if self.rk_formulation == RungeKuttaFormulation.increment: self.x1.assign(x0) - print("stage", stage) for i in range(stage): self.x1.assign(self.x1 + self.dt*self.butcher_matrix[stage-1, i]*self.k[i]) for evaluate in self.evaluate_source: @@ -340,9 +339,6 @@ def solve_stage(self, x0, stage): self.x1.assign(x0) for i in range(self.nStages): self.x1.assign(self.x1 + self.dt*self.butcher_matrix[stage, i]*self.k[i]) - #self.x1.assign(self.x1) - print("Hello") - if self.limiter is not None: self.limiter.apply(self.x1) @@ -452,11 +448,6 @@ def apply_cycle(self, x_out, x_in): for i in range(self.nStages): self.solve_stage(x_in, i) x_out.assign(self.x1) - self.xdiff = Function(self.fs) - self.xdiff.assign(x_out-x_in) - print("xdiff", np.max(self.xdiff.dat.data), np.min(self.xdiff.dat.data)) - #breakpoint() - class ForwardEuler(ExplicitRungeKutta): """ diff --git a/gusto/time_discretisation/sdc.py b/gusto/time_discretisation/sdc.py index 5ee2b1bb7..64089a0c2 100644 --- a/gusto/time_discretisation/sdc.py +++ b/gusto/time_discretisation/sdc.py @@ -57,7 +57,12 @@ from qmat import genQCoeffs, genQDeltaCoeffs -__all__ = ["SDC", "IDC", "RIDC", "RIDC_Reduced"] +from gusto.core.logging import ( + logger, DEBUG, logging_ksp_monitor_true_residual, + attach_custom_monitor +) + +__all__ = ["SDC", "IDC", "RIDC", "RIDC_Reduced", "Parallel_RIDC"] class SDC(object, metaclass=ABCMeta): @@ -416,7 +421,6 @@ def apply(self, x_out, x_in): for m in range(self.M): self.base.dt = float(self.dtau[m]) self.base.apply_cycle(self.Unodes[m+1], self.Unodes[m]) - print(self.base.dt) else: for m in range(self.M): self.Unodes[m+1].assign(self.Un) @@ -425,7 +429,6 @@ def apply(self, x_out, x_in): evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) self.Udiff = Function(self.W) self.Udiff.assign(self.Unodes[0] - self.Unodes[-1]) - print(np.max(self.Udiff.dat.data[:]), np.min(self.Udiff.dat.data[:])) #breakpoint() # Iterate through correction sweeps @@ -555,6 +558,7 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, self.K = K self.M = M self.dt = Constant(float(self.dt_coarse)/(self.M)) + print("dt", float(self.dt)) self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) integration_matrix = self.lagrange_integration_matrix(self.M+1) @@ -838,7 +842,6 @@ def apply(self, x_out, x_in): for m in range(self.M): self.base.dt = float(self.dt) - print(self.base.dt) self.base.apply(self.Unodes[m+1], self.Unodes[m]) #breakpoint() for m in range(self.M+1): @@ -847,7 +850,6 @@ def apply(self, x_out, x_in): # Iterate through correction sweeps for k in range(1, self.K+1): - print("Correction sweep", k) # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) for m in range(self.M+1): self.Uin.assign(self.Unodes[m]) @@ -1110,7 +1112,6 @@ def compute_quad_final(self, Q, fUnodes, m): quad = Function(self.W) quad.assign(0.) for k in range(0, self.K+1): - print(k, self.K) quad += float(Q[-1, k])*fUnodes[m - self.K + k] return quad @@ -1235,7 +1236,6 @@ def apply(self, x_out, x_in): for m in range(self.M): self.base.dt = float(self.dt) - print(self.base.dt) self.base.apply(self.Unodes[m+1], self.Unodes[m]) #breakpoint() for m in range(self.M+1): @@ -1244,7 +1244,6 @@ def apply(self, x_out, x_in): # Iterate through correction sweeps for k in range(1, self.K+1): - print("Correction sweep", k) # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) for m in range(self.M+1): self.Uin.assign(self.Unodes[m]) @@ -1366,9 +1365,6 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, integration_matrix = self.lagrange_integration_matrix(l+1) integration_matrix = 0.5 * (l) * float(self.dt) * integration_matrix self.Q.append(integration_matrix) - print(integration_matrix) - print(l) - print(self.K) # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] # self.Q = 0.5 * (self.K) * float(self.dt) * integration_matrix @@ -1543,7 +1539,6 @@ def compute_quad_final(self, Q, fUnodes, m): quad.assign(0.) l = np.shape(Q)[0] - 1 for k in range(0, l+1): - print(k, self.K) quad += float(Q[-1, k])*fUnodes[m - l + k] return quad @@ -1668,7 +1663,6 @@ def apply(self, x_out, x_in): for m in range(self.M): self.base.dt = float(self.dt) - print(self.base.dt) self.base.apply(self.Unodes[m+1], self.Unodes[m]) #breakpoint() for m in range(self.M+1): @@ -1677,7 +1671,6 @@ def apply(self, x_out, x_in): # Iterate through correction sweeps for k in range(1, self.K+1): - print("Correction sweep", k) # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) for m in range(self.M+1): self.Uin.assign(self.Unodes[m]) @@ -1754,6 +1747,494 @@ def apply(self, x_out, x_in): x_out.assign(self.Unodes[-1]) +class Parallel_RIDC(object, metaclass=ABCMeta): + """Class for Integral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, K, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, + limiter=None, options=None, communicator=None): + """ + Initialise IDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of subintervals + K (int): Max number of correction interations + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + final_update (bool, optional): Whether to compute final update, or just take last + quadrature value. Defaults to True + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + options (:class:`AdvectionOptions`, optional): an object containing + + initial_guess (str, optional): Initial guess to be base timestepper, or copy + communicator (MPI communicator, optional): communicator for parallel execution. Defaults to None. + """ + self.base = base_scheme + self.field_name = field_name + self.domain = domain + self.dt_coarse = domain.dt + self.limiter = limiter + self.augmentation = self.base.augmentation + self.wrapper = self.base.wrapper + self.K = K + self.M = M + self.dt = Constant(float(self.dt_coarse)/(self.M)) + self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) + self.Q = [] + for l in range(1, self.K+1): + integration_matrix = self.lagrange_integration_matrix(l+1) + integration_matrix = 0.5 * (l) * float(self.dt) * integration_matrix + self.Q.append(integration_matrix) + # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] + # self.Q = 0.5 * (self.K) * float(self.dt) * integration_matrix + + # integration_matrix = self.lagrange_integration_matrix(self.K) + # # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] + # self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix + + + #self.Q = float(self.dt_coarse) * integration_matrix + + # Set default linear and nonlinear solver options if none passed in + if linear_solver_parameters is None: + self.linear_solver_parameters = {'snes_type': 'ksponly', + 'ksp_type': 'cg', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.linear_solver_parameters = linear_solver_parameters + + if nonlinear_solver_parameters is None: + self.nonlinear_solver_parameters = {'snes_type': 'newtonls', + 'ksp_type': 'gmres', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.nonlinear_solver_parameters = nonlinear_solver_parameters + self.comm = communicator + + + def setup(self, equation, apply_bcs=True, *active_labels): + """ + Set up the SDC time discretisation based on the equation.n + + Args: + equation (:class:`PrognosticEquation`): the model's equation. + apply_bcs (bool, optional): whether to apply the equation's boundary + conditions. Defaults to True. + *active_labels (:class:`Label`): labels indicating which terms of + the equation to include. + """ + # Inherit from base time discretisation + self.base.setup(equation, apply_bcs, *active_labels) + self.equation = self.base.equation + self.residual = self.base.residual + self.evaluate_source = self.base.evaluate_source + + for t in self.residual: + # Check all terms are labeled implicit or explicit + if ((not t.has_label(implicit)) and (not t.has_label(explicit)) + and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): + raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") + + # Set up bcs + self.bcs = self.base.bcs + + # Set up SDC variables + if self.field_name is not None and hasattr(equation, "field_names"): + self.idx = equation.field_names.index(self.field_name) + W = equation.spaces[self.idx] + else: + self.field_name = equation.field_name + W = equation.function_space + self.idx = None + self.W = W + self.Unodes = [Function(W) for _ in range(self.M+1)] + self.Unodes1 = [Function(W) for _ in range(self.M+1)] + self.fUnodes = [Function(W) for _ in range(self.M+1)] + self.quad = [Function(W) for _ in range(self.M+1)] + self.source_Uk = [Function(W) for _ in range(self.M+1)] + self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] + self.U_SDC = Function(W) + self.U_start = Function(W) + self.Un = Function(W) + self.Q_ = Function(W) + self.quad_final = Function(W) + self.U_fin = Function(W) + self.Urhs = Function(W) + self.Uin = Function(W) + self.source_in = Function(W) + self.source_Ukp1_m = Function(W) + self.source_Uk_m = Function(W) + self.Uk_mp1 = Function(W) + self.Uk_m = Function(W) + self.Ukp1_m = Function(W) + self.node_count = 0 + + @property + def nlevels(self): + return 1 + + def equidistant_nodes(self ,M): + # This returns a grid of M equispaced nodes from -1 to 1 + grid = np.linspace(-1., 1., M) + #grid = 0.5 * (grid + 1) + return grid + + def lagrange_polynomial(self, index, nodes): + # This returns the coefficients of the Lagrange polynomial l_m with m=index + + M = len(nodes) + + # c is the denominator + c = 1. + for k in range(M): + if k != index: + c *= (nodes[index] - nodes[k]) + + coeffs = np.zeros(M) + coeffs[0] = 1. + m = 0 + + for k in range(M): + if k != index: + m += 1 + d1 = np.zeros(M) + d2 = np.zeros(M) + + d1 = (-1.)*nodes[k] * coeffs + d2[1:m+1] = coeffs[0:m] + + coeffs = d1+d2 + return coeffs / c + + def integrate_polynomial(self, p): + # given a list of coefficients of a polynomial p, this returns those of the integral of p + integral_coeffs = np.zeros(len(p)+1) + + for n, pn in enumerate(p): + integral_coeffs[n+1] = 1/(n+1) * pn + + return integral_coeffs + + def evaluate(self, p, a, b): + # given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) + value = 0. + for n, pn in enumerate(p): + value += pn * (b**n - a**n) + + return value + + def lagrange_integration_matrix(self, M): + # using the functions defined above, this returns the MxM integration matrix + + # set up equidistant nodes and initialise matrix to zero + nodes = self.equidistant_nodes(M) + L = len(nodes) + int_matrix = np.zeros((L, L)) + + # fill in matrix values + for index in range(L): + coeff_p = self.lagrange_polynomial(index, nodes) + int_coeff = self.integrate_polynomial(coeff_p) + + for n in range(L-1): + int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) + + return int_matrix + + def compute_quad(self, Q, fUnodes, m): + """ + Computes integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + for k in range(0, np.shape(Q)[1]): + quad += float(Q[m, k])*fUnodes[k] + return quad + + def compute_quad_final(self, Q, fUnodes, m): + """ + Computes final integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + l = np.shape(Q)[0] - 1 + for k in range(0, l+1): + quad += float(Q[-1, k])*fUnodes[m - l + k] + return quad + + @property + def res_rhs(self): + """Set up the residual for the calculation of F(y).""" + a = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Urhs, old_idx=self.idx), + drop) + # F(y) + L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), + drop, + replace_subject(self.Uin, old_idx=self.idx)) + L_source = self.residual.label_map(lambda t: t.has_label(source_label), + replace_subject(self.source_in, old_idx=self.idx), + drop) + residual_rhs = a - (L + L_source) + return residual_rhs.form + + @property + def res(self): + """Set up the discretisation's residual for a given node m.""" + # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation + # and y^(k)_m for N2N formulation + mass_form = self.residual.label_map( + lambda t: t.has_label(time_derivative), + map_if_false=drop) + residual = mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) + residual -= mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_start, old_idx=self.idx)) + + # Calculate source terms + r_source_kp1 = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_source_kp1 = r_source_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_source_kp1 + + r_source_k = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), + map_if_false=drop) + r_source_k = r_source_k.label_map( + all_terms, + map_if_true=lambda t: Constant(self.dt)*t) + residual -= r_source_k + + # Add on final implicit terms + # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + r_imp_kp1 = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), + map_if_false=drop) + r_imp_kp1 = r_imp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_imp_kp1 + r_imp_k = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), + map_if_false=drop) + r_imp_k = r_imp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_imp_k + + r_exp_kp1 = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_exp_kp1 = r_exp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_exp_kp1 + r_exp_k = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), + map_if_false=drop) + r_exp_k = r_exp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_exp_k + + + # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j + # and s1j = q1j. + Q = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Q_, old_idx=self.idx), + drop) + residual += Q + return residual.form + + @cached_property + def solver(self): + """Set up a list of solvers for each problem at a node m.""" + # setup solver using residual defined in derived class + problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__ + solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) + return solver + + @cached_property + def solver_rhs(self): + """Set up the problem and the solver for mass matrix inversion.""" + # setup linear solver using rhs residual defined in derived class + prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__+"_rhs" + return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, + options_prefix=solver_name) + + @wrapper_apply + def apply(self, x_out, x_in): + + # Parallelised code: + # Correction sweep + x_out.assign(x_in) + self.kval = self.comm.ensemble_comm.rank + #logger.info(f'Communicator: {self.kval:.2e}') + #breakpoint() + self.Un.assign(x_in) + self.Unodes[0].assign(self.Un) + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + self.Uin.assign(self.Unodes[0]) + self.solver_rhs.solve() + self.fUnodes[0].assign(self.Urhs) + # On first communicator + if (self.comm.ensemble_comm.rank == 0): + logger.info(f'Starting base timestepper: {self.kval:.2e}') + # base timestepper + + for m in range(self.M): + self.base.dt = float(self.dt) + logger.info(f'Base stepper: {self.kval:.2e}') + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + logger.info(f'Base stepper done: {self.kval:.2e}') + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m+1], self.base.dt, x_out=self.source_Uk[m+1]) + + # Send base guess to k+1 correction + #breakpoint() + # for i in range(1, self.K): + # self.comm.send(self.node_count, dest=int(i), tag=11) # Send data to all other processes + self.comm.send(self.Unodes[m+1], dest=self.kval+1, tag=100+m+1) + logger.info(f'Sent data to process {self.kval+1:.2e} from process {self.kval:.2e}') + + else: + for m in range(1, self.kval + 1): + self.comm.recv(self.Unodes[m], source=self.kval-1, tag=100+m) + logger.info(f'Recieved data to process {self.kval:.2e} from process {self.kval-1:.2e}') + self.Uin.assign(self.Unodes[m]) + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m].assign(self.Urhs) + # if k == 1: + # for m in range(0, self.kval*(self.kval+1)/2): + # self.Unodes[m] = self.comm.recv(source=0, tag=123) + # else: + # for m in range(0, self.kval*(self.kval+1)/2): + # self.Unodes1[m] = self.comm.recv(source=self.kval-1, tag=11) + # self.Unodes[m].assign(self.Unodes1[m]) + for m in range(0, self.kval): + # if (m >= self.kval*(self.kval+1)/2): + # self.Unodes1[m] = self.comm.recv(source=self.kval-1, tag=11) + # self.Unodes[m].assign(self.Unodes1[m]) + # Get f(u) + # if (m > self.kval+2): + # self.Uin.assign(self.Unodes[m]) + # # Include source terms + # for evaluate in self.evaluate_source: + # evaluate(self.Uin, self.base.dt, x_out=self.source_in) + # self.solver_rhs.solve() + # self.fUnodes[m].assign(self.Urhs) + # Set S matrix + self.Q_.assign(self.compute_quad(self.Q[self.kval-1], self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_SDC.assign(self.Unodes[m+1]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + # Send our updated value to next communicator + if self.kval < self.K: + self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) + logger.info(f'Sent data to process {self.kval+1:.2e} from process {self.kval:.2e}') + + for m in range(self.kval, self.M): + self.comm.recv(self.Unodes[m+1], source=self.kval-1, tag=100+m+1) + logger.info(f'Recieved data to process {self.kval:.2e} from process {self.kval-1:.2e}') + self.Uin.assign(self.Unodes[m+1]) + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m+1].assign(self.Urhs) + + # Set S matrix + self.Q_.assign(self.compute_quad_final(self.Q[self.kval-1], self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_SDC.assign(self.Unodes[m+1]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + + # Send our updated value to next communicator + if self.kval < self.K: + self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) + logger.info(f'Sent data to process {self.kval+1:.2e} from process {self.kval:.2e}') + + if (self.kval == self.K): + logger.info(f'Broadcasting for: {self.kval:.2e}') + x_out.assign(self.Unodes1[-1]) + for i in range(self.K): + # Send the final result to all other ranks + self.comm.send(x_out, dest=i, tag=200) + else: + + # Receive the final result from Rank K + self.comm.recv(x_out, source=self.K, tag=200) + + # class RIDC2(IDC): # """Class for Revisionist Integral Deferred Correction schemes.""" From ab5645785eebe8dab9814b8077706c0b5bc1cdd3 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Fri, 4 Apr 2025 14:48:13 +0100 Subject: [PATCH 05/43] Parallel IO working --- gusto/core/io.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gusto/core/io.py b/gusto/core/io.py index 1e69a8651..a375f6cf3 100644 --- a/gusto/core/io.py +++ b/gusto/core/io.py @@ -55,7 +55,7 @@ def pick_up_mesh(output, mesh_name, comm=COMM_WORLD): else: dumpdir = path.join("results", output.dirname) chkfile = path.join(dumpdir, "chkpt.h5") - with CheckpointFile(chkfile, 'r', comm=comm) as chk: + with CheckpointFile(chkfile, 'r', comm=mesh.comm) as chk: mesh = chk.load_mesh(mesh_name) if dumpdir: @@ -635,7 +635,7 @@ def pick_up_from_checkpoint(self, state_fields, comm=COMM_WORLD): step = chk.read_attribute("/", "step") else: - with CheckpointFile(chkfile, 'r', comm) as chk: + with CheckpointFile(chkfile, 'r', self.domain.mesh.comm) as chk: mesh = self.domain.mesh # Recover compulsory fields from the checkpoint for field_name in self.to_pick_up: @@ -739,7 +739,7 @@ def dump(self, state_fields, time_data): if last_ref_update_time is not None: self.chkpt.write_attribute("/", "last_ref_update_time", last_ref_update_time) else: - with CheckpointFile(self.chkpt_path, 'w') as chk: + with CheckpointFile(self.chkpt_path, 'w', self.mesh.comm) as chk: chk.save_mesh(self.domain.mesh) for field_name in self.to_pick_up: chk.save_function(state_fields(field_name), name=field_name) @@ -943,7 +943,7 @@ def make_nc_dataset(filename, access, comm): """ try: - nc_field_file = Dataset(filename, access, parallel=True) + nc_field_file = Dataset(filename, access, parallel=True, comm=comm) nc_supports_parallel = True except ValueError: # parallel netCDF not available, use the serial version instead From 465b79281d2b5fe93bc7bfa74124e93b10b22d8e Mon Sep 17 00:00:00 2001 From: atb1995 Date: Fri, 25 Apr 2025 11:01:33 +0100 Subject: [PATCH 06/43] Adding time parallel SDC --- gusto/time_discretisation/sdc.py | 180 ++++++++++++++++++++++++++++++- 1 file changed, 178 insertions(+), 2 deletions(-) diff --git a/gusto/time_discretisation/sdc.py b/gusto/time_discretisation/sdc.py index 64089a0c2..7306d809c 100644 --- a/gusto/time_discretisation/sdc.py +++ b/gusto/time_discretisation/sdc.py @@ -62,7 +62,7 @@ attach_custom_monitor ) -__all__ = ["SDC", "IDC", "RIDC", "RIDC_Reduced", "Parallel_RIDC"] +__all__ = ["SDC", "IDC", "RIDC", "RIDC_Reduced", "Parallel_RIDC", "Parallel_SDC"] class SDC(object, metaclass=ABCMeta): @@ -2398,4 +2398,180 @@ def apply(self, x_out, x_in): # self.Unodes[m].assign(self.Unodes1[m]) # self.source_Uk[m].assign(self.source_Ukp1[m]) -# x_out.assign(self.Unodes[-1]) \ No newline at end of file +# x_out.assign(self.Unodes[-1]) + +class Parallel_SDC(SDC): + """Class for Spectral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, + field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, + limiter=None, options=None, initial_guess="base", communicator=None): + """ + Initialise SDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of quadrature nodes to compute spectral integration over + maxk (int): Max number of correction interations + quad_type (str): Type of quadrature to be used. Options are + GAUSS, RADAU-LEFT, RADAU-RIGHT and LOBATTO + node_type (str): Node type to be used. Options are + EQUID, LEGENDRE, CHEBY-1, CHEBY-2, CHEBY-3 and CHEBY-4 + qdelta_imp (str): Implicit Qdelta matrix to be used. Options are + BE, LU, TRAP, EXACT, PIC, OPT, WEIRD, MIN-SR-NS, MIN-SR-S + qdelta_exp (str): Explicit Qdelta matrix to be used. Options are + FE, EXACT, PIC + formulation (str, optional): Whether to use node-to-node or zero-to-node + formulation. Options are N2N and Z2N. Defaults to N2N + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + final_update (bool, optional): Whether to compute final update, or just take last + quadrature value. Defaults to True + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + options (:class:`AdvectionOptions`, optional): an object containing + options to either be passed to the spatial discretisation, or + to control the "wrapper" methods, such as Embedded DG or a + recovery method. Defaults to None. + initial_guess (str, optional): Initial guess to be base timestepper, or copy + """ + super().__init__(base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, + formulation = "Z2N", field_name=field_name, + linear_solver_parameters=linear_solver_parameters, nonlinear_solver_parameters= nonlinear_solver_parameters, + final_update=final_update, + limiter=limiter, options=options, initial_guess=initial_guess) + self.comm = communicator + + def compute_quad(self): + """ + Computes integration of F(y) on quadrature nodes + """ + x = Function(self.W) + for j in range(self.M): + x.assign(float(self.Q[j, self.comm.ensemble_comm.rank])*self.fUnodes[self.comm.ensemble_comm.rank]) + self.comm.reduce(x, self.quad[j], root=j) + + def compute_quad_final(self): + """ + Computes final integration of F(y) on quadrature nodes + """ + x = Function(self.W) + x.assign(float(self.Qfin[self.comm.ensemble_comm.rank])*self.fUnodes[self.comm.ensemble_comm.rank]) + self.comm.allreduce(x, self.quad_final) + + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + self.U_start.assign(self.Un) + solver_list = self.solvers + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + if (self.base_flag): + for m in range(self.M): + self.base.dt = float(self.dtau[m]) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + else: + for m in range(self.M): + self.Unodes[m+1].assign(self.Un) + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + k = 0 + while k < self.maxk: + k += 1 + + if self.qdelta_imp_type == "MIN-SR-FLEX": + # Recompute Implicit Q_delta matrix for each iteration k + self.Qdelta_imp = genQDeltaCoeffs( + self.qdelta_imp_type, + form=self.formulation, + nodes=self.nodes, + Q=self.Q, + nNodes=self.M, + nodeType=self.node_type, + quadType=self.quad_type, + k=k + ) + + # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + # for Z2N: sum(j=1,M) (q_mj*F(y_m^k) + q_mj*S(y_m^k)) + self.Uin.assign(self.Unodes[self.comm.ensemble_comm.rank+1]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[self.comm.ensemble_comm.rank].assign(self.Urhs) + + self.compute_quad() + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + + + # Set Q or S matrix + self.Q_.assign(self.quad[self.comm.ensemble_comm.rank]) + + # Set initial guess for solver, and pick correct solver + self.solver = solver_list[self.comm.ensemble_comm.rank] + self.U_SDC.assign(self.Unodes[self.comm.ensemble_comm.rank+1]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + # for Z2N: + # y_m^(k+1) = y^n + sum(j=1,m) Qdelta_imp[m,j]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + # + sum(j=1,M) Q_delta_exp[m,j]*(S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + self.solver.solve() + self.Unodes1[self.comm.ensemble_comm.rank+1].assign(self.U_SDC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[self.comm.ensemble_comm.rank+1], self.base.dt, x_out=self.source_Ukp1[self.comm.ensemble_comm.rank+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[self.comm.ensemble_comm.rank+1]) + + self.Unodes[self.comm.ensemble_comm.rank+1].assign(self.Unodes1[self.comm.ensemble_comm.rank+1]) + self.source_Uk[self.comm.ensemble_comm.rank+1].assign(self.source_Ukp1[self.comm.ensemble_comm.rank+1]) + + if self.maxk > 0: + # Compute value at dt rather than final quadrature node tau_M + if self.final_update: + self.Uin.assign(self.Unodes1[self.comm.ensemble_comm.rank+1]) + self.source_in.assign(self.source_Ukp1[self.comm.ensemble_comm.rank+1]) + self.solver_rhs.solve() + self.fUnodes[self.comm.ensemble_comm.rank].assign(self.Urhs) + self.compute_quad_final() + # Compute y_(n+1) = y_n + sum(j=1,M) q_j*F(y_j) + if self.comm.ensemble_comm.rank == self.M-1: + self.U_fin.assign(self.Unodes[-1]) + self.comm.bcast(self.U_fin, self.M -1) + self.solver_fin.solve() + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.U_fin) + x_out.assign(self.U_fin) + else: + # Take value at final quadrature node dtau_M + if self.comm.ensemble_comm.rank == self.M-1: + x_out.assign(self.Unodes[-1]) + self.comm.bcast(x_out, self.M -1) + else: + x_out.assign(self.Unodes[-1]) From 7b6b42e09eb01549085fbd62563dc19719bc16e1 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Mon, 12 May 2025 14:42:00 +0100 Subject: [PATCH 07/43] Small changes to SDC and RIDC --- gusto/time_discretisation/sdc.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/gusto/time_discretisation/sdc.py b/gusto/time_discretisation/sdc.py index b41e1df68..354056879 100644 --- a/gusto/time_discretisation/sdc.py +++ b/gusto/time_discretisation/sdc.py @@ -133,6 +133,7 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im # Rescale to be over [0,dt] rather than [0,1] self.nodes = float(self.dt_coarse)*self.nodes + self.dtau = np.diff(np.append(0, self.nodes)) self.Q = float(self.dt_coarse)*self.Q self.Qfin = float(self.dt_coarse)*self.weights @@ -140,13 +141,14 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im self.formulation = formulation self.node_type = node_type self.quad_type = quad_type + # breakpoint() + # prin # Get Q_delta matrices self.Qdelta_imp = genQDeltaCoeffs(qdelta_imp, form=formulation, nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) self.Qdelta_exp = genQDeltaCoeffs(qdelta_exp, form=formulation, nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) - # Set default linear and nonlinear solver options if none passed in if linear_solver_parameters is None: self.linear_solver_parameters = {'snes_type': 'ksponly', @@ -427,9 +429,6 @@ def apply(self, x_out, x_in): for m in range(self.M+1): for evaluate in self.evaluate_source: evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) - self.Udiff = Function(self.W) - self.Udiff.assign(self.Unodes[0] - self.Unodes[-1]) - #breakpoint() # Iterate through correction sweeps k = 0 From e9687a771b9aed961cdd59d71f05eb51887d7efe Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 13 May 2025 11:15:28 +0100 Subject: [PATCH 08/43] Parallel test added --- gusto/time_discretisation/__init__.py | 3 +- .../deferred_correction.py | 975 +++++++ gusto/time_discretisation/parallel_dc.py | 365 +++ gusto/time_discretisation/sdc.py | 2576 ----------------- .../time_discretisation.py | 1 - integration-tests/conftest.py | 43 +- ...est_sdc.py => test_deferred_correction.py} | 34 +- integration-tests/model/test_parallel_dc.py | 70 + 8 files changed, 1472 insertions(+), 2595 deletions(-) create mode 100644 gusto/time_discretisation/deferred_correction.py create mode 100644 gusto/time_discretisation/parallel_dc.py delete mode 100644 gusto/time_discretisation/sdc.py rename integration-tests/model/{test_sdc.py => test_deferred_correction.py} (68%) create mode 100644 integration-tests/model/test_parallel_dc.py diff --git a/gusto/time_discretisation/__init__.py b/gusto/time_discretisation/__init__.py index 20b428449..e26959ad2 100644 --- a/gusto/time_discretisation/__init__.py +++ b/gusto/time_discretisation/__init__.py @@ -4,4 +4,5 @@ from gusto.time_discretisation.imex_runge_kutta import * # noqa from gusto.time_discretisation.multi_level_schemes import * # noqa from gusto.time_discretisation.wrappers import * # noqa -from gusto.time_discretisation.sdc import * # noqa \ No newline at end of file +from gusto.time_discretisation.deferred_correction import * # noqa +from gusto.time_discretisation.parallel_dc import * # noqa \ No newline at end of file diff --git a/gusto/time_discretisation/deferred_correction.py b/gusto/time_discretisation/deferred_correction.py new file mode 100644 index 000000000..99a5e1968 --- /dev/null +++ b/gusto/time_discretisation/deferred_correction.py @@ -0,0 +1,975 @@ +u""" +Objects for discretising time derivatives using Deferred Correction (DC) +Methods. We have Spectral Deferred Correction (SDC) and Serial Revisionist Integral +Deferred Correction (RIDC) methods. + +SDC and RIDC objects discretise ∂y/∂t = F(y), for variable y, time t and +operator F. + +Written in Picard integral form this equation is +y(t) = y_n + int[t_n,t] F(y(s)) ds + +================================================================================ +SDC Formulation: +================================================================================ + +SDC methods are based on the idea of integrating the function F(y) over the +interval [t_n, t_n+1] using quadrature. We can then evaluate the function +using some quadrature rule, we can evaluate y on a temporal quadrature node as +y_m = y_n + sum[j=1,M] q_mj*F(y_j) +where q_mj can be found by integrating Lagrange polynomials. This is similar to +how Runge-Kutta methods are formed. + +In matrix form this equation is: +(I - dt*Q*F)(y)=y_n + +Computing y by Picard iteration through k we get: +y^(k+1)=y^k + (y_n - (I - dt*Q*F)(y^k)) + +Finally, to get our SDC method we precondition this system, using some approximation +of Q, Q_delta: +(I - dt*Q_delta*F)(y^(k+1)) = y_n + dt*(Q - Q_delta)F(y^k) + +The zero-to-node (Z2N) formulation is then: +y_m^(k+1) = y_n + sum(j=1,M) q'_mj*(F(y_j^(k+1)) - F(y_j^k)) + + sum(j=1,M) q_mj*F(y_(m-1)^k) +for entires q_mj in Q and q'_mj in Q_delta. + +Node-wise from previous quadrature node (N2N formulation), the implicit SDC calculation is: +y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k)) + + sum(j=1,M) s_mj*F(y_(m-1)^k) +where s_mj = q_mj - q_(m-1)j for entires q_ik in Q. + +Key choices in our SDC method are: +- Choice of quadrature node type (e.g. gauss-lobatto) +- Number of quadrature nodes +- Number of iterations - each iteration increases the order of accuracy up to + the order of the underlying quadrature +- Choice of Q_delta (e.g. Forward Euler, Backward Euler, LU-trick) +- How to get initial solution on quadrature nodes + +================================================================================ +RIDC Formulation: +================================================================================ + +RIDC methods are closely related to SDC methods, but use equidistant nodes and +a slightly different formulation, discretising the error equation in a different way. +The idea is to use a low-order method to get an initial guess of the solution, and then +correct this solution using a high-order method. The correction is done by solving +the error equation, which is derived from the original equation. +SDC can also be thought of in this way. + +The error equation is: + + +""" + +from abc import ABCMeta +import numpy as np +from firedrake import ( + Function, NonlinearVariationalProblem, NonlinearVariationalSolver, Constant +) +from firedrake.fml import ( + replace_subject, all_terms, drop +) +from firedrake.utils import cached_property +from gusto.time_discretisation.time_discretisation import wrapper_apply +from gusto.core.labels import (time_derivative, implicit, explicit, source_label) +from qmat import genQCoeffs, genQDeltaCoeffs + +__all__ = ["SDC", "RIDC"] + +class SDC(object, metaclass=ABCMeta): + """Class for Spectral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, + formulation="N2N", field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, + limiter=None, options=None, initial_guess="base"): + """ + Initialise SDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of quadrature nodes to compute spectral integration over + maxk (int): Max number of correction interations + quad_type (str): Type of quadrature to be used. Options are + GAUSS, RADAU-LEFT, RADAU-RIGHT and LOBATTO + node_type (str): Node type to be used. Options are + EQUID, LEGENDRE, CHEBY-1, CHEBY-2, CHEBY-3 and CHEBY-4 + qdelta_imp (str): Implicit Qdelta matrix to be used. Options are + BE, LU, TRAP, EXACT, PIC, OPT, WEIRD, MIN-SR-NS, MIN-SR-S + qdelta_exp (str): Explicit Qdelta matrix to be used. Options are + FE, EXACT, PIC + formulation (str, optional): Whether to use node-to-node or zero-to-node + formulation. Options are N2N and Z2N. Defaults to N2N + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + final_update (bool, optional): Whether to compute final update, or just take last + quadrature value. Defaults to True + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + options (:class:`AdvectionOptions`, optional): an object containing + options to either be passed to the spatial discretisation, or + to control the "wrapper" methods, such as Embedded DG or a + recovery method. Defaults to None. + initial_guess (str, optional): Initial guess to be base timestepper, or copy + """ + # Check the configuration options + if (not (formulation == "N2N" or formulation == "Z2N")): + raise ValueError('Formulation not implemented') + + # Initialise parameters + self.base = base_scheme + self.base.dt = domain.dt + self.field_name = field_name + self.domain = domain + self.dt_coarse = domain.dt + self.M = M + self.maxk = maxk + self.final_update = final_update + self.formulation = formulation + self.limiter = limiter + self.augmentation = self.base.augmentation + self.wrapper = self.base.wrapper + + # Get quadrature nodes and weights + self.nodes, self.weights, self.Q = genQCoeffs("Collocation", nNodes=M, + nodeType=node_type, + quadType=quad_type, + form=formulation) + + # Rescale to be over [0,dt] rather than [0,1] + self.nodes = float(self.dt_coarse)*self.nodes + + self.dtau = np.diff(np.append(0, self.nodes)) + self.Q = float(self.dt_coarse)*self.Q + self.Qfin = float(self.dt_coarse)*self.weights + self.qdelta_imp_type = qdelta_imp + self.formulation = formulation + self.node_type = node_type + self.quad_type = quad_type + + # Get Q_delta matrices + self.Qdelta_imp = genQDeltaCoeffs(qdelta_imp, form=formulation, + nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) + self.Qdelta_exp = genQDeltaCoeffs(qdelta_exp, form=formulation, + nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) + # Set default linear and nonlinear solver options if none passed in + if linear_solver_parameters is None: + self.linear_solver_parameters = {'snes_type': 'ksponly', + 'ksp_type': 'cg', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.linear_solver_parameters = linear_solver_parameters + + if nonlinear_solver_parameters is None: + self.nonlinear_solver_parameters = {'snes_type': 'newtonls', + 'ksp_type': 'gmres', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.nonlinear_solver_parameters = nonlinear_solver_parameters + + # Flag to check wheter initial guess is generated using base time discretisation + # (i.e. Forward Euler) + if (initial_guess == "base"): + self.base_flag = True + else: + self.base_flag = False + + def setup(self, equation, apply_bcs=True, *active_labels): + """ + Set up the SDC time discretisation based on the equation.n + + Args: + equation (:class:`PrognosticEquation`): the model's equation. + apply_bcs (bool, optional): whether to apply the equation's boundary + conditions. Defaults to True. + *active_labels (:class:`Label`): labels indicating which terms of + the equation to include. + """ + # Inherit from base time discretisation + self.base.setup(equation, apply_bcs, *active_labels) + self.equation = self.base.equation + self.residual = self.base.residual + self.evaluate_source = self.base.evaluate_source + + for t in self.residual: + # Check all terms are labeled implicit or explicit + if ((not t.has_label(implicit)) and (not t.has_label(explicit)) + and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): + raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") + + # Set up bcs + self.bcs = self.base.bcs + + # Set up SDC variables + if self.field_name is not None and hasattr(equation, "field_names"): + self.idx = equation.field_names.index(self.field_name) + W = equation.spaces[self.idx] + else: + self.field_name = equation.field_name + W = equation.function_space + self.idx = None + self.W = W + self.Unodes = [Function(W) for _ in range(self.M+1)] + self.Unodes1 = [Function(W) for _ in range(self.M+1)] + self.fUnodes = [Function(W) for _ in range(self.M)] + self.quad = [Function(W) for _ in range(self.M)] + self.source_Uk = [Function(W) for _ in range(self.M+1)] + self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] + self.U_DC = Function(W) + self.U_start = Function(W) + self.Un = Function(W) + self.Q_ = Function(W) + self.quad_final = Function(W) + self.U_fin = Function(W) + self.Urhs = Function(W) + self.Uin = Function(W) + self.source_in = Function(W) + + @property + def nlevels(self): + return 1 + + def compute_quad(self): + """ + Computes integration of F(y) on quadrature nodes + """ + for j in range(self.M): + self.quad[j].assign(0.) + for k in range(self.M): + self.quad[j] += float(self.Q[j, k])*self.fUnodes[k] + + def compute_quad_final(self): + """ + Computes final integration of F(y) on quadrature nodes + """ + self.quad_final.assign(0.) + for k in range(self.M): + self.quad_final += float(self.Qfin[k])*self.fUnodes[k] + + @property + def res_rhs(self): + """Set up the residual for the calculation of F(y).""" + a = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Urhs, old_idx=self.idx), + drop) + # F(y) + L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), + drop, + replace_subject(self.Uin, old_idx=self.idx)) + L_source = self.residual.label_map(lambda t: t.has_label(source_label), + replace_subject(self.source_in, old_idx=self.idx), + drop) + residual_rhs = a - (L + L_source) + return residual_rhs.form + + @property + def res_fin(self): + """Set up the residual for final solve.""" + # y_(n+1) + a = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.U_fin, old_idx=self.idx), + drop) + # y_n + F_exp = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Un, old_idx=self.idx), + drop) + F_exp = F_exp.label_map(lambda t: t.has_label(time_derivative), + lambda t: -1*t) + + # sum(j=1,M) q_j*F(y_j) + Q = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.quad_final, old_idx=self.idx), + drop) + + residual_final = a + F_exp + Q + return residual_final.form + + def res(self, m): + """Set up the discretisation's residual for a given node m.""" + # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation + # and y^(k)_m for N2N formulation + mass_form = self.residual.label_map( + lambda t: t.has_label(time_derivative), + map_if_false=drop) + residual = mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_DC, old_idx=self.idx)) + residual -= mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_start, old_idx=self.idx)) + # Loop through nodes up to m-1 and calcualte + # sum(j=1,m-1) Qdelta_imp[m,j]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + for i in range(m): + r_imp_kp1 = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Unodes1[i+1], old_idx=self.idx), + map_if_false=drop) + r_imp_kp1 = r_imp_kp1.label_map( + all_terms, + lambda t: Constant(self.Qdelta_imp[m, i])*t) + residual += r_imp_kp1 + r_imp_k = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Unodes[i+1], old_idx=self.idx), + map_if_false=drop) + r_imp_k = r_imp_k.label_map( + all_terms, + lambda t: Constant(self.Qdelta_imp[m, i])*t) + residual -= r_imp_k + # Loop through nodes up to m-1 and calcualte + # sum(j=1,M) Q_delta_exp[m,j]*(S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + for i in range(self.M): + r_exp_kp1 = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Unodes1[i+1], old_idx=self.idx), + map_if_false=drop) + r_exp_kp1 = r_exp_kp1.label_map( + all_terms, + lambda t: Constant(self.Qdelta_exp[m, i])*t) + + residual += r_exp_kp1 + r_exp_k = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Unodes[i+1], old_idx=self.idx), + map_if_false=drop) + r_exp_k = r_exp_k.label_map( + all_terms, + lambda t: Constant(self.Qdelta_exp[m, i])*t) + residual -= r_exp_k + + # Calculate source terms + r_source_kp1 = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Ukp1[i+1], old_idx=self.idx), + map_if_false=drop) + r_source_kp1 = r_source_kp1.label_map( + all_terms, + lambda t: Constant(self.Qdelta_exp[m, i])*t) + residual += r_source_kp1 + + r_source_k = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Uk[i+1], old_idx=self.idx), + map_if_false=drop) + r_source_k = r_source_k.label_map( + all_terms, + map_if_true=lambda t: Constant(self.Qdelta_exp[m, i])*t) + residual -= r_source_k + + # Add on final implicit terms + # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + r_imp_kp1 = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.U_DC, old_idx=self.idx), + map_if_false=drop) + r_imp_kp1 = r_imp_kp1.label_map( + all_terms, + lambda t: Constant(self.Qdelta_imp[m, m])*t) + residual += r_imp_kp1 + r_imp_k = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Unodes[m+1], old_idx=self.idx), + map_if_false=drop) + r_imp_k = r_imp_k.label_map( + all_terms, + lambda t: Constant(self.Qdelta_imp[m, m])*t) + residual -= r_imp_k + + # Add on error term. sum(j=1,M) q_mj*F(y_m^k) for Z2N formulation + # and sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j + # and s1j = q1j. + Q = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Q_, old_idx=self.idx), + drop) + residual += Q + return residual.form + + @cached_property + def solvers(self): + """Set up a list of solvers for each problem at a node m.""" + solvers = [] + for m in range(self.M): + # setup solver using residual defined in derived class + problem = NonlinearVariationalProblem(self.res(m), self.U_DC, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__ + "%s" % (m) + solvers.append(NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name)) + return solvers + + @cached_property + def solver_fin(self): + """Set up the problem and the solver for final update.""" + # setup linear solver using final residual defined in derived class + prob_fin = NonlinearVariationalProblem(self.res_fin, self.U_fin, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__+"_final" + return NonlinearVariationalSolver(prob_fin, solver_parameters=self.linear_solver_parameters, + options_prefix=solver_name) + + @cached_property + def solver_rhs(self): + """Set up the problem and the solver for mass matrix inversion.""" + # setup linear solver using rhs residual defined in derived class + prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__+"_rhs" + return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, + options_prefix=solver_name) + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + self.U_start.assign(self.Un) + solver_list = self.solvers + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + if (self.base_flag): + for m in range(self.M): + self.base.dt = float(self.dtau[m]) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + else: + for m in range(self.M): + self.Unodes[m+1].assign(self.Un) + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + k = 0 + while k < self.maxk: + k += 1 + + if self.qdelta_imp_type == "MIN-SR-FLEX": + # Recompute Implicit Q_delta matrix for each iteration k + self.Qdelta_imp = genQDeltaCoeffs( + self.qdelta_imp_type, + form=self.formulation, + nodes=self.nodes, + Q=self.Q, + nNodes=self.M, + nodeType=self.node_type, + quadType=self.quad_type, + k=k + ) + + # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + # for Z2N: sum(j=1,M) (q_mj*F(y_m^k) + q_mj*S(y_m^k)) + for m in range(1, self.M+1): + self.Uin.assign(self.Unodes[m]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m-1].assign(self.Urhs) + self.compute_quad() + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + for m in range(1, self.M+1): + # Set Q or S matrix + self.Q_.assign(self.quad[m-1]) + + # Set initial guess for solver, and pick correct solver + if (self.formulation == "N2N"): + self.U_start.assign(self.Unodes1[m-1]) + self.solver = solver_list[m-1] + self.U_DC.assign(self.Unodes[m]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + # for Z2N: + # y_m^(k+1) = y^n + sum(j=1,m) Qdelta_imp[m,j]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + # + sum(j=1,M) Q_delta_exp[m,j]*(S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + self.solver.solve() + self.Unodes1[m].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1[m]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m]) + for m in range(1, self.M+1): + self.Unodes[m].assign(self.Unodes1[m]) + self.source_Uk[m].assign(self.source_Ukp1[m]) + + if self.maxk > 0: + # Compute value at dt rather than final quadrature node tau_M + if self.final_update: + for m in range(1, self.M+1): + self.Uin.assign(self.Unodes1[m]) + self.source_in.assign(self.source_Ukp1[m]) + self.solver_rhs.solve() + self.fUnodes[m-1].assign(self.Urhs) + self.compute_quad_final() + # Compute y_(n+1) = y_n + sum(j=1,M) q_j*F(y_j) + + self.U_fin.assign(self.Unodes[-1]) + self.solver_fin.solve() + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.U_fin) + x_out.assign(self.U_fin) + else: + # Take value at final quadrature node dtau_M + x_out.assign(self.Unodes[-1]) + else: + x_out.assign(self.Unodes[-1]) + +class RIDC(object, metaclass=ABCMeta): + """Class for Revisionist Integral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, K, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, + limiter=None, options=None, reduced=True): + """ + Initialise RIDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of subintervals + K (int): Max number of correction interations + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + options (:class:`AdvectionOptions`, optional): an object containing + reduced (bool, optional): whether to use reduced or full stencils for RIDC. + """ + self.base = base_scheme + self.field_name = field_name + self.domain = domain + self.dt_coarse = domain.dt + self.limiter = limiter + self.augmentation = self.base.augmentation + self.wrapper = self.base.wrapper + self.K = K + self.M = M + self.reduced = reduced + self.dt = Constant(float(self.dt_coarse)/(self.M)) + + # Use equidistant nodes + self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) + + if reduced: + self.Q = [] + for l in range(1, self.K+1): + integration_matrix = self.lagrange_integration_matrix(l+1) + integration_matrix = 0.5 * (l) * float(self.dt) * integration_matrix + self.Q.append(integration_matrix) + else: + # Get integration weights + integration_matrix = self.lagrange_integration_matrix(self.K+1) + # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] + self.Q = 0.5 * (self.K) * float(self.dt) * integration_matrix + + # Set default linear and nonlinear solver options if none passed in + if linear_solver_parameters is None: + self.linear_solver_parameters = {'snes_type': 'ksponly', + 'ksp_type': 'cg', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.linear_solver_parameters = linear_solver_parameters + + if nonlinear_solver_parameters is None: + self.nonlinear_solver_parameters = {'snes_type': 'newtonls', + 'ksp_type': 'gmres', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.nonlinear_solver_parameters = nonlinear_solver_parameters + + + def setup(self, equation, apply_bcs=True, *active_labels): + """ + Set up the RIDC time discretisation based on the equation. + + Args: + equation (:class:`PrognosticEquation`): the model's equation. + apply_bcs (bool, optional): whether to apply the equation's boundary + conditions. Defaults to True. + *active_labels (:class:`Label`): labels indicating which terms of + the equation to include. + """ + # Inherit from base time discretisation + self.base.setup(equation, apply_bcs, *active_labels) + self.equation = self.base.equation + self.residual = self.base.residual + self.evaluate_source = self.base.evaluate_source + + for t in self.residual: + # Check all terms are labeled implicit or explicit + if ((not t.has_label(implicit)) and (not t.has_label(explicit)) + and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): + raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") + + # Set up bcs + self.bcs = self.base.bcs + + # Set up RIDC variables + if self.field_name is not None and hasattr(equation, "field_names"): + self.idx = equation.field_names.index(self.field_name) + W = equation.spaces[self.idx] + else: + self.field_name = equation.field_name + W = equation.function_space + self.idx = None + self.W = W + self.Unodes = [Function(W) for _ in range(self.M+1)] + self.Unodes1 = [Function(W) for _ in range(self.M+1)] + self.fUnodes = [Function(W) for _ in range(self.M+1)] + self.quad = [Function(W) for _ in range(self.M+1)] + self.source_Uk = [Function(W) for _ in range(self.M+1)] + self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] + self.U_DC = Function(W) + self.U_start = Function(W) + self.Un = Function(W) + self.Q_ = Function(W) + self.quad_final = Function(W) + self.U_fin = Function(W) + self.Urhs = Function(W) + self.Uin = Function(W) + self.source_in = Function(W) + self.source_Ukp1_m = Function(W) + self.source_Uk_m = Function(W) + self.Uk_mp1 = Function(W) + self.Uk_m = Function(W) + self.Ukp1_m = Function(W) + + @property + def nlevels(self): + return 1 + + def equidistant_nodes(self ,M): + """ + Returns a grid of M equispaced nodes from -1 to 1 + """ + grid = np.linspace(-1., 1., M) + return grid + + def lagrange_polynomial(self, index, nodes): + """ + Returns the coefficients of the Lagrange polynomial l_m with m=index + """ + + M = len(nodes) + + # c is the denominator + c = 1. + for k in range(M): + if k != index: + c *= (nodes[index] - nodes[k]) + + coeffs = np.zeros(M) + coeffs[0] = 1. + m = 0 + + for k in range(M): + if k != index: + m += 1 + d1 = np.zeros(M) + d2 = np.zeros(M) + + d1 = (-1.)*nodes[k] * coeffs + d2[1:m+1] = coeffs[0:m] + + coeffs = d1+d2 + return coeffs / c + + def integrate_polynomial(self, p): + """ + Given a list of coefficients of a polynomial p, + this returns those of the integral of p + """ + integral_coeffs = np.zeros(len(p)+1) + + for n, pn in enumerate(p): + integral_coeffs[n+1] = 1/(n+1) * pn + + return integral_coeffs + + def evaluate(self, p, a, b): + """ + Given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) + """ + value = 0. + for n, pn in enumerate(p): + value += pn * (b**n - a**n) + + return value + + def lagrange_integration_matrix(self, M): + """ + Returns the integration matrix for the Lagrange polynomial of order M + """ + + # Set up equidistant nodes and initialise matrix to zero + nodes = self.equidistant_nodes(M) + L = len(nodes) + int_matrix = np.zeros((L, L)) + + # Fill in matrix values + for index in range(L): + coeff_p = self.lagrange_polynomial(index, nodes) + int_coeff = self.integrate_polynomial(coeff_p) + + for n in range(L-1): + int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) + + return int_matrix + + def compute_quad(self, Q, fUnodes, m): + """ + Computes integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + for k in range(0, np.shape(Q)[1]): + quad += float(Q[m, k])*fUnodes[k] + return quad + + def compute_quad_final(self, Q, fUnodes, m): + """ + Computes final integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + if self.reduced: + l = np.shape(Q)[0] - 1 + else: + l = self.K + for k in range(0, l+1): + quad += float(Q[-1, k])*fUnodes[m - l + k] + return quad + + @property + def res_rhs(self): + """Set up the residual for the calculation of F(y).""" + a = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Urhs, old_idx=self.idx), + drop) + # F(y) + L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), + drop, + replace_subject(self.Uin, old_idx=self.idx)) + L_source = self.residual.label_map(lambda t: t.has_label(source_label), + replace_subject(self.source_in, old_idx=self.idx), + drop) + residual_rhs = a - (L + L_source) + return residual_rhs.form + + @property + def res(self): + """Set up the discretisation's residual.""" + # Add time derivative terms y^(k+1)_m - y_n + mass_form = self.residual.label_map( + lambda t: t.has_label(time_derivative), + map_if_false=drop) + residual = mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_DC, old_idx=self.idx)) + residual -= mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_start, old_idx=self.idx)) + + # Calculate source terms + r_source_kp1 = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_source_kp1 = r_source_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_source_kp1 + + r_source_k = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), + map_if_false=drop) + r_source_k = r_source_k.label_map( + all_terms, + map_if_true=lambda t: Constant(self.dt)*t) + residual -= r_source_k + + # Add on final implicit terms + # dt*(F(y_(m)^(k+1)) - F(y_(m)^k)) + r_imp_kp1 = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.U_DC, old_idx=self.idx), + map_if_false=drop) + r_imp_kp1 = r_imp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_imp_kp1 + r_imp_k = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), + map_if_false=drop) + r_imp_k = r_imp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_imp_k + + r_exp_kp1 = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_exp_kp1 = r_exp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_exp_kp1 + r_exp_k = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), + map_if_false=drop) + r_exp_k = r_exp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_exp_k + + + # Add on sum(j=1,M) s_mj*F(y_m^k), where s_mj = q_mj-q_m-1j + # and s1j = q1j. + Q = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Q_, old_idx=self.idx), + drop) + residual += Q + return residual.form + + @cached_property + def solver(self): + """Set up the problem and the solver for the nonlinear solve.""" + # setup solver using residual defined in derived class + problem = NonlinearVariationalProblem(self.res, self.U_DC, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__ + solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) + return solver + + @cached_property + def solver_rhs(self): + """Set up the problem and the solver for mass matrix inversion.""" + # setup linear solver using rhs residual defined in derived class + prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__+"_rhs" + return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, + options_prefix=solver_name) + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + self.M1 = self.K + + for m in range(self.M): + self.base.dt = float(self.dt) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + for k in range(1, self.K+1): + # Compute: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + for m in range(self.M+1): + self.Uin.assign(self.Unodes[m]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m].assign(self.Urhs) + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + if self.reduced: + self.M1 = k + for m in range(0, self.M1): + # Set integration matrix + if self.reduced: + self.Q_.assign(self.compute_quad(self.Q[k-1], self.fUnodes, m+1)) + else: + self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_DC.assign(self.Unodes[m+1]) + + # Compute: + # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + for m in range(self.M1, self.M): + # Set integration matrix + if self.reduced: + self.Q_.assign(self.compute_quad_final(self.Q[k-1], self.fUnodes, m+1)) + else: + self.Q_.assign(self.compute_quad_final(self.Q, self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_DC.assign(self.Unodes[m+1]) + + # Compute: + # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + + for m in range(self.M+1): + self.Unodes[m].assign(self.Unodes1[m]) + self.source_Uk[m].assign(self.source_Ukp1[m]) + + x_out.assign(self.Unodes[-1]) diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py new file mode 100644 index 000000000..f27844010 --- /dev/null +++ b/gusto/time_discretisation/parallel_dc.py @@ -0,0 +1,365 @@ +u""" +Objects for discretising time derivatives using time-parallel Deferred Correction +Methods. + +This module inherits from the serial SDC and RIDC classes, and implements the +parallelisation of the SDC and RIDC methods using MPI. + +SDC parallelises across the quadrature nodes by using diagonal QDelta matrices, +while RIDC parallelises across the correction iterations by using a reduced stencil +and pipelining. +""" + +from firedrake import ( + Function +) +from gusto.time_discretisation.time_discretisation import wrapper_apply +from qmat import genQDeltaCoeffs +from gusto.time_discretisation.deferred_correction import SDC, RIDC + +__all__ = ["Parallel_RIDC", "Parallel_SDC"] + +class Parallel_RIDC(RIDC): + """Class for Parallel Revisionist Integral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, K, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, + limiter=None, options=None, communicator=None): + """ + Initialise RIDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of subintervals + K (int): Max number of correction interations + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + options (:class:`AdvectionOptions`, optional): an object containing + communicator (MPI communicator, optional): communicator for parallel execution. Defaults to None. + """ + + super(Parallel_RIDC, self).__init__(base_scheme, domain, M, K, field_name, + linear_solver_parameters, nonlinear_solver_parameters, + limiter, options, reduced=True) + self.comm = communicator + + + def setup(self, equation, apply_bcs=True, *active_labels): + """ + Set up the SDC time discretisation based on the equation.n + + Args: + equation (:class:`PrognosticEquation`): the model's equation. + apply_bcs (bool, optional): whether to apply the equation's boundary + conditions. Defaults to True. + *active_labels (:class:`Label`): labels indicating which terms of + the equation to include. + """ + super(Parallel_RIDC, self).setup(equation, apply_bcs, *active_labels) + + self.Uk_mp1 = Function(self.W) + self.Uk_m = Function(self.W) + self.Ukp1_m = Function(self.W) + + @wrapper_apply + def apply(self, x_out, x_in): + + # Time parallelised code + + # Set up varibles on this rank + x_out.assign(x_in) + self.kval = self.comm.ensemble_comm.rank + self.Un.assign(x_in) + self.Unodes[0].assign(self.Un) + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + self.Uin.assign(self.Unodes[0]) + self.solver_rhs.solve() + self.fUnodes[0].assign(self.Urhs) + + # On first communicator, we do the base timestepper + if (self.comm.ensemble_comm.rank == 0): + # Base timestepper + for m in range(self.M): + self.base.dt = float(self.dt) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m+1], self.base.dt, x_out=self.source_Uk[m+1]) + + # Send base guess to k+1 correction + self.comm.send(self.Unodes[m+1], dest=self.kval+1, tag=100+m+1) + else: + for m in range(1, self.kval + 1): + # Recieve and evaluate the stencil of guesses we need to correct + self.comm.recv(self.Unodes[m], source=self.kval-1, tag=100+m) + self.Uin.assign(self.Unodes[m]) + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m].assign(self.Urhs) + for m in range(0, self.kval): + # Set S matrix + self.Q_.assign(self.compute_quad(self.Q[self.kval-1], self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_DC.assign(self.Unodes[m+1]) + + # Compute + # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + # Send our updated value to next communicator + if self.kval < self.K: + self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) + + for m in range(self.kval, self.M): + # Recieve the guess we need to correct and evaluate the rhs + self.comm.recv(self.Unodes[m+1], source=self.kval-1, tag=100+m+1) + self.Uin.assign(self.Unodes[m+1]) + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m+1].assign(self.Urhs) + + # Set S matrix + self.Q_.assign(self.compute_quad_final(self.Q[self.kval-1], self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_DC.assign(self.Unodes[m+1]) + + # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + + # Send our updated value to next communicator + if self.kval < self.K: + self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) + + if (self.kval == self.K): + # Broadcast the final result to all other ranks + x_out.assign(self.Unodes1[-1]) + for i in range(self.K): + # Send the final result to all other ranks + self.comm.send(x_out, dest=i, tag=200) + else: + # Receive the final result from rank K + self.comm.recv(x_out, source=self.K, tag=200) + + +class Parallel_SDC(SDC): + """Class for Spectral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, + field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, + limiter=None, options=None, initial_guess="base", communicator=None): + """ + Initialise SDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of quadrature nodes to compute spectral integration over + maxk (int): Max number of correction interations + quad_type (str): Type of quadrature to be used. Options are + GAUSS, RADAU-LEFT, RADAU-RIGHT and LOBATTO + node_type (str): Node type to be used. Options are + EQUID, LEGENDRE, CHEBY-1, CHEBY-2, CHEBY-3 and CHEBY-4 + qdelta_imp (str): Implicit Qdelta matrix to be used. Options are + BE, LU, TRAP, EXACT, PIC, OPT, WEIRD, MIN-SR-NS, MIN-SR-S + qdelta_exp (str): Explicit Qdelta matrix to be used. Options are + FE, EXACT, PIC + formulation (str, optional): Whether to use node-to-node or zero-to-node + formulation. Options are N2N and Z2N. Defaults to N2N + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + final_update (bool, optional): Whether to compute final update, or just take last + quadrature value. Defaults to True + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + options (:class:`AdvectionOptions`, optional): an object containing + options to either be passed to the spatial discretisation, or + to control the "wrapper" methods, such as Embedded DG or a + recovery method. Defaults to None. + initial_guess (str, optional): Initial guess to be base timestepper, or copy + communicator (MPI communicator, optional): communicator for parallel execution. Defaults to None. + """ + super().__init__(base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, + formulation = "Z2N", field_name=field_name, + linear_solver_parameters=linear_solver_parameters, nonlinear_solver_parameters= nonlinear_solver_parameters, + final_update=final_update, + limiter=limiter, options=options, initial_guess=initial_guess) + self.comm = communicator + + def compute_quad(self): + """ + Computes integration of F(y) on quadrature nodes + """ + x = Function(self.W) + for j in range(self.M): + x.assign(float(self.Q[j, self.comm.ensemble_comm.rank])*self.fUnodes[self.comm.ensemble_comm.rank]) + self.comm.reduce(x, self.quad[j], root=j) + + def compute_quad_final(self): + """ + Computes final integration of F(y) on quadrature nodes + """ + x = Function(self.W) + x.assign(float(self.Qfin[self.comm.ensemble_comm.rank])*self.fUnodes[self.comm.ensemble_comm.rank]) + self.comm.allreduce(x, self.quad_final) + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + self.U_start.assign(self.Un) + solver_list = self.solvers + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + if (self.base_flag): + for m in range(self.M): + self.base.dt = float(self.dtau[m]) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + else: + for m in range(self.M): + self.Unodes[m+1].assign(self.Un) + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + k = 0 + while k < self.maxk: + k += 1 + + if self.qdelta_imp_type == "MIN-SR-FLEX": + # Recompute Implicit Q_delta matrix for each iteration k + self.Qdelta_imp = genQDeltaCoeffs( + self.qdelta_imp_type, + form=self.formulation, + nodes=self.nodes, + Q=self.Q, + nNodes=self.M, + nodeType=self.node_type, + quadType=self.quad_type, + k=k + ) + + # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + # for Z2N: sum(j=1,M) (q_mj*F(y_m^k) + q_mj*S(y_m^k)) + self.Uin.assign(self.Unodes[self.comm.ensemble_comm.rank+1]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[self.comm.ensemble_comm.rank].assign(self.Urhs) + + self.compute_quad() + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + + + # Set Q or S matrix + self.Q_.assign(self.quad[self.comm.ensemble_comm.rank]) + + # Set initial guess for solver, and pick correct solver + self.solver = solver_list[self.comm.ensemble_comm.rank] + self.U_DC.assign(self.Unodes[self.comm.ensemble_comm.rank+1]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + # for Z2N: + # y_m^(k+1) = y^n + sum(j=1,m) Qdelta_imp[m,j]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + # + sum(j=1,M) Q_delta_exp[m,j]*(S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + self.solver.solve() + self.Unodes1[self.comm.ensemble_comm.rank+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[self.comm.ensemble_comm.rank+1], self.base.dt, x_out=self.source_Ukp1[self.comm.ensemble_comm.rank+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[self.comm.ensemble_comm.rank+1]) + + self.Unodes[self.comm.ensemble_comm.rank+1].assign(self.Unodes1[self.comm.ensemble_comm.rank+1]) + self.source_Uk[self.comm.ensemble_comm.rank+1].assign(self.source_Ukp1[self.comm.ensemble_comm.rank+1]) + + if self.maxk > 0: + # Compute value at dt rather than final quadrature node tau_M + if self.final_update: + self.Uin.assign(self.Unodes1[self.comm.ensemble_comm.rank+1]) + self.source_in.assign(self.source_Ukp1[self.comm.ensemble_comm.rank+1]) + self.solver_rhs.solve() + self.fUnodes[self.comm.ensemble_comm.rank].assign(self.Urhs) + self.compute_quad_final() + # Compute y_(n+1) = y_n + sum(j=1,M) q_j*F(y_j) + if self.comm.ensemble_comm.rank == self.M-1: + self.U_fin.assign(self.Unodes[-1]) + self.comm.bcast(self.U_fin, self.M -1) + self.solver_fin.solve() + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.U_fin) + x_out.assign(self.U_fin) + else: + # Take value at final quadrature node dtau_M + if self.comm.ensemble_comm.rank == self.M-1: + x_out.assign(self.Unodes[-1]) + self.comm.bcast(x_out, self.M -1) + else: + x_out.assign(self.Unodes[-1]) diff --git a/gusto/time_discretisation/sdc.py b/gusto/time_discretisation/sdc.py deleted file mode 100644 index 354056879..000000000 --- a/gusto/time_discretisation/sdc.py +++ /dev/null @@ -1,2576 +0,0 @@ -u""" -Objects for discretising time derivatives using Spectral Deferred Correction -Methods. - -SDC objects discretise ∂y/∂t = F(y), for variable y, time t and -operator F. - -Written in Picard integral form this equation is -y(t) = y_n + int[t_n,t] F(y(s)) ds - -Using some quadrature rule, we can evaluate y on a temporal quadrature node as -y_m = y_n + sum[j=1,M] q_mj*F(y_j) -where q_mj can be found by integrating Lagrange polynomials. This is similar to -how Runge-Kutta methods are formed. - -In matrix form this equation is: -(I - dt*Q*F)(y)=y_n - -Computing y by Picard iteration through k we get: -y^(k+1)=y^k + (y_n - (I - dt*Q*F)(y^k)) - -Finally, to get our SDC method we precondition this system, using some approximation -of Q Q_delta: -(I - dt*Q_delta*F)(y^(k+1)) = y_n + dt*(Q - Q_delta)F(y^k) - -The zero-to-node (Z2N) formulation is then: -y_m^(k+1) = y_n + sum(j=1,M) q'_mj*(F(y_j^(k+1)) - F(y_j^k)) - + sum(j=1,M) q_mj*F(y_(m-1)^k) -for entires q_mj in Q and q'_mj in Q_delta. - -Node-wise from previous quadrature node (N2N formulation), the implicit SDC calculation is: -y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k)) - + sum(j=1,M) s_mj*F(y_(m-1)^k) -where s_mj = q_mj - q_(m-1)j for entires q_ik in Q. - - -Key choices in our SDC method are: -- Choice of quadrature node type (e.g. gauss-lobatto) -- Number of quadrature nodes -- Number of iterations - each iteration increases the order of accuracy up to - the order of the underlying quadrature -- Choice of Q_delta (e.g. Forward Euler, Backward Euler, LU-trick) -- How to get initial solution on quadrature nodes -""" - -from abc import ABCMeta -import numpy as np -from firedrake import ( - Function, NonlinearVariationalProblem, NonlinearVariationalSolver, Constant -) -from firedrake.fml import ( - replace_subject, all_terms, drop -) -from firedrake.utils import cached_property -from gusto.time_discretisation.time_discretisation import wrapper_apply -from gusto.core.labels import (time_derivative, implicit, explicit, source_label) - -from qmat import genQCoeffs, genQDeltaCoeffs - -from gusto.core.logging import ( - logger, DEBUG, logging_ksp_monitor_true_residual, - attach_custom_monitor -) - -__all__ = ["SDC", "IDC", "RIDC", "RIDC_Reduced", "Parallel_RIDC", "Parallel_SDC"] - - -class SDC(object, metaclass=ABCMeta): - """Class for Spectral Deferred Correction schemes.""" - - def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, - formulation="N2N", field_name=None, - linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, - limiter=None, options=None, initial_guess="base"): - """ - Initialise SDC object - Args: - base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on - quadrature nodes. - domain (:class:`Domain`): the model's domain object, containing the - mesh and the compatible function spaces. - M (int): Number of quadrature nodes to compute spectral integration over - maxk (int): Max number of correction interations - quad_type (str): Type of quadrature to be used. Options are - GAUSS, RADAU-LEFT, RADAU-RIGHT and LOBATTO - node_type (str): Node type to be used. Options are - EQUID, LEGENDRE, CHEBY-1, CHEBY-2, CHEBY-3 and CHEBY-4 - qdelta_imp (str): Implicit Qdelta matrix to be used. Options are - BE, LU, TRAP, EXACT, PIC, OPT, WEIRD, MIN-SR-NS, MIN-SR-S - qdelta_exp (str): Explicit Qdelta matrix to be used. Options are - FE, EXACT, PIC - formulation (str, optional): Whether to use node-to-node or zero-to-node - formulation. Options are N2N and Z2N. Defaults to N2N - field_name (str, optional): name of the field to be evolved. - Defaults to None. - linear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying linear solver. Defaults to None. - nonlinear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying nonlinear solver. Defaults to None. - final_update (bool, optional): Whether to compute final update, or just take last - quadrature value. Defaults to True - limiter (:class:`Limiter` object, optional): a limiter to apply to - the evolving field to enforce monotonicity. Defaults to None. - options (:class:`AdvectionOptions`, optional): an object containing - options to either be passed to the spatial discretisation, or - to control the "wrapper" methods, such as Embedded DG or a - recovery method. Defaults to None. - initial_guess (str, optional): Initial guess to be base timestepper, or copy - """ - # Check the configuration options - if (not (formulation == "N2N" or formulation == "Z2N")): - raise ValueError('Formulation not implemented') - - # Initialise parameters - self.base = base_scheme - self.base.dt = domain.dt - self.field_name = field_name - self.domain = domain - self.dt_coarse = domain.dt - self.M = M - self.maxk = maxk - self.final_update = final_update - self.formulation = formulation - self.limiter = limiter - self.augmentation = self.base.augmentation - self.wrapper = self.base.wrapper - - # Get quadrature nodes and weights - self.nodes, self.weights, self.Q = genQCoeffs("Collocation", nNodes=M, - nodeType=node_type, - quadType=quad_type, - form=formulation) - - # Rescale to be over [0,dt] rather than [0,1] - self.nodes = float(self.dt_coarse)*self.nodes - - self.dtau = np.diff(np.append(0, self.nodes)) - self.Q = float(self.dt_coarse)*self.Q - self.Qfin = float(self.dt_coarse)*self.weights - self.qdelta_imp_type = qdelta_imp - self.formulation = formulation - self.node_type = node_type - self.quad_type = quad_type - # breakpoint() - # prin - - # Get Q_delta matrices - self.Qdelta_imp = genQDeltaCoeffs(qdelta_imp, form=formulation, - nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) - self.Qdelta_exp = genQDeltaCoeffs(qdelta_exp, form=formulation, - nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) - # Set default linear and nonlinear solver options if none passed in - if linear_solver_parameters is None: - self.linear_solver_parameters = {'snes_type': 'ksponly', - 'ksp_type': 'cg', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.linear_solver_parameters = linear_solver_parameters - - if nonlinear_solver_parameters is None: - self.nonlinear_solver_parameters = {'snes_type': 'newtonls', - 'ksp_type': 'gmres', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.nonlinear_solver_parameters = nonlinear_solver_parameters - - # Flag to check wheter initial guess is generated using base time discretisation - # (i.e. Forward Euler) - if (initial_guess == "base"): - self.base_flag = True - else: - self.base_flag = False - - def setup(self, equation, apply_bcs=True, *active_labels): - """ - Set up the SDC time discretisation based on the equation.n - - Args: - equation (:class:`PrognosticEquation`): the model's equation. - apply_bcs (bool, optional): whether to apply the equation's boundary - conditions. Defaults to True. - *active_labels (:class:`Label`): labels indicating which terms of - the equation to include. - """ - # Inherit from base time discretisation - self.base.setup(equation, apply_bcs, *active_labels) - self.equation = self.base.equation - self.residual = self.base.residual - self.evaluate_source = self.base.evaluate_source - - for t in self.residual: - # Check all terms are labeled implicit or explicit - if ((not t.has_label(implicit)) and (not t.has_label(explicit)) - and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): - raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") - - # Set up bcs - self.bcs = self.base.bcs - - # Set up SDC variables - if self.field_name is not None and hasattr(equation, "field_names"): - self.idx = equation.field_names.index(self.field_name) - W = equation.spaces[self.idx] - else: - self.field_name = equation.field_name - W = equation.function_space - self.idx = None - self.W = W - self.Unodes = [Function(W) for _ in range(self.M+1)] - self.Unodes1 = [Function(W) for _ in range(self.M+1)] - self.fUnodes = [Function(W) for _ in range(self.M)] - self.quad = [Function(W) for _ in range(self.M)] - self.source_Uk = [Function(W) for _ in range(self.M+1)] - self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] - self.U_SDC = Function(W) - self.U_start = Function(W) - self.Un = Function(W) - self.Q_ = Function(W) - self.quad_final = Function(W) - self.U_fin = Function(W) - self.Urhs = Function(W) - self.Uin = Function(W) - self.source_in = Function(W) - - @property - def nlevels(self): - return 1 - - def compute_quad(self): - """ - Computes integration of F(y) on quadrature nodes - """ - for j in range(self.M): - self.quad[j].assign(0.) - for k in range(self.M): - self.quad[j] += float(self.Q[j, k])*self.fUnodes[k] - - def compute_quad_final(self): - """ - Computes final integration of F(y) on quadrature nodes - """ - self.quad_final.assign(0.) - for k in range(self.M): - self.quad_final += float(self.Qfin[k])*self.fUnodes[k] - - @property - def res_rhs(self): - """Set up the residual for the calculation of F(y).""" - a = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Urhs, old_idx=self.idx), - drop) - # F(y) - L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), - drop, - replace_subject(self.Uin, old_idx=self.idx)) - L_source = self.residual.label_map(lambda t: t.has_label(source_label), - replace_subject(self.source_in, old_idx=self.idx), - drop) - residual_rhs = a - (L + L_source) - return residual_rhs.form - - @property - def res_fin(self): - """Set up the residual for final solve.""" - # y_(n+1) - a = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.U_fin, old_idx=self.idx), - drop) - # y_n - F_exp = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Un, old_idx=self.idx), - drop) - F_exp = F_exp.label_map(lambda t: t.has_label(time_derivative), - lambda t: -1*t) - - # sum(j=1,M) q_j*F(y_j) - Q = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.quad_final, old_idx=self.idx), - drop) - - residual_final = a + F_exp + Q - return residual_final.form - - def res(self, m): - """Set up the discretisation's residual for a given node m.""" - # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation - # and y^(k)_m for N2N formulation - mass_form = self.residual.label_map( - lambda t: t.has_label(time_derivative), - map_if_false=drop) - residual = mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) - residual -= mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_start, old_idx=self.idx)) - # Loop through nodes up to m-1 and calcualte - # sum(j=1,m-1) Qdelta_imp[m,j]*(F(y_(m)^(k+1)) - F(y_(m)^k)) - for i in range(m): - r_imp_kp1 = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.Unodes1[i+1], old_idx=self.idx), - map_if_false=drop) - r_imp_kp1 = r_imp_kp1.label_map( - all_terms, - lambda t: Constant(self.Qdelta_imp[m, i])*t) - residual += r_imp_kp1 - r_imp_k = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.Unodes[i+1], old_idx=self.idx), - map_if_false=drop) - r_imp_k = r_imp_k.label_map( - all_terms, - lambda t: Constant(self.Qdelta_imp[m, i])*t) - residual -= r_imp_k - # Loop through nodes up to m-1 and calcualte - # sum(j=1,M) Q_delta_exp[m,j]*(S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - for i in range(self.M): - r_exp_kp1 = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Unodes1[i+1], old_idx=self.idx), - map_if_false=drop) - r_exp_kp1 = r_exp_kp1.label_map( - all_terms, - lambda t: Constant(self.Qdelta_exp[m, i])*t) - - residual += r_exp_kp1 - r_exp_k = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Unodes[i+1], old_idx=self.idx), - map_if_false=drop) - r_exp_k = r_exp_k.label_map( - all_terms, - lambda t: Constant(self.Qdelta_exp[m, i])*t) - residual -= r_exp_k - - # Calculate source terms - r_source_kp1 = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Ukp1[i+1], old_idx=self.idx), - map_if_false=drop) - r_source_kp1 = r_source_kp1.label_map( - all_terms, - lambda t: Constant(self.Qdelta_exp[m, i])*t) - residual += r_source_kp1 - - r_source_k = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Uk[i+1], old_idx=self.idx), - map_if_false=drop) - r_source_k = r_source_k.label_map( - all_terms, - map_if_true=lambda t: Constant(self.Qdelta_exp[m, i])*t) - residual -= r_source_k - - # Add on final implicit terms - # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) - r_imp_kp1 = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), - map_if_false=drop) - r_imp_kp1 = r_imp_kp1.label_map( - all_terms, - lambda t: Constant(self.Qdelta_imp[m, m])*t) - residual += r_imp_kp1 - r_imp_k = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.Unodes[m+1], old_idx=self.idx), - map_if_false=drop) - r_imp_k = r_imp_k.label_map( - all_terms, - lambda t: Constant(self.Qdelta_imp[m, m])*t) - residual -= r_imp_k - - # Add on error term. sum(j=1,M) q_mj*F(y_m^k) for Z2N formulation - # and sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j - # and s1j = q1j. - Q = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Q_, old_idx=self.idx), - drop) - residual += Q - return residual.form - - @cached_property - def solvers(self): - """Set up a list of solvers for each problem at a node m.""" - solvers = [] - for m in range(self.M): - # setup solver using residual defined in derived class - problem = NonlinearVariationalProblem(self.res(m), self.U_SDC, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__ + "%s" % (m) - solvers.append(NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name)) - return solvers - - @cached_property - def solver_fin(self): - """Set up the problem and the solver for final update.""" - # setup linear solver using final residual defined in derived class - prob_fin = NonlinearVariationalProblem(self.res_fin, self.U_fin, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__+"_final" - return NonlinearVariationalSolver(prob_fin, solver_parameters=self.linear_solver_parameters, - options_prefix=solver_name) - - @cached_property - def solver_rhs(self): - """Set up the problem and the solver for mass matrix inversion.""" - # setup linear solver using rhs residual defined in derived class - prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__+"_rhs" - return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, - options_prefix=solver_name) - - @wrapper_apply - def apply(self, x_out, x_in): - self.Un.assign(x_in) - self.U_start.assign(self.Un) - solver_list = self.solvers - - # Compute initial guess on quadrature nodes with low-order - # base timestepper - self.Unodes[0].assign(self.Un) - if (self.base_flag): - for m in range(self.M): - self.base.dt = float(self.dtau[m]) - self.base.apply_cycle(self.Unodes[m+1], self.Unodes[m]) - else: - for m in range(self.M): - self.Unodes[m+1].assign(self.Un) - for m in range(self.M+1): - for evaluate in self.evaluate_source: - evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) - - # Iterate through correction sweeps - k = 0 - while k < self.maxk: - k += 1 - - if self.qdelta_imp_type == "MIN-SR-FLEX": - # Recompute Implicit Q_delta matrix for each iteration k - self.Qdelta_imp = genQDeltaCoeffs( - self.qdelta_imp_type, - form=self.formulation, - nodes=self.nodes, - Q=self.Q, - nNodes=self.M, - nodeType=self.node_type, - quadType=self.quad_type, - k=k - ) - - # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) - # for Z2N: sum(j=1,M) (q_mj*F(y_m^k) + q_mj*S(y_m^k)) - for m in range(1, self.M+1): - self.Uin.assign(self.Unodes[m]) - # Include source terms - for evaluate in self.evaluate_source: - evaluate(self.Uin, self.base.dt, x_out=self.source_in) - self.solver_rhs.solve() - self.fUnodes[m-1].assign(self.Urhs) - self.compute_quad() - - # Loop through quadrature nodes and solve - self.Unodes1[0].assign(self.Unodes[0]) - for evaluate in self.evaluate_source: - evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - for m in range(1, self.M+1): - # Set Q or S matrix - self.Q_.assign(self.quad[m-1]) - - # Set initial guess for solver, and pick correct solver - if (self.formulation == "N2N"): - self.U_start.assign(self.Unodes1[m-1]) - self.solver = solver_list[m-1] - self.U_SDC.assign(self.Unodes[m]) - - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - # for Z2N: - # y_m^(k+1) = y^n + sum(j=1,m) Qdelta_imp[m,j]*(F(y_(m)^(k+1)) - F(y_(m)^k)) - # + sum(j=1,M) Q_delta_exp[m,j]*(S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - self.solver.solve() - self.Unodes1[m].assign(self.U_SDC) - - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1[m]) - - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m]) - for m in range(1, self.M+1): - self.Unodes[m].assign(self.Unodes1[m]) - self.source_Uk[m].assign(self.source_Ukp1[m]) - - if self.maxk > 0: - # Compute value at dt rather than final quadrature node tau_M - if self.final_update: - for m in range(1, self.M+1): - self.Uin.assign(self.Unodes1[m]) - self.source_in.assign(self.source_Ukp1[m]) - self.solver_rhs.solve() - self.fUnodes[m-1].assign(self.Urhs) - self.compute_quad_final() - # Compute y_(n+1) = y_n + sum(j=1,M) q_j*F(y_j) - - self.U_fin.assign(self.Unodes[-1]) - self.solver_fin.solve() - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.U_fin) - x_out.assign(self.U_fin) - else: - # Take value at final quadrature node dtau_M - x_out.assign(self.Unodes[-1]) - else: - x_out.assign(self.Unodes[-1]) - -class IDC(object, metaclass=ABCMeta): - """Class for Integral Deferred Correction schemes.""" - - def __init__(self, base_scheme, domain, M, K, field_name=None, - linear_solver_parameters=None, nonlinear_solver_parameters=None, - limiter=None, options=None): - """ - Initialise IDC object - Args: - base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on - quadrature nodes. - domain (:class:`Domain`): the model's domain object, containing the - mesh and the compatible function spaces. - M (int): Number of subintervals - K (int): Max number of correction interations - field_name (str, optional): name of the field to be evolved. - Defaults to None. - linear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying linear solver. Defaults to None. - nonlinear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying nonlinear solver. Defaults to None. - final_update (bool, optional): Whether to compute final update, or just take last - quadrature value. Defaults to True - limiter (:class:`Limiter` object, optional): a limiter to apply to - the evolving field to enforce monotonicity. Defaults to None. - options (:class:`AdvectionOptions`, optional): an object containing - - initial_guess (str, optional): Initial guess to be base timestepper, or copy - """ - self.base = base_scheme - self.field_name = field_name - self.domain = domain - self.dt_coarse = domain.dt - self.limiter = limiter - self.augmentation = self.base.augmentation - self.wrapper = self.base.wrapper - self.K = K - self.M = M - self.dt = Constant(float(self.dt_coarse)/(self.M)) - print("dt", float(self.dt)) - self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) - integration_matrix = self.lagrange_integration_matrix(self.M+1) - - # Rescale nodes from [0, 1] to [0, dt] - self.Q = float(self.dt_coarse) * integration_matrix - - # Set default linear and nonlinear solver options if none passed in - if linear_solver_parameters is None: - self.linear_solver_parameters = {'snes_type': 'ksponly', - 'ksp_type': 'cg', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.linear_solver_parameters = linear_solver_parameters - - if nonlinear_solver_parameters is None: - self.nonlinear_solver_parameters = {'snes_type': 'newtonls', - 'ksp_type': 'gmres', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.nonlinear_solver_parameters = nonlinear_solver_parameters - - - def setup(self, equation, apply_bcs=True, *active_labels): - """ - Set up the SDC time discretisation based on the equation.n - - Args: - equation (:class:`PrognosticEquation`): the model's equation. - apply_bcs (bool, optional): whether to apply the equation's boundary - conditions. Defaults to True. - *active_labels (:class:`Label`): labels indicating which terms of - the equation to include. - """ - # Inherit from base time discretisation - self.base.setup(equation, apply_bcs, *active_labels) - self.equation = self.base.equation - self.residual = self.base.residual - self.evaluate_source = self.base.evaluate_source - - for t in self.residual: - # Check all terms are labeled implicit or explicit - if ((not t.has_label(implicit)) and (not t.has_label(explicit)) - and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): - raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") - - # Set up bcs - self.bcs = self.base.bcs - - # Set up SDC variables - if self.field_name is not None and hasattr(equation, "field_names"): - self.idx = equation.field_names.index(self.field_name) - W = equation.spaces[self.idx] - else: - self.field_name = equation.field_name - W = equation.function_space - self.idx = None - self.W = W - self.Unodes = [Function(W) for _ in range(self.M+1)] - self.Unodes1 = [Function(W) for _ in range(self.M+1)] - self.fUnodes = [Function(W) for _ in range(self.M+1)] - self.quad = [Function(W) for _ in range(self.M+1)] - self.source_Uk = [Function(W) for _ in range(self.M+1)] - self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] - self.U_SDC = Function(W) - self.U_start = Function(W) - self.Un = Function(W) - self.Q_ = Function(W) - self.quad_final = Function(W) - self.U_fin = Function(W) - self.Urhs = Function(W) - self.Uin = Function(W) - self.source_in = Function(W) - self.source_Ukp1_m = Function(W) - self.source_Uk_m = Function(W) - self.Uk_mp1 = Function(W) - self.Uk_m = Function(W) - self.Ukp1_m = Function(W) - - @property - def nlevels(self): - return 1 - - def equidistant_nodes(self ,M): - # This returns a grid of M equispaced nodes from -1 to 1 - grid = np.linspace(-1., 1., M) - return grid - - def lagrange_polynomial(self, index, nodes): - # This returns the coefficients of the Lagrange polynomial l_m with m=index - - M = len(nodes) - - # c is the denominator - c = 1. - for k in range(M): - if k != index: - c *= (nodes[index] - nodes[k]) - - coeffs = np.zeros(M) - coeffs[0] = 1. - m = 0 - - for k in range(M): - if k != index: - m += 1 - d1 = np.zeros(M) - d2 = np.zeros(M) - - d1 = (-1.)*nodes[k] * coeffs - d2[1:m+1] = coeffs[0:m] - - coeffs = d1+d2 - return coeffs / c - - def integrate_polynomial(self, p): - # given a list of coefficients of a polynomial p, this returns those of the integral of p - integral_coeffs = np.zeros(len(p)+1) - - for n, pn in enumerate(p): - integral_coeffs[n+1] = 1/(n+1) * pn - - return integral_coeffs - - def evaluate(self, p, a, b): - # given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) - value = 0. - for n, pn in enumerate(p): - value += pn * (b**n - a**n) - - return value - - def lagrange_integration_matrix(self, M): - # using the functions defined above, this returns the MxM integration matrix - - # set up equidistant nodes and initialise matrix to zero - nodes = self.equidistant_nodes(M) - nodes = 0.5 * (nodes + 1) - L = len(nodes) - int_matrix = np.zeros((L, L)) - - # fill in matrix values - for index in range(L): - coeff_p = self.lagrange_polynomial(index, nodes) - int_coeff = self.integrate_polynomial(coeff_p) - - for n in range(L-1): - int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) - - return int_matrix - - def compute_quad(self, Q, fUnodes, m): - """ - Computes integration of F(y) on quadrature nodes - """ - quad = Function(self.W) - quad.assign(0.) - for k in range(0, np.shape(Q)[0]): - quad += float(Q[m, k])*fUnodes[k] - return quad - - @property - def res_rhs(self): - """Set up the residual for the calculation of F(y).""" - a = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Urhs, old_idx=self.idx), - drop) - # F(y) - L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), - drop, - replace_subject(self.Uin, old_idx=self.idx)) - L_source = self.residual.label_map(lambda t: t.has_label(source_label), - replace_subject(self.source_in, old_idx=self.idx), - drop) - residual_rhs = a - (L + L_source) - return residual_rhs.form - - @property - def res(self): - """Set up the discretisation's residual for a given node m.""" - # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation - # and y^(k)_m for N2N formulation - mass_form = self.residual.label_map( - lambda t: t.has_label(time_derivative), - map_if_false=drop) - residual = mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) - residual -= mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_start, old_idx=self.idx)) - - # Calculate source terms - r_source_kp1 = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_source_kp1 = r_source_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_source_kp1 - - r_source_k = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), - map_if_false=drop) - r_source_k = r_source_k.label_map( - all_terms, - map_if_true=lambda t: Constant(self.dt)*t) - residual -= r_source_k - - # Add on final implicit terms - # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) - r_imp_kp1 = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), - map_if_false=drop) - r_imp_kp1 = r_imp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_imp_kp1 - r_imp_k = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), - map_if_false=drop) - r_imp_k = r_imp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_imp_k - - r_exp_kp1 = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_exp_kp1 = r_exp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_exp_kp1 - r_exp_k = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), - map_if_false=drop) - r_exp_k = r_exp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_exp_k - - - # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j - # and s1j = q1j. - Q = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Q_, old_idx=self.idx), - drop) - residual += Q - return residual.form - - @cached_property - def solver(self): - """Set up a list of solvers for each problem at a node m.""" - # setup solver using residual defined in derived class - problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__ - solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) - return solver - - @cached_property - def solver_rhs(self): - """Set up the problem and the solver for mass matrix inversion.""" - # setup linear solver using rhs residual defined in derived class - prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__+"_rhs" - return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, - options_prefix=solver_name) - - @wrapper_apply - def apply(self, x_out, x_in): - self.Un.assign(x_in) - - # Compute initial guess on quadrature nodes with low-order - # base timestepper - self.Unodes[0].assign(self.Un) - - for m in range(self.M): - self.base.dt = float(self.dt) - self.base.apply(self.Unodes[m+1], self.Unodes[m]) - #breakpoint() - for m in range(self.M+1): - for evaluate in self.evaluate_source: - evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) - - # Iterate through correction sweeps - for k in range(1, self.K+1): - # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) - for m in range(self.M+1): - self.Uin.assign(self.Unodes[m]) - # Include source terms - for evaluate in self.evaluate_source: - evaluate(self.Uin, self.base.dt, x_out=self.source_in) - self.solver_rhs.solve() - self.fUnodes[m].assign(self.Urhs) - - # Loop through quadrature nodes and solve - self.Unodes1[0].assign(self.Unodes[0]) - for evaluate in self.evaluate_source: - evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - for m in range(0, self.M): - # Set S matrix - self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) - - # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m]) - self.Ukp1_m.assign(self.Unodes1[m]) - self.Uk_mp1.assign(self.Unodes[m+1]) - self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) - self.U_SDC.assign(self.Unodes[m+1]) - - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - self.solver.solve() - self.Unodes1[m+1].assign(self.U_SDC) - - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) - - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m+1]) - - for m in range(self.M+1): - self.Unodes[m].assign(self.Unodes1[m]) - self.source_Uk[m].assign(self.source_Ukp1[m]) - - x_out.assign(self.Unodes[-1]) - -class RIDC(object, metaclass=ABCMeta): - """Class for Integral Deferred Correction schemes.""" - - def __init__(self, base_scheme, domain, M, K, field_name=None, - linear_solver_parameters=None, nonlinear_solver_parameters=None, - limiter=None, options=None): - """ - Initialise IDC object - Args: - base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on - quadrature nodes. - domain (:class:`Domain`): the model's domain object, containing the - mesh and the compatible function spaces. - M (int): Number of subintervals - K (int): Max number of correction interations - field_name (str, optional): name of the field to be evolved. - Defaults to None. - linear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying linear solver. Defaults to None. - nonlinear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying nonlinear solver. Defaults to None. - final_update (bool, optional): Whether to compute final update, or just take last - quadrature value. Defaults to True - limiter (:class:`Limiter` object, optional): a limiter to apply to - the evolving field to enforce monotonicity. Defaults to None. - options (:class:`AdvectionOptions`, optional): an object containing - - initial_guess (str, optional): Initial guess to be base timestepper, or copy - """ - self.base = base_scheme - self.field_name = field_name - self.domain = domain - self.dt_coarse = domain.dt - self.limiter = limiter - self.augmentation = self.base.augmentation - self.wrapper = self.base.wrapper - self.K = K - self.M = M - self.dt = Constant(float(self.dt_coarse)/(self.M)) - self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) - integration_matrix = self.lagrange_integration_matrix(self.K+1) - # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] - self.Q = 0.5 * (self.K) * float(self.dt) * integration_matrix - - # integration_matrix = self.lagrange_integration_matrix(self.K) - # # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] - # self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix - - - #self.Q = float(self.dt_coarse) * integration_matrix - - # Set default linear and nonlinear solver options if none passed in - if linear_solver_parameters is None: - self.linear_solver_parameters = {'snes_type': 'ksponly', - 'ksp_type': 'cg', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.linear_solver_parameters = linear_solver_parameters - - if nonlinear_solver_parameters is None: - self.nonlinear_solver_parameters = {'snes_type': 'newtonls', - 'ksp_type': 'gmres', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.nonlinear_solver_parameters = nonlinear_solver_parameters - - - def setup(self, equation, apply_bcs=True, *active_labels): - """ - Set up the SDC time discretisation based on the equation.n - - Args: - equation (:class:`PrognosticEquation`): the model's equation. - apply_bcs (bool, optional): whether to apply the equation's boundary - conditions. Defaults to True. - *active_labels (:class:`Label`): labels indicating which terms of - the equation to include. - """ - # Inherit from base time discretisation - self.base.setup(equation, apply_bcs, *active_labels) - self.equation = self.base.equation - self.residual = self.base.residual - self.evaluate_source = self.base.evaluate_source - - for t in self.residual: - # Check all terms are labeled implicit or explicit - if ((not t.has_label(implicit)) and (not t.has_label(explicit)) - and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): - raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") - - # Set up bcs - self.bcs = self.base.bcs - - # Set up SDC variables - if self.field_name is not None and hasattr(equation, "field_names"): - self.idx = equation.field_names.index(self.field_name) - W = equation.spaces[self.idx] - else: - self.field_name = equation.field_name - W = equation.function_space - self.idx = None - self.W = W - self.Unodes = [Function(W) for _ in range(self.M+1)] - self.Unodes1 = [Function(W) for _ in range(self.M+1)] - self.fUnodes = [Function(W) for _ in range(self.M+1)] - self.quad = [Function(W) for _ in range(self.M+1)] - self.source_Uk = [Function(W) for _ in range(self.M+1)] - self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] - self.U_SDC = Function(W) - self.U_start = Function(W) - self.Un = Function(W) - self.Q_ = Function(W) - self.quad_final = Function(W) - self.U_fin = Function(W) - self.Urhs = Function(W) - self.Uin = Function(W) - self.source_in = Function(W) - self.source_Ukp1_m = Function(W) - self.source_Uk_m = Function(W) - self.Uk_mp1 = Function(W) - self.Uk_m = Function(W) - self.Ukp1_m = Function(W) - - @property - def nlevels(self): - return 1 - - def equidistant_nodes(self ,M): - # This returns a grid of M equispaced nodes from -1 to 1 - grid = np.linspace(-1., 1., M) - #grid = 0.5 * (grid + 1) - return grid - - def lagrange_polynomial(self, index, nodes): - # This returns the coefficients of the Lagrange polynomial l_m with m=index - - M = len(nodes) - - # c is the denominator - c = 1. - for k in range(M): - if k != index: - c *= (nodes[index] - nodes[k]) - - coeffs = np.zeros(M) - coeffs[0] = 1. - m = 0 - - for k in range(M): - if k != index: - m += 1 - d1 = np.zeros(M) - d2 = np.zeros(M) - - d1 = (-1.)*nodes[k] * coeffs - d2[1:m+1] = coeffs[0:m] - - coeffs = d1+d2 - return coeffs / c - - def integrate_polynomial(self, p): - # given a list of coefficients of a polynomial p, this returns those of the integral of p - integral_coeffs = np.zeros(len(p)+1) - - for n, pn in enumerate(p): - integral_coeffs[n+1] = 1/(n+1) * pn - - return integral_coeffs - - def evaluate(self, p, a, b): - # given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) - value = 0. - for n, pn in enumerate(p): - value += pn * (b**n - a**n) - - return value - - def lagrange_integration_matrix(self, M): - # using the functions defined above, this returns the MxM integration matrix - - # set up equidistant nodes and initialise matrix to zero - nodes = self.equidistant_nodes(M) - L = len(nodes) - int_matrix = np.zeros((L, L)) - - # fill in matrix values - for index in range(L): - coeff_p = self.lagrange_polynomial(index, nodes) - int_coeff = self.integrate_polynomial(coeff_p) - - for n in range(L-1): - int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) - - return int_matrix - - def compute_quad(self, Q, fUnodes, m): - """ - Computes integration of F(y) on quadrature nodes - """ - quad = Function(self.W) - quad.assign(0.) - for k in range(0, self.K+1): - quad += float(Q[m, k])*fUnodes[k] - return quad - - def compute_quad_final(self, Q, fUnodes, m): - """ - Computes final integration of F(y) on quadrature nodes - """ - quad = Function(self.W) - quad.assign(0.) - for k in range(0, self.K+1): - quad += float(Q[-1, k])*fUnodes[m - self.K + k] - return quad - - @property - def res_rhs(self): - """Set up the residual for the calculation of F(y).""" - a = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Urhs, old_idx=self.idx), - drop) - # F(y) - L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), - drop, - replace_subject(self.Uin, old_idx=self.idx)) - L_source = self.residual.label_map(lambda t: t.has_label(source_label), - replace_subject(self.source_in, old_idx=self.idx), - drop) - residual_rhs = a - (L + L_source) - return residual_rhs.form - - @property - def res(self): - """Set up the discretisation's residual for a given node m.""" - # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation - # and y^(k)_m for N2N formulation - mass_form = self.residual.label_map( - lambda t: t.has_label(time_derivative), - map_if_false=drop) - residual = mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) - residual -= mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_start, old_idx=self.idx)) - - # Calculate source terms - r_source_kp1 = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_source_kp1 = r_source_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_source_kp1 - - r_source_k = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), - map_if_false=drop) - r_source_k = r_source_k.label_map( - all_terms, - map_if_true=lambda t: Constant(self.dt)*t) - residual -= r_source_k - - # Add on final implicit terms - # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) - r_imp_kp1 = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), - map_if_false=drop) - r_imp_kp1 = r_imp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_imp_kp1 - r_imp_k = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), - map_if_false=drop) - r_imp_k = r_imp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_imp_k - - r_exp_kp1 = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_exp_kp1 = r_exp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_exp_kp1 - r_exp_k = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), - map_if_false=drop) - r_exp_k = r_exp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_exp_k - - - # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j - # and s1j = q1j. - Q = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Q_, old_idx=self.idx), - drop) - residual += Q - return residual.form - - @cached_property - def solver(self): - """Set up a list of solvers for each problem at a node m.""" - # setup solver using residual defined in derived class - problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__ - solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) - return solver - - @cached_property - def solver_rhs(self): - """Set up the problem and the solver for mass matrix inversion.""" - # setup linear solver using rhs residual defined in derived class - prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__+"_rhs" - return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, - options_prefix=solver_name) - - @wrapper_apply - def apply(self, x_out, x_in): - self.Un.assign(x_in) - - # Compute initial guess on quadrature nodes with low-order - # base timestepper - self.Unodes[0].assign(self.Un) - - for m in range(self.M): - self.base.dt = float(self.dt) - self.base.apply(self.Unodes[m+1], self.Unodes[m]) - #breakpoint() - for m in range(self.M+1): - for evaluate in self.evaluate_source: - evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) - - # Iterate through correction sweeps - for k in range(1, self.K+1): - # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) - for m in range(self.M+1): - self.Uin.assign(self.Unodes[m]) - # Include source terms - for evaluate in self.evaluate_source: - evaluate(self.Uin, self.base.dt, x_out=self.source_in) - self.solver_rhs.solve() - self.fUnodes[m].assign(self.Urhs) - - # Loop through quadrature nodes and solve - self.Unodes1[0].assign(self.Unodes[0]) - for evaluate in self.evaluate_source: - evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - for m in range(0, self.K): - # Set S matrix - self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) - - # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m]) - self.Ukp1_m.assign(self.Unodes1[m]) - self.Uk_mp1.assign(self.Unodes[m+1]) - self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) - self.U_SDC.assign(self.Unodes[m+1]) - - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - self.solver.solve() - self.Unodes1[m+1].assign(self.U_SDC) - - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) - - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m+1]) - for m in range(self.K, self.M): - # Set S matrix - self.Q_.assign(self.compute_quad_final(self.Q, self.fUnodes, m+1)) - - # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m]) - self.Ukp1_m.assign(self.Unodes1[m]) - self.Uk_mp1.assign(self.Unodes[m+1]) - self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) - self.U_SDC.assign(self.Unodes[m+1]) - - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - self.solver.solve() - self.Unodes1[m+1].assign(self.U_SDC) - - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) - - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m+1]) - - for m in range(self.M+1): - self.Unodes[m].assign(self.Unodes1[m]) - self.source_Uk[m].assign(self.source_Ukp1[m]) - - x_out.assign(self.Unodes[-1]) - -class RIDC_Reduced(object, metaclass=ABCMeta): - """Class for Integral Deferred Correction schemes.""" - - def __init__(self, base_scheme, domain, M, K, field_name=None, - linear_solver_parameters=None, nonlinear_solver_parameters=None, - limiter=None, options=None): - """ - Initialise IDC object - Args: - base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on - quadrature nodes. - domain (:class:`Domain`): the model's domain object, containing the - mesh and the compatible function spaces. - M (int): Number of subintervals - K (int): Max number of correction interations - field_name (str, optional): name of the field to be evolved. - Defaults to None. - linear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying linear solver. Defaults to None. - nonlinear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying nonlinear solver. Defaults to None. - final_update (bool, optional): Whether to compute final update, or just take last - quadrature value. Defaults to True - limiter (:class:`Limiter` object, optional): a limiter to apply to - the evolving field to enforce monotonicity. Defaults to None. - options (:class:`AdvectionOptions`, optional): an object containing - - initial_guess (str, optional): Initial guess to be base timestepper, or copy - """ - self.base = base_scheme - self.field_name = field_name - self.domain = domain - self.dt_coarse = domain.dt - self.limiter = limiter - self.augmentation = self.base.augmentation - self.wrapper = self.base.wrapper - self.K = K - self.M = M - self.dt = Constant(float(self.dt_coarse)/(self.M)) - self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) - self.Q = [] - for l in range(1, self.K+1): - integration_matrix = self.lagrange_integration_matrix(l+1) - integration_matrix = 0.5 * (l) * float(self.dt) * integration_matrix - self.Q.append(integration_matrix) - # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] - # self.Q = 0.5 * (self.K) * float(self.dt) * integration_matrix - - # integration_matrix = self.lagrange_integration_matrix(self.K) - # # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] - # self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix - - - #self.Q = float(self.dt_coarse) * integration_matrix - - # Set default linear and nonlinear solver options if none passed in - if linear_solver_parameters is None: - self.linear_solver_parameters = {'snes_type': 'ksponly', - 'ksp_type': 'cg', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.linear_solver_parameters = linear_solver_parameters - - if nonlinear_solver_parameters is None: - self.nonlinear_solver_parameters = {'snes_type': 'newtonls', - 'ksp_type': 'gmres', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.nonlinear_solver_parameters = nonlinear_solver_parameters - - - def setup(self, equation, apply_bcs=True, *active_labels): - """ - Set up the SDC time discretisation based on the equation.n - - Args: - equation (:class:`PrognosticEquation`): the model's equation. - apply_bcs (bool, optional): whether to apply the equation's boundary - conditions. Defaults to True. - *active_labels (:class:`Label`): labels indicating which terms of - the equation to include. - """ - # Inherit from base time discretisation - self.base.setup(equation, apply_bcs, *active_labels) - self.equation = self.base.equation - self.residual = self.base.residual - self.evaluate_source = self.base.evaluate_source - - for t in self.residual: - # Check all terms are labeled implicit or explicit - if ((not t.has_label(implicit)) and (not t.has_label(explicit)) - and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): - raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") - - # Set up bcs - self.bcs = self.base.bcs - - # Set up SDC variables - if self.field_name is not None and hasattr(equation, "field_names"): - self.idx = equation.field_names.index(self.field_name) - W = equation.spaces[self.idx] - else: - self.field_name = equation.field_name - W = equation.function_space - self.idx = None - self.W = W - self.Unodes = [Function(W) for _ in range(self.M+1)] - self.Unodes1 = [Function(W) for _ in range(self.M+1)] - self.fUnodes = [Function(W) for _ in range(self.M+1)] - self.quad = [Function(W) for _ in range(self.M+1)] - self.source_Uk = [Function(W) for _ in range(self.M+1)] - self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] - self.U_SDC = Function(W) - self.U_start = Function(W) - self.Un = Function(W) - self.Q_ = Function(W) - self.quad_final = Function(W) - self.U_fin = Function(W) - self.Urhs = Function(W) - self.Uin = Function(W) - self.source_in = Function(W) - self.source_Ukp1_m = Function(W) - self.source_Uk_m = Function(W) - self.Uk_mp1 = Function(W) - self.Uk_m = Function(W) - self.Ukp1_m = Function(W) - - @property - def nlevels(self): - return 1 - - def equidistant_nodes(self ,M): - # This returns a grid of M equispaced nodes from -1 to 1 - grid = np.linspace(-1., 1., M) - #grid = 0.5 * (grid + 1) - return grid - - def lagrange_polynomial(self, index, nodes): - # This returns the coefficients of the Lagrange polynomial l_m with m=index - - M = len(nodes) - - # c is the denominator - c = 1. - for k in range(M): - if k != index: - c *= (nodes[index] - nodes[k]) - - coeffs = np.zeros(M) - coeffs[0] = 1. - m = 0 - - for k in range(M): - if k != index: - m += 1 - d1 = np.zeros(M) - d2 = np.zeros(M) - - d1 = (-1.)*nodes[k] * coeffs - d2[1:m+1] = coeffs[0:m] - - coeffs = d1+d2 - return coeffs / c - - def integrate_polynomial(self, p): - # given a list of coefficients of a polynomial p, this returns those of the integral of p - integral_coeffs = np.zeros(len(p)+1) - - for n, pn in enumerate(p): - integral_coeffs[n+1] = 1/(n+1) * pn - - return integral_coeffs - - def evaluate(self, p, a, b): - # given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) - value = 0. - for n, pn in enumerate(p): - value += pn * (b**n - a**n) - - return value - - def lagrange_integration_matrix(self, M): - # using the functions defined above, this returns the MxM integration matrix - - # set up equidistant nodes and initialise matrix to zero - nodes = self.equidistant_nodes(M) - L = len(nodes) - int_matrix = np.zeros((L, L)) - - # fill in matrix values - for index in range(L): - coeff_p = self.lagrange_polynomial(index, nodes) - int_coeff = self.integrate_polynomial(coeff_p) - - for n in range(L-1): - int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) - - return int_matrix - - def compute_quad(self, Q, fUnodes, m): - """ - Computes integration of F(y) on quadrature nodes - """ - quad = Function(self.W) - quad.assign(0.) - for k in range(0, np.shape(Q)[1]): - quad += float(Q[m, k])*fUnodes[k] - return quad - - def compute_quad_final(self, Q, fUnodes, m): - """ - Computes final integration of F(y) on quadrature nodes - """ - quad = Function(self.W) - quad.assign(0.) - l = np.shape(Q)[0] - 1 - for k in range(0, l+1): - quad += float(Q[-1, k])*fUnodes[m - l + k] - return quad - - @property - def res_rhs(self): - """Set up the residual for the calculation of F(y).""" - a = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Urhs, old_idx=self.idx), - drop) - # F(y) - L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), - drop, - replace_subject(self.Uin, old_idx=self.idx)) - L_source = self.residual.label_map(lambda t: t.has_label(source_label), - replace_subject(self.source_in, old_idx=self.idx), - drop) - residual_rhs = a - (L + L_source) - return residual_rhs.form - - @property - def res(self): - """Set up the discretisation's residual for a given node m.""" - # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation - # and y^(k)_m for N2N formulation - mass_form = self.residual.label_map( - lambda t: t.has_label(time_derivative), - map_if_false=drop) - residual = mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) - residual -= mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_start, old_idx=self.idx)) - - # Calculate source terms - r_source_kp1 = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_source_kp1 = r_source_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_source_kp1 - - r_source_k = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), - map_if_false=drop) - r_source_k = r_source_k.label_map( - all_terms, - map_if_true=lambda t: Constant(self.dt)*t) - residual -= r_source_k - - # Add on final implicit terms - # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) - r_imp_kp1 = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), - map_if_false=drop) - r_imp_kp1 = r_imp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_imp_kp1 - r_imp_k = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), - map_if_false=drop) - r_imp_k = r_imp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_imp_k - - r_exp_kp1 = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_exp_kp1 = r_exp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_exp_kp1 - r_exp_k = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), - map_if_false=drop) - r_exp_k = r_exp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_exp_k - - - # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j - # and s1j = q1j. - Q = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Q_, old_idx=self.idx), - drop) - residual += Q - return residual.form - - @cached_property - def solver(self): - """Set up a list of solvers for each problem at a node m.""" - # setup solver using residual defined in derived class - problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__ - solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) - return solver - - @cached_property - def solver_rhs(self): - """Set up the problem and the solver for mass matrix inversion.""" - # setup linear solver using rhs residual defined in derived class - prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__+"_rhs" - return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, - options_prefix=solver_name) - - @wrapper_apply - def apply(self, x_out, x_in): - self.Un.assign(x_in) - - # Compute initial guess on quadrature nodes with low-order - # base timestepper - self.Unodes[0].assign(self.Un) - - for m in range(self.M): - self.base.dt = float(self.dt) - self.base.apply(self.Unodes[m+1], self.Unodes[m]) - #breakpoint() - for m in range(self.M+1): - for evaluate in self.evaluate_source: - evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) - - # Iterate through correction sweeps - for k in range(1, self.K+1): - # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) - for m in range(self.M+1): - self.Uin.assign(self.Unodes[m]) - # Include source terms - for evaluate in self.evaluate_source: - evaluate(self.Uin, self.base.dt, x_out=self.source_in) - self.solver_rhs.solve() - self.fUnodes[m].assign(self.Urhs) - - # Loop through quadrature nodes and solve - self.Unodes1[0].assign(self.Unodes[0]) - for evaluate in self.evaluate_source: - evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - for m in range(0, k): - # Set S matrix - self.Q_.assign(self.compute_quad(self.Q[k-1], self.fUnodes, m+1)) - - # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m]) - self.Ukp1_m.assign(self.Unodes1[m]) - self.Uk_mp1.assign(self.Unodes[m+1]) - self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) - self.U_SDC.assign(self.Unodes[m+1]) - - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - self.solver.solve() - self.Unodes1[m+1].assign(self.U_SDC) - - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) - - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m+1]) - for m in range(k, self.M): - # Set S matrix - self.Q_.assign(self.compute_quad_final(self.Q[k-1], self.fUnodes, m+1)) - - # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m]) - self.Ukp1_m.assign(self.Unodes1[m]) - self.Uk_mp1.assign(self.Unodes[m+1]) - self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) - self.U_SDC.assign(self.Unodes[m+1]) - - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - self.solver.solve() - self.Unodes1[m+1].assign(self.U_SDC) - - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) - - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m+1]) - - for m in range(self.M+1): - self.Unodes[m].assign(self.Unodes1[m]) - self.source_Uk[m].assign(self.source_Ukp1[m]) - - x_out.assign(self.Unodes[-1]) - -class Parallel_RIDC(object, metaclass=ABCMeta): - """Class for Integral Deferred Correction schemes.""" - - def __init__(self, base_scheme, domain, M, K, field_name=None, - linear_solver_parameters=None, nonlinear_solver_parameters=None, - limiter=None, options=None, communicator=None): - """ - Initialise IDC object - Args: - base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on - quadrature nodes. - domain (:class:`Domain`): the model's domain object, containing the - mesh and the compatible function spaces. - M (int): Number of subintervals - K (int): Max number of correction interations - field_name (str, optional): name of the field to be evolved. - Defaults to None. - linear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying linear solver. Defaults to None. - nonlinear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying nonlinear solver. Defaults to None. - final_update (bool, optional): Whether to compute final update, or just take last - quadrature value. Defaults to True - limiter (:class:`Limiter` object, optional): a limiter to apply to - the evolving field to enforce monotonicity. Defaults to None. - options (:class:`AdvectionOptions`, optional): an object containing - - initial_guess (str, optional): Initial guess to be base timestepper, or copy - communicator (MPI communicator, optional): communicator for parallel execution. Defaults to None. - """ - self.base = base_scheme - self.field_name = field_name - self.domain = domain - self.dt_coarse = domain.dt - self.limiter = limiter - self.augmentation = self.base.augmentation - self.wrapper = self.base.wrapper - self.K = K - self.M = M - self.dt = Constant(float(self.dt_coarse)/(self.M)) - self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) - self.Q = [] - for l in range(1, self.K+1): - integration_matrix = self.lagrange_integration_matrix(l+1) - integration_matrix = 0.5 * (l) * float(self.dt) * integration_matrix - self.Q.append(integration_matrix) - # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] - # self.Q = 0.5 * (self.K) * float(self.dt) * integration_matrix - - # integration_matrix = self.lagrange_integration_matrix(self.K) - # # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] - # self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix - - - #self.Q = float(self.dt_coarse) * integration_matrix - - # Set default linear and nonlinear solver options if none passed in - if linear_solver_parameters is None: - self.linear_solver_parameters = {'snes_type': 'ksponly', - 'ksp_type': 'cg', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.linear_solver_parameters = linear_solver_parameters - - if nonlinear_solver_parameters is None: - self.nonlinear_solver_parameters = {'snes_type': 'newtonls', - 'ksp_type': 'gmres', - 'pc_type': 'bjacobi', - 'sub_pc_type': 'ilu'} - else: - self.nonlinear_solver_parameters = nonlinear_solver_parameters - self.comm = communicator - - - def setup(self, equation, apply_bcs=True, *active_labels): - """ - Set up the SDC time discretisation based on the equation.n - - Args: - equation (:class:`PrognosticEquation`): the model's equation. - apply_bcs (bool, optional): whether to apply the equation's boundary - conditions. Defaults to True. - *active_labels (:class:`Label`): labels indicating which terms of - the equation to include. - """ - # Inherit from base time discretisation - self.base.setup(equation, apply_bcs, *active_labels) - self.equation = self.base.equation - self.residual = self.base.residual - self.evaluate_source = self.base.evaluate_source - - for t in self.residual: - # Check all terms are labeled implicit or explicit - if ((not t.has_label(implicit)) and (not t.has_label(explicit)) - and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): - raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") - - # Set up bcs - self.bcs = self.base.bcs - - # Set up SDC variables - if self.field_name is not None and hasattr(equation, "field_names"): - self.idx = equation.field_names.index(self.field_name) - W = equation.spaces[self.idx] - else: - self.field_name = equation.field_name - W = equation.function_space - self.idx = None - self.W = W - self.Unodes = [Function(W) for _ in range(self.M+1)] - self.Unodes1 = [Function(W) for _ in range(self.M+1)] - self.fUnodes = [Function(W) for _ in range(self.M+1)] - self.quad = [Function(W) for _ in range(self.M+1)] - self.source_Uk = [Function(W) for _ in range(self.M+1)] - self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] - self.U_SDC = Function(W) - self.U_start = Function(W) - self.Un = Function(W) - self.Q_ = Function(W) - self.quad_final = Function(W) - self.U_fin = Function(W) - self.Urhs = Function(W) - self.Uin = Function(W) - self.source_in = Function(W) - self.source_Ukp1_m = Function(W) - self.source_Uk_m = Function(W) - self.Uk_mp1 = Function(W) - self.Uk_m = Function(W) - self.Ukp1_m = Function(W) - self.node_count = 0 - - @property - def nlevels(self): - return 1 - - def equidistant_nodes(self ,M): - # This returns a grid of M equispaced nodes from -1 to 1 - grid = np.linspace(-1., 1., M) - #grid = 0.5 * (grid + 1) - return grid - - def lagrange_polynomial(self, index, nodes): - # This returns the coefficients of the Lagrange polynomial l_m with m=index - - M = len(nodes) - - # c is the denominator - c = 1. - for k in range(M): - if k != index: - c *= (nodes[index] - nodes[k]) - - coeffs = np.zeros(M) - coeffs[0] = 1. - m = 0 - - for k in range(M): - if k != index: - m += 1 - d1 = np.zeros(M) - d2 = np.zeros(M) - - d1 = (-1.)*nodes[k] * coeffs - d2[1:m+1] = coeffs[0:m] - - coeffs = d1+d2 - return coeffs / c - - def integrate_polynomial(self, p): - # given a list of coefficients of a polynomial p, this returns those of the integral of p - integral_coeffs = np.zeros(len(p)+1) - - for n, pn in enumerate(p): - integral_coeffs[n+1] = 1/(n+1) * pn - - return integral_coeffs - - def evaluate(self, p, a, b): - # given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) - value = 0. - for n, pn in enumerate(p): - value += pn * (b**n - a**n) - - return value - - def lagrange_integration_matrix(self, M): - # using the functions defined above, this returns the MxM integration matrix - - # set up equidistant nodes and initialise matrix to zero - nodes = self.equidistant_nodes(M) - L = len(nodes) - int_matrix = np.zeros((L, L)) - - # fill in matrix values - for index in range(L): - coeff_p = self.lagrange_polynomial(index, nodes) - int_coeff = self.integrate_polynomial(coeff_p) - - for n in range(L-1): - int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) - - return int_matrix - - def compute_quad(self, Q, fUnodes, m): - """ - Computes integration of F(y) on quadrature nodes - """ - quad = Function(self.W) - quad.assign(0.) - for k in range(0, np.shape(Q)[1]): - quad += float(Q[m, k])*fUnodes[k] - return quad - - def compute_quad_final(self, Q, fUnodes, m): - """ - Computes final integration of F(y) on quadrature nodes - """ - quad = Function(self.W) - quad.assign(0.) - l = np.shape(Q)[0] - 1 - for k in range(0, l+1): - quad += float(Q[-1, k])*fUnodes[m - l + k] - return quad - - @property - def res_rhs(self): - """Set up the residual for the calculation of F(y).""" - a = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Urhs, old_idx=self.idx), - drop) - # F(y) - L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), - drop, - replace_subject(self.Uin, old_idx=self.idx)) - L_source = self.residual.label_map(lambda t: t.has_label(source_label), - replace_subject(self.source_in, old_idx=self.idx), - drop) - residual_rhs = a - (L + L_source) - return residual_rhs.form - - @property - def res(self): - """Set up the discretisation's residual for a given node m.""" - # Add time derivative terms y^(k+1)_m - y_start for node m. y_start is y_n for Z2N formulation - # and y^(k)_m for N2N formulation - mass_form = self.residual.label_map( - lambda t: t.has_label(time_derivative), - map_if_false=drop) - residual = mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) - residual -= mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_start, old_idx=self.idx)) - - # Calculate source terms - r_source_kp1 = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_source_kp1 = r_source_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_source_kp1 - - r_source_k = self.residual.label_map( - lambda t: t.has_label(source_label), - map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), - map_if_false=drop) - r_source_k = r_source_k.label_map( - all_terms, - map_if_true=lambda t: Constant(self.dt)*t) - residual -= r_source_k - - # Add on final implicit terms - # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) - r_imp_kp1 = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), - map_if_false=drop) - r_imp_kp1 = r_imp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_imp_kp1 - r_imp_k = self.residual.label_map( - lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), - map_if_false=drop) - r_imp_k = r_imp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_imp_k - - r_exp_kp1 = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), - map_if_false=drop) - r_exp_kp1 = r_exp_kp1.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual += r_exp_kp1 - r_exp_k = self.residual.label_map( - lambda t: t.has_label(explicit), - map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), - map_if_false=drop) - r_exp_k = r_exp_k.label_map( - all_terms, - lambda t: Constant(self.dt)*t) - residual -= r_exp_k - - - # Add on sum(j=1,M) s_mj*F(y_m^k) for N2N formulation, where s_mj = q_mj-q_m-1j - # and s1j = q1j. - Q = self.residual.label_map(lambda t: t.has_label(time_derivative), - replace_subject(self.Q_, old_idx=self.idx), - drop) - residual += Q - return residual.form - - @cached_property - def solver(self): - """Set up a list of solvers for each problem at a node m.""" - # setup solver using residual defined in derived class - problem = NonlinearVariationalProblem(self.res, self.U_SDC, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__ - solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) - return solver - - @cached_property - def solver_rhs(self): - """Set up the problem and the solver for mass matrix inversion.""" - # setup linear solver using rhs residual defined in derived class - prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) - solver_name = self.field_name+self.__class__.__name__+"_rhs" - return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, - options_prefix=solver_name) - - @wrapper_apply - def apply(self, x_out, x_in): - - # Parallelised code: - # Correction sweep - x_out.assign(x_in) - self.kval = self.comm.ensemble_comm.rank - #logger.info(f'Communicator: {self.kval:.2e}') - #breakpoint() - self.Un.assign(x_in) - self.Unodes[0].assign(self.Un) - # Loop through quadrature nodes and solve - self.Unodes1[0].assign(self.Unodes[0]) - for evaluate in self.evaluate_source: - evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - self.Uin.assign(self.Unodes[0]) - self.solver_rhs.solve() - self.fUnodes[0].assign(self.Urhs) - # On first communicator - if (self.comm.ensemble_comm.rank == 0): - logger.info(f'Starting base timestepper: {self.kval:.2e}') - # base timestepper - - for m in range(self.M): - self.base.dt = float(self.dt) - logger.info(f'Base stepper: {self.kval:.2e}') - self.base.apply(self.Unodes[m+1], self.Unodes[m]) - logger.info(f'Base stepper done: {self.kval:.2e}') - for evaluate in self.evaluate_source: - evaluate(self.Unodes[m+1], self.base.dt, x_out=self.source_Uk[m+1]) - - # Send base guess to k+1 correction - #breakpoint() - # for i in range(1, self.K): - # self.comm.send(self.node_count, dest=int(i), tag=11) # Send data to all other processes - self.comm.send(self.Unodes[m+1], dest=self.kval+1, tag=100+m+1) - logger.info(f'Sent data to process {self.kval+1:.2e} from process {self.kval:.2e}') - - else: - for m in range(1, self.kval + 1): - self.comm.recv(self.Unodes[m], source=self.kval-1, tag=100+m) - logger.info(f'Recieved data to process {self.kval:.2e} from process {self.kval-1:.2e}') - self.Uin.assign(self.Unodes[m]) - for evaluate in self.evaluate_source: - evaluate(self.Uin, self.base.dt, x_out=self.source_in) - self.solver_rhs.solve() - self.fUnodes[m].assign(self.Urhs) - # if k == 1: - # for m in range(0, self.kval*(self.kval+1)/2): - # self.Unodes[m] = self.comm.recv(source=0, tag=123) - # else: - # for m in range(0, self.kval*(self.kval+1)/2): - # self.Unodes1[m] = self.comm.recv(source=self.kval-1, tag=11) - # self.Unodes[m].assign(self.Unodes1[m]) - for m in range(0, self.kval): - # if (m >= self.kval*(self.kval+1)/2): - # self.Unodes1[m] = self.comm.recv(source=self.kval-1, tag=11) - # self.Unodes[m].assign(self.Unodes1[m]) - # Get f(u) - # if (m > self.kval+2): - # self.Uin.assign(self.Unodes[m]) - # # Include source terms - # for evaluate in self.evaluate_source: - # evaluate(self.Uin, self.base.dt, x_out=self.source_in) - # self.solver_rhs.solve() - # self.fUnodes[m].assign(self.Urhs) - # Set S matrix - self.Q_.assign(self.compute_quad(self.Q[self.kval-1], self.fUnodes, m+1)) - - # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m]) - self.Ukp1_m.assign(self.Unodes1[m]) - self.Uk_mp1.assign(self.Unodes[m+1]) - self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) - self.U_SDC.assign(self.Unodes[m+1]) - - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - self.solver.solve() - self.Unodes1[m+1].assign(self.U_SDC) - - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) - - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m+1]) - # Send our updated value to next communicator - if self.kval < self.K: - self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) - logger.info(f'Sent data to process {self.kval+1:.2e} from process {self.kval:.2e}') - - for m in range(self.kval, self.M): - self.comm.recv(self.Unodes[m+1], source=self.kval-1, tag=100+m+1) - logger.info(f'Recieved data to process {self.kval:.2e} from process {self.kval-1:.2e}') - self.Uin.assign(self.Unodes[m+1]) - for evaluate in self.evaluate_source: - evaluate(self.Uin, self.base.dt, x_out=self.source_in) - self.solver_rhs.solve() - self.fUnodes[m+1].assign(self.Urhs) - - # Set S matrix - self.Q_.assign(self.compute_quad_final(self.Q[self.kval-1], self.fUnodes, m+1)) - - # Set initial guess for solver, and pick correct solver - self.U_start.assign(self.Unodes1[m]) - self.Ukp1_m.assign(self.Unodes1[m]) - self.Uk_mp1.assign(self.Unodes[m+1]) - self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) - self.U_SDC.assign(self.Unodes[m+1]) - - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - self.solver.solve() - self.Unodes1[m+1].assign(self.U_SDC) - - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) - - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[m+1]) - - # Send our updated value to next communicator - if self.kval < self.K: - self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) - logger.info(f'Sent data to process {self.kval+1:.2e} from process {self.kval:.2e}') - - if (self.kval == self.K): - logger.info(f'Broadcasting for: {self.kval:.2e}') - x_out.assign(self.Unodes1[-1]) - for i in range(self.K): - # Send the final result to all other ranks - self.comm.send(x_out, dest=i, tag=200) - else: - - # Receive the final result from Rank K - self.comm.recv(x_out, source=self.K, tag=200) - - -# class RIDC2(IDC): -# """Class for Revisionist Integral Deferred Correction schemes.""" - -# def __init__(self, base_scheme, domain, M, K, field_name=None, -# linear_solver_parameters=None, nonlinear_solver_parameters=None, -# limiter=None, options=None): -# """ -# Initialise IDC object -# Args: -# base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on -# quadrature nodes. -# domain (:class:`Domain`): the model's domain object, containing the -# mesh and the compatible function spaces. -# M (int): Number of quadrature nodes to compute spectral integration over -# K (int): Max number of correction interations -# field_name (str, optional): name of the field to be evolved. -# Defaults to None. -# linear_solver_parameters (dict, optional): dictionary of parameters to -# pass to the underlying linear solver. Defaults to None. -# nonlinear_solver_parameters (dict, optional): dictionary of parameters to -# pass to the underlying nonlinear solver. Defaults to None. -# final_update (bool, optional): Whether to compute final update, or just take last -# quadrature value. Defaults to True -# limiter (:class:`Limiter` object, optional): a limiter to apply to -# the evolving field to enforce monotonicity. Defaults to None. -# options (:class:`AdvectionOptions`, optional): an object containing - -# initial_guess (str, optional): Initial guess to be base timestepper, or copy -# """ -# super(RIDC, self).__init__(base_scheme, domain, M, K, -# field_name, linear_solver_parameters, nonlinear_solver_parameters, -# limiter, options) -# self.dt = Constant(float(self.dt_coarse)/(self.M - 1)) -# self.nodes = np.arange(0, M*float(self.dt), float(self.dt)) -# #integration_matrix = self.lagrange_integration_matrix(self.K) -# #self.Q = 0.5 * (self.K-1) * float(self.dt) * integration_matrix - -# def setup(self, equation, apply_bcs=True, *active_labels): -# """ -# Set up the SDC time discretisation based on the equation.n - -# Args: -# equation (:class:`PrognosticEquation`): the model's equation. -# apply_bcs (bool, optional): whether to apply the equation's boundary -# conditions. Defaults to True. -# *active_labels (:class:`Label`): labels indicating which terms of -# the equation to include. -# """ -# super(RIDC, self).setup(equation, apply_bcs, *active_labels) - -# def compute_quad(self, Q, fUnodes, m): -# """ -# Computes integration of F(y) on quadrature nodes -# """ -# quad = Function(self.W) -# quad.assign(0.) -# for k in range(0, np.shape(Q)[0]): -# quad += float(Q[m, k])*fUnodes[k] -# return quad - -# def compute_quad_final(self, Q, fUnodes, m): -# """ -# Computes final integration of F(y) on quadrature nodes -# """ -# quad = Function(self.W) -# quad.assign(0.) -# for k in range(0, self.K): -# quad += float(Q[-1, k])*fUnodes[m + 1 - self.K + k] -# return quad - -# @wrapper_apply -# def apply(self, x_out, x_in): -# self.Un.assign(x_in) - -# # Compute initial guess on quadrature nodes with low-order -# # base timestepper -# self.Unodes[0].assign(self.Un) - -# for m in range(self.M-1): -# self.base.dt = float(self.dt) -# self.base.apply(self.Unodes[m+1], self.Unodes[m]) -# for m in range(self.M): -# for evaluate in self.evaluate_source: -# evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) - -# # Iterate through correction sweeps -# for k in range(1, self.K): -# print("Correction sweep", k) -# # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) -# for m in range(self.M): -# self.Uin.assign(self.Unodes[m]) -# # Include source terms -# for evaluate in self.evaluate_source: -# evaluate(self.Uin, self.base.dt, x_out=self.source_in) -# self.solver_rhs.solve() -# self.fUnodes[m].assign(self.Urhs) -# #self.compute_quad() -# # Loop through quadrature nodes and solve -# self.Unodes1[0].assign(self.Unodes[0]) -# for evaluate in self.evaluate_source: -# evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) -# for m in range(0, self.M-1): -# # Set S matrix -# self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) -# #self.Q_.assign(self.quad[m-1]) - -# # Set initial guess for solver, and pick correct solver -# self.U_start.assign(self.Unodes1[m]) -# self.Ukp1_m.assign(self.Unodes1[m]) -# self.Uk_mp1.assign(self.Unodes[m+1]) -# self.Uk_m.assign(self.Unodes[m]) -# self.source_Ukp1_m.assign(self.source_Ukp1[m]) -# self.source_Uk_m.assign(self.source_Uk[m]) -# self.U_SDC.assign(self.Unodes[m+1]) - -# # Compute -# # for N2N: -# # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) -# # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) -# # + sum(j=1,M) s_mj*(F+S)(y^k) -# self.solver.solve() -# self.Unodes1[m+1].assign(self.U_SDC) - -# # Evaluate source terms -# for evaluate in self.evaluate_source: -# evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) - -# # Apply limiter if required -# if self.limiter is not None: -# self.limiter.apply(self.Unodes1[m+1]) -# # for m in range(self.K-1, self.M-1): -# # # Set S matrix -# # self.Q_.assign(self.compute_quad_final(self.Q, self.fUnodes, m+1)) - -# # # Set initial guess for solver, and pick correct solver -# # self.U_start.assign(self.Unodes1[m]) -# # self.Ukp1_m.assign(self.Unodes1[m]) -# # self.Uk_mp1.assign(self.Unodes[m+1]) -# # self.Uk_m.assign(self.Unodes[m]) -# # self.source_Ukp1_m.assign(self.source_Ukp1[m]) -# # self.source_Uk_m.assign(self.source_Uk[m]) -# # self.U_SDC.assign(self.Unodes[m+1]) - -# # # Compute -# # # for N2N: -# # # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) -# # # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) -# # # + sum(j=1,M) s_mj*(F+S)(y^k) -# # self.solver.solve() -# # self.Unodes1[m+1].assign(self.U_SDC) - -# # # Evaluate source terms -# # for evaluate in self.evaluate_source: -# # evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) - -# # # Apply limiter if required -# # if self.limiter is not None: -# # self.limiter.apply(self.Unodes1[m+1]) - -# for m in range(self.M): -# self.Unodes[m].assign(self.Unodes1[m]) -# self.source_Uk[m].assign(self.source_Ukp1[m]) - -# x_out.assign(self.Unodes[-1]) - -class Parallel_SDC(SDC): - """Class for Spectral Deferred Correction schemes.""" - - def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, - field_name=None, - linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, - limiter=None, options=None, initial_guess="base", communicator=None): - """ - Initialise SDC object - Args: - base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on - quadrature nodes. - domain (:class:`Domain`): the model's domain object, containing the - mesh and the compatible function spaces. - M (int): Number of quadrature nodes to compute spectral integration over - maxk (int): Max number of correction interations - quad_type (str): Type of quadrature to be used. Options are - GAUSS, RADAU-LEFT, RADAU-RIGHT and LOBATTO - node_type (str): Node type to be used. Options are - EQUID, LEGENDRE, CHEBY-1, CHEBY-2, CHEBY-3 and CHEBY-4 - qdelta_imp (str): Implicit Qdelta matrix to be used. Options are - BE, LU, TRAP, EXACT, PIC, OPT, WEIRD, MIN-SR-NS, MIN-SR-S - qdelta_exp (str): Explicit Qdelta matrix to be used. Options are - FE, EXACT, PIC - formulation (str, optional): Whether to use node-to-node or zero-to-node - formulation. Options are N2N and Z2N. Defaults to N2N - field_name (str, optional): name of the field to be evolved. - Defaults to None. - linear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying linear solver. Defaults to None. - nonlinear_solver_parameters (dict, optional): dictionary of parameters to - pass to the underlying nonlinear solver. Defaults to None. - final_update (bool, optional): Whether to compute final update, or just take last - quadrature value. Defaults to True - limiter (:class:`Limiter` object, optional): a limiter to apply to - the evolving field to enforce monotonicity. Defaults to None. - options (:class:`AdvectionOptions`, optional): an object containing - options to either be passed to the spatial discretisation, or - to control the "wrapper" methods, such as Embedded DG or a - recovery method. Defaults to None. - initial_guess (str, optional): Initial guess to be base timestepper, or copy - """ - super().__init__(base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, - formulation = "Z2N", field_name=field_name, - linear_solver_parameters=linear_solver_parameters, nonlinear_solver_parameters= nonlinear_solver_parameters, - final_update=final_update, - limiter=limiter, options=options, initial_guess=initial_guess) - self.comm = communicator - - def compute_quad(self): - """ - Computes integration of F(y) on quadrature nodes - """ - x = Function(self.W) - for j in range(self.M): - x.assign(float(self.Q[j, self.comm.ensemble_comm.rank])*self.fUnodes[self.comm.ensemble_comm.rank]) - self.comm.reduce(x, self.quad[j], root=j) - - def compute_quad_final(self): - """ - Computes final integration of F(y) on quadrature nodes - """ - x = Function(self.W) - x.assign(float(self.Qfin[self.comm.ensemble_comm.rank])*self.fUnodes[self.comm.ensemble_comm.rank]) - self.comm.allreduce(x, self.quad_final) - - - @wrapper_apply - def apply(self, x_out, x_in): - self.Un.assign(x_in) - self.U_start.assign(self.Un) - solver_list = self.solvers - - # Compute initial guess on quadrature nodes with low-order - # base timestepper - self.Unodes[0].assign(self.Un) - if (self.base_flag): - for m in range(self.M): - self.base.dt = float(self.dtau[m]) - self.base.apply(self.Unodes[m+1], self.Unodes[m]) - else: - for m in range(self.M): - self.Unodes[m+1].assign(self.Un) - for m in range(self.M+1): - for evaluate in self.evaluate_source: - evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) - - # Iterate through correction sweeps - k = 0 - while k < self.maxk: - k += 1 - - if self.qdelta_imp_type == "MIN-SR-FLEX": - # Recompute Implicit Q_delta matrix for each iteration k - self.Qdelta_imp = genQDeltaCoeffs( - self.qdelta_imp_type, - form=self.formulation, - nodes=self.nodes, - Q=self.Q, - nNodes=self.M, - nodeType=self.node_type, - quadType=self.quad_type, - k=k - ) - - # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) - # for Z2N: sum(j=1,M) (q_mj*F(y_m^k) + q_mj*S(y_m^k)) - self.Uin.assign(self.Unodes[self.comm.ensemble_comm.rank+1]) - # Include source terms - for evaluate in self.evaluate_source: - evaluate(self.Uin, self.base.dt, x_out=self.source_in) - self.solver_rhs.solve() - self.fUnodes[self.comm.ensemble_comm.rank].assign(self.Urhs) - - self.compute_quad() - - # Loop through quadrature nodes and solve - self.Unodes1[0].assign(self.Unodes[0]) - for evaluate in self.evaluate_source: - evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - - - # Set Q or S matrix - self.Q_.assign(self.quad[self.comm.ensemble_comm.rank]) - - # Set initial guess for solver, and pick correct solver - self.solver = solver_list[self.comm.ensemble_comm.rank] - self.U_SDC.assign(self.Unodes[self.comm.ensemble_comm.rank+1]) - - # Compute - # for N2N: - # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) - # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) - # for Z2N: - # y_m^(k+1) = y^n + sum(j=1,m) Qdelta_imp[m,j]*(F(y_(m)^(k+1)) - F(y_(m)^k)) - # + sum(j=1,M) Q_delta_exp[m,j]*(S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - self.solver.solve() - self.Unodes1[self.comm.ensemble_comm.rank+1].assign(self.U_SDC) - - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.Unodes1[self.comm.ensemble_comm.rank+1], self.base.dt, x_out=self.source_Ukp1[self.comm.ensemble_comm.rank+1]) - - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.Unodes1[self.comm.ensemble_comm.rank+1]) - - self.Unodes[self.comm.ensemble_comm.rank+1].assign(self.Unodes1[self.comm.ensemble_comm.rank+1]) - self.source_Uk[self.comm.ensemble_comm.rank+1].assign(self.source_Ukp1[self.comm.ensemble_comm.rank+1]) - - if self.maxk > 0: - # Compute value at dt rather than final quadrature node tau_M - if self.final_update: - self.Uin.assign(self.Unodes1[self.comm.ensemble_comm.rank+1]) - self.source_in.assign(self.source_Ukp1[self.comm.ensemble_comm.rank+1]) - self.solver_rhs.solve() - self.fUnodes[self.comm.ensemble_comm.rank].assign(self.Urhs) - self.compute_quad_final() - # Compute y_(n+1) = y_n + sum(j=1,M) q_j*F(y_j) - if self.comm.ensemble_comm.rank == self.M-1: - self.U_fin.assign(self.Unodes[-1]) - self.comm.bcast(self.U_fin, self.M -1) - self.solver_fin.solve() - # Apply limiter if required - if self.limiter is not None: - self.limiter.apply(self.U_fin) - x_out.assign(self.U_fin) - else: - # Take value at final quadrature node dtau_M - if self.comm.ensemble_comm.rank == self.M-1: - x_out.assign(self.Unodes[-1]) - self.comm.bcast(x_out, self.M -1) - else: - x_out.assign(self.Unodes[-1]) diff --git a/gusto/time_discretisation/time_discretisation.py b/gusto/time_discretisation/time_discretisation.py index 21a60fade..2e08b9102 100644 --- a/gusto/time_discretisation/time_discretisation.py +++ b/gusto/time_discretisation/time_discretisation.py @@ -22,7 +22,6 @@ from gusto.core.logging import logger, DEBUG, logging_ksp_monitor_true_residual from gusto.time_discretisation.wrappers import * from gusto.solvers import mass_parameters -import numpy as np __all__ = ["TimeDiscretisation", "ExplicitTimeDiscretisation", "BackwardEuler", diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 02b1addfc..bb9b14ed2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -15,9 +15,16 @@ TracerSetup.__new__.__defaults__ = (None,)*len(opts) -def tracer_sphere(tmpdir, degree, small_dt): +def tracer_sphere(tmpdir, degree, small_dt, ensemble): radius = 1 - mesh = IcosahedralSphereMesh(radius=radius, + if ensemble is not None: + dirname = str(tmpdir)+"comm"+str(ensemble.ensemble_comm.rank) + mesh = IcosahedralSphereMesh(radius=radius, + refinement_level=3, + degree=1, comm=ensemble.comm) + else: + dirname = str(tmpdir) + mesh = IcosahedralSphereMesh(radius=radius, refinement_level=3, degree=1) x = SpatialCoordinate(mesh) @@ -30,7 +37,7 @@ def tracer_sphere(tmpdir, degree, small_dt): else: dt = pi/3. * 0.02 - output = OutputParameters(dirname=str(tmpdir), dumpfreq=15) + output = OutputParameters(dirname=dirname, dumpfreq=15) domain = Domain(mesh, dt, family="BDM", degree=degree) io = IO(domain, output) @@ -47,9 +54,14 @@ def tracer_sphere(tmpdir, degree, small_dt): uexpr, umax, radius, tol) -def tracer_slice(tmpdir, degree, small_dt): +def tracer_slice(tmpdir, degree, small_dt, ensemble): n = 30 if degree == 0 else 15 - m = PeriodicIntervalMesh(n, 1.) + if ensemble is not None: + dirname = str(tmpdir)+"comm"+str(ensemble.ensemble_comm.rank) + m = PeriodicIntervalMesh(n, 1., comm=ensemble.comm) + else: + dirname = str(tmpdir) + m = PeriodicIntervalMesh(n, 1.) mesh = ExtrudedMesh(m, layers=n, layer_height=1./n) # Parameters chosen so that dt != 1 and u != 1 @@ -61,7 +73,7 @@ def tracer_slice(tmpdir, degree, small_dt): else: dt = 0.01 tmax = 0.75 - output = OutputParameters(dirname=str(tmpdir), dumpfreq=25) + output = OutputParameters(dirname=dirname, dumpfreq=25) domain = Domain(mesh, dt, family="CG", degree=degree) io = IO(domain, output) @@ -83,16 +95,21 @@ def tracer_slice(tmpdir, degree, small_dt): return TracerSetup(domain, tmax, io, f_init, f_end, degree, uexpr, tol=tol) -def tracer_blob_slice(tmpdir, degree, small_dt): +def tracer_blob_slice(tmpdir, degree, small_dt, ensemble): if small_dt: dt = 0.002 else: dt = 0.01 L = 10. - m = PeriodicIntervalMesh(10, L) + if ensemble is not None: + dirname = str(tmpdir)+"comm"+str(ensemble.ensemble_comm.rank) + m = PeriodicIntervalMesh(10, L, comm=ensemble.comm) + else: + dirname = str(tmpdir) + m = PeriodicIntervalMesh(10, L) mesh = ExtrudedMesh(m, layers=10, layer_height=1.) - output = OutputParameters(dirname=str(tmpdir), dumpfreq=25) + output = OutputParameters(dirname=dirname, dumpfreq=25) domain = Domain(mesh, dt, family="CG", degree=degree) io = IO(domain, output) @@ -106,14 +123,14 @@ def tracer_blob_slice(tmpdir, degree, small_dt): @pytest.fixture() def tracer_setup(): - def _tracer_setup(tmpdir, geometry, blob=False, degree=1, small_dt=False): + def _tracer_setup(tmpdir, geometry, blob=False, degree=1, small_dt=False, ensemble=None): if geometry == "sphere": assert not blob - return tracer_sphere(tmpdir, degree, small_dt) + return tracer_sphere(tmpdir, degree, small_dt, ensemble=ensemble) elif geometry == "slice": if blob: - return tracer_blob_slice(tmpdir, degree, small_dt) + return tracer_blob_slice(tmpdir, degree, small_dt, ensemble=ensemble) else: - return tracer_slice(tmpdir, degree, small_dt) + return tracer_slice(tmpdir, degree, small_dt, ensemble=ensemble) return _tracer_setup diff --git a/integration-tests/model/test_sdc.py b/integration-tests/model/test_deferred_correction.py similarity index 68% rename from integration-tests/model/test_sdc.py rename to integration-tests/model/test_deferred_correction.py index fd1dba068..66e1eba77 100644 --- a/integration-tests/model/test_sdc.py +++ b/integration-tests/model/test_deferred_correction.py @@ -1,11 +1,14 @@ """ -This runs a simple transport test on the sphere using the SDC time discretisations to +This runs a simple transport test on the sphere using the DC time discretisations to test whether the errors are within tolerance. The test is run for the following schemes: - IMEX_SDC_Le(1,1) - IMEX SDC with 1 quadrature node of Gauss type (2nd order scheme) - IMEX_SDC_R(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type (3rd order scheme) using LU decomposition for the implicit update - BE_SDC_Lo(3,3) - Implicit SDC with 3 quadrature nodes of Lobatto type (4th order scheme). - FE_SDC_Le(3,5) - Explicit SDC with 3 quadrature nodes of Gauss type (6th order scheme). +- IMEX_RIDC_R(3) - IMEX RIDC with 4 quadrature nodes of equidistant type, reduced stencils (3rd order scheme). +- BE_RIDC(4) - Implicit RIDC with 3 quadrature nodes of equidistant type, full stencils (4th order scheme). +- FE_RIDC(4) - Explicit RIDC with 3 quadrature nodes of equidistant type, full stencils (4th order scheme). """ from firedrake import norm @@ -19,8 +22,9 @@ def run(timestepper, tmax, f_end): @pytest.mark.parametrize( - "scheme", ["IMEX_SDC_Le(1,1)", "IMEX_SDC_R(2,2)", "BE_SDC_Lo(3,3)", "FE_SDC_Le(3,5)"]) -def test_sdc(tmpdir, scheme, tracer_setup): + "scheme", ["IMEX_SDC_Le(1,1)", "IMEX_SDC_R(2,2)", "BE_SDC_Lo(3,3)", "FE_SDC_Le(3,5)", "IMEX_RIDC_R(3)", + "BE_RIDC(4)", "FE_RIDC(4)"]) +def test_dc(tmpdir, scheme, tracer_setup): geometry = "sphere" setup = tracer_setup(tmpdir, geometry) domain = setup.domain @@ -66,11 +70,33 @@ def test_sdc(tmpdir, scheme, tracer_setup): elif scheme == "FE_SDC_Le(3,5)": quad_type = "GAUSS" M = 3 - k = 4 + k = 5 eqn.label_terms(lambda t: not t.has_label(time_derivative), explicit) base_scheme = ForwardEuler(domain) scheme = SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, qdelta_exp, final_update=True, initial_guess="base") + elif scheme == "IMEX_RIDC_R(3)": + k = 2 + M = k*(k+1)//2 + 1 + eqn = ContinuityEquation(domain, V, "f") + # Split continuity term + eqn = split_continuity_form(eqn) + eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) + eqn.label_terms(lambda t: t.has_label(transport), explicit) + base_scheme = IMEX_Euler(domain) + scheme = RIDC(base_scheme, domain, M, k, reduced=True) + elif scheme == "BE_RIDC(4)": + k = 3 + M = 3 + eqn.label_terms(lambda t: not t.has_label(time_derivative), implicit) + base_scheme = BackwardEuler(domain) + scheme = RIDC(base_scheme, domain, M, k, reduced=False) + elif scheme == "FE_RIDC(4)": + M = 3 + k = 3 + eqn.label_terms(lambda t: not t.has_label(time_derivative), explicit) + base_scheme = ForwardEuler(domain) + scheme = RIDC(base_scheme, domain, M, k, reduced=False) transport_method = DGUpwind(eqn, 'f') diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py new file mode 100644 index 000000000..a8305f4f8 --- /dev/null +++ b/integration-tests/model/test_parallel_dc.py @@ -0,0 +1,70 @@ +""" +This runs a simple transport test on the sphere using the DC time discretisations to +test whether the errors are within tolerance. The test is run for the following schemes: +- IMEX_SDC_R(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type (3rd order scheme) using +- IMEX_RIDC_R(3) - IMEX RIDC with 4 quadrature nodes of equidistant type, reduced stencils (3rd order scheme). +""" + +from firedrake import norm, Ensemble, COMM_WORLD +from gusto import * +import pytest +from pytest_mpi.parallel_assert import parallel_assert + + +def run(timestepper, tmax, f_end): + timestepper.run(0, tmax) + print(norm(timestepper.fields("f") - f_end) / norm(f_end)) + return norm(timestepper.fields("f") - f_end) / norm(f_end) + +@pytest.mark.parallel(nprocs=[2,4]) +@pytest.mark.parametrize( + "scheme", ["IMEX_SDC_R(3,3)", "IMEX_RIDC_R(3)"]) +def test_parallel_dc(tmpdir, scheme, tracer_setup): + + if scheme == "IMEX_SDC_R(3,3)": + M = 2 + k = M + ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) + elif scheme == "IMEX_RIDC_R(3)": + k = 1 + ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(k+1)) + geometry = "sphere" + setup = tracer_setup(tmpdir, geometry, ensemble=ensemble) + domain = setup.domain + V = domain.spaces("DG") + + if scheme == "IMEX_SDC_R(3,3)": + quad_type = "RADAU-RIGHT" + node_type = "LEGENDRE" + qdelta_imp = "MIN-SR-FLEX" + qdelta_exp = "MIN-SR-NS" + eqn = ContinuityEquation(domain, V, "f") + # Split continuity term + eqn = split_continuity_form(eqn) + eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) + eqn.label_terms(lambda t: t.has_label(transport), explicit) + base_scheme = IMEX_Euler(domain) + scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, + qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) + elif scheme == "IMEX_RIDC_R(3)": + M = k*(k+1)//2 + 1 + eqn = ContinuityEquation(domain, V, "f") + # Split continuity term + eqn = split_continuity_form(eqn) + eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) + eqn.label_terms(lambda t: t.has_label(transport), explicit) + base_scheme = IMEX_Euler(domain) + scheme = Parallel_RIDC(base_scheme, domain, M, k, communicator=ensemble) + + transport_method = DGUpwind(eqn, 'f') + + time_varying_velocity = False + timestepper = PrescribedTransport( + eqn, scheme, setup.io, time_varying_velocity, transport_method + ) + + # Initial conditions + timestepper.fields("f").interpolate(setup.f_init) + timestepper.fields("u").project(setup.uexpr) + #run(timestepper, setup.tmax, setup.f_end) + #parallel_assert(run(timestepper, setup.tmax, setup.f_end) < setup.tol, "Error too large") From 86fb713c73b8508284c240ff81dae3f0c9f8dc26 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 13 May 2025 12:10:36 +0100 Subject: [PATCH 09/43] Tidy up of code --- gusto/core/io.py | 2 +- .../deferred_correction.py | 101 ++++++++++-------- .../explicit_runge_kutta.py | 1 + gusto/time_discretisation/parallel_dc.py | 30 ++++-- integration-tests/conftest.py | 19 ++-- .../model/test_deferred_correction.py | 2 +- integration-tests/model/test_parallel_dc.py | 8 +- 7 files changed, 94 insertions(+), 69 deletions(-) diff --git a/gusto/core/io.py b/gusto/core/io.py index a375f6cf3..963f52bdc 100644 --- a/gusto/core/io.py +++ b/gusto/core/io.py @@ -55,7 +55,7 @@ def pick_up_mesh(output, mesh_name, comm=COMM_WORLD): else: dumpdir = path.join("results", output.dirname) chkfile = path.join(dumpdir, "chkpt.h5") - with CheckpointFile(chkfile, 'r', comm=mesh.comm) as chk: + with CheckpointFile(chkfile, 'r') as chk: mesh = chk.load_mesh(mesh_name) if dumpdir: diff --git a/gusto/time_discretisation/deferred_correction.py b/gusto/time_discretisation/deferred_correction.py index 99a5e1968..9272d15ef 100644 --- a/gusto/time_discretisation/deferred_correction.py +++ b/gusto/time_discretisation/deferred_correction.py @@ -1,67 +1,76 @@ -u""" +""" Objects for discretising time derivatives using Deferred Correction (DC) -Methods. We have Spectral Deferred Correction (SDC) and Serial Revisionist Integral -Deferred Correction (RIDC) methods. +Methods. This includes Spectral Deferred Correction (SDC) and Serial Revisionist +Integral Deferred Correction (RIDC) methods. -SDC and RIDC objects discretise ∂y/∂t = F(y), for variable y, time t and -operator F. +These methods discretise ∂y/∂t = F(y), for variable y, time t, and operator F. -Written in Picard integral form this equation is -y(t) = y_n + int[t_n,t] F(y(s)) ds +In Picard integral form, this equation is: +y(t) = y_n + ∫[t_n, t] F(y(s)) ds ================================================================================ -SDC Formulation: +Spectral Deferred Correction (SDC) Formulation: ================================================================================ -SDC methods are based on the idea of integrating the function F(y) over the -interval [t_n, t_n+1] using quadrature. We can then evaluate the function -using some quadrature rule, we can evaluate y on a temporal quadrature node as -y_m = y_n + sum[j=1,M] q_mj*F(y_j) -where q_mj can be found by integrating Lagrange polynomials. This is similar to -how Runge-Kutta methods are formed. +SDC methods integrate the function F(y) over the interval [t_n, t_n+1] using +quadrature. Evaluating y on temporal quadrature nodes gives: +y_m = y_n + Σ[j=1,M] q_mj * F(y_j) +where q_mj are derived from integrating Lagrange polynomials, similar to how +Runge-Kutta methods are constructed. -In matrix form this equation is: -(I - dt*Q*F)(y)=y_n +In matrix form: +(I - dt * Q * F)(y) = y_n -Computing y by Picard iteration through k we get: -y^(k+1)=y^k + (y_n - (I - dt*Q*F)(y^k)) +Using Picard iteration: +y^(k+1) = y^k + (y_n - (I - dt * Q * F)(y^k)) -Finally, to get our SDC method we precondition this system, using some approximation -of Q, Q_delta: -(I - dt*Q_delta*F)(y^(k+1)) = y_n + dt*(Q - Q_delta)F(y^k) +Preconditioning this system with an approximation Q_delta gives: +(I - dt * Q_delta * F)(y^(k+1)) = y_n + dt * (Q - Q_delta) * F(y^k) -The zero-to-node (Z2N) formulation is then: -y_m^(k+1) = y_n + sum(j=1,M) q'_mj*(F(y_j^(k+1)) - F(y_j^k)) - + sum(j=1,M) q_mj*F(y_(m-1)^k) -for entires q_mj in Q and q'_mj in Q_delta. +Two formulations are commonly used: +1. Zero-to-node (Z2N): + y_m^(k+1) = y_n + Σ[j=1,M] q'_mj * (F(y_j^(k+1)) - F(y_j^k)) + + Σ[j=1,M] q_mj * F(y_(j)^k) + where q_mj are entries in Q and q'_mj are entries in Q_delta. -Node-wise from previous quadrature node (N2N formulation), the implicit SDC calculation is: -y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k)) - + sum(j=1,M) s_mj*F(y_(m-1)^k) -where s_mj = q_mj - q_(m-1)j for entires q_ik in Q. +2. Node-to-node (N2N): + y_m^(k+1) = y_(m-1)^(k+1) + dtau_m * (F(y_(m)^(k+1)) - F(y_(m)^k)) + + Σ[j=1,M] s_mj * F(y_(j)^k) + where s_mj = q_mj - q_(m-1)j for entries q_ik in Q. -Key choices in our SDC method are: -- Choice of quadrature node type (e.g. gauss-lobatto) +Key choices in SDC: +- Quadrature node type (e.g., Gauss-Lobatto) - Number of quadrature nodes -- Number of iterations - each iteration increases the order of accuracy up to - the order of the underlying quadrature -- Choice of Q_delta (e.g. Forward Euler, Backward Euler, LU-trick) -- How to get initial solution on quadrature nodes +- Number of iterations (each iteration increases accuracy up to the quadrature order) +- Choice of Q_delta (e.g., Forward Euler, Backward Euler, LU-trick) +- Initial solution on quadrature nodes ================================================================================ -RIDC Formulation: +Revisionist Integral Deferred Correction (RIDC) Formulation: ================================================================================ -RIDC methods are closely related to SDC methods, but use equidistant nodes and -a slightly different formulation, discretising the error equation in a different way. -The idea is to use a low-order method to get an initial guess of the solution, and then -correct this solution using a high-order method. The correction is done by solving -the error equation, which is derived from the original equation. -SDC can also be thought of in this way. +RIDC methods are similar to SDC but use equidistant nodes and a different +formulation for the error equation. The process involves: +1. Using a low-order method (predictor) to compute an initial solution: + y_m^(0) = y_(m-1)^(0) + dt * F(y_(m)^(0)) + +2. Performing K correction steps: + y_m^(k+1) = y_(m-1)^(k+1) + dt * (F(y_(m)^(k+1)) - F(y_(m)^k)) + + Σ[j=1,M] s_mj * F(y_(j)^k) +We solve on N equispaced nodes on the interval [0, T] divided into J intervals, +each further divided into M subintervals: -The error equation is: + 0 * * * * * | * * * * * | * * * * * | * * * * * | * * * * * T + | J intervals, each with M subintervals | +Here, M >> K, and M must be at least K * (K+1) / 2 for the reduced stencil RIDC method. +dt = T / N, N = J * M. +Each correction sweep increases accuracy up to the quadrature order. +Key choices in RIDC: +- Number of subintervals J +- Number of quadrature nodes M + 1 +- Number of correction iterations K """ from abc import ABCMeta @@ -79,6 +88,7 @@ __all__ = ["SDC", "RIDC"] + class SDC(object, metaclass=ABCMeta): """Class for Spectral Deferred Correction schemes.""" @@ -530,6 +540,7 @@ def apply(self, x_out, x_in): else: x_out.assign(self.Unodes[-1]) + class RIDC(object, metaclass=ABCMeta): """Class for Revisionist Integral Deferred Correction schemes.""" @@ -600,7 +611,6 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, else: self.nonlinear_solver_parameters = nonlinear_solver_parameters - def setup(self, equation, apply_bcs=True, *active_labels): """ Set up the RIDC time discretisation based on the equation. @@ -661,7 +671,7 @@ def setup(self, equation, apply_bcs=True, *active_labels): def nlevels(self): return 1 - def equidistant_nodes(self ,M): + def equidistant_nodes(self, M): """ Returns a grid of M equispaced nodes from -1 to 1 """ @@ -846,7 +856,6 @@ def res(self): lambda t: Constant(self.dt)*t) residual -= r_exp_k - # Add on sum(j=1,M) s_mj*F(y_m^k), where s_mj = q_mj-q_m-1j # and s1j = q1j. Q = self.residual.label_map(lambda t: t.has_label(time_derivative), diff --git a/gusto/time_discretisation/explicit_runge_kutta.py b/gusto/time_discretisation/explicit_runge_kutta.py index 34c47b977..9cedee0cf 100644 --- a/gusto/time_discretisation/explicit_runge_kutta.py +++ b/gusto/time_discretisation/explicit_runge_kutta.py @@ -449,6 +449,7 @@ def apply_cycle(self, x_out, x_in): self.solve_stage(x_in, i) x_out.assign(self.x1) + class ForwardEuler(ExplicitRungeKutta): """ Implements the forward Euler timestepping scheme. diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index f27844010..5c3f2b1c9 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -19,6 +19,7 @@ __all__ = ["Parallel_RIDC", "Parallel_SDC"] + class Parallel_RIDC(RIDC): """Class for Parallel Revisionist Integral Deferred Correction schemes.""" @@ -51,6 +52,13 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, limiter, options, reduced=True) self.comm = communicator + # Checks for parallel RIDC + if self.comm is None: + raise ValueError("No communicator provided. Please provide a valid MPI communicator.") + if self.comm.ensemble_comm.size != self.K - 1: + raise ValueError("Number of ranks must be equal to K-1 for Parallel RIDC.") + if self.M < self.K*(self.K+1)//2: + raise ValueError("Number of subintervals M must be greater than K*(K+1)/2 for Parallel RIDC.") def setup(self, equation, apply_bcs=True, *active_labels): """ @@ -71,9 +79,6 @@ def setup(self, equation, apply_bcs=True, *active_labels): @wrapper_apply def apply(self, x_out, x_in): - - # Time parallelised code - # Set up varibles on this rank x_out.assign(x_in) self.kval = self.comm.ensemble_comm.rank @@ -87,7 +92,7 @@ def apply(self, x_out, x_in): self.solver_rhs.solve() self.fUnodes[0].assign(self.Urhs) - # On first communicator, we do the base timestepper + # On first communicator, we do the predictor step if (self.comm.ensemble_comm.rank == 0): # Base timestepper for m in range(self.M): @@ -123,7 +128,7 @@ def apply(self, x_out, x_in): # Compute # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) - # + sum(j=1,M) s_mj*(F+S)(y^k) + # + sum(j=1,M) s_mj*(F+S)(y_j^k) self.solver.solve() self.Unodes1[m+1].assign(self.U_DC) @@ -232,12 +237,18 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im communicator (MPI communicator, optional): communicator for parallel execution. Defaults to None. """ super().__init__(base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, - formulation = "Z2N", field_name=field_name, - linear_solver_parameters=linear_solver_parameters, nonlinear_solver_parameters= nonlinear_solver_parameters, + formulation="Z2N", field_name=field_name, + linear_solver_parameters=linear_solver_parameters, nonlinear_solver_parameters=nonlinear_solver_parameters, final_update=final_update, limiter=limiter, options=options, initial_guess=initial_guess) self.comm = communicator + # Checks for parallel SDC + if self.comm is None: + raise ValueError("No communicator provided. Please provide a valid MPI communicator.") + if self.comm.ensemble_comm.size != self.M: + raise ValueError("Number of ranks must be equal to the number of nodes M for Parallel SDC.") + def compute_quad(self): """ Computes integration of F(y) on quadrature nodes @@ -309,7 +320,6 @@ def apply(self, x_out, x_in): for evaluate in self.evaluate_source: evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) - # Set Q or S matrix self.Q_.assign(self.quad[self.comm.ensemble_comm.rank]) @@ -350,7 +360,7 @@ def apply(self, x_out, x_in): # Compute y_(n+1) = y_n + sum(j=1,M) q_j*F(y_j) if self.comm.ensemble_comm.rank == self.M-1: self.U_fin.assign(self.Unodes[-1]) - self.comm.bcast(self.U_fin, self.M -1) + self.comm.bcast(self.U_fin, self.M-1) self.solver_fin.solve() # Apply limiter if required if self.limiter is not None: @@ -360,6 +370,6 @@ def apply(self, x_out, x_in): # Take value at final quadrature node dtau_M if self.comm.ensemble_comm.rank == self.M-1: x_out.assign(self.Unodes[-1]) - self.comm.bcast(x_out, self.M -1) + self.comm.bcast(x_out, self.M-1) else: x_out.assign(self.Unodes[-1]) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index bb9b14ed2..5975b7792 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -18,15 +18,20 @@ def tracer_sphere(tmpdir, degree, small_dt, ensemble): radius = 1 if ensemble is not None: - dirname = str(tmpdir)+"comm"+str(ensemble.ensemble_comm.rank) - mesh = IcosahedralSphereMesh(radius=radius, - refinement_level=3, - degree=1, comm=ensemble.comm) + dirname = str(tmpdir) + "comm" + str(ensemble.ensemble_comm.rank) + mesh = IcosahedralSphereMesh( + radius=radius, + refinement_level=3, + degree=1, + comm=ensemble.comm + ) else: dirname = str(tmpdir) - mesh = IcosahedralSphereMesh(radius=radius, - refinement_level=3, - degree=1) + mesh = IcosahedralSphereMesh( + radius=radius, + refinement_level=3, + degree=1 + ) x = SpatialCoordinate(mesh) # Parameters chosen so that dt != 1 diff --git a/integration-tests/model/test_deferred_correction.py b/integration-tests/model/test_deferred_correction.py index 66e1eba77..7db9051f1 100644 --- a/integration-tests/model/test_deferred_correction.py +++ b/integration-tests/model/test_deferred_correction.py @@ -23,7 +23,7 @@ def run(timestepper, tmax, f_end): @pytest.mark.parametrize( "scheme", ["IMEX_SDC_Le(1,1)", "IMEX_SDC_R(2,2)", "BE_SDC_Lo(3,3)", "FE_SDC_Le(3,5)", "IMEX_RIDC_R(3)", - "BE_RIDC(4)", "FE_RIDC(4)"]) + "BE_RIDC(4)", "FE_RIDC(4)"]) def test_dc(tmpdir, scheme, tracer_setup): geometry = "sphere" setup = tracer_setup(tmpdir, geometry) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index a8305f4f8..8218feaf1 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -16,7 +16,8 @@ def run(timestepper, tmax, f_end): print(norm(timestepper.fields("f") - f_end) / norm(f_end)) return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[2,4]) + +@pytest.mark.parallel(nprocs=[2, 4]) @pytest.mark.parametrize( "scheme", ["IMEX_SDC_R(3,3)", "IMEX_RIDC_R(3)"]) def test_parallel_dc(tmpdir, scheme, tracer_setup): @@ -45,7 +46,7 @@ def test_parallel_dc(tmpdir, scheme, tracer_setup): eqn.label_terms(lambda t: t.has_label(transport), explicit) base_scheme = IMEX_Euler(domain) scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, - qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) + qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) elif scheme == "IMEX_RIDC_R(3)": M = k*(k+1)//2 + 1 eqn = ContinuityEquation(domain, V, "f") @@ -66,5 +67,4 @@ def test_parallel_dc(tmpdir, scheme, tracer_setup): # Initial conditions timestepper.fields("f").interpolate(setup.f_init) timestepper.fields("u").project(setup.uexpr) - #run(timestepper, setup.tmax, setup.f_end) - #parallel_assert(run(timestepper, setup.tmax, setup.f_end) < setup.tol, "Error too large") + parallel_assert(run(timestepper, setup.tmax, setup.f_end) < setup.tol, "Error too large") From 0ea9ed40ee0e6a29b55acd97c8b47539242952c7 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 13 May 2025 12:44:59 +0100 Subject: [PATCH 10/43] Fixing error check --- gusto/time_discretisation/parallel_dc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index 5c3f2b1c9..58696806b 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -55,8 +55,8 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, # Checks for parallel RIDC if self.comm is None: raise ValueError("No communicator provided. Please provide a valid MPI communicator.") - if self.comm.ensemble_comm.size != self.K - 1: - raise ValueError("Number of ranks must be equal to K-1 for Parallel RIDC.") + if self.comm.ensemble_comm.size != self.K + 1: + raise ValueError("Number of ranks must be equal to K+1 for Parallel RIDC.") if self.M < self.K*(self.K+1)//2: raise ValueError("Number of subintervals M must be greater than K*(K+1)/2 for Parallel RIDC.") From 52166b601b97ffa1db459133ff3f0037e740a083 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 13 May 2025 18:07:28 +0100 Subject: [PATCH 11/43] Fixed IO in parallel test --- gusto/core/io.py | 2 +- integration-tests/conftest.py | 53 ++++++---------- integration-tests/model/test_parallel_dc.py | 67 +++++++++++++++------ 3 files changed, 67 insertions(+), 55 deletions(-) diff --git a/gusto/core/io.py b/gusto/core/io.py index 963f52bdc..54c3573e4 100644 --- a/gusto/core/io.py +++ b/gusto/core/io.py @@ -55,7 +55,7 @@ def pick_up_mesh(output, mesh_name, comm=COMM_WORLD): else: dumpdir = path.join("results", output.dirname) chkfile = path.join(dumpdir, "chkpt.h5") - with CheckpointFile(chkfile, 'r') as chk: + with CheckpointFile(chkfile, 'r', comm) as chk: mesh = chk.load_mesh(mesh_name) if dumpdir: diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 5975b7792..7fce74377 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -15,23 +15,15 @@ TracerSetup.__new__.__defaults__ = (None,)*len(opts) -def tracer_sphere(tmpdir, degree, small_dt, ensemble): +def tracer_sphere(tmpdir, degree, small_dt): radius = 1 - if ensemble is not None: - dirname = str(tmpdir) + "comm" + str(ensemble.ensemble_comm.rank) - mesh = IcosahedralSphereMesh( - radius=radius, - refinement_level=3, - degree=1, - comm=ensemble.comm - ) - else: - dirname = str(tmpdir) - mesh = IcosahedralSphereMesh( - radius=radius, - refinement_level=3, - degree=1 - ) + + dirname = str(tmpdir) + mesh = IcosahedralSphereMesh( + radius=radius, + refinement_level=3, + degree=1 + ) x = SpatialCoordinate(mesh) # Parameters chosen so that dt != 1 @@ -59,14 +51,11 @@ def tracer_sphere(tmpdir, degree, small_dt, ensemble): uexpr, umax, radius, tol) -def tracer_slice(tmpdir, degree, small_dt, ensemble): +def tracer_slice(tmpdir, degree, small_dt): n = 30 if degree == 0 else 15 - if ensemble is not None: - dirname = str(tmpdir)+"comm"+str(ensemble.ensemble_comm.rank) - m = PeriodicIntervalMesh(n, 1., comm=ensemble.comm) - else: - dirname = str(tmpdir) - m = PeriodicIntervalMesh(n, 1.) + + dirname = str(tmpdir) + m = PeriodicIntervalMesh(n, 1.) mesh = ExtrudedMesh(m, layers=n, layer_height=1./n) # Parameters chosen so that dt != 1 and u != 1 @@ -100,18 +89,14 @@ def tracer_slice(tmpdir, degree, small_dt, ensemble): return TracerSetup(domain, tmax, io, f_init, f_end, degree, uexpr, tol=tol) -def tracer_blob_slice(tmpdir, degree, small_dt, ensemble): +def tracer_blob_slice(tmpdir, degree, small_dt): if small_dt: dt = 0.002 else: dt = 0.01 L = 10. - if ensemble is not None: - dirname = str(tmpdir)+"comm"+str(ensemble.ensemble_comm.rank) - m = PeriodicIntervalMesh(10, L, comm=ensemble.comm) - else: - dirname = str(tmpdir) - m = PeriodicIntervalMesh(10, L) + dirname = str(tmpdir) + m = PeriodicIntervalMesh(10, L) mesh = ExtrudedMesh(m, layers=10, layer_height=1.) output = OutputParameters(dirname=dirname, dumpfreq=25) @@ -128,14 +113,14 @@ def tracer_blob_slice(tmpdir, degree, small_dt, ensemble): @pytest.fixture() def tracer_setup(): - def _tracer_setup(tmpdir, geometry, blob=False, degree=1, small_dt=False, ensemble=None): + def _tracer_setup(tmpdir, geometry, blob=False, degree=1, small_dt=False): if geometry == "sphere": assert not blob - return tracer_sphere(tmpdir, degree, small_dt, ensemble=ensemble) + return tracer_sphere(tmpdir, degree, small_dt) elif geometry == "slice": if blob: - return tracer_blob_slice(tmpdir, degree, small_dt, ensemble=ensemble) + return tracer_blob_slice(tmpdir, degree, small_dt) else: - return tracer_slice(tmpdir, degree, small_dt, ensemble=ensemble) + return tracer_slice(tmpdir, degree, small_dt) return _tracer_setup diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 8218feaf1..29e61eea9 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -1,37 +1,67 @@ """ -This runs a simple transport test on the sphere using the DC time discretisations to +This runs a simple transport test on the sphere using the parallel DC time discretisations to test whether the errors are within tolerance. The test is run for the following schemes: - IMEX_SDC_R(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type (3rd order scheme) using - IMEX_RIDC_R(3) - IMEX RIDC with 4 quadrature nodes of equidistant type, reduced stencils (3rd order scheme). """ -from firedrake import norm, Ensemble, COMM_WORLD +from firedrake import (norm, Ensemble, COMM_WORLD, SpatialCoordinate, + as_vector, pi, exp, IcosahedralSphereMesh) + from gusto import * import pytest from pytest_mpi.parallel_assert import parallel_assert - def run(timestepper, tmax, f_end): timestepper.run(0, tmax) print(norm(timestepper.fields("f") - f_end) / norm(f_end)) return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[2, 4]) +@pytest.mark.parallel(nprocs=[3,6]) @pytest.mark.parametrize( "scheme", ["IMEX_SDC_R(3,3)", "IMEX_RIDC_R(3)"]) -def test_parallel_dc(tmpdir, scheme, tracer_setup): +def test_parallel_dc(tmpdir, scheme): if scheme == "IMEX_SDC_R(3,3)": - M = 2 + M = 3 k = M ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) elif scheme == "IMEX_RIDC_R(3)": - k = 1 + k = 2 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(k+1)) - geometry = "sphere" - setup = tracer_setup(tmpdir, geometry, ensemble=ensemble) - domain = setup.domain + + # Get the tracer setup + radius = 1 + dirname = str(tmpdir) + mesh = IcosahedralSphereMesh( + radius=radius, + refinement_level=3, + degree=1, + comm=ensemble.comm + ) + x = SpatialCoordinate(mesh) + + # Parameters chosen so that dt != 1 + # Gaussian is translated from (lon=pi/2, lat=0) to (lon=0, lat=0) + # to demonstrate that transport is working correctly + + dt = pi/3. * 0.02 + + output = OutputParameters(dirname=dirname, dump_vtus=False, dump_nc=True, dumpfreq=15) + domain = Domain(mesh, dt, family="BDM", degree=1) + io = IO(domain, output) + + umax = 1.0 + uexpr = as_vector([- umax * x[1] / radius, umax * x[0] / radius, 0.0]) + + tmax = pi/2 + f_init = exp(-x[2]**2 - x[0]**2) + f_end = exp(-x[2]**2 - x[1]**2) + + tol = 0.05 + + domain = domain V = domain.spaces("DG") if scheme == "IMEX_SDC_R(3,3)": @@ -40,31 +70,28 @@ def test_parallel_dc(tmpdir, scheme, tracer_setup): qdelta_imp = "MIN-SR-FLEX" qdelta_exp = "MIN-SR-NS" eqn = ContinuityEquation(domain, V, "f") - # Split continuity term eqn = split_continuity_form(eqn) eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) eqn.label_terms(lambda t: t.has_label(transport), explicit) base_scheme = IMEX_Euler(domain) - scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, - qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) + time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, + qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) elif scheme == "IMEX_RIDC_R(3)": M = k*(k+1)//2 + 1 eqn = ContinuityEquation(domain, V, "f") - # Split continuity term eqn = split_continuity_form(eqn) eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) eqn.label_terms(lambda t: t.has_label(transport), explicit) base_scheme = IMEX_Euler(domain) - scheme = Parallel_RIDC(base_scheme, domain, M, k, communicator=ensemble) + time_scheme = Parallel_RIDC(base_scheme, domain, M, k, communicator=ensemble) transport_method = DGUpwind(eqn, 'f') time_varying_velocity = False timestepper = PrescribedTransport( - eqn, scheme, setup.io, time_varying_velocity, transport_method + eqn, time_scheme, io, time_varying_velocity, transport_method ) - # Initial conditions - timestepper.fields("f").interpolate(setup.f_init) - timestepper.fields("u").project(setup.uexpr) - parallel_assert(run(timestepper, setup.tmax, setup.f_end) < setup.tol, "Error too large") + timestepper.fields("f").interpolate(f_init) + timestepper.fields("u").project(uexpr) + parallel_assert(run(timestepper, tmax, f_end) < tol, "Error too large") From a1abc48b33e01ce0de1257af9a6de2838b6ae740 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Wed, 14 May 2025 09:41:52 +0100 Subject: [PATCH 12/43] Fix parallel test --- integration-tests/model/test_parallel_dc.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 29e61eea9..0195ee4bc 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -12,23 +12,24 @@ import pytest from pytest_mpi.parallel_assert import parallel_assert + def run(timestepper, tmax, f_end): timestepper.run(0, tmax) print(norm(timestepper.fields("f") - f_end) / norm(f_end)) return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[3,6]) +@pytest.mark.parallel(nprocs=[2, 4]) @pytest.mark.parametrize( - "scheme", ["IMEX_SDC_R(3,3)", "IMEX_RIDC_R(3)"]) + "scheme", ["IMEX_SDC_R(2,2)", "IMEX_RIDC_R(2)"]) def test_parallel_dc(tmpdir, scheme): if scheme == "IMEX_SDC_R(3,3)": - M = 3 + M = 2 k = M ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) elif scheme == "IMEX_RIDC_R(3)": - k = 2 + k = 1 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(k+1)) # Get the tracer setup @@ -75,7 +76,7 @@ def test_parallel_dc(tmpdir, scheme): eqn.label_terms(lambda t: t.has_label(transport), explicit) base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, - qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) + qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) elif scheme == "IMEX_RIDC_R(3)": M = k*(k+1)//2 + 1 eqn = ContinuityEquation(domain, V, "f") From c2803baa98cfd9d4443af76e5a26fb96bba7551a Mon Sep 17 00:00:00 2001 From: atb1995 Date: Wed, 14 May 2025 09:53:09 +0100 Subject: [PATCH 13/43] Docs fix --- .github/workflows/docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index e77dcca3f..98c95aa7b 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest # The docker container to use. container: - image: firedrakeproject/firedrake-docdeps:latest + image: firedrakeproject/firedrake-vanilla-default:latest steps: - uses: actions/checkout@v4 From eff003b57c94596607de53191ca673ffd8ab9b5b Mon Sep 17 00:00:00 2001 From: atb1995 Date: Wed, 14 May 2025 10:17:20 +0100 Subject: [PATCH 14/43] Fix test again.. --- integration-tests/model/test_parallel_dc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 0195ee4bc..5b06cb81e 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -24,11 +24,11 @@ def run(timestepper, tmax, f_end): "scheme", ["IMEX_SDC_R(2,2)", "IMEX_RIDC_R(2)"]) def test_parallel_dc(tmpdir, scheme): - if scheme == "IMEX_SDC_R(3,3)": + if scheme == "IMEX_SDC_R(2,2)": M = 2 k = M ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) - elif scheme == "IMEX_RIDC_R(3)": + elif scheme == "IMEX_RIDC_R(2)": k = 1 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(k+1)) @@ -65,7 +65,7 @@ def test_parallel_dc(tmpdir, scheme): domain = domain V = domain.spaces("DG") - if scheme == "IMEX_SDC_R(3,3)": + if scheme == "IMEX_SDC_R(2,2)": quad_type = "RADAU-RIGHT" node_type = "LEGENDRE" qdelta_imp = "MIN-SR-FLEX" @@ -77,7 +77,7 @@ def test_parallel_dc(tmpdir, scheme): base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) - elif scheme == "IMEX_RIDC_R(3)": + elif scheme == "IMEX_RIDC_R(2)": M = k*(k+1)//2 + 1 eqn = ContinuityEquation(domain, V, "f") eqn = split_continuity_form(eqn) From 2349b1dca1a27e073b0302e7f520a50d73769d95 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Wed, 14 May 2025 13:01:44 +0100 Subject: [PATCH 15/43] Updated tests --- .github/workflows/build.yml | 2 +- gusto/time_discretisation/parallel_dc.py | 7 +++++-- integration-tests/model/test_parallel_dc.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e7d5a153b..f055b1176 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,7 +33,7 @@ jobs: GUSTO_PARALLEL_LOG: CONSOLE PYOP2_CFLAGS: -O0 # Make sure that tests with >4 processes are not silently skipped - PYTEST_MPI_MAX_NPROCS: 4 + PYTEST_MPI_MAX_NPROCS: 6 EXTRA_PYTEST_ARGS: --durations=100 --timeout=3600 --timeout-method=thread -o faulthandler_timeout=3660 --show-capture=no --verbose gusto-repo/unit-tests gusto-repo/integration-tests gusto-repo/examples steps: - name: Fix HOME diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index 58696806b..0581d7c6b 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -62,7 +62,7 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, def setup(self, equation, apply_bcs=True, *active_labels): """ - Set up the SDC time discretisation based on the equation.n + Set up the RIDC time discretisation based on the equation. Args: equation (:class:`PrognosticEquation`): the model's equation. @@ -372,4 +372,7 @@ def apply(self, x_out, x_in): x_out.assign(self.Unodes[-1]) self.comm.bcast(x_out, self.M-1) else: - x_out.assign(self.Unodes[-1]) + # Take value at final quadrature node dtau_M + if self.comm.ensemble_comm.rank == self.M-1: + x_out.assign(self.Unodes[-1]) + self.comm.bcast(x_out, self.M-1) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 5b06cb81e..70bf448c0 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -19,17 +19,17 @@ def run(timestepper, tmax, f_end): return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[2, 4]) +@pytest.mark.parallel(nprocs=[3, 6]) @pytest.mark.parametrize( - "scheme", ["IMEX_SDC_R(2,2)", "IMEX_RIDC_R(2)"]) + "scheme", ["IMEX_SDC_R(3,3)", "IMEX_RIDC_R(3)"]) def test_parallel_dc(tmpdir, scheme): - if scheme == "IMEX_SDC_R(2,2)": - M = 2 + if scheme == "IMEX_SDC_R(3,3)": + M = 3 k = M ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) - elif scheme == "IMEX_RIDC_R(2)": - k = 1 + elif scheme == "IMEX_RIDC_R(3)": + k = 2 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(k+1)) # Get the tracer setup @@ -65,7 +65,7 @@ def test_parallel_dc(tmpdir, scheme): domain = domain V = domain.spaces("DG") - if scheme == "IMEX_SDC_R(2,2)": + if scheme == "IMEX_SDC_R(3,3)": quad_type = "RADAU-RIGHT" node_type = "LEGENDRE" qdelta_imp = "MIN-SR-FLEX" @@ -77,8 +77,8 @@ def test_parallel_dc(tmpdir, scheme): base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) - elif scheme == "IMEX_RIDC_R(2)": - M = k*(k+1)//2 + 1 + elif scheme == "IMEX_RIDC_R(3)": + M = k*(k+1)//2 + 4 eqn = ContinuityEquation(domain, V, "f") eqn = split_continuity_form(eqn) eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) From be9a6e17b2963035d2658b83875bb649bdfdd28f Mon Sep 17 00:00:00 2001 From: atb1995 Date: Fri, 16 May 2025 13:07:29 +0100 Subject: [PATCH 16/43] Add call to 6 processors --- .github/workflows/build.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f055b1176..d4cd1c8d1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -118,6 +118,13 @@ jobs: firedrake-run-split-tests 4 3 "$EXTRA_PYTEST_ARGS" "--log-file=gusto4_{#}.log" timeout-minutes: 10 + - name: Run tests (nprocs = 6) + if: success() || steps.install-two.conclusion == 'success' + run: | + . venv-gusto/bin/activate + firedrake-run-split-tests 6 2 "$EXTRA_PYTEST_ARGS" "--log-file=gusto6_{#}.log" + timeout-minutes: 10 + - name: Upload pytest log files uses: actions/upload-artifact@v4 if: success() || steps.install-two.conclusion == 'success' From 0222e55a2f143d1f016e73dfab032d2f0fb3886a Mon Sep 17 00:00:00 2001 From: atb1995 Date: Fri, 16 May 2025 15:07:36 +0100 Subject: [PATCH 17/43] Trying longer time.. --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d4cd1c8d1..0c7426dea 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -123,7 +123,7 @@ jobs: run: | . venv-gusto/bin/activate firedrake-run-split-tests 6 2 "$EXTRA_PYTEST_ARGS" "--log-file=gusto6_{#}.log" - timeout-minutes: 10 + timeout-minutes: 30 - name: Upload pytest log files uses: actions/upload-artifact@v4 From 801f7830b9a6f57100b8bf84edc1c6da4a3ba727 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Mon, 21 Jul 2025 11:10:23 +0100 Subject: [PATCH 18/43] Update using qmat --- .../deferred_correction.py | 20 ++++++++++--------- gusto/time_discretisation/parallel_dc.py | 19 ++++++++++++++---- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/gusto/time_discretisation/deferred_correction.py b/gusto/time_discretisation/deferred_correction.py index 9272d15ef..5c88e5ed0 100644 --- a/gusto/time_discretisation/deferred_correction.py +++ b/gusto/time_discretisation/deferred_correction.py @@ -579,20 +579,22 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, self.reduced = reduced self.dt = Constant(float(self.dt_coarse)/(self.M)) - # Use equidistant nodes - self.nodes = np.arange(0, (M+1)*float(self.dt), float(self.dt)) - if reduced: self.Q = [] for l in range(1, self.K+1): - integration_matrix = self.lagrange_integration_matrix(l+1) - integration_matrix = 0.5 * (l) * float(self.dt) * integration_matrix - self.Q.append(integration_matrix) + _, _, Q = genQCoeffs("Collocation", nNodes=l+1, + nodeType="EQUID", + quadType="LOBATTO", + form="N2N") + Q = l* float(self.dt) * Q + self.Q.append(Q) else: # Get integration weights - integration_matrix = self.lagrange_integration_matrix(self.K+1) - # Rescale integration matrix to be over [0, self.dt_coarse] rather than [-1, 1] - self.Q = 0.5 * (self.K) * float(self.dt) * integration_matrix + _, _, self.Q = genQCoeffs("Collocation", nNodes=K+1, + nodeType="EQUID", + quadType="LOBATTO", + form="N2N") + self.Q = self.K*float(self.dt)*self.Q # Set default linear and nonlinear solver options if none passed in if linear_solver_parameters is None: diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index 0581d7c6b..12ef5ecf7 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -103,10 +103,12 @@ def apply(self, x_out, x_in): # Send base guess to k+1 correction self.comm.send(self.Unodes[m+1], dest=self.kval+1, tag=100+m+1) + self.comm.send(self.source_Uk[m+1], dest=self.kval+1, tag=200+m+1) else: for m in range(1, self.kval + 1): # Recieve and evaluate the stencil of guesses we need to correct self.comm.recv(self.Unodes[m], source=self.kval-1, tag=100+m) + self.comm.recv(self.source_Uk[m], source=self.kval-1, tag=200+m) self.Uin.assign(self.Unodes[m]) for evaluate in self.evaluate_source: evaluate(self.Uin, self.base.dt, x_out=self.source_in) @@ -121,8 +123,11 @@ def apply(self, x_out, x_in): self.Ukp1_m.assign(self.Unodes1[m]) self.Uk_mp1.assign(self.Unodes[m+1]) self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk_m) + evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1_m) + # self.source_Ukp1_m.assign(self.source_Ukp1[m]) + # self.source_Uk_m.assign(self.source_Uk[m]) self.U_DC.assign(self.Unodes[m+1]) # Compute @@ -142,10 +147,12 @@ def apply(self, x_out, x_in): # Send our updated value to next communicator if self.kval < self.K: self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) + self.comm.send(self.source_Ukp1[m+1], dest=self.kval+1, tag=200+m+1) for m in range(self.kval, self.M): # Recieve the guess we need to correct and evaluate the rhs self.comm.recv(self.Unodes[m+1], source=self.kval-1, tag=100+m+1) + self.comm.recv(self.source_Uk[m+1], source=self.kval-1, tag=200+m+1) self.Uin.assign(self.Unodes[m+1]) for evaluate in self.evaluate_source: evaluate(self.Uin, self.base.dt, x_out=self.source_in) @@ -160,8 +167,11 @@ def apply(self, x_out, x_in): self.Ukp1_m.assign(self.Unodes1[m]) self.Uk_mp1.assign(self.Unodes[m+1]) self.Uk_m.assign(self.Unodes[m]) - self.source_Ukp1_m.assign(self.source_Ukp1[m]) - self.source_Uk_m.assign(self.source_Uk[m]) + # self.source_Ukp1_m.assign(self.source_Ukp1[m]) + # self.source_Uk_m.assign(self.source_Uk[m]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk_m) + evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1_m) self.U_DC.assign(self.Unodes[m+1]) # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) @@ -181,6 +191,7 @@ def apply(self, x_out, x_in): # Send our updated value to next communicator if self.kval < self.K: self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) + self.comm.send(self.source_Ukp1[m+1], dest=self.kval+1, tag=200+m+1) if (self.kval == self.K): # Broadcast the final result to all other ranks From 90d175a848a02d8f640899b14d1ae6693405c2db Mon Sep 17 00:00:00 2001 From: atb1995 Date: Mon, 21 Jul 2025 11:13:22 +0100 Subject: [PATCH 19/43] Tidy up --- .../deferred_correction.py | 78 ------------------- gusto/time_discretisation/parallel_dc.py | 14 +--- 2 files changed, 4 insertions(+), 88 deletions(-) diff --git a/gusto/time_discretisation/deferred_correction.py b/gusto/time_discretisation/deferred_correction.py index 5c88e5ed0..32ca42cd7 100644 --- a/gusto/time_discretisation/deferred_correction.py +++ b/gusto/time_discretisation/deferred_correction.py @@ -673,84 +673,6 @@ def setup(self, equation, apply_bcs=True, *active_labels): def nlevels(self): return 1 - def equidistant_nodes(self, M): - """ - Returns a grid of M equispaced nodes from -1 to 1 - """ - grid = np.linspace(-1., 1., M) - return grid - - def lagrange_polynomial(self, index, nodes): - """ - Returns the coefficients of the Lagrange polynomial l_m with m=index - """ - - M = len(nodes) - - # c is the denominator - c = 1. - for k in range(M): - if k != index: - c *= (nodes[index] - nodes[k]) - - coeffs = np.zeros(M) - coeffs[0] = 1. - m = 0 - - for k in range(M): - if k != index: - m += 1 - d1 = np.zeros(M) - d2 = np.zeros(M) - - d1 = (-1.)*nodes[k] * coeffs - d2[1:m+1] = coeffs[0:m] - - coeffs = d1+d2 - return coeffs / c - - def integrate_polynomial(self, p): - """ - Given a list of coefficients of a polynomial p, - this returns those of the integral of p - """ - integral_coeffs = np.zeros(len(p)+1) - - for n, pn in enumerate(p): - integral_coeffs[n+1] = 1/(n+1) * pn - - return integral_coeffs - - def evaluate(self, p, a, b): - """ - Given a list of coefficients of a polynomial p, this returns the value of p(b)-p(a) - """ - value = 0. - for n, pn in enumerate(p): - value += pn * (b**n - a**n) - - return value - - def lagrange_integration_matrix(self, M): - """ - Returns the integration matrix for the Lagrange polynomial of order M - """ - - # Set up equidistant nodes and initialise matrix to zero - nodes = self.equidistant_nodes(M) - L = len(nodes) - int_matrix = np.zeros((L, L)) - - # Fill in matrix values - for index in range(L): - coeff_p = self.lagrange_polynomial(index, nodes) - int_coeff = self.integrate_polynomial(coeff_p) - - for n in range(L-1): - int_matrix[n+1, index] = self.evaluate(int_coeff, nodes[n], nodes[n+1]) - - return int_matrix - def compute_quad(self, Q, fUnodes, m): """ Computes integration of F(y) on quadrature nodes diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index 12ef5ecf7..538302b30 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -123,11 +123,8 @@ def apply(self, x_out, x_in): self.Ukp1_m.assign(self.Unodes1[m]) self.Uk_mp1.assign(self.Unodes[m+1]) self.Uk_m.assign(self.Unodes[m]) - for evaluate in self.evaluate_source: - evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk_m) - evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1_m) - # self.source_Ukp1_m.assign(self.source_Ukp1[m]) - # self.source_Uk_m.assign(self.source_Uk[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) self.U_DC.assign(self.Unodes[m+1]) # Compute @@ -167,11 +164,8 @@ def apply(self, x_out, x_in): self.Ukp1_m.assign(self.Unodes1[m]) self.Uk_mp1.assign(self.Unodes[m+1]) self.Uk_m.assign(self.Unodes[m]) - # self.source_Ukp1_m.assign(self.source_Ukp1[m]) - # self.source_Uk_m.assign(self.source_Uk[m]) - for evaluate in self.evaluate_source: - evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk_m) - evaluate(self.Unodes1[m], self.base.dt, x_out=self.source_Ukp1_m) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) self.U_DC.assign(self.Unodes[m+1]) # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) From 97e81c38723589cfc246b751cced19a628479380 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Mon, 21 Jul 2025 11:17:09 +0100 Subject: [PATCH 20/43] lint fix --- .../deferred_correction.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/gusto/time_discretisation/deferred_correction.py b/gusto/time_discretisation/deferred_correction.py index 32ca42cd7..1f515b6f3 100644 --- a/gusto/time_discretisation/deferred_correction.py +++ b/gusto/time_discretisation/deferred_correction.py @@ -581,20 +581,26 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, if reduced: self.Q = [] - for l in range(1, self.K+1): - _, _, Q = genQCoeffs("Collocation", nNodes=l+1, - nodeType="EQUID", - quadType="LOBATTO", - form="N2N") - Q = l* float(self.dt) * Q + for l in range(1, self.K + 1): + _, _, Q = genQCoeffs( + "Collocation", + nNodes=l + 1, + nodeType="EQUID", + quadType="LOBATTO", + form="N2N" + ) + Q = l * float(self.dt) * Q self.Q.append(Q) else: # Get integration weights - _, _, self.Q = genQCoeffs("Collocation", nNodes=K+1, - nodeType="EQUID", - quadType="LOBATTO", - form="N2N") - self.Q = self.K*float(self.dt)*self.Q + _, _, self.Q = genQCoeffs( + "Collocation", + nNodes=self.K + 1, + nodeType="EQUID", + quadType="LOBATTO", + form="N2N" + ) + self.Q = self.K * float(self.dt) * self.Q # Set default linear and nonlinear solver options if none passed in if linear_solver_parameters is None: From ebfc09f253901714aae04d8d9e715213a2ba4afe Mon Sep 17 00:00:00 2001 From: atb1995 Date: Mon, 21 Jul 2025 13:03:47 +0100 Subject: [PATCH 21/43] Final update --- integration-tests/model/test_parallel_dc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 70bf448c0..fd99bd5f1 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -76,7 +76,7 @@ def test_parallel_dc(tmpdir, scheme): eqn.label_terms(lambda t: t.has_label(transport), explicit) base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, - qdelta_exp, final_update=False, initial_guess="copy", communicator=ensemble) + qdelta_exp, final_update=True, initial_guess="copy", communicator=ensemble) elif scheme == "IMEX_RIDC_R(3)": M = k*(k+1)//2 + 4 eqn = ContinuityEquation(domain, V, "f") From c5bb96bc8297a380d1c459533ba6470606e14192 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Mon, 21 Jul 2025 14:57:36 +0100 Subject: [PATCH 22/43] small changes to test --- integration-tests/model/test_parallel_dc.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index fd99bd5f1..f9e83860d 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -1,8 +1,8 @@ """ This runs a simple transport test on the sphere using the parallel DC time discretisations to test whether the errors are within tolerance. The test is run for the following schemes: -- IMEX_SDC_R(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type (3rd order scheme) using -- IMEX_RIDC_R(3) - IMEX RIDC with 4 quadrature nodes of equidistant type, reduced stencils (3rd order scheme). +- IMEX_SDC(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type (3rd order scheme) using +- IMEX_RIDC(3) - IMEX RIDC with 4 quadrature nodes of equidistant type, reduced stencils (3rd order scheme). """ from firedrake import (norm, Ensemble, COMM_WORLD, SpatialCoordinate, @@ -19,16 +19,16 @@ def run(timestepper, tmax, f_end): return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[3, 6]) +@pytest.mark.parallel(nprocs=[6]) @pytest.mark.parametrize( - "scheme", ["IMEX_SDC_R(3,3)", "IMEX_RIDC_R(3)"]) + "scheme", ["IMEX_SDC(3,3)", "IMEX_RIDC(3)"]) def test_parallel_dc(tmpdir, scheme): - if scheme == "IMEX_SDC_R(3,3)": + if scheme == "IMEX_SDC(3,3)": M = 3 k = M ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) - elif scheme == "IMEX_RIDC_R(3)": + elif scheme == "IMEX_RIDC(3)": k = 2 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(k+1)) @@ -65,7 +65,7 @@ def test_parallel_dc(tmpdir, scheme): domain = domain V = domain.spaces("DG") - if scheme == "IMEX_SDC_R(3,3)": + if scheme == "IMEX_SDC(3,3)": quad_type = "RADAU-RIGHT" node_type = "LEGENDRE" qdelta_imp = "MIN-SR-FLEX" @@ -77,7 +77,7 @@ def test_parallel_dc(tmpdir, scheme): base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, qdelta_exp, final_update=True, initial_guess="copy", communicator=ensemble) - elif scheme == "IMEX_RIDC_R(3)": + elif scheme == "IMEX_RIDC(3)": M = k*(k+1)//2 + 4 eqn = ContinuityEquation(domain, V, "f") eqn = split_continuity_form(eqn) From 727f1b53ce21323cc9d1d6040c1808bfb36abdaf Mon Sep 17 00:00:00 2001 From: atb1995 Date: Wed, 30 Jul 2025 13:43:20 +0100 Subject: [PATCH 23/43] printing error in integration test --- integration-tests/model/test_parallel_dc.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index f9e83860d..68b05b2f7 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -64,25 +64,21 @@ def test_parallel_dc(tmpdir, scheme): domain = domain V = domain.spaces("DG") + eqn = ContinuityEquation(domain, V, "f") + eqn = split_continuity_form(eqn) + eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) + eqn.label_terms(lambda t: t.has_label(transport), explicit) if scheme == "IMEX_SDC(3,3)": quad_type = "RADAU-RIGHT" node_type = "LEGENDRE" qdelta_imp = "MIN-SR-FLEX" qdelta_exp = "MIN-SR-NS" - eqn = ContinuityEquation(domain, V, "f") - eqn = split_continuity_form(eqn) - eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) - eqn.label_terms(lambda t: t.has_label(transport), explicit) base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, qdelta_exp, final_update=True, initial_guess="copy", communicator=ensemble) elif scheme == "IMEX_RIDC(3)": M = k*(k+1)//2 + 4 - eqn = ContinuityEquation(domain, V, "f") - eqn = split_continuity_form(eqn) - eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) - eqn.label_terms(lambda t: t.has_label(transport), explicit) base_scheme = IMEX_Euler(domain) time_scheme = Parallel_RIDC(base_scheme, domain, M, k, communicator=ensemble) @@ -95,4 +91,5 @@ def test_parallel_dc(tmpdir, scheme): timestepper.fields("f").interpolate(f_init) timestepper.fields("u").project(uexpr) - parallel_assert(run(timestepper, tmax, f_end) < tol, "Error too large") + error = run(timestepper, tmax, f_end) + parallel_assert(error < tol, f"Error too large, Error: {error}, tol: {tol}") From 97afce6ec150520c9b88c8973a85d0287da03ce9 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Thu, 31 Jul 2025 09:45:48 +0100 Subject: [PATCH 24/43] Switch to 4 and 2, still seems to pass locally.. --- .github/workflows/build.yml | 9 +-------- integration-tests/model/test_parallel_dc.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0c7426dea..e7d5a153b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,7 +33,7 @@ jobs: GUSTO_PARALLEL_LOG: CONSOLE PYOP2_CFLAGS: -O0 # Make sure that tests with >4 processes are not silently skipped - PYTEST_MPI_MAX_NPROCS: 6 + PYTEST_MPI_MAX_NPROCS: 4 EXTRA_PYTEST_ARGS: --durations=100 --timeout=3600 --timeout-method=thread -o faulthandler_timeout=3660 --show-capture=no --verbose gusto-repo/unit-tests gusto-repo/integration-tests gusto-repo/examples steps: - name: Fix HOME @@ -118,13 +118,6 @@ jobs: firedrake-run-split-tests 4 3 "$EXTRA_PYTEST_ARGS" "--log-file=gusto4_{#}.log" timeout-minutes: 10 - - name: Run tests (nprocs = 6) - if: success() || steps.install-two.conclusion == 'success' - run: | - . venv-gusto/bin/activate - firedrake-run-split-tests 6 2 "$EXTRA_PYTEST_ARGS" "--log-file=gusto6_{#}.log" - timeout-minutes: 30 - - name: Upload pytest log files uses: actions/upload-artifact@v4 if: success() || steps.install-two.conclusion == 'success' diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 68b05b2f7..7d877896d 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -19,17 +19,17 @@ def run(timestepper, tmax, f_end): return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[6]) +@pytest.mark.parallel(nprocs=[2,4]) @pytest.mark.parametrize( - "scheme", ["IMEX_SDC(3,3)", "IMEX_RIDC(3)"]) + "scheme", ["IMEX_SDC(2,2)", "IMEX_RIDC(2)"]) def test_parallel_dc(tmpdir, scheme): - if scheme == "IMEX_SDC(3,3)": - M = 3 + if scheme == "IMEX_SDC(2,2)": + M = 2 k = M ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) - elif scheme == "IMEX_RIDC(3)": - k = 2 + elif scheme == "IMEX_RIDC(2)": + k = 1 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(k+1)) # Get the tracer setup @@ -69,7 +69,7 @@ def test_parallel_dc(tmpdir, scheme): eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) eqn.label_terms(lambda t: t.has_label(transport), explicit) - if scheme == "IMEX_SDC(3,3)": + if scheme == "IMEX_SDC(2,2)": quad_type = "RADAU-RIGHT" node_type = "LEGENDRE" qdelta_imp = "MIN-SR-FLEX" @@ -77,7 +77,7 @@ def test_parallel_dc(tmpdir, scheme): base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, qdelta_exp, final_update=True, initial_guess="copy", communicator=ensemble) - elif scheme == "IMEX_RIDC(3)": + elif scheme == "IMEX_RIDC(2)": M = k*(k+1)//2 + 4 base_scheme = IMEX_Euler(domain) time_scheme = Parallel_RIDC(base_scheme, domain, M, k, communicator=ensemble) From 0d9621f0f3c146fa4e42674ce84c834541e4d7d8 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Thu, 31 Jul 2025 10:29:11 +0100 Subject: [PATCH 25/43] make sure IO is correct for test and use base for SDC --- integration-tests/model/test_parallel_dc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 7d877896d..0396c9cb4 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -19,7 +19,7 @@ def run(timestepper, tmax, f_end): return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[2,4]) +@pytest.mark.parallel(nprocs=[2, 4]) @pytest.mark.parametrize( "scheme", ["IMEX_SDC(2,2)", "IMEX_RIDC(2)"]) def test_parallel_dc(tmpdir, scheme): @@ -76,9 +76,9 @@ def test_parallel_dc(tmpdir, scheme): qdelta_exp = "MIN-SR-NS" base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, - qdelta_exp, final_update=True, initial_guess="copy", communicator=ensemble) + qdelta_exp, final_update=True, initial_guess="base", communicator=ensemble) elif scheme == "IMEX_RIDC(2)": - M = k*(k+1)//2 + 4 + M = 5 base_scheme = IMEX_Euler(domain) time_scheme = Parallel_RIDC(base_scheme, domain, M, k, communicator=ensemble) From 575a7e2fb387030508e6da8330cf382280b0dfa5 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Thu, 31 Jul 2025 12:02:26 +0100 Subject: [PATCH 26/43] Make sure non Picard integral on advection term --- integration-tests/model/test_parallel_dc.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 0396c9cb4..9e8bbb775 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -26,7 +26,7 @@ def test_parallel_dc(tmpdir, scheme): if scheme == "IMEX_SDC(2,2)": M = 2 - k = M + k = 3 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) elif scheme == "IMEX_RIDC(2)": k = 1 @@ -65,11 +65,11 @@ def test_parallel_dc(tmpdir, scheme): domain = domain V = domain.spaces("DG") eqn = ContinuityEquation(domain, V, "f") - eqn = split_continuity_form(eqn) - eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) - eqn.label_terms(lambda t: t.has_label(transport), explicit) if scheme == "IMEX_SDC(2,2)": + eqn = ContinuityEquation(domain, V, "f") + eqn.label_terms(lambda t: not t.has_label(time_derivative), implicit) + quad_type = "RADAU-RIGHT" node_type = "LEGENDRE" qdelta_imp = "MIN-SR-FLEX" @@ -78,6 +78,10 @@ def test_parallel_dc(tmpdir, scheme): time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, qdelta_exp, final_update=True, initial_guess="base", communicator=ensemble) elif scheme == "IMEX_RIDC(2)": + eqn = split_continuity_form(eqn) + eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) + eqn.label_terms(lambda t: t.has_label(transport), explicit) + M = 5 base_scheme = IMEX_Euler(domain) time_scheme = Parallel_RIDC(base_scheme, domain, M, k, communicator=ensemble) From a5420200c33e591ba9ab4d38e24cde2e4b8d1318 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 5 Aug 2025 10:08:03 +0100 Subject: [PATCH 27/43] Extend timeout and alter test set up --- .github/workflows/build.yml | 2 +- integration-tests/model/test_parallel_dc.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e7d5a153b..8707f0720 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -116,7 +116,7 @@ jobs: run: | . venv-gusto/bin/activate firedrake-run-split-tests 4 3 "$EXTRA_PYTEST_ARGS" "--log-file=gusto4_{#}.log" - timeout-minutes: 10 + timeout-minutes: 15 - name: Upload pytest log files uses: actions/upload-artifact@v4 diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 9e8bbb775..e478a3528 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -26,7 +26,7 @@ def test_parallel_dc(tmpdir, scheme): if scheme == "IMEX_SDC(2,2)": M = 2 - k = 3 + k = 2 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) elif scheme == "IMEX_RIDC(2)": k = 1 @@ -67,7 +67,6 @@ def test_parallel_dc(tmpdir, scheme): eqn = ContinuityEquation(domain, V, "f") if scheme == "IMEX_SDC(2,2)": - eqn = ContinuityEquation(domain, V, "f") eqn.label_terms(lambda t: not t.has_label(time_derivative), implicit) quad_type = "RADAU-RIGHT" @@ -76,7 +75,7 @@ def test_parallel_dc(tmpdir, scheme): qdelta_exp = "MIN-SR-NS" base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, - qdelta_exp, final_update=True, initial_guess="base", communicator=ensemble) + qdelta_exp, final_update=False, initial_guess="base", communicator=ensemble) elif scheme == "IMEX_RIDC(2)": eqn = split_continuity_form(eqn) eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) From dff1198f094f25de5b20ba0df81a2142c2ef25fe Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 5 Aug 2025 10:31:37 +0100 Subject: [PATCH 28/43] minor change --- integration-tests/model/test_parallel_dc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index e478a3528..392ff68c6 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -25,8 +25,8 @@ def run(timestepper, tmax, f_end): def test_parallel_dc(tmpdir, scheme): if scheme == "IMEX_SDC(2,2)": - M = 2 - k = 2 + M = 4 + k = 4 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) elif scheme == "IMEX_RIDC(2)": k = 1 @@ -75,7 +75,7 @@ def test_parallel_dc(tmpdir, scheme): qdelta_exp = "MIN-SR-NS" base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, - qdelta_exp, final_update=False, initial_guess="base", communicator=ensemble) + qdelta_exp, final_update=True, initial_guess="base", communicator=ensemble) elif scheme == "IMEX_RIDC(2)": eqn = split_continuity_form(eqn) eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) From 5c795096b3989df906fa051e1e8ad1189ebf8c4b Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 5 Aug 2025 11:44:41 +0100 Subject: [PATCH 29/43] Alter of test for only 4 processors --- .github/workflows/build.yml | 2 +- integration-tests/model/test_parallel_dc.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8707f0720..14f104097 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -116,7 +116,7 @@ jobs: run: | . venv-gusto/bin/activate firedrake-run-split-tests 4 3 "$EXTRA_PYTEST_ARGS" "--log-file=gusto4_{#}.log" - timeout-minutes: 15 + timeout-minutes: 20 - name: Upload pytest log files uses: actions/upload-artifact@v4 diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 392ff68c6..a35afbf21 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -19,9 +19,9 @@ def run(timestepper, tmax, f_end): return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[2, 4]) +@pytest.mark.parallel(nprocs=[4]) @pytest.mark.parametrize( - "scheme", ["IMEX_SDC(2,2)", "IMEX_RIDC(2)"]) + "scheme", ["IMEX_SDC(4,4)", "IMEX_RIDC(2)"]) def test_parallel_dc(tmpdir, scheme): if scheme == "IMEX_SDC(2,2)": From e0d843f889e08f1565b2ed47ed4ddb15176ed8a3 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 5 Aug 2025 12:47:32 +0100 Subject: [PATCH 30/43] Fix errors in unit test.. --- integration-tests/model/test_parallel_dc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index a35afbf21..7f11befff 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -24,7 +24,7 @@ def run(timestepper, tmax, f_end): "scheme", ["IMEX_SDC(4,4)", "IMEX_RIDC(2)"]) def test_parallel_dc(tmpdir, scheme): - if scheme == "IMEX_SDC(2,2)": + if scheme == "IMEX_SDC(4,4)": M = 4 k = 4 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) @@ -66,7 +66,7 @@ def test_parallel_dc(tmpdir, scheme): V = domain.spaces("DG") eqn = ContinuityEquation(domain, V, "f") - if scheme == "IMEX_SDC(2,2)": + if scheme == "IMEX_SDC(4,4)": eqn.label_terms(lambda t: not t.has_label(time_derivative), implicit) quad_type = "RADAU-RIGHT" From b2cceb10a2b6d72b7c37fdf78985291360951c12 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Mon, 11 Aug 2025 09:50:01 +0100 Subject: [PATCH 31/43] Fix to scaling of Qdelta matrices for diagonal SDC --- .../deferred_correction.py | 19 +++++++++++-------- gusto/time_discretisation/parallel_dc.py | 2 +- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/gusto/time_discretisation/deferred_correction.py b/gusto/time_discretisation/deferred_correction.py index 1f515b6f3..aad9675cf 100644 --- a/gusto/time_discretisation/deferred_correction.py +++ b/gusto/time_discretisation/deferred_correction.py @@ -155,12 +155,6 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im quadType=quad_type, form=formulation) - # Rescale to be over [0,dt] rather than [0,1] - self.nodes = float(self.dt_coarse)*self.nodes - - self.dtau = np.diff(np.append(0, self.nodes)) - self.Q = float(self.dt_coarse)*self.Q - self.Qfin = float(self.dt_coarse)*self.weights self.qdelta_imp_type = qdelta_imp self.formulation = formulation self.node_type = node_type @@ -168,9 +162,18 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im # Get Q_delta matrices self.Qdelta_imp = genQDeltaCoeffs(qdelta_imp, form=formulation, - nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) + nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type, k=1) self.Qdelta_exp = genQDeltaCoeffs(qdelta_exp, form=formulation, nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) + + # Rescale to be over [0,dt] rather than [0,1] + self.nodes = float(self.dt_coarse)*self.nodes + self.dtau = np.diff(np.append(0, self.nodes)) + self.Q = float(self.dt_coarse)*self.Q + self.Qfin = float(self.dt_coarse)*self.weights + self.Qdelta_imp = float(self.dt_coarse)*self.Qdelta_imp + self.Qdelta_exp = float(self.dt_coarse)*self.Qdelta_exp + # Set default linear and nonlinear solver options if none passed in if linear_solver_parameters is None: self.linear_solver_parameters = {'snes_type': 'ksponly', @@ -459,7 +462,7 @@ def apply(self, x_out, x_in): if self.qdelta_imp_type == "MIN-SR-FLEX": # Recompute Implicit Q_delta matrix for each iteration k - self.Qdelta_imp = genQDeltaCoeffs( + self.Qdelta_imp = float(self.dt_coarse)*genQDeltaCoeffs( self.qdelta_imp_type, form=self.formulation, nodes=self.nodes, diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index 538302b30..8f2305173 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -298,7 +298,7 @@ def apply(self, x_out, x_in): if self.qdelta_imp_type == "MIN-SR-FLEX": # Recompute Implicit Q_delta matrix for each iteration k - self.Qdelta_imp = genQDeltaCoeffs( + self.Qdelta_imp = float(self.dt_coarse)*genQDeltaCoeffs( self.qdelta_imp_type, form=self.formulation, nodes=self.nodes, From 5cea6d824f7110209c537c6e51527b7d51d6ea74 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 12 Aug 2025 08:58:26 +0100 Subject: [PATCH 32/43] Updated integration test again.. hopefully will not time out --- integration-tests/model/test_parallel_dc.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 7f11befff..88835a3e4 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -1,8 +1,8 @@ """ This runs a simple transport test on the sphere using the parallel DC time discretisations to test whether the errors are within tolerance. The test is run for the following schemes: -- IMEX_SDC(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type (3rd order scheme) using -- IMEX_RIDC(3) - IMEX RIDC with 4 quadrature nodes of equidistant type, reduced stencils (3rd order scheme). +- IMEX_SDC(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type and 2 correction sweeps (2nd order scheme) +- IMEX_RIDC(2) - IMEX RIDC with 3 quadrature nodes of equidistant type, 1 correction sweep, reduced stencils (2nd order scheme). """ from firedrake import (norm, Ensemble, COMM_WORLD, SpatialCoordinate, @@ -19,14 +19,14 @@ def run(timestepper, tmax, f_end): return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[4]) +@pytest.mark.parallel(nprocs=[2, 4]) @pytest.mark.parametrize( - "scheme", ["IMEX_SDC(4,4)", "IMEX_RIDC(2)"]) + "scheme", ["IMEX_SDC(2,2)", "IMEX_RIDC(2)"]) def test_parallel_dc(tmpdir, scheme): - if scheme == "IMEX_SDC(4,4)": - M = 4 - k = 4 + if scheme == "IMEX_SDC(2,2)": + M = 2 + k = 2 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) elif scheme == "IMEX_RIDC(2)": k = 1 @@ -66,7 +66,7 @@ def test_parallel_dc(tmpdir, scheme): V = domain.spaces("DG") eqn = ContinuityEquation(domain, V, "f") - if scheme == "IMEX_SDC(4,4)": + if scheme == "IMEX_SDC(2,2)": eqn.label_terms(lambda t: not t.has_label(time_derivative), implicit) quad_type = "RADAU-RIGHT" From e14513ad3a557d8dbdfb8e6a92a10e76434716a2 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 12 Aug 2025 12:04:58 +0100 Subject: [PATCH 33/43] Only test 4 processors --- integration-tests/model/test_parallel_dc.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 88835a3e4..3730bb232 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -19,13 +19,13 @@ def run(timestepper, tmax, f_end): return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[2, 4]) +@pytest.mark.parallel(nprocs=[4]) @pytest.mark.parametrize( - "scheme", ["IMEX_SDC(2,2)", "IMEX_RIDC(2)"]) + "scheme", ["IMEX_SDC(4,2)", "IMEX_RIDC(2)"]) def test_parallel_dc(tmpdir, scheme): - if scheme == "IMEX_SDC(2,2)": - M = 2 + if scheme == "IMEX_SDC(4,2)": + M = 4 k = 2 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) elif scheme == "IMEX_RIDC(2)": @@ -66,7 +66,7 @@ def test_parallel_dc(tmpdir, scheme): V = domain.spaces("DG") eqn = ContinuityEquation(domain, V, "f") - if scheme == "IMEX_SDC(2,2)": + if scheme == "IMEX_SDC(4,2)": eqn.label_terms(lambda t: not t.has_label(time_derivative), implicit) quad_type = "RADAU-RIGHT" From d64c8386d4b379ea2b01a3f593bfa9f06b40a768 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 12 Aug 2025 13:49:07 +0100 Subject: [PATCH 34/43] Test alterations.. --- .github/workflows/build.yml | 2 +- integration-tests/model/test_parallel_dc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 14f104097..8979e488a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -116,7 +116,7 @@ jobs: run: | . venv-gusto/bin/activate firedrake-run-split-tests 4 3 "$EXTRA_PYTEST_ARGS" "--log-file=gusto4_{#}.log" - timeout-minutes: 20 + timeout-minutes: 30 - name: Upload pytest log files uses: actions/upload-artifact@v4 diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 3730bb232..1d6daef12 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -21,7 +21,7 @@ def run(timestepper, tmax, f_end): @pytest.mark.parallel(nprocs=[4]) @pytest.mark.parametrize( - "scheme", ["IMEX_SDC(4,2)", "IMEX_RIDC(2)"]) + "scheme", ["IMEX_RIDC(2)", "IMEX_SDC(4,2)"]) def test_parallel_dc(tmpdir, scheme): if scheme == "IMEX_SDC(4,2)": From 1d02752b44890f7c873bfe3c09f419181a3a2960 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Tue, 12 Aug 2025 14:56:24 +0100 Subject: [PATCH 35/43] Alter to work on 2 cores so see if it works.. --- .github/workflows/build.yml | 2 +- integration-tests/model/test_parallel_dc.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8979e488a..e7d5a153b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -116,7 +116,7 @@ jobs: run: | . venv-gusto/bin/activate firedrake-run-split-tests 4 3 "$EXTRA_PYTEST_ARGS" "--log-file=gusto4_{#}.log" - timeout-minutes: 30 + timeout-minutes: 10 - name: Upload pytest log files uses: actions/upload-artifact@v4 diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 1d6daef12..5ada1792d 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -19,13 +19,13 @@ def run(timestepper, tmax, f_end): return norm(timestepper.fields("f") - f_end) / norm(f_end) -@pytest.mark.parallel(nprocs=[4]) +@pytest.mark.parallel(nprocs=[2]) @pytest.mark.parametrize( - "scheme", ["IMEX_RIDC(2)", "IMEX_SDC(4,2)"]) + "scheme", ["IMEX_RIDC(2)", "IMEX_SDC(2,2)"]) def test_parallel_dc(tmpdir, scheme): - if scheme == "IMEX_SDC(4,2)": - M = 4 + if scheme == "IMEX_SDC(2,2)": + M = 2 k = 2 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) elif scheme == "IMEX_RIDC(2)": @@ -66,7 +66,7 @@ def test_parallel_dc(tmpdir, scheme): V = domain.spaces("DG") eqn = ContinuityEquation(domain, V, "f") - if scheme == "IMEX_SDC(4,2)": + if scheme == "IMEX_SDC(2,2)": eqn.label_terms(lambda t: not t.has_label(time_derivative), implicit) quad_type = "RADAU-RIGHT" From 0ca07a41970b23b96c00bef2d4b30c97c34c6592 Mon Sep 17 00:00:00 2001 From: Alex Brown <81297297+atb1995@users.noreply.github.com> Date: Thu, 21 Aug 2025 09:09:17 +0100 Subject: [PATCH 36/43] Update gusto/time_discretisation/parallel_dc.py Co-authored-by: Thomas Bendall --- gusto/time_discretisation/parallel_dc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index 8f2305173..ea2467578 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -1,4 +1,4 @@ -u""" +""" Objects for discretising time derivatives using time-parallel Deferred Correction Methods. From fc67e8f9164c97e8c12bfb927d1fa3be4619563b Mon Sep 17 00:00:00 2001 From: atb1995 Date: Thu, 11 Sep 2025 16:47:49 +0100 Subject: [PATCH 37/43] Build in flush option --- gusto/time_discretisation/parallel_dc.py | 68 +++++++++++++++------ integration-tests/model/test_parallel_dc.py | 7 ++- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index ea2467578..c201a65b5 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -23,7 +23,7 @@ class Parallel_RIDC(RIDC): """Class for Parallel Revisionist Integral Deferred Correction schemes.""" - def __init__(self, base_scheme, domain, M, K, field_name=None, + def __init__(self, base_scheme, domain, M, K, J, output_freq, flush_freq = None, field_name=None, linear_solver_parameters=None, nonlinear_solver_parameters=None, limiter=None, options=None, communicator=None): """ @@ -35,6 +35,9 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, mesh and the compatible function spaces. M (int): Number of subintervals K (int): Max number of correction interations + J (int): Number of intervals + output_freq (int): Frequency at which output is done + flush_freq (int): Frequency at which to flush the pipeline field_name (str, optional): name of the field to be evolved. Defaults to None. linear_solver_parameters (dict, optional): dictionary of parameters to @@ -51,6 +54,24 @@ def __init__(self, base_scheme, domain, M, K, field_name=None, linear_solver_parameters, nonlinear_solver_parameters, limiter, options, reduced=True) self.comm = communicator + self.TAG_EXCHANGE_FIELD = 11 # Tag for sending nodal fields (Firedrake Functions) + self.TAG_EXCHANGE_SOURCE = self.TAG_EXCHANGE_FIELD + J # Tag for sending nodal source fields (Firedrake Functions) + self.TAG_FLUSH_PIPE = self.TAG_EXCHANGE_SOURCE + J # Tag for flushing pipe and restarting + self.TAG_FINAL_OUT = self.TAG_FLUSH_PIPE + J # Tag for the final broadcast and output + self.TAG_END_INTERVAL = self.TAG_FINAL_OUT + J # Tag for telling the rank above you that you have ended interval j + + + if flush_freq is None: + self.flush_freq = 1 + else: + self.flush_freq = flush_freq + + self.J = J + self.step = 1 + self.output_freq = output_freq + + if self.output_freq%self.flush_freq != 0: + raise Warning("Output on all parallel in time ranks will not be the same!") # Checks for parallel RIDC if self.comm is None: @@ -102,13 +123,13 @@ def apply(self, x_out, x_in): evaluate(self.Unodes[m+1], self.base.dt, x_out=self.source_Uk[m+1]) # Send base guess to k+1 correction - self.comm.send(self.Unodes[m+1], dest=self.kval+1, tag=100+m+1) - self.comm.send(self.source_Uk[m+1], dest=self.kval+1, tag=200+m+1) + self.comm.send(self.Unodes[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.send(self.source_Uk[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_SOURCE + self.step) else: for m in range(1, self.kval + 1): # Recieve and evaluate the stencil of guesses we need to correct - self.comm.recv(self.Unodes[m], source=self.kval-1, tag=100+m) - self.comm.recv(self.source_Uk[m], source=self.kval-1, tag=200+m) + self.comm.recv(self.Unodes[m], source=self.kval-1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.recv(self.source_Uk[m], source=self.kval-1, tag=self.TAG_EXCHANGE_SOURCE + self.step) self.Uin.assign(self.Unodes[m]) for evaluate in self.evaluate_source: evaluate(self.Uin, self.base.dt, x_out=self.source_in) @@ -143,13 +164,13 @@ def apply(self, x_out, x_in): self.limiter.apply(self.Unodes1[m+1]) # Send our updated value to next communicator if self.kval < self.K: - self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) - self.comm.send(self.source_Ukp1[m+1], dest=self.kval+1, tag=200+m+1) + self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.send(self.source_Ukp1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_SOURCE + self.step) for m in range(self.kval, self.M): # Recieve the guess we need to correct and evaluate the rhs - self.comm.recv(self.Unodes[m+1], source=self.kval-1, tag=100+m+1) - self.comm.recv(self.source_Uk[m+1], source=self.kval-1, tag=200+m+1) + self.comm.recv(self.Unodes[m+1], source=self.kval-1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.recv(self.source_Uk[m+1], source=self.kval-1, tag=self.TAG_EXCHANGE_SOURCE + self.step) self.Uin.assign(self.Unodes[m+1]) for evaluate in self.evaluate_source: evaluate(self.Uin, self.base.dt, x_out=self.source_in) @@ -184,18 +205,25 @@ def apply(self, x_out, x_in): # Send our updated value to next communicator if self.kval < self.K: - self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=100+m+1) - self.comm.send(self.source_Ukp1[m+1], dest=self.kval+1, tag=200+m+1) - - if (self.kval == self.K): - # Broadcast the final result to all other ranks - x_out.assign(self.Unodes1[-1]) - for i in range(self.K): - # Send the final result to all other ranks - self.comm.send(x_out, dest=i, tag=200) + self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.send(self.source_Ukp1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_SOURCE + self.step) + + for m in range(self.M+1): + self.Unodes[m].assign(self.Unodes1[m]) + self.source_Uk[m].assign(self.source_Ukp1[m]) + + if (self.flush_freq > 0 and self.step % self.flush_freq == 0) or self.step == self.J: + # Flush the pipe to ensure all ranks have the same data + if (self.kval == self.K): + x_out.assign(self.Unodes[-1]) + for i in range(self.K): + self.comm.send(x_out, dest=i, tag=self.TAG_FLUSH_PIPE + self.step) + else: + self.comm.recv(x_out, source=self.K, tag=self.TAG_FLUSH_PIPE + self.step) else: - # Receive the final result from rank K - self.comm.recv(x_out, source=self.K, tag=200) + x_out.assign(self.Unodes[-1]) + + self.step += 1 class Parallel_SDC(SDC): diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 5ada1792d..3ace723aa 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -48,8 +48,8 @@ def test_parallel_dc(tmpdir, scheme): # to demonstrate that transport is working correctly dt = pi/3. * 0.02 - - output = OutputParameters(dirname=dirname, dump_vtus=False, dump_nc=True, dumpfreq=15) + dumpfreq = 15 + output = OutputParameters(dirname=dirname, dump_vtus=False, dump_nc=True, dumpfreq=dumpfreq) domain = Domain(mesh, dt, family="BDM", degree=1) io = IO(domain, output) @@ -82,8 +82,9 @@ def test_parallel_dc(tmpdir, scheme): eqn.label_terms(lambda t: t.has_label(transport), explicit) M = 5 + J = int(tmax/dt) base_scheme = IMEX_Euler(domain) - time_scheme = Parallel_RIDC(base_scheme, domain, M, k, communicator=ensemble) + time_scheme = Parallel_RIDC(base_scheme, domain, M, k, J, output_freq=dumpfreq, communicator=ensemble) transport_method = DGUpwind(eqn, 'f') From 6922291fa0e00fc7f31eea7e04e75414525be268 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Fri, 12 Sep 2025 11:32:57 +0100 Subject: [PATCH 38/43] Fix to split common form --- gusto/equations/common_forms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gusto/equations/common_forms.py b/gusto/equations/common_forms.py index b8caf2e5c..5f21fd521 100644 --- a/gusto/equations/common_forms.py +++ b/gusto/equations/common_forms.py @@ -367,7 +367,7 @@ def split_continuity_form(equation): u_trial = TrialFunctions(W)[u_idx] qbar = split(equation.X_ref)[idx] # Add linearisation to adv_term - linear_adv_term = linear_advection_form(test, qbar, u_trial) + linear_adv_term = linear_advection_form(test, qbar, u_trial, qbar, uadv) adv_term = linearisation(adv_term, linear_adv_term) # Add linearisation to div_term linear_div_term = transporting_velocity(qbar*test*div(u_trial)*dx, u_trial) From 665bb17c9b5e21b6e60d28525aace16a2bdc5073 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Fri, 12 Sep 2025 13:31:42 +0100 Subject: [PATCH 39/43] Lint fix --- gusto/time_discretisation/parallel_dc.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index c201a65b5..5cad75bab 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -23,7 +23,7 @@ class Parallel_RIDC(RIDC): """Class for Parallel Revisionist Integral Deferred Correction schemes.""" - def __init__(self, base_scheme, domain, M, K, J, output_freq, flush_freq = None, field_name=None, + def __init__(self, base_scheme, domain, M, K, J, output_freq, flush_freq=None, field_name=None, linear_solver_parameters=None, nonlinear_solver_parameters=None, limiter=None, options=None, communicator=None): """ @@ -37,7 +37,7 @@ def __init__(self, base_scheme, domain, M, K, J, output_freq, flush_freq = None, K (int): Max number of correction interations J (int): Number of intervals output_freq (int): Frequency at which output is done - flush_freq (int): Frequency at which to flush the pipeline + flush_freq (int): Frequency at which to flush the pipeline field_name (str, optional): name of the field to be evolved. Defaults to None. linear_solver_parameters (dict, optional): dictionary of parameters to @@ -54,12 +54,11 @@ def __init__(self, base_scheme, domain, M, K, J, output_freq, flush_freq = None, linear_solver_parameters, nonlinear_solver_parameters, limiter, options, reduced=True) self.comm = communicator - self.TAG_EXCHANGE_FIELD = 11 # Tag for sending nodal fields (Firedrake Functions) - self.TAG_EXCHANGE_SOURCE = self.TAG_EXCHANGE_FIELD + J # Tag for sending nodal source fields (Firedrake Functions) - self.TAG_FLUSH_PIPE = self.TAG_EXCHANGE_SOURCE + J # Tag for flushing pipe and restarting - self.TAG_FINAL_OUT = self.TAG_FLUSH_PIPE + J # Tag for the final broadcast and output - self.TAG_END_INTERVAL = self.TAG_FINAL_OUT + J # Tag for telling the rank above you that you have ended interval j - + self.TAG_EXCHANGE_FIELD = 11 # Tag for sending nodal fields (Firedrake Functions) + self.TAG_EXCHANGE_SOURCE = self.TAG_EXCHANGE_FIELD + J # Tag for sending nodal source fields (Firedrake Functions) + self.TAG_FLUSH_PIPE = self.TAG_EXCHANGE_SOURCE + J # Tag for flushing pipe and restarting + self.TAG_FINAL_OUT = self.TAG_FLUSH_PIPE + J # Tag for the final broadcast and output + self.TAG_END_INTERVAL = self.TAG_FINAL_OUT + J # Tag for telling the rank above you that you have ended interval j if flush_freq is None: self.flush_freq = 1 @@ -70,7 +69,7 @@ def __init__(self, base_scheme, domain, M, K, J, output_freq, flush_freq = None, self.step = 1 self.output_freq = output_freq - if self.output_freq%self.flush_freq != 0: + if self.output_freq % self.flush_freq != 0: raise Warning("Output on all parallel in time ranks will not be the same!") # Checks for parallel RIDC @@ -207,7 +206,7 @@ def apply(self, x_out, x_in): if self.kval < self.K: self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_FIELD + self.step) self.comm.send(self.source_Ukp1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_SOURCE + self.step) - + for m in range(self.M+1): self.Unodes[m].assign(self.Unodes1[m]) self.source_Uk[m].assign(self.source_Ukp1[m]) From cc93be9d212f11ed8b0bc97cf0ddb3d994521c51 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Fri, 12 Sep 2025 16:47:13 +0100 Subject: [PATCH 40/43] Correct warning through logger --- gusto/time_discretisation/parallel_dc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index 5cad75bab..7f9e9647b 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -16,6 +16,7 @@ from gusto.time_discretisation.time_discretisation import wrapper_apply from qmat import genQDeltaCoeffs from gusto.time_discretisation.deferred_correction import SDC, RIDC +from gusto.core.logging import logger __all__ = ["Parallel_RIDC", "Parallel_SDC"] @@ -69,8 +70,8 @@ def __init__(self, base_scheme, domain, M, K, J, output_freq, flush_freq=None, f self.step = 1 self.output_freq = output_freq - if self.output_freq % self.flush_freq != 0: - raise Warning("Output on all parallel in time ranks will not be the same!") + if self.flush_freq == 0 or (self.flush_freq != 0 and self.output_freq % self.flush_freq != 0): + logger.warn("Output on all parallel in time ranks will not be the same until end of run!") # Checks for parallel RIDC if self.comm is None: From 4df3480c7806a40bc6f9d9e6ae00f8f2e8d4dd91 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Mon, 15 Sep 2025 10:21:57 +0100 Subject: [PATCH 41/43] Add flush frequency to testing --- integration-tests/model/test_parallel_dc.py | 22 ++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py index 3ace723aa..f1b3fe585 100644 --- a/integration-tests/model/test_parallel_dc.py +++ b/integration-tests/model/test_parallel_dc.py @@ -2,7 +2,10 @@ This runs a simple transport test on the sphere using the parallel DC time discretisations to test whether the errors are within tolerance. The test is run for the following schemes: - IMEX_SDC(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type and 2 correction sweeps (2nd order scheme) -- IMEX_RIDC(2) - IMEX RIDC with 3 quadrature nodes of equidistant type, 1 correction sweep, reduced stencils (2nd order scheme). +- IMEX_RIDC(2,1) - IMEX RIDC with 3 quadrature nodes of equidistant type, 1 correction sweep, reduced stencils (2nd order scheme). + Has a pipeline flush frequency of 1 (every timestep). +- IMEX_RIDC(2,5) - IMEX RIDC with 3 quadrature nodes of equidistant type, 1 correction sweep, reduced stencils (2nd order scheme). + Has a pipeline flush frequency of 5 (every 5 timesteps). """ from firedrake import (norm, Ensemble, COMM_WORLD, SpatialCoordinate, @@ -21,14 +24,14 @@ def run(timestepper, tmax, f_end): @pytest.mark.parallel(nprocs=[2]) @pytest.mark.parametrize( - "scheme", ["IMEX_RIDC(2)", "IMEX_SDC(2,2)"]) + "scheme", ["IMEX_RIDC(2,1)", "IMEX_RIDC(2,5)", "IMEX_SDC(2,2)"]) def test_parallel_dc(tmpdir, scheme): if scheme == "IMEX_SDC(2,2)": M = 2 k = 2 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) - elif scheme == "IMEX_RIDC(2)": + elif scheme == "IMEX_RIDC(2,1)" or scheme == "IMEX_RIDC(2,5)": k = 1 ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(k+1)) @@ -76,7 +79,7 @@ def test_parallel_dc(tmpdir, scheme): base_scheme = IMEX_Euler(domain) time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, qdelta_exp, final_update=True, initial_guess="base", communicator=ensemble) - elif scheme == "IMEX_RIDC(2)": + elif scheme == "IMEX_RIDC(2,1)": eqn = split_continuity_form(eqn) eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) eqn.label_terms(lambda t: t.has_label(transport), explicit) @@ -84,7 +87,16 @@ def test_parallel_dc(tmpdir, scheme): M = 5 J = int(tmax/dt) base_scheme = IMEX_Euler(domain) - time_scheme = Parallel_RIDC(base_scheme, domain, M, k, J, output_freq=dumpfreq, communicator=ensemble) + time_scheme = Parallel_RIDC(base_scheme, domain, M, k, J, output_freq=dumpfreq, flush_freq=1, communicator=ensemble) + elif scheme == "IMEX_RIDC(2,5)": + eqn = split_continuity_form(eqn) + eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) + eqn.label_terms(lambda t: t.has_label(transport), explicit) + + M = 5 + J = int(tmax/dt) + base_scheme = IMEX_Euler(domain) + time_scheme = Parallel_RIDC(base_scheme, domain, M, k, J, output_freq=dumpfreq, flush_freq=5, communicator=ensemble) transport_method = DGUpwind(eqn, 'f') From 7eeeab89ad1e53b4fc7182d6526caa0741283336 Mon Sep 17 00:00:00 2001 From: atb1995 Date: Wed, 15 Oct 2025 17:01:03 +0100 Subject: [PATCH 42/43] Fix to common forms --- gusto/equations/common_forms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gusto/equations/common_forms.py b/gusto/equations/common_forms.py index 5f21fd521..360866d8f 100644 --- a/gusto/equations/common_forms.py +++ b/gusto/equations/common_forms.py @@ -446,7 +446,7 @@ def split_linear_advection_form(test, qbar, ubar, ubar_full): :class:`LabelledForm`: a labelled transport form. """ - L = test*dot(ubar, grad(qbar))*dx + L = inner(test, dot(ubar, grad(qbar)))*dx form = transporting_velocity(L, ubar_full) return transport(form, TransportEquationType.advective) From 630a61e4dd0a2b60d766f827dd4cb5434cd3e77e Mon Sep 17 00:00:00 2001 From: atb1995 Date: Fri, 14 Nov 2025 11:16:22 +0000 Subject: [PATCH 43/43] minor changes in logging and imex_rk --- gusto/time_discretisation/imex_runge_kutta.py | 38 ++++++++++++------- gusto/time_discretisation/parallel_dc.py | 6 +++ 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/gusto/time_discretisation/imex_runge_kutta.py b/gusto/time_discretisation/imex_runge_kutta.py index ba3911fd4..d6f500227 100644 --- a/gusto/time_discretisation/imex_runge_kutta.py +++ b/gusto/time_discretisation/imex_runge_kutta.py @@ -93,6 +93,10 @@ def __init__(self, domain, butcher_imp, butcher_exp, field_name=None, self.butcher_exp = butcher_exp self.nStages = int(np.shape(self.butcher_imp)[1]) + self.ksp_max = 0 + self.ksp_it = 0 + self.newton_it = 0 + # Set default linear and nonlinear solver options if none passed in if linear_solver_parameters is None: self.linear_solver_parameters = {'snes_type': 'ksponly', @@ -231,7 +235,7 @@ def final_res(self): def solvers(self): """Set up a list of solvers for each problem at a stage.""" solvers = [] - for stage in range(self.nStages): + for stage in range(1, self.nStages): # setup solver using residual defined in derived class problem = NonlinearVariationalProblem(self.res(stage), self.x_out, bcs=self.bcs) solver_name = self.field_name+self.__class__.__name__ + "%s" % (stage) @@ -253,19 +257,25 @@ def apply(self, x_out, x_in): solver_list = self.solvers for stage in range(self.nStages): - self.solver = solver_list[stage] - # Set initial solver guess - if (stage > 0): - self.x_out.assign(self.xs[stage-1]) - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.xs[stage-1], self.dt, x_out=self.source[stage-1]) - self.solver.solve() - - # Apply limiter - if self.limiter is not None: - self.limiter.apply(self.x_out) - self.xs[stage].assign(self.x_out) + if stage == 0: + self.xs[stage].assign(x_in) + else: + self.solver = solver_list[stage-1] + # Set initial solver guess + if (stage > 0): + self.x_out.assign(self.xs[stage-1]) + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.xs[stage-1], self.dt, x_out=self.source[stage-1]) + self.solver.solve() + self.newton_it += self.solver.snes.getIterationNumber() + self.ksp_it += self.solver.snes.getLinearSolveIterations() + self.ksp_max = max(self.ksp_max, self.solver.snes.getLinearSolveIterations()) + + # Apply limiter + if self.limiter is not None: + self.limiter.apply(self.x_out) + self.xs[stage].assign(self.x_out) # Solve final stage for evaluate in self.evaluate_source: diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py index 7f9e9647b..991cc3a6e 100644 --- a/gusto/time_discretisation/parallel_dc.py +++ b/gusto/time_discretisation/parallel_dc.py @@ -69,6 +69,9 @@ def __init__(self, base_scheme, domain, M, K, J, output_freq, flush_freq=None, f self.J = J self.step = 1 self.output_freq = output_freq + self.newton_it = 0 + self.ksp_it = 0 + self.ksp_max = 0 if self.flush_freq == 0 or (self.flush_freq != 0 and self.output_freq % self.flush_freq != 0): logger.warn("Output on all parallel in time ranks will not be the same until end of run!") @@ -154,6 +157,9 @@ def apply(self, x_out, x_in): # + sum(j=1,M) s_mj*(F+S)(y_j^k) self.solver.solve() self.Unodes1[m+1].assign(self.U_DC) + self.newton_it += self.solver.snes.getIterationNumber() + self.ksp_it += self.solver.snes.getLinearSolveIterations() + self.ksp_max = max(self.ksp_max, self.solver.snes.getLinearSolveIterations()) # Evaluate source terms for evaluate in self.evaluate_source: