Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sympy as sp
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.logic.boolalg import BooleanAtom
from sympy.core.relational import Relational
from toposort import toposort

RESERVED_SYMBOLS = ["x", "k", "p", "y", "w", "h", "t", "AMICI_EMPTY_BOLUS"]
Expand Down Expand Up @@ -347,7 +348,11 @@ def toposort_symbols(
}


def _parse_special_functions(sym: sp.Expr, toplevel: bool = True) -> sp.Expr:
def _parse_special_functions(
sym: sp.Expr,
parameters: Iterable[sp.Symbol],
toplevel: bool = True,
) -> sp.Expr:
"""
Recursively checks the symbolic expression for functions which have be
to parsed in a special way, such as piecewise functions
Expand All @@ -358,11 +363,12 @@ def _parse_special_functions(sym: sp.Expr, toplevel: bool = True) -> sp.Expr:
:param toplevel:
as this is called recursively, are we in the top level expression?
"""

args = tuple(
arg
if arg.__class__.__name__ == "piecewise"
and sym.__class__.__name__ == "piecewise"
else _parse_special_functions(arg, False)
else _parse_special_functions(arg, parameters, False)
for arg in sym.args
)

Expand Down Expand Up @@ -404,8 +410,8 @@ def _parse_special_functions(sym: sp.Expr, toplevel: bool = True) -> sp.Expr:
denested_args = args
else:
# this is sbml piecewise, can be nested
denested_args = _denest_piecewise(args)
return _parse_piecewise_to_heaviside(denested_args)
denested_args = _denest_piecewise(args, parameters)
return _parse_piecewise_to_heaviside(denested_args, parameters)

if sym.__class__.__name__ == "plus" and not sym.args:
return sp.Float(0.0)
Expand All @@ -424,6 +430,7 @@ def _parse_special_functions(sym: sp.Expr, toplevel: bool = True) -> sp.Expr:

def _denest_piecewise(
args: Sequence[sp.Expr | sp.logic.boolalg.Boolean | bool],
parameters: Sequence[sp.Symbol] | None = None,
) -> tuple[sp.Expr | sp.logic.boolalg.Boolean | bool]:
"""
Denest piecewise functions that contain piecewise as condition
Expand All @@ -435,13 +442,16 @@ def _denest_piecewise(
Arguments where conditions no longer contain piecewise functions and
the conditional dependency is flattened out
"""
if parameters is None:
parameters = []

args_out = []
for coeff, cond in grouper(args, 2, True):
# handling of this case is explicitely disabled in
# _parse_special_functions as keeping track of coeff/cond
# arguments is tricky. Simpler to just parse them out here
if coeff.__class__.__name__ == "piecewise":
coeff = _parse_special_functions(coeff, False)
coeff = _parse_special_functions(coeff, parameters, False)

# we can have conditions that are piecewise function
# returning True or False
Expand All @@ -451,7 +461,7 @@ def _denest_piecewise(
previous_was_picked = sp.false
# recursively denest those first
for sub_coeff, sub_cond in grouper(
_denest_piecewise(cond.args), 2, True
_denest_piecewise(cond.args, parameters), 2, True
):
# flatten the individual pieces
pick_this = sp.And(sp.Not(previous_was_picked), sub_cond)
Expand All @@ -465,7 +475,70 @@ def _denest_piecewise(
return tuple(args_out[:-1])


def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr:
def _is_c1_piecewise(
pw: sp.Piecewise, parameters: Sequence[sp.Symbol]
) -> bool:
"""Return ``True`` if ``pw`` is continuously differentiable with respect to
``parameters``.

This check ensures that piecewise expressions which are already smooth are
not transformed into events.
"""

from sympy.calculus.util import continuous_domain

pieces = pw.args

# collect boundaries appearing in the conditions
boundaries: list[tuple[sp.Symbol, sp.Expr]] = []
for _, cond in pieces:
if cond in (True, False, sp.true, sp.false):
continue
if not isinstance(cond, Relational):
return False
if isinstance(cond.lhs, sp.Symbol) and not cond.rhs.has(cond.lhs):
boundaries.append((cond.lhs, cond.rhs))
elif isinstance(cond.rhs, sp.Symbol) and not cond.lhs.has(cond.rhs):
boundaries.append((cond.rhs, cond.lhs))
else:
return False

# check that each piece and its derivatives are continuous on R
for expr, _ in pieces:
for var in parameters:
try:
if continuous_domain(expr, var, sp.S.Reals) != sp.S.Reals:
return False
except NotImplementedError:
return False
try:
dexpr = sp.diff(expr, var)
if continuous_domain(dexpr, var, sp.S.Reals) != sp.S.Reals:
return False
except NotImplementedError:
return False

# check continuity and derivative continuity at boundaries
for (sym, boundary), (expr_left, _), (expr_right, _) in zip(
boundaries, pieces[:-1], pieces[1:]
):
if not sp.simplify(
expr_left.subs(sym, boundary) - expr_right.subs(sym, boundary)
).is_zero:
return False
for var in parameters:
if not sp.simplify(
sp.diff(expr_left, var).subs(sym, boundary)
- sp.diff(expr_right, var).subs(sym, boundary)
).is_zero:
return False

return True


def _parse_piecewise_to_heaviside(
args: Iterable[sp.Expr], parameters: Sequence[sp.Symbol] | None = None
) -> sp.Expr:
"""
Piecewise functions cannot be transformed into C++ right away, but AMICI
has a special interface for Heaviside functions, so we transform them.
Expand All @@ -484,6 +557,14 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr:
# smbl piecewise
grouped_args = grouper(args, 2, True)

pw = sp.Piecewise(*grouped_args)

if parameters is None:
parameters = []

if _is_c1_piecewise(pw, parameters):
return pw

for coeff, trigger in grouped_args:
if isinstance(coeff, BooleanAtom):
coeff = sp.Integer(int(bool(coeff)))
Expand Down
9 changes: 7 additions & 2 deletions python/sdist/amici/pysb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,9 @@ def _add_expression(
else:
component = Expression
ode_model.add_component(
component(sym, name, _parse_special_functions(expr))
component(
sym, name, _parse_special_functions(expr, ode_model.sym("p"))
)
)

if name in observables:
Expand All @@ -711,7 +713,10 @@ def _add_expression(
# changes, I would expect symbol redefinition warnings in CPP models and overwriting in JAX models, but as both
# symbols refer to the same symbolic entity, this should not be a problem (untested)
obs = Observable(
y, name, _parse_special_functions(expr), transformation=trafo
y,
name,
_parse_special_functions(expr, ode_model.sym("p")),
transformation=trafo,
)
ode_model.add_component(obs)

Expand Down
19 changes: 16 additions & 3 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,11 +822,12 @@ def _process_sbml(
if not self._discard_annotations:
self._process_annotations()
self.check_support()
self._gather_locals(hardcode_symbols=hardcode_symbols)
self._process_parameters(
constant_parameters=constant_parameters,
hardcode_symbols=hardcode_symbols,
)
) # needs to be processed first such that we can use parameters to filter Piecewise functions
self._gather_locals(hardcode_symbols=hardcode_symbols)
self._convert_parameters()
self._process_compartments()
self._process_species()
self._process_reactions()
Expand Down Expand Up @@ -1383,6 +1384,16 @@ def _process_parameters(
),
}

@log_execution_time("converting SBML parameters", logger)
def _convert_parameters(self):
# parameter ID => initial assignment sympy expression
par_id_to_ia = {
par.getId(): _try_evalf(ia)
for par in self.sbml.getListOfParameters()
if (ia := self._get_element_initial_assignment(par.getId()))
is not None
}

# Parameters that need to be turned into expressions or species
# so far, this concerns parameters with symbolic initial assignments
# (those have been skipped above) that are not rate rule targets
Expand Down Expand Up @@ -2940,7 +2951,9 @@ def subs_locals(expr: sp.Basic) -> sp.Basic:
try:
expr = expr.replace(
sp.Piecewise,
lambda *args: _parse_piecewise_to_heaviside(args),
lambda *args: _parse_piecewise_to_heaviside(
args, list(self.symbols[SymbolId.PARAMETER].keys())
),
)
except RuntimeError as err:
raise SBMLException(str(err)) from err
Expand Down
8 changes: 7 additions & 1 deletion python/sdist/amici/sbml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,13 @@ def mathml2sympy(
with sp.core.parameters.evaluate(False):
expr = sp.sympify(formula, locals=locals)

expr = _parse_special_functions(expr)
# no easy way of accessing proper parameters here, so passing empty list + error
# if you really want to do this, replace this function by sbmlmath
if "Piecewise" in str(expr):
raise SBMLException(
"Piecewise expressions are not supported as part of spline functions"
)
expr = _parse_special_functions(expr, [])

if expression_type is not None:
_check_unsupported_functions(expr, expression_type)
Expand Down
74 changes: 74 additions & 0 deletions python/tests/test_heavisides.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,77 @@ def sx_expected(t, x_1_0):
x_expected,
sx_expected,
)


def test_parse_piecewise_c1_no_heaviside():
"""_parse_piecewise_to_heaviside should keep C1 piecewise expressions."""

import sympy as sp
from amici.import_utils import (
_parse_piecewise_to_heaviside,
amici_time_symbol,
symbol_with_assumptions,
)

t = amici_time_symbol
x = symbol_with_assumptions("x_1")
pw = sp.Piecewise((t * x, t < 1), (x + (t - 1) * x, True))

res = _parse_piecewise_to_heaviside(pw.args, [])
assert isinstance(res, sp.Piecewise)
assert sp.simplify(res - pw) == 0

p = symbol_with_assumptions("p1")
pw_param = sp.Piecewise(
(p**2, p < 1),
((p - 1) ** 2 + 2 * p - 1, True),
)

res_param = _parse_piecewise_to_heaviside(pw_param.args, [p])
assert isinstance(res_param, sp.Piecewise)
assert sp.simplify(res_param - pw_param) == 0


def test_parse_piecewise_discontinuous_to_heaviside():
"""_parse_piecewise_to_heaviside should convert discontinuous piecewise."""

import sympy as sp
from amici.import_utils import (
_parse_piecewise_to_heaviside,
amici_time_symbol,
symbol_with_assumptions,
)

t = amici_time_symbol
x = symbol_with_assumptions("x_1")

pw_state = sp.Piecewise((t * x, x < 1), (2 * t * x, True))
res_state = _parse_piecewise_to_heaviside(pw_state.args, [])
assert not isinstance(res_state, sp.Piecewise)
expected_state = t * x * (
1 - sp.Heaviside(x - 1, 1)
) + 2 * t * x * sp.Heaviside(x - 1, 1)
assert sp.simplify(res_state - expected_state) == 0

p = symbol_with_assumptions("p1")
pw_param = sp.Piecewise((0, p < 1), (1, True))
res_param = _parse_piecewise_to_heaviside(pw_param.args, [p])
assert not isinstance(res_param, sp.Piecewise)
expected_param = sp.Heaviside(p - 1, 1)
assert sp.simplify(res_param - expected_param) == 0


def test_parse_piecewise_c1_constant_zero():
"""Piecewise expressions evaluating to zero should simplify to zero."""

import sympy as sp
from amici.import_utils import (
_parse_piecewise_to_heaviside,
symbol_with_assumptions,
)

p = symbol_with_assumptions("p1")
pw_zero = sp.Piecewise((p - p, p < 1), (0, True), evaluate=False)

res_zero = _parse_piecewise_to_heaviside(pw_zero.args, [p])
assert sp.simplify(res_zero) == 0
Loading