Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sdist/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import sympy as sp
from sympy.codegen.rewriting import Optimization, optimize
from sympy.printing.cxx import CXX11CodePrinter

from .import_utils import RESERVED_SYMBOLS
from sympy.utilities.iterables import numbered_symbols
from toposort import toposort

Expand Down Expand Up @@ -50,6 +52,12 @@
else:
self._fpoptimizer = None

def _print_Symbol(self, expr: sp.Symbol) -> str:
name = super()._print_Symbol(expr)
if name in RESERVED_SYMBOLS and name != "t":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if a model has a parameter t that is not time?

return f"amici_{name}"

Check warning on line 58 in python/sdist/amici/cxxcodeprinter.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/cxxcodeprinter.py#L58

Added line #L58 was not covered by tests
return name

def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str:
if self._fpoptimizer:
if isinstance(expr, list):
Expand Down
10 changes: 8 additions & 2 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .de_model_components import *
from .de_model import DEModel
from .import_utils import (
RESERVED_SYMBOLS,
strip_pysb,
)
from .logging import get_logger, log_execution_time, set_log_level
Expand Down Expand Up @@ -404,7 +405,12 @@ def _write_index_files(self, name: str) -> None:
continue
if str(symbol_name) == "":
raise ValueError(f'{name} contains a symbol called ""')
lines.append(f"#define {symbol_name} {name}[{index}]")
sanitized_name = (
f"amici_{symbol_name}"
if str(symbol_name) in RESERVED_SYMBOLS
else str(symbol_name)
)
lines.append(f"#define {sanitized_name} {name}[{index}]")
if name == "stau":
# we only need a single macro, as all entries have the same symbol
break
Expand Down Expand Up @@ -1243,7 +1249,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
Template initializer list of ids
"""
return "\n".join(
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
f'"{strip_pysb(symbol)}", // {name}[{idx}]'
for idx, symbol in enumerate(self.model.sym(name))
)

Expand Down
7 changes: 0 additions & 7 deletions python/sdist/amici/de_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sympy as sp

from .import_utils import (
RESERVED_SYMBOLS,
ObservableTransformation,
amici_time_symbol,
cast_to_sym,
Expand Down Expand Up @@ -66,12 +65,6 @@ def __init__(
f"identifier must be sympy.Symbol, was {type(identifier)}"
)

if str(identifier) in RESERVED_SYMBOLS or (
hasattr(identifier, "name") and identifier.name in RESERVED_SYMBOLS
):
raise ValueError(
f'Cannot add model quantity with name "{name}", please rename.'
)
self._identifier: sp.Symbol = identifier

if not isinstance(name, str):
Expand Down
8 changes: 8 additions & 0 deletions python/sdist/amici/jax/jaxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from collections.abc import Iterable
from logging import warning

from ..import_utils import RESERVED_SYMBOLS

import sympy as sp
from sympy.printing.numpy import NumPyPrinter

Expand All @@ -17,6 +19,12 @@ def _jnp_array_str(array) -> str:
class AmiciJaxCodePrinter(NumPyPrinter):
"""JAX code printer"""

def _print_Symbol(self, expr: sp.Symbol) -> str:
name = super()._print_Symbol(expr)
if name in RESERVED_SYMBOLS and name != "t":
return f"amici_{name}"
return name

def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str:
try:
code = super().doprint(expr, assign_to)
Expand Down
16 changes: 12 additions & 4 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from amici.de_export import is_valid_identifier
from amici.import_utils import (
RESERVED_SYMBOLS,
strip_pysb,
)
from amici.logging import get_logger, log_execution_time, set_log_level
Expand All @@ -40,13 +41,20 @@
logger = get_logger(__name__, logging.ERROR)


def _sanitize(name: str) -> str:
return (
f"amici_{name}" if name in RESERVED_SYMBOLS and name != "t" else name
)


def _jax_variable_assignments(
model: DEModel, sym_names: tuple[str, ...]
) -> dict:
return {
f"{sym_name.upper()}_SYMS": "".join(
str(strip_pysb(s)) + ", " for s in model.sym(sym_name)
f"{sym_name.upper()}_SYMS": ", ".join(
_sanitize(str(strip_pysb(s))) for s in model.sym(sym_name)
)
+ ", "
if model.sym(sym_name)
else "_"
for sym_name in sym_names
Expand All @@ -63,7 +71,7 @@ def _jax_variable_equations(
return {
f"{eq_name.upper()}_EQ": "\n".join(
code_printer._get_sym_lines(
(str(strip_pysb(s)) for s in model.sym(eq_name)),
(_sanitize(str(strip_pysb(s))) for s in model.sym(eq_name)),
model.eq(eq_name).subs(subs),
indent,
)
Expand All @@ -78,7 +86,7 @@ def _jax_return_variables(
) -> dict:
return {
f"{eq_name.upper()}_RET": _jnp_array_str(
strip_pysb(s) for s in model.sym(eq_name)
_sanitize(str(strip_pysb(s))) for s in model.sym(eq_name)
)
if model.sym(eq_name)
else "jnp.array([])"
Expand Down
20 changes: 0 additions & 20 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from .de_model_components import symbol_to_type, Expression
from .sympy_utils import smart_is_zero_matrix, smart_multiply
from .import_utils import (
RESERVED_SYMBOLS,
_check_unsupported_functions,
_get_str_symbol_identifiers,
amici_time_symbol,
Expand Down Expand Up @@ -691,7 +690,6 @@ def _build_ode_model(
)
self._replace_compartments_with_volumes()

self._clean_reserved_symbols()
self._process_time()

ode_model = DEModel(
Expand Down Expand Up @@ -2783,24 +2781,6 @@ def _replace_in_all_expressions(
for spline in self.splines:
spline._replace_in_all_expressions(old, new)

def _clean_reserved_symbols(self) -> None:
"""
Remove all reserved symbols from self.symbols
"""
for sym in RESERVED_SYMBOLS:
old_symbol = symbol_with_assumptions(sym)
new_symbol = symbol_with_assumptions(f"amici_{sym}")
self._replace_in_all_expressions(
old_symbol, new_symbol, replace_identifiers=True
)
for symbols_ids, symbols in self.symbols.items():
if old_symbol in symbols:
# reconstitute the whole dict in order to keep the ordering
self.symbols[symbols_ids] = {
new_symbol if k is old_symbol else k: v
for k, v in symbols.items()
}

def _sympify(
self,
var_or_math: libsbml.SBase
Expand Down
7 changes: 1 addition & 6 deletions python/tests/test_bngl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,7 @@ def test_compare_to_pysb_simulation(example):
with pytest.raises(ValueError, match="Conservation laws"):
bngl2amici(model_file, outdir, compute_conservation_laws=True)

if example in ["empty_compartments_block", "motor"]:
with pytest.raises(ValueError, match="Cannot add"):
bngl2amici(model_file, outdir, **kwargs)
return
else:
bngl2amici(model_file, outdir, **kwargs)
bngl2amici(model_file, outdir, **kwargs)

amici_model_module = amici.import_model_module(pysb_model.name, outdir)

Expand Down
14 changes: 4 additions & 10 deletions tests/sbml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def verify_results(settings, rdata, expected, wrapper, model, atol, rtol):
new_key = expr_id.removeprefix("flux_")
else:
new_key = expr_id
if expr_id.removeprefix("amici_") in simulated.columns:
if expr_id in simulated.columns:
continue # skip if already present
expression_data[new_key] = rdata.w[:, expr_idx]

Expand All @@ -48,12 +48,6 @@ def verify_results(settings, rdata, expected, wrapper, model, atol, rtol):
axis=1,
)

# handle renamed reserved symbols
simulated.rename(
columns={c: c.replace("amici_", "") for c in simulated.columns},
inplace=True,
)

# SBML test suite case 01308 defines species with initialAmount and
# hasOnlySubstanceUnits="true", but then request results as concentrations.
requested_concentrations = [
Expand Down Expand Up @@ -150,9 +144,9 @@ def concentrations_to_amounts(
) or comp is None:
continue

simulated.loc[:, species] *= simulated.loc[
:, comp if comp in simulated.columns else f"amici_{comp}"
]
if comp not in simulated.columns:
continue
simulated.loc[:, species] *= simulated.loc[:, comp]


def write_result_file(
Expand Down
Loading