From 049befd6385897391574d927ed987fa5d7f4c3bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 26 Jun 2025 17:29:07 +0100 Subject: [PATCH 1/5] Refine piecewise smoothness check --- python/sdist/amici/import_utils.py | 97 +++++++++++++++++++++++++++--- python/tests/test_heavisides.py | 74 +++++++++++++++++++++++ 2 files changed, 164 insertions(+), 7 deletions(-) diff --git a/python/sdist/amici/import_utils.py b/python/sdist/amici/import_utils.py index 7549eb6f1d..6e1e0ca599 100644 --- a/python/sdist/amici/import_utils.py +++ b/python/sdist/amici/import_utils.py @@ -15,6 +15,7 @@ import sympy as sp from sympy.functions.elementary.piecewise import ExprCondPair from sympy.logic.boolalg import BooleanAtom +from sympy.core.relational import Relational from toposort import toposort RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"] @@ -347,7 +348,11 @@ def toposort_symbols( } -def _parse_special_functions(sym: sp.Expr, toplevel: bool = True) -> sp.Expr: +def _parse_special_functions( + sym: sp.Expr, + toplevel: bool = True, + parameters: Sequence[sp.Symbol] | None = None, +) -> sp.Expr: """ Recursively checks the symbolic expression for functions which have be to parsed in a special way, such as piecewise functions @@ -358,11 +363,14 @@ def _parse_special_functions(sym: sp.Expr, toplevel: bool = True) -> sp.Expr: :param toplevel: as this is called recursively, are we in the top level expression? """ + if parameters is None: + parameters = [] + args = tuple( arg if arg.__class__.__name__ == "piecewise" and sym.__class__.__name__ == "piecewise" - else _parse_special_functions(arg, False) + else _parse_special_functions(arg, False, parameters) for arg in sym.args ) @@ -404,8 +412,8 @@ def _parse_special_functions(sym: sp.Expr, toplevel: bool = True) -> sp.Expr: denested_args = args else: # this is sbml piecewise, can be nested - denested_args = _denest_piecewise(args) - return _parse_piecewise_to_heaviside(denested_args) + denested_args = _denest_piecewise(args, parameters) + return _parse_piecewise_to_heaviside(denested_args, parameters) if sym.__class__.__name__ == "plus" and not sym.args: return sp.Float(0.0) @@ -424,6 +432,7 @@ def _parse_special_functions(sym: sp.Expr, toplevel: bool = True) -> sp.Expr: def _denest_piecewise( args: Sequence[sp.Expr | sp.logic.boolalg.Boolean | bool], + parameters: Sequence[sp.Symbol] | None = None, ) -> tuple[sp.Expr | sp.logic.boolalg.Boolean | bool]: """ Denest piecewise functions that contain piecewise as condition @@ -435,13 +444,16 @@ def _denest_piecewise( Arguments where conditions no longer contain piecewise functions and the conditional dependency is flattened out """ + if parameters is None: + parameters = [] + args_out = [] for coeff, cond in grouper(args, 2, True): # handling of this case is explicitely disabled in # _parse_special_functions as keeping track of coeff/cond # arguments is tricky. Simpler to just parse them out here if coeff.__class__.__name__ == "piecewise": - coeff = _parse_special_functions(coeff, False) + coeff = _parse_special_functions(coeff, False, parameters) # we can have conditions that are piecewise function # returning True or False @@ -451,7 +463,7 @@ def _denest_piecewise( previous_was_picked = sp.false # recursively denest those first for sub_coeff, sub_cond in grouper( - _denest_piecewise(cond.args), 2, True + _denest_piecewise(cond.args, parameters), 2, True ): # flatten the individual pieces pick_this = sp.And(sp.Not(previous_was_picked), sub_cond) @@ -465,7 +477,70 @@ def _denest_piecewise( return tuple(args_out[:-1]) -def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr: +def _is_c1_piecewise( + pw: sp.Piecewise, parameters: Sequence[sp.Symbol] +) -> bool: + """Return ``True`` if ``pw`` is continuously differentiable with respect to + ``parameters``. + + This check ensures that piecewise expressions which are already smooth are + not transformed into events. + """ + + from sympy.calculus.util import continuous_domain + + pieces = pw.args + + # collect boundaries appearing in the conditions + boundaries: list[tuple[sp.Symbol, sp.Expr]] = [] + for _, cond in pieces: + if cond in (True, False, sp.true, sp.false): + continue + if not isinstance(cond, Relational): + return False + if isinstance(cond.lhs, sp.Symbol) and not cond.rhs.has(cond.lhs): + boundaries.append((cond.lhs, cond.rhs)) + elif isinstance(cond.rhs, sp.Symbol) and not cond.lhs.has(cond.rhs): + boundaries.append((cond.rhs, cond.lhs)) + else: + return False + + # check that each piece and its derivatives are continuous on R + for expr, _ in pieces: + for var in parameters: + try: + if continuous_domain(expr, var, sp.S.Reals) != sp.S.Reals: + return False + except NotImplementedError: + return False + try: + dexpr = sp.diff(expr, var) + if continuous_domain(dexpr, var, sp.S.Reals) != sp.S.Reals: + return False + except NotImplementedError: + return False + + # check continuity and derivative continuity at boundaries + for (sym, boundary), (expr_left, _), (expr_right, _) in zip( + boundaries, pieces[:-1], pieces[1:] + ): + if not sp.simplify( + expr_left.subs(sym, boundary) - expr_right.subs(sym, boundary) + ).is_zero: + return False + for var in parameters: + if not sp.simplify( + sp.diff(expr_left, var).subs(sym, boundary) + - sp.diff(expr_right, var).subs(sym, boundary) + ).is_zero: + return False + + return True + + +def _parse_piecewise_to_heaviside( + args: Iterable[sp.Expr], parameters: Sequence[sp.Symbol] | None = None +) -> sp.Expr: """ Piecewise functions cannot be transformed into C++ right away, but AMICI has a special interface for Heaviside functions, so we transform them. @@ -484,6 +559,14 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr: # smbl piecewise grouped_args = grouper(args, 2, True) + pw = sp.Piecewise(*grouped_args) + + if parameters is None: + parameters = [] + + if _is_c1_piecewise(pw, parameters): + return pw + for coeff, trigger in grouped_args: if isinstance(coeff, BooleanAtom): coeff = sp.Integer(int(bool(coeff))) diff --git a/python/tests/test_heavisides.py b/python/tests/test_heavisides.py index f0f26e470b..0c18e241b7 100644 --- a/python/tests/test_heavisides.py +++ b/python/tests/test_heavisides.py @@ -356,3 +356,77 @@ def sx_expected(t, x_1_0): x_expected, sx_expected, ) + + +def test_parse_piecewise_c1_no_heaviside(): + """_parse_piecewise_to_heaviside should keep C1 piecewise expressions.""" + + import sympy as sp + from amici.import_utils import ( + _parse_piecewise_to_heaviside, + amici_time_symbol, + symbol_with_assumptions, + ) + + t = amici_time_symbol + x = symbol_with_assumptions("x_1") + pw = sp.Piecewise((t * x, t < 1), (x + (t - 1) * x, True)) + + res = _parse_piecewise_to_heaviside(pw.args, []) + assert isinstance(res, sp.Piecewise) + assert sp.simplify(res - pw) == 0 + + p = symbol_with_assumptions("p1") + pw_param = sp.Piecewise( + (p**2, p < 1), + ((p - 1) ** 2 + 2 * p - 1, True), + ) + + res_param = _parse_piecewise_to_heaviside(pw_param.args, [p]) + assert isinstance(res_param, sp.Piecewise) + assert sp.simplify(res_param - pw_param) == 0 + + +def test_parse_piecewise_discontinuous_to_heaviside(): + """_parse_piecewise_to_heaviside should convert discontinuous piecewise.""" + + import sympy as sp + from amici.import_utils import ( + _parse_piecewise_to_heaviside, + amici_time_symbol, + symbol_with_assumptions, + ) + + t = amici_time_symbol + x = symbol_with_assumptions("x_1") + + pw_state = sp.Piecewise((t * x, x < 1), (2 * t * x, True)) + res_state = _parse_piecewise_to_heaviside(pw_state.args, []) + assert not isinstance(res_state, sp.Piecewise) + expected_state = t * x * ( + 1 - sp.Heaviside(x - 1, 1) + ) + 2 * t * x * sp.Heaviside(x - 1, 1) + assert sp.simplify(res_state - expected_state) == 0 + + p = symbol_with_assumptions("p1") + pw_param = sp.Piecewise((0, p < 1), (1, True)) + res_param = _parse_piecewise_to_heaviside(pw_param.args, [p]) + assert not isinstance(res_param, sp.Piecewise) + expected_param = sp.Heaviside(p - 1, 1) + assert sp.simplify(res_param - expected_param) == 0 + + +def test_parse_piecewise_c1_constant_zero(): + """Piecewise expressions evaluating to zero should simplify to zero.""" + + import sympy as sp + from amici.import_utils import ( + _parse_piecewise_to_heaviside, + symbol_with_assumptions, + ) + + p = symbol_with_assumptions("p1") + pw_zero = sp.Piecewise((p - p, p < 1), (0, True), evaluate=False) + + res_zero = _parse_piecewise_to_heaviside(pw_zero.args, [p]) + assert sp.simplify(res_zero) == 0 From 7c4b74ac842fcd85407282b7b00e49ccf76dea1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Sun, 29 Jun 2025 13:14:52 +0100 Subject: [PATCH 2/5] pass parameters to piecewise filter --- python/sdist/amici/import_utils.py | 8 +++----- python/sdist/amici/pysb_import.py | 9 +++++++-- python/sdist/amici/sbml_import.py | 8 +++++--- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/python/sdist/amici/import_utils.py b/python/sdist/amici/import_utils.py index 6e1e0ca599..5a00c7199c 100644 --- a/python/sdist/amici/import_utils.py +++ b/python/sdist/amici/import_utils.py @@ -350,8 +350,8 @@ def toposort_symbols( def _parse_special_functions( sym: sp.Expr, + parameters: Iterable[sp.Symbol], toplevel: bool = True, - parameters: Sequence[sp.Symbol] | None = None, ) -> sp.Expr: """ Recursively checks the symbolic expression for functions which have be @@ -363,14 +363,12 @@ def _parse_special_functions( :param toplevel: as this is called recursively, are we in the top level expression? """ - if parameters is None: - parameters = [] args = tuple( arg if arg.__class__.__name__ == "piecewise" and sym.__class__.__name__ == "piecewise" - else _parse_special_functions(arg, False, parameters) + else _parse_special_functions(arg, parameters, False) for arg in sym.args ) @@ -453,7 +451,7 @@ def _denest_piecewise( # _parse_special_functions as keeping track of coeff/cond # arguments is tricky. Simpler to just parse them out here if coeff.__class__.__name__ == "piecewise": - coeff = _parse_special_functions(coeff, False, parameters) + coeff = _parse_special_functions(coeff, parameters, False) # we can have conditions that are piecewise function # returning True or False diff --git a/python/sdist/amici/pysb_import.py b/python/sdist/amici/pysb_import.py index fb54dd745c..11ed8bf6e6 100644 --- a/python/sdist/amici/pysb_import.py +++ b/python/sdist/amici/pysb_import.py @@ -694,7 +694,9 @@ def _add_expression( else: component = Expression ode_model.add_component( - component(sym, name, _parse_special_functions(expr)) + component( + sym, name, _parse_special_functions(expr, ode_model.sym("p")) + ) ) if name in observables: @@ -711,7 +713,10 @@ def _add_expression( # changes, I would expect symbol redefinition warnings in CPP models and overwriting in JAX models, but as both # symbols refer to the same symbolic entity, this should not be a problem (untested) obs = Observable( - y, name, _parse_special_functions(expr), transformation=trafo + y, + name, + _parse_special_functions(expr, ode_model.sym("p")), + transformation=trafo, ) ode_model.add_component(obs) diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 131db1a041..8f8ecf4c1c 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -822,11 +822,11 @@ def _process_sbml( if not self._discard_annotations: self._process_annotations() self.check_support() - self._gather_locals(hardcode_symbols=hardcode_symbols) self._process_parameters( constant_parameters=constant_parameters, hardcode_symbols=hardcode_symbols, - ) + ) # needs to be processed first such that we can use parameters to filter Piecewise functions + self._gather_locals(hardcode_symbols=hardcode_symbols) self._process_compartments() self._process_species() self._process_reactions() @@ -2943,7 +2943,9 @@ def subs_locals(expr: sp.Basic) -> sp.Basic: try: expr = expr.replace( sp.Piecewise, - lambda *args: _parse_piecewise_to_heaviside(args), + lambda *args: _parse_piecewise_to_heaviside( + args, list(self.symbols[SymbolId.PARAMETER].keys()) + ), ) except RuntimeError as err: raise SBMLException(str(err)) from err From edac39dbfb612470b2301c756bfc51bf6ad1f5da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Sun, 29 Jun 2025 13:58:51 +0100 Subject: [PATCH 3/5] fix splines --- python/sdist/amici/sbml_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/sbml_utils.py b/python/sdist/amici/sbml_utils.py index 068c9ffb6b..cc164a78ed 100644 --- a/python/sdist/amici/sbml_utils.py +++ b/python/sdist/amici/sbml_utils.py @@ -510,7 +510,13 @@ def mathml2sympy( with sp.core.parameters.evaluate(False): expr = sp.sympify(formula, locals=locals) - expr = _parse_special_functions(expr) + # no easy way of accessing proper parameters here, so passing empty list + error + # if you really want to do this, replace this function by sbmlmath + if "Piecewise" in str(expr): + raise SBMLException( + "Piecewise expressions are not supported as part of spline functions" + ) + expr = _parse_special_functions(expr, []) if expression_type is not None: _check_unsupported_functions(expr, expression_type) From 342d8c34314de324cf457336a289e200f8060907 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Mon, 30 Jun 2025 12:21:35 +0100 Subject: [PATCH 4/5] Update testSBMLSuite.py --- tests/sbml/testSBMLSuite.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/sbml/testSBMLSuite.py b/tests/sbml/testSBMLSuite.py index f6d293c9f6..5f15df31dc 100755 --- a/tests/sbml/testSBMLSuite.py +++ b/tests/sbml/testSBMLSuite.py @@ -13,6 +13,8 @@ import shutil from pathlib import Path +import optimistix + import amici import pandas as pd import pytest @@ -21,7 +23,10 @@ import jax.numpy as jnp import numpy as np import diffrax -from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS +from amici.jax.petab import ( + DEFAULT_CONTROLLER_SETTINGS, + DEFAULT_ROOT_FINDER_SETTINGS, +) from utils import ( verify_results, @@ -72,11 +77,6 @@ def test_sbml_testsuite_case(test_id, result_path, sbml_semantic_cases_dir): atol, rtol = apply_settings(settings, solver, model, test_id) - if test_id in sensitivity_check_cases: - model.requireSensitivitiesForAllParameters() - solver.setSensitivityOrder(amici.SensitivityOrder.first) - solver.setSensitivityMethod(amici.SensitivityMethod.forward) - # simulate model rdata = amici.runAmiciSimulation(model, solver) if rdata["status"] != amici.AMICI_SUCCESS: @@ -208,6 +208,7 @@ def jax_sensitivity_check( icoeff=DEFAULT_CONTROLLER_SETTINGS["icoeff"], dcoeff=DEFAULT_CONTROLLER_SETTINGS["dcoeff"], ) + root_finder = optimistix.Newton(**DEFAULT_ROOT_FINDER_SETTINGS) def simulate(pars): x, _ = jax_model.simulate_condition( @@ -221,6 +222,7 @@ def simulate(pars): jnp.zeros((ts_jnp.shape[0], 0)), solver, controller, + root_finder, diffrax.DirectAdjoint(), diffrax.SteadyStateEvent(), 2**10, From f85577b4be67b0d4f5dfb637a995c578db402e6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Wed, 2 Jul 2025 15:58:45 +0100 Subject: [PATCH 5/5] split _process_parameters --- python/sdist/amici/sbml_import.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 36739c42e4..efeee6281a 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -827,6 +827,7 @@ def _process_sbml( hardcode_symbols=hardcode_symbols, ) # needs to be processed first such that we can use parameters to filter Piecewise functions self._gather_locals(hardcode_symbols=hardcode_symbols) + self._convert_parameters() self._process_compartments() self._process_species() self._process_reactions() @@ -1383,6 +1384,16 @@ def _process_parameters( ), } + @log_execution_time("converting SBML parameters", logger) + def _convert_parameters(self): + # parameter ID => initial assignment sympy expression + par_id_to_ia = { + par.getId(): _try_evalf(ia) + for par in self.sbml.getListOfParameters() + if (ia := self._get_element_initial_assignment(par.getId())) + is not None + } + # Parameters that need to be turned into expressions or species # so far, this concerns parameters with symbolic initial assignments # (those have been skipped above) that are not rate rule targets