diff --git a/python/sdist/amici/cxxcodeprinter.py b/python/sdist/amici/cxxcodeprinter.py index b5b4ed809c..47290e69dc 100644 --- a/python/sdist/amici/cxxcodeprinter.py +++ b/python/sdist/amici/cxxcodeprinter.py @@ -10,6 +10,8 @@ import sympy as sp from sympy.codegen.rewriting import Optimization, optimize from sympy.printing.cxx import CXX11CodePrinter + +from .import_utils import RESERVED_SYMBOLS from sympy.utilities.iterables import numbered_symbols from toposort import toposort @@ -50,6 +52,12 @@ def __init__(self): else: self._fpoptimizer = None + def _print_Symbol(self, expr: sp.Symbol) -> str: + name = super()._print_Symbol(expr) + if name in RESERVED_SYMBOLS and name != "t": + return f"amici_{name}" + return name + def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str: if self._fpoptimizer: if isinstance(expr, list): diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index bf94a8e983..96bf63d7e7 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -58,6 +58,7 @@ from .de_model_components import * from .de_model import DEModel from .import_utils import ( + RESERVED_SYMBOLS, strip_pysb, ) from .logging import get_logger, log_execution_time, set_log_level @@ -404,7 +405,12 @@ def _write_index_files(self, name: str) -> None: continue if str(symbol_name) == "": raise ValueError(f'{name} contains a symbol called ""') - lines.append(f"#define {symbol_name} {name}[{index}]") + sanitized_name = ( + f"amici_{symbol_name}" + if str(symbol_name) in RESERVED_SYMBOLS + else str(symbol_name) + ) + lines.append(f"#define {sanitized_name} {name}[{index}]") if name == "stau": # we only need a single macro, as all entries have the same symbol break @@ -1243,7 +1249,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str: Template initializer list of ids """ return "\n".join( - f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]' + f'"{strip_pysb(symbol)}", // {name}[{idx}]' for idx, symbol in enumerate(self.model.sym(name)) ) diff --git a/python/sdist/amici/de_model_components.py b/python/sdist/amici/de_model_components.py index 85af284716..866a46d9b4 100644 --- a/python/sdist/amici/de_model_components.py +++ b/python/sdist/amici/de_model_components.py @@ -7,7 +7,6 @@ import sympy as sp from .import_utils import ( - RESERVED_SYMBOLS, ObservableTransformation, amici_time_symbol, cast_to_sym, @@ -66,12 +65,6 @@ def __init__( f"identifier must be sympy.Symbol, was {type(identifier)}" ) - if str(identifier) in RESERVED_SYMBOLS or ( - hasattr(identifier, "name") and identifier.name in RESERVED_SYMBOLS - ): - raise ValueError( - f'Cannot add model quantity with name "{name}", please rename.' - ) self._identifier: sp.Symbol = identifier if not isinstance(name, str): diff --git a/python/sdist/amici/jax/jaxcodeprinter.py b/python/sdist/amici/jax/jaxcodeprinter.py index b20697ef26..4ca7c927ff 100644 --- a/python/sdist/amici/jax/jaxcodeprinter.py +++ b/python/sdist/amici/jax/jaxcodeprinter.py @@ -4,6 +4,8 @@ from collections.abc import Iterable from logging import warning +from ..import_utils import RESERVED_SYMBOLS + import sympy as sp from sympy.printing.numpy import NumPyPrinter @@ -17,6 +19,12 @@ def _jnp_array_str(array) -> str: class AmiciJaxCodePrinter(NumPyPrinter): """JAX code printer""" + def _print_Symbol(self, expr: sp.Symbol) -> str: + name = super()._print_Symbol(expr) + if name in RESERVED_SYMBOLS and name != "t": + return f"amici_{name}" + return name + def doprint(self, expr: sp.Expr, assign_to: str | None = None) -> str: try: code = super().doprint(expr, assign_to) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 5c3ebd1456..8fbf3df3e8 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -28,6 +28,7 @@ from amici.de_export import is_valid_identifier from amici.import_utils import ( + RESERVED_SYMBOLS, strip_pysb, ) from amici.logging import get_logger, log_execution_time, set_log_level @@ -40,13 +41,20 @@ logger = get_logger(__name__, logging.ERROR) +def _sanitize(name: str) -> str: + return ( + f"amici_{name}" if name in RESERVED_SYMBOLS and name != "t" else name + ) + + def _jax_variable_assignments( model: DEModel, sym_names: tuple[str, ...] ) -> dict: return { - f"{sym_name.upper()}_SYMS": "".join( - str(strip_pysb(s)) + ", " for s in model.sym(sym_name) + f"{sym_name.upper()}_SYMS": ", ".join( + _sanitize(str(strip_pysb(s))) for s in model.sym(sym_name) ) + + ", " if model.sym(sym_name) else "_" for sym_name in sym_names @@ -63,7 +71,7 @@ def _jax_variable_equations( return { f"{eq_name.upper()}_EQ": "\n".join( code_printer._get_sym_lines( - (str(strip_pysb(s)) for s in model.sym(eq_name)), + (_sanitize(str(strip_pysb(s))) for s in model.sym(eq_name)), model.eq(eq_name).subs(subs), indent, ) @@ -78,7 +86,7 @@ def _jax_return_variables( ) -> dict: return { f"{eq_name.upper()}_RET": _jnp_array_str( - strip_pysb(s) for s in model.sym(eq_name) + _sanitize(str(strip_pysb(s))) for s in model.sym(eq_name) ) if model.sym(eq_name) else "jnp.array([])" diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 131db1a041..6b392ce519 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -35,7 +35,6 @@ from .de_model_components import symbol_to_type, Expression from .sympy_utils import smart_is_zero_matrix, smart_multiply from .import_utils import ( - RESERVED_SYMBOLS, _check_unsupported_functions, _get_str_symbol_identifiers, amici_time_symbol, @@ -691,7 +690,6 @@ def _build_ode_model( ) self._replace_compartments_with_volumes() - self._clean_reserved_symbols() self._process_time() ode_model = DEModel( @@ -2783,24 +2781,6 @@ def _replace_in_all_expressions( for spline in self.splines: spline._replace_in_all_expressions(old, new) - def _clean_reserved_symbols(self) -> None: - """ - Remove all reserved symbols from self.symbols - """ - for sym in RESERVED_SYMBOLS: - old_symbol = symbol_with_assumptions(sym) - new_symbol = symbol_with_assumptions(f"amici_{sym}") - self._replace_in_all_expressions( - old_symbol, new_symbol, replace_identifiers=True - ) - for symbols_ids, symbols in self.symbols.items(): - if old_symbol in symbols: - # reconstitute the whole dict in order to keep the ordering - self.symbols[symbols_ids] = { - new_symbol if k is old_symbol else k: v - for k, v in symbols.items() - } - def _sympify( self, var_or_math: libsbml.SBase diff --git a/python/tests/test_bngl.py b/python/tests/test_bngl.py index 42926e379a..52ad6dc8a5 100644 --- a/python/tests/test_bngl.py +++ b/python/tests/test_bngl.py @@ -77,12 +77,7 @@ def test_compare_to_pysb_simulation(example): with pytest.raises(ValueError, match="Conservation laws"): bngl2amici(model_file, outdir, compute_conservation_laws=True) - if example in ["empty_compartments_block", "motor"]: - with pytest.raises(ValueError, match="Cannot add"): - bngl2amici(model_file, outdir, **kwargs) - return - else: - bngl2amici(model_file, outdir, **kwargs) + bngl2amici(model_file, outdir, **kwargs) amici_model_module = amici.import_model_module(pysb_model.name, outdir) diff --git a/tests/sbml/utils.py b/tests/sbml/utils.py index cdc8153921..293cd2405e 100644 --- a/tests/sbml/utils.py +++ b/tests/sbml/utils.py @@ -34,7 +34,7 @@ def verify_results(settings, rdata, expected, wrapper, model, atol, rtol): new_key = expr_id.removeprefix("flux_") else: new_key = expr_id - if expr_id.removeprefix("amici_") in simulated.columns: + if expr_id in simulated.columns: continue # skip if already present expression_data[new_key] = rdata.w[:, expr_idx] @@ -48,12 +48,6 @@ def verify_results(settings, rdata, expected, wrapper, model, atol, rtol): axis=1, ) - # handle renamed reserved symbols - simulated.rename( - columns={c: c.replace("amici_", "") for c in simulated.columns}, - inplace=True, - ) - # SBML test suite case 01308 defines species with initialAmount and # hasOnlySubstanceUnits="true", but then request results as concentrations. requested_concentrations = [ @@ -150,9 +144,9 @@ def concentrations_to_amounts( ) or comp is None: continue - simulated.loc[:, species] *= simulated.loc[ - :, comp if comp in simulated.columns else f"amici_{comp}" - ] + if comp not in simulated.columns: + continue + simulated.loc[:, species] *= simulated.loc[:, comp] def write_result_file(