diff --git a/.github/workflows/test_petab_sciml.yml b/.github/workflows/test_petab_sciml.yml new file mode 100644 index 0000000000..adf04e2f05 --- /dev/null +++ b/.github/workflows/test_petab_sciml.yml @@ -0,0 +1,92 @@ +name: PEtab +on: + push: + branches: + - develop + - main + pull_request: + branches: + - main + - develop + - jax_sciml + merge_group: + workflow_dispatch: + +jobs: + build: + name: PEtab SciML Testsuite + + runs-on: ubuntu-latest + + env: + ENABLE_GCOV_COVERAGE: TRUE + + strategy: + matrix: + python-version: ["3.12"] + + steps: + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - uses: actions/checkout@v4 + with: + fetch-depth: 20 + + # todo, update after https://github.com/sebapersson/petab_sciml_testsuite/issues/14 is merged + - name: Download PEtab SciML test suite + run: | + git clone --depth 1 --branch main \ + https://github.com/FFroehlich/petab_sciml_testsuite \ + tests/sciml/testsuite + + - name: Install apt dependencies + uses: ./.github/actions/install-apt-dependencies + + # install dependencies + - name: apt + run: | + sudo apt-get update \ + && sudo apt-get install -y python3-venv + + - run: | + echo "${HOME}/.local/bin/" >> $GITHUB_PATH + + # install AMICI + - name: Install python package + run: scripts/installAmiciSource.sh + + - name: Install petab + run: | + source ./venv/bin/activate \ + && pip3 install wheel pytest shyaml pytest-cov + + # retrieve test models + - name: Download and install PEtab SciML + run: | + source ./venv/bin/activate \ + && python -m pip install git+https://github.com/petab-dev/petab_sciml.git@main#subdirectory=src/python \ + + + - name: Install petab + run: | + source ./venv/bin/activate \ + && python3 -m pip uninstall -y petab \ + && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@sciml \ + + - name: Run PEtab SciML testsuite + run: | + source ./venv/bin/activate \ + && pytest --cov-report=xml:coverage_petab_sciml.xml \ + --cov=amici tests/sciml/test_sciml.py + + - name: Codecov + if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev' + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage_petab_sciml.xml + flags: petab_sciml + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 0123a81757..1499441dc4 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,9 @@ models/model_calvetti/build/* amici_models/ +# PEtab SciML test suite (downloaded dynamically) +tests/sciml/testsuite/ + simulate_model_*_hdf.m simulate_model_*.m @@ -196,3 +199,4 @@ debug/* tests/benchmark_models/cache_fiddy/* venv/* .coverage +tests/sciml/models/* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..e69de29bb2 diff --git a/CHANGELOG.md b/CHANGELOG.md index 880b7f4a35..3305b344df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -74,6 +74,8 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni This only works on shared file systems, as the solver state is stored in a temporary HDF5 file. * `amici.ExpData` is now picklable. +* Implemented support for the [PEtab SciML](https://github.com/PEtab-dev/petab_sciml) + extension for the JAX interface. * The import function `sbml2amici`, `pysb2amici`, and `antimony2amici` now return an instance of the generated model class if called with `compile=True` (default). diff --git a/doc/conf.py b/doc/conf.py index a5e620384b..fa58b48cb0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -31,6 +31,7 @@ import exhale_multiproject_monkeypatch # noqa: F401 # need to import before setting typing.TYPE_CHECKING=True, fails otherwise + import amici import pandas as pd # noqa: F401 import sympy as sp # noqa: F401 @@ -365,6 +366,11 @@ def install_doxygen(): "ExpDataPtrVector": ":class:`amici.amici.ExpData`", } +# TODO: alias for forward type definition, remove after release of petab_sciml +autodoc_type_aliases = { + "NNModel": "petab_sciml.NNModel", +} + def process_docstring(app, what, name, obj, options, lines): # only apply in the amici.amici module diff --git a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb index c21d1f1f67..bfd26757d4 100644 --- a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -93,8 +93,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Access the results\n", - "results" + "# Define the simulation condition\n", + "simulation_condition = (\"model1_data1\",)\n", + "\n", + "# Access the results for the specified condition\n", + "ic = results[\"simulation_conditions\"].index(simulation_condition)\n", + "print(\"llh: \", results[\"llh\"][ic])\n", + "print(\"state variables: \", results[\"x\"][ic, :])" ] }, { @@ -356,7 +361,7 @@ "metadata": {}, "outputs": [], "source": [ - "grad._my" + "grad._my[ic, :]" ] }, { @@ -393,7 +398,7 @@ "nps = jax_problem._np_numeric[ic, :]\n", "\n", "# Load parameters for the specified condition\n", - "p = jax_problem.load_parameters(simulation_condition[0])\n", + "p = jax_problem.load_model_parameters(simulation_condition[0])\n", "\n", "\n", "# Define a function to compute the gradient with respect to dynamic timepoints\n", @@ -612,16 +617,16 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "b8382b0b2b68f49e", "metadata": {}, - "outputs": [], + "cell_type": "code", "source": [ "# Profile gradient computation using forward sensitivity analysis\n", "solver.set_sensitivity_order(amici.SensitivityOrder.first)\n", "solver.set_sensitivity_method(amici.SensitivityMethod.forward)" - ] + ], + "id": "81fe95a6e7f613f1", + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -687,8 +692,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.0" + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/doc/rtd_requirements.txt b/doc/rtd_requirements.txt index d014b27755..7c7c1b5d93 100644 --- a/doc/rtd_requirements.txt +++ b/doc/rtd_requirements.txt @@ -7,6 +7,8 @@ setuptools>=67.7.2 # https://github.com/pysb/pysb/pull/599 # for building the documentation, we don't care whether this fully works git+https://github.com/pysb/pysb@0afeaab385e9a1d813ecf6fdaf0153f4b91358af +# For forward type definition in generate_equinox +git+https://github.com/PEtab-dev/petab_sciml.git@727d177fd3f85509d0bdcc278b672e9eeafd2384#subdirectory=src/python matplotlib>=3.7.1 optax nbsphinx @@ -16,6 +18,7 @@ sphinx_rtd_theme>=1.2.0 petab[vis]>=0.2.0 sphinx-autodoc-typehints ipython>=8.13.2 +h5py>=3.14.0 breathe>=4.35.0 exhale>=0.3.7 -e git+https://github.com/mithro/sphinx-contrib-mithro#egg=sphinx-contrib-exhale-multiproject&subdirectory=sphinx-contrib-exhale-multiproject diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 0defef2fc8..d87638423d 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -165,6 +165,7 @@ def __init__( allow_reinit_fixpar_initcond: bool | None = True, generate_sensitivity_code: bool | None = True, model_name: str | None = "model", + hybridization: dict | None = None, ): """ Generate AMICI C++ files for the DE provided to the constructor. @@ -196,6 +197,10 @@ def __init__( :param model_name: name of the model to be used during code generation + + :param hybridization: + dict representation of the hybridization information in the PEtab YAML file, see + https://petab-sciml.readthedocs.io/latest/format.html#problem-yaml-file """ set_log_level(logger, verbose) @@ -237,6 +242,7 @@ def __init__( self.allow_reinit_fixpar_initcond: bool = allow_reinit_fixpar_initcond self._build_hints = set() self.generate_sensitivity_code: bool = generate_sensitivity_code + self.hybridisation = hybridization @log_execution_time("generating cpp code", logger) def generate_model_code(self) -> None: diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index b7fafb299a..7d8508a658 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -31,6 +31,7 @@ Event, EventObservable, Expression, + LogLikelihood, LogLikelihoodRZ, LogLikelihoodY, LogLikelihoodZ, @@ -39,6 +40,7 @@ Observable, ObservableParameter, Parameter, + Sigma, SigmaY, SigmaZ, State, @@ -204,6 +206,7 @@ def __init__( verbose: bool | int | None = False, simplify: Callable | None = _default_simplify, cache_simplify: bool = False, + hybridisation: bool = False, ): """ Create a new DEModel instance. @@ -2540,6 +2543,186 @@ def _process_heavisides( return dxdt + @property + def _components(self) -> list[ModelQuantity]: + """ + Returns the components of the model + + :return: + components of the model + """ + return ( + self._algebraic_states + + self._algebraic_equations + + self._conservation_laws + + self._constants + + self._differential_states + + self._event_observables + + self._events + + self._expressions + + self._log_likelihood_ys + + self._log_likelihood_zs + + self._log_likelihood_rzs + + self._observables + + self._parameters + + self._sigma_ys + + self._sigma_zs + + self._splines + ) + + def _process_hybridization(self, hybridization: dict) -> None: + """ + Parses the hybridization information and updates the model accordingly + + :param hybridization: + dict representation of the hybridization information in the PEtab YAML file, see + https://petab-sciml.readthedocs.io/latest/format.html#problem-yaml-file + """ + added_expressions = False + orig_obs = tuple([s.get_id() for s in self._observables]) + for net_id, net in hybridization.items(): + if net["static"]: + continue # do not integrate into ODEs, handle in amici.jax.petab + inputs = [ + comp + for comp in self._components + if str(comp.get_id()) in net["input_vars"] + ] + # sort inputs by order in input_vars + inputs = sorted( + inputs, + key=lambda comp: net["input_vars"].index(str(comp.get_id())), + ) + if len(inputs) != len(net["input_vars"]): + found_vars = {str(comp.get_id()) for comp in inputs} + missing_vars = set(net["input_vars"]) - found_vars + raise ValueError( + f"Could not find all input variables for neural network {net_id}. " + f"Missing variables: {sorted(missing_vars)}" + ) + for inp in inputs: + if isinstance( + inp, + Sigma + | LogLikelihood + | Event + | ConservationLaw + | Observable, + ): + raise NotImplementedError( + f"{inp.get_name()} ({type(inp)}) is not supported as neural network input." + ) + + outputs = { + out_var: {"comp": comp, "ind": net["output_vars"][out_var]} + for comp in self._components + if (out_var := str(comp.get_id())) in net["output_vars"] + # TODO: SYNTAX NEEDS to CHANGE + or (out_var := str(comp.get_id()) + "_dot") + in net["output_vars"] + } + if len(outputs.keys()) != len(net["output_vars"]): + found_vars = set(outputs.keys()) + missing_vars = set(net["output_vars"]) - found_vars + raise ValueError( + f"Could not find all output variables for neural network {net_id}. " + f"Missing variables: {sorted(missing_vars)}" + ) + + for out_var, parts in outputs.items(): + comp = parts["comp"] + # remove output from model components + if isinstance(comp, Parameter): + self._parameters.remove(comp) + elif isinstance(comp, Expression): + self._expressions.remove(comp) + elif isinstance(comp, DifferentialState): + pass + else: + raise NotImplementedError( + f"{comp.get_name()} ({type(comp)}) is not supported as neural network output." + ) + + # generate dummy Function + out_val = sp.Function(net_id)( + *[input.get_id() for input in inputs], parts["ind"] + ) + + # add to the model + if isinstance(comp, DifferentialState): + ix = self._differential_states.index(comp) + # TODO: SYNTAX NEEDS to CHANGE + if out_var.endswith("_dot"): + self._differential_states[ix].set_dt(out_val) + else: + self._differential_states[ix].set_val(out_val) + else: + self.add_component( + Expression( + identifier=comp.get_id(), + name=net_id, + value=out_val, + ) + ) + added_expressions = True + + observables = { + ob_var: {"comp": comp, "ind": net["observable_vars"][ob_var]} + for comp in self._components + if (ob_var := str(comp.get_id())) in net["observable_vars"] + # # TODO: SYNTAX NEEDS to CHANGE + # or (ob_var := str(comp.get_id()) + "_dot") + # in net["observable_vars"] + } + if len(observables.keys()) != len(net["observable_vars"]): + found_vars = set(observables.keys()) + missing_vars = set(net["observable_vars"]) - found_vars + raise ValueError( + f"Could not find all observable variables for neural network {net_id}. " + f"Missing variables: {sorted(missing_vars)}" + ) + + for ob_var, parts in observables.items(): + comp = parts["comp"] + if isinstance(comp, Observable): + self._observables.remove(comp) + else: + raise ValueError( + f"{comp.get_name()} ({type(comp)}) is not an observable." + ) + out_val = sp.Function(net_id)( + *[input.get_id() for input in inputs], parts["ind"] + ) + # add to the model + self.add_component( + Observable( + identifier=comp.get_id(), + name=net_id, + value=out_val, + ) + ) + + new_order = [orig_obs.index(s.get_id()) for s in self._observables] + self._observables = [self._observables[i] for i in new_order] + + if added_expressions: + # toposort expressions + w_sorted = toposort_symbols( + dict( + zip( + self.sym("w"), + self.eq("w"), + strict=True, + ) + ) + ) + old_syms = tuple(self._syms["w"]) + topo_expr_syms = tuple(w_sorted.keys()) + new_order = [old_syms.index(s) for s in topo_expr_syms] + self._expressions = [self._expressions[i] for i in new_order] + self._syms["w"] = sp.Matrix(topo_expr_syms) + self._eqs["w"] = sp.Matrix(list(w_sorted.values())) + def get_explicit_roots(self) -> set[sp.Expr]: """ Returns explicit formulas for all discontinuities (events) diff --git a/python/sdist/amici/jax/__init__.py b/python/sdist/amici/jax/__init__.py index b05de6c87e..5c2e24fb31 100644 --- a/python/sdist/amici/jax/__init__.py +++ b/python/sdist/amici/jax/__init__.py @@ -10,6 +10,7 @@ from warnings import warn from amici.jax.model import JAXModel +from amici.jax.nn import Flatten, cat, generate_equinox, tanhshrink from amici.jax.petab import ( JAXProblem, ReturnValue, @@ -26,7 +27,11 @@ __all__ = [ "JAXModel", "JAXProblem", + "Flatten", + "generate_equinox", "run_simulations", "petab_simulate", "ReturnValue", + "tanhshrink", + "cat", ] diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index 804b06f5a4..b5247d2eab 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -3,18 +3,24 @@ import equinox as eqx import jax.numpy as jnp +import jax.random as jr +import jaxtyping as jt from interpax import interp1d from jax.numpy import inf as oo from jax.numpy import nan as nan +from amici import _module_from_path from amici.jax.model import JAXModel, safe_div, safe_log +TPL_NET_IMPORTS + class JAXModel_TPL_MODEL_NAME(JAXModel): api_version = TPL_MODEL_API_VERSION def __init__(self): self.jax_py_file = Path(__file__).resolve() + self.nns = {TPL_NETS} self.parameters = TPL_P_VALUES super().__init__() diff --git a/python/sdist/amici/jax/jaxcodeprinter.py b/python/sdist/amici/jax/jaxcodeprinter.py index b20697ef26..dc50faa3aa 100644 --- a/python/sdist/amici/jax/jaxcodeprinter.py +++ b/python/sdist/amici/jax/jaxcodeprinter.py @@ -5,6 +5,7 @@ from logging import warning import sympy as sp +from sympy.core.function import UndefinedFunction from sympy.printing.numpy import NumPyPrinter @@ -42,6 +43,12 @@ def _print_Mul(self, expr: sp.Expr) -> str: return super()._print_Mul(expr) return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})" + def _print_Function(self, expr: sp.Expr) -> str: + if isinstance(expr.func, UndefinedFunction): + return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]" + else: + return super()._print_Function(expr) + def _print_Max(self, expr: sp.Expr) -> str: """ Print the max function, replacing it with jnp.max. diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 24be436bdf..c2134bf9f6 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -49,6 +49,7 @@ class JAXModel(eqx.Module): MODEL_API_VERSION = "0.0.4" api_version: str jax_py_file: Path + nns: dict parameters: jnp.ndarray = field(default_factory=lambda: jnp.array([])) def __init__(self): @@ -544,6 +545,8 @@ def simulate_condition( x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), + init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]), + init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]), ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]: @@ -587,6 +590,10 @@ def simulate_condition( mask for re-initialization. If `True`, the corresponding state variable is re-initialized. :param x_reinit: re-initialized state vector. If not provided, the state vector is not re-initialized. + :param init_override: + override model input e.g. with neural net outputs. If not provided, the inputs are not overridden. + :param init_override_mask: + mask for input override. If `True`, the corresponding input is replaced with the corresponding value from `init_override`. :param ts_mask: mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. @@ -601,6 +608,11 @@ def simulate_condition( if x_preeq.shape[0]: x = x_preeq + elif init_override.shape[0]: + x_def = self._x0(t0, p) + x = jnp.squeeze( + jnp.where(init_override_mask, init_override, x_def) + ) else: x = self._x0(t0, p) diff --git a/python/sdist/amici/jax/nn.py b/python/sdist/amici/jax/nn.py new file mode 100644 index 0000000000..0bd16b4655 --- /dev/null +++ b/python/sdist/amici/jax/nn.py @@ -0,0 +1,419 @@ +from pathlib import Path + +import equinox as eqx +import jax.numpy as jnp + +from amici import amiciModulePath +from amici._codegen.template import apply_template + + +class Flatten(eqx.Module): + """Custom implementation of a `torch.flatten` layer for Equinox.""" + + start_dim: int + end_dim: int + + def __init__(self, start_dim: int, end_dim: int): + super().__init__() + self.start_dim = start_dim + self.end_dim = end_dim + + def __call__(self, x): + if self.end_dim == -1: + return jnp.reshape(x, x.shape[: self.start_dim] + (-1,)) + else: + return jnp.reshape( + x, x.shape[: self.start_dim] + (-1,) + x.shape[self.end_dim :] + ) + + +def tanhshrink(x: jnp.ndarray) -> jnp.ndarray: + """Custom implementation of the torch.nn.Tanhshrink activation function for JAX.""" + return x - jnp.tanh(x) + + +def cat(tensors, axis: int = 0): + """Alias for torch.cat using JAX's concatenate/stack function. + + Handles both regular arrays and zero-dimensional (scalar) arrays by + using stack instead of concatenate for 0D arrays. + + :param tensors: + List of arrays to concatenate + :param axis: + Dimension along which to concatenate (default: 0) + + :return: + Concatenated array + """ + # Check if all tensors are 0-dimensional (scalars) + if all(jnp.ndim(t) == 0 for t in tensors): + # For 0D arrays, use stack instead of concatenate + return jnp.stack(tensors, axis=axis) + return jnp.concatenate(tensors, axis=axis) + + +def generate_equinox( + nn_model: "NNModel", # noqa: F821 + filename: Path | str, + frozen_layers: dict[str, bool] | None = None, +) -> None: + """ + Generate Equinox model file from petab_sciml neural network object. + + :param nn_model: + Neural network model in petab_sciml format + :param filename: + output filename for generated Equinox model + :param frozen_layers: + list of layer names to freeze during training + """ + # TODO: move to top level import and replace forward type definitions + from petab_sciml import Layer + + if frozen_layers is None: + frozen_layers = {} + + filename = Path(filename) + layer_indent = 12 + node_indent = 8 + + layers = {layer.layer_id: layer for layer in nn_model.layers} + + # Collect placeholder nodes to determine input handling + placeholder_nodes = [ + node for node in nn_model.forward if node.op == "placeholder" + ] + input_names = [node.name for node in placeholder_nodes] + + # Generate input unpacking line + if len(input_names) == 1: + input_unpack = f"{input_names[0]} = input" + else: + input_unpack = f"{', '.join(input_names)} = input" + + # Generate forward pass lines (excluding placeholder nodes) + forward_lines = [ + _generate_forward( + node, + node_indent, + frozen_layers, + layers.get( + node.target, + Layer(layer_id="dummy", layer_type="Linear"), + ).layer_type, + ) + for node in nn_model.forward + ] + # Filter out empty lines from placeholder processing + forward_lines = [line for line in forward_lines if line] + # Prepend input unpacking + forward_code = f"{' ' * node_indent}{input_unpack}\n" + "\n".join( + forward_lines + ) + + tpl_data = { + "MODEL_ID": nn_model.nn_model_id, + "LAYERS": ",\n".join( + [ + _generate_layer(layer, layer_indent, ilayer) + for ilayer, layer in enumerate(nn_model.layers) + ] + )[layer_indent:], + "FORWARD": forward_code[node_indent:], + "INPUT": ", ".join([f"'{inp.input_id}'" for inp in nn_model.inputs]), + "OUTPUT": ", ".join( + [ + f"'{arg}'" + for arg in next( + node for node in nn_model.forward if node.op == "output" + ).args + ] + ), + "N_LAYERS": len(nn_model.layers), + } + + filename.parent.mkdir(parents=True, exist_ok=True) + + apply_template( + Path(amiciModulePath) / "jax" / "nn.template.py", + filename, + tpl_data, + ) + + +def _process_argval(v): + """ + Process argument value for layer instantiation string + """ + if isinstance(v, str): + return f"'{v}'" + if isinstance(v, bool): + return str(v) + return str(v) + + +def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F821 + """ + Generate layer definition string for a given layer + + :param layer: + petab_sciml Layer object + :param indent: + indentation level for generated string + :param ilayer: + layer index for key generation + + :return: + string defining the layer in equinox syntax + """ + if layer.layer_type.startswith( + ("BatchNorm", "AlphaDropout", "InstanceNorm") + ): + raise NotImplementedError( + f"{layer.layer_type} layers currently not supported" + ) + if layer.layer_type.startswith("MaxPool") and "dilation" in layer.args: + raise NotImplementedError("MaxPool layers with dilation not supported") + if layer.layer_type.startswith("Dropout") and "inplace" in layer.args: + raise NotImplementedError("Dropout layers with inplace not supported") + if layer.layer_type == "Bilinear": + raise NotImplementedError("Bilinear layers not supported") + + # mapping of layer names in sciml yaml format to equinox/custom amici implementations + layer_map = { + "Dropout1d": "eqx.nn.Dropout", + "Dropout2d": "eqx.nn.Dropout", + "Flatten": "amici.jax.Flatten", + } + + # mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations + kwarg_map = { + "Linear": { + "bias": "use_bias", + }, + "Conv1d": { + "bias": "use_bias", + }, + "Conv2d": { + "bias": "use_bias", + }, + "LayerNorm": { + "elementwise_affine": "use_bias", # Deprecation warning - replace LayerNorm(elementwise_affine) with LayerNorm(use_bias) + "normalized_shape": "shape", + }, + } + # list of keyword arguments to ignore when generating layer, as they are not supported in equinox (see above) + kwarg_ignore = { + "Dropout1d": ("inplace",), + "Dropout2d": ("inplace",), + } + # construct argument string for layer instantiation + kwargs = [ + f"{kwarg_map.get(layer.layer_type, {}).get(k, k)}={_process_argval(v)}" + for k, v in layer.args.items() + if k not in kwarg_ignore.get(layer.layer_type, ()) + ] + # add key for initialization + if layer.layer_type in ( + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + ): + kwargs += [f"key=keys[{ilayer}]"] + type_str = layer_map.get(layer.layer_type, f"eqx.nn.{layer.layer_type}") + layer_str = f"{type_str}({', '.join(kwargs)})" + return f"{' ' * indent}'{layer.layer_id}': {layer_str}" + + +def _format_function_call( + var_name: str, fun_str: str, args: list, kwargs: list[str], indent: int +) -> str: + """ + Utility function to format a function call assignment string. + + :param var_name: + name of the variable to assign the result to + :param fun_str: + string representation of the function to call + :param args: + list of positional arguments + :param kwargs: + list of keyword arguments as strings + :param indent: + indentation level for generated string + + :return: + formatted string representing the function call assignment + """ + args_str = ", ".join([f"{arg}" for arg in args]) + kwargs_str = ", ".join(kwargs) + all_args = ", ".join(filter(None, [args_str, kwargs_str])) + return f"{' ' * indent}{var_name} = {fun_str}({all_args})" + + +def _process_layer_call( + node: "Node", # noqa: F821 + layer_type: str, + frozen_layers: dict[str, bool], +) -> tuple[str, str]: + """ + Process a layer (call_module) node and return function string and optional tree string. + + :param node: + petab sciml Node object representing a layer call + :param layer_type: + petab sciml layer type of the node + :param frozen_layers: + dict of layer names to boolean indicating whether layer is frozen + + :return: + tuple of (function_string, tree_string) where tree_string is empty if no tree is needed + """ + fun_str = f"self.layers['{node.target}']" + tree_string = "" + + # Handle frozen layers + if node.name in frozen_layers: + if frozen_layers[node.name]: + arr_attr = frozen_layers[node.name] + get_lambda = f"lambda layer: getattr(layer, '{arr_attr}')" + replacer = "replace_fn = lambda arr: jax.lax.stop_gradient(arr)" + tree_string = f"tree_{node.name} = eqx.tree_at({get_lambda}, {fun_str}, {replacer})" + fun_str = f"tree_{node.name}" + else: + fun_str = f"jax.lax.stop_gradient({fun_str})" + + # Handle vmap for certain layer types + if layer_type.startswith(("Conv", "Linear", "LayerNorm")): + if layer_type in ("LayerNorm",): + dims = f"len({fun_str}.shape)+1" + elif layer_type == "Linear": + dims = 2 + elif layer_type.endswith("1d"): + dims = 3 + elif layer_type.endswith("2d"): + dims = 4 + elif layer_type.endswith("3d"): + dims = 5 + fun_str = f"(jax.vmap({fun_str}) if len({node.args[0]}.shape) == {dims} else {fun_str})" + + return fun_str, tree_string + + +def _process_activation_call(node: "Node") -> str: # noqa: F821 + """ + Process an activation function (call_function/call_method) node and return function string. + + :param node: + petab sciml Node object representing an activation function call + + :return: + string representation of the activation function + """ + # Mapping of function names in sciml yaml format to equinox/custom amici implementations + activation_map = { + "hardtanh": "jax.nn.hard_tanh", + "hardsigmoid": "jax.nn.hard_sigmoid", + "hardswish": "jax.nn.hard_swish", + "tanhshrink": "amici.jax.tanhshrink", + "softsign": "jax.nn.soft_sign", + "cat": "amici.jax.cat", + } + + # Validate hardtanh parameters + if node.target == "hardtanh": + if node.kwargs.pop("min_val", -1.0) != -1.0: + raise NotImplementedError( + "min_val != -1.0 not supported for hardtanh" + ) + if node.kwargs.pop("max_val", 1.0) != 1.0: + raise NotImplementedError( + "max_val != 1.0 not supported for hardtanh" + ) + + # Handle kwarg aliasing for cat (dim -> axis) + if node.target == "cat": + if "dim" in node.kwargs: + node.kwargs["axis"] = node.kwargs.pop("dim") + # Convert list of variable names to proper bracket-enclosed list + if isinstance(node.args[0], list): + # node.args[0] is a list like ['net_input1', 'net_input2'] + # We need to convert it to a single string representing the list: [net_input1, net_input2] + node.args = tuple( + ["[" + ", ".join(node.args[0]) + "]"] + list(node.args[1:]) + ) + + return activation_map.get(node.target, f"jax.nn.{node.target}") + + +def _generate_forward( + node: "Node", # noqa: F821 + indent, + frozen_layers: dict[str, bool] | None = None, + layer_type: str = "", +) -> str: + """ + Generate forward pass line for a given node + + :param node: + petab sciml Node object representing a step in the forward pass + :param indent: + indentation level for generated string + :param frozen_layers: + dict of layer names to boolean indicating whether layer is frozen + :param layer_type: + petab sciml layer type of the node (only relevant for call_module nodes) + + :return: + string defining the forward pass implementation for the given node in equinox syntax + """ + if frozen_layers is None: + frozen_layers = {} + + # Handle placeholder nodes - skip individual processing, handled collectively in generate_equinox + if node.op == "placeholder": + return "" + + # Handle output nodes + if node.op == "output": + args_str = ", ".join([f"{arg}" for arg in node.args]) + return f"{' ' * indent}{node.target} = {args_str}" + + # Process layer calls + tree_string = "" + if node.op == "call_module": + fun_str, tree_string = _process_layer_call( + node, layer_type, frozen_layers + ) + + # Process activation function calls + if node.op in ("call_function", "call_method"): + fun_str = _process_activation_call(node) + + # Build kwargs list, filtering out unsupported arguments + kwargs = [ + f"{k}={item}" + for k, item in node.kwargs.items() + if k not in ("inplace",) + ] + + # Add key parameter for Dropout layers + if layer_type.startswith("Dropout"): + kwargs += ["key=key"] + + # Format the function call + if node.op in ("call_module", "call_function", "call_method"): + result = _format_function_call( + node.name, fun_str, node.args, kwargs, indent + ) + # Prepend tree_string if needed for frozen layers + if tree_string: + return f"{' ' * indent}{tree_string}\n{result}" + return result + + raise NotImplementedError(f"Operation {node.op} not supported") diff --git a/python/sdist/amici/jax/nn.template.py b/python/sdist/amici/jax/nn.template.py new file mode 100644 index 0000000000..6b20a39f1b --- /dev/null +++ b/python/sdist/amici/jax/nn.template.py @@ -0,0 +1,28 @@ +# ruff: noqa: F401, F821, F841 +import equinox as eqx +import jax +import jax.nn +import jax.numpy as jnp +import jax.random as jr + +import amici.jax.nn + + +class TPL_MODEL_ID(eqx.Module): + layers: dict + inputs: list[str] + outputs: list[str] + + def __init__(self, key): + super().__init__() + keys = jr.split(key, TPL_N_LAYERS) + self.layers = {TPL_LAYERS} + self.inputs = [TPL_INPUT] + self.outputs = [TPL_OUTPUT] + + def forward(self, input, key=None): + TPL_FORWARD + return output + + +net = TPL_MODEL_ID diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index cbe9b2b310..fa8fa259d6 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -6,7 +6,7 @@ The user generally won't have to directly call any function from this module as this will be done by :py:func:`amici.pysb_import.pysb2jax`, -:py:func:`amici.sbml_import.SbmlImporter.sbml2jax` and +:py:func:`amici.sbml_import.SbmlImporter.` and :py:func:`amici.petab_import.import_model`. """ @@ -29,6 +29,7 @@ ) from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str from amici.jax.model import JAXModel +from amici.jax.nn import generate_equinox from amici.logging import get_logger, log_execution_time, set_log_level from amici.sympy_utils import ( _custom_pow_eval_derivative, @@ -123,6 +124,7 @@ def __init__( outdir: Path | str | None = None, verbose: bool | int | None = False, model_name: str | None = "model", + hybridization: dict[str, dict] = None, ): """ Generate AMICI jax files for the ODE provided to the constructor. @@ -139,6 +141,10 @@ def __init__( :param model_name: name of the model to be used during code generation + + :param hybridization: + dict representation of the hybridization information in the PEtab YAML file, see + https://petab-sciml.readthedocs.io/latest/format.html#problem-yaml-file """ set_log_level(logger, verbose) @@ -161,6 +167,8 @@ def __init__( self.model: DEModel = ode_model + self.hybridization = hybridization if hybridization is not None else {} + self._code_printer = AmiciJaxCodePrinter() @log_execution_time("generating jax code", logger) @@ -173,6 +181,7 @@ def generate_model_code(self) -> None: ): self._prepare_model_folder() self._generate_jax_code() + self._generate_nn_code() def _prepare_model_folder(self) -> None: """ @@ -261,6 +270,14 @@ def _generate_jax_code(self) -> None: # can flag conflicts in the future "MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'", }, + "NET_IMPORTS": "\n".join( + f"{net} = _module_from_path('{net}', Path(__file__).parent / '{net}.py')" + for net in self.hybridization.keys() + ), + "NETS": ",\n".join( + f'"{net}": {net}.net(jr.PRNGKey(0))' + for net in self.hybridization.keys() + ), } apply_template( @@ -269,6 +286,14 @@ def _generate_jax_code(self) -> None: tpl_data, ) + def _generate_nn_code(self) -> None: + for net_name, net in self.hybridization.items(): + generate_equinox( + net["model"], + self.model_path / f"{net_name}.py", + net["frozen_layers"], + ) + def _implicit_roots(self) -> list[sp.Expr]: """Return root functions that require rootfinding.""" roots = [] diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 0335851271..bb4a732e16 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -1,6 +1,8 @@ """PEtab wrappers for JAX models.""" "" import copy +import logging +import re import shutil from collections.abc import Callable, Iterable, Sized from numbers import Number @@ -8,6 +10,7 @@ import diffrax import equinox as eqx +import h5py import jax.lax import jax.numpy as jnp import jaxtyping as jt @@ -19,6 +22,7 @@ from amici import _module_from_path from amici.jax.model import JAXModel, ReturnValue +from amici.logging import get_logger from amici.petab.parameter_mapping import ( ParameterMappingForCondition, create_parameter_mapping, @@ -43,6 +47,8 @@ petab.LOG10: 2, } +logger = get_logger(__name__, logging.WARNING) + def jax_unscale( parameter: jnp.float_, @@ -68,6 +74,31 @@ def jax_unscale( raise ValueError(f"Invalid parameter scaling: {scale_str}") +# IDEA: Implement this class in petab-sciml instead? +class HybridProblem(petab.Problem): + hybridization_df: pd.DataFrame + + def __init__(self, petab_problem: petab.Problem): + self.__dict__.update(petab_problem.__dict__) + self.hybridization_df = _get_hybridization_df(petab_problem) + + +def _get_hybridization_df(petab_problem): + if "sciml" in petab_problem.extensions_config: + hybridizations = [ + pd.read_csv(hf, sep="\t", index_col=0) + for hf in petab_problem.extensions_config["sciml"][ + "hybridization_files" + ] + ] + hybridization_df = pd.concat(hybridizations) + return hybridization_df + + +def _get_hybrid_petab_problem(petab_problem: petab.Problem): + return HybridProblem(petab_problem) + + class JAXProblem(eqx.Module): """ PEtab problem wrapper for JAX models. @@ -101,7 +132,7 @@ class JAXProblem(eqx.Module): _np_mask: np.ndarray _np_indices: np.ndarray _petab_measurement_indices: np.ndarray - _petab_problem: petab.Problem + _petab_problem: petab.Problem | HybridProblem def __init__(self, model: JAXModel, petab_problem: petab.Problem): """ @@ -112,10 +143,12 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): :param petab_problem: PEtab problem to simulate. """ - self.model = model scs = petab_problem.get_simulation_conditions_from_measurement_df() self.simulation_conditions = tuple(tuple(sc) for sc in scs.values) - self._petab_problem = petab_problem + self._petab_problem = _get_hybrid_petab_problem(petab_problem) + self.parameters, self.model = ( + self._initialize_model_with_nominal_values(model) + ) self._parameter_mappings = self._get_parameter_mappings(scs) ( self._ts_dyn, @@ -133,8 +166,6 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): self._np_indices, ) = self._get_measurements(scs) - self.parameters = self._get_nominal_parameter_values() - def save(self, directory: Path): """ Save the problem to a directory. @@ -496,19 +527,238 @@ def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: ) return tuple(tuple(row) for _, row in simulation_conditions.iterrows()) - def _get_nominal_parameter_values(self) -> jt.Float[jt.Array, "np"]: + def _initialize_model_parameters(self, model: JAXModel) -> dict: + """ + Initialize model parameter structure with zeros. + + :param model: + JAX model with neural networks + + :return: + Nested dictionary structure for model parameters + """ + return { + net_id: { + layer_id: { + attribute: jnp.zeros_like(getattr(layer, attribute)) + for attribute in ["weight", "bias"] + if hasattr(layer, attribute) + } + for layer_id, layer in nn.layers.items() + } + for net_id, nn in model.nns.items() + } + + def _load_parameter_arrays_from_files(self) -> dict: + """ + Load neural network parameter arrays from HDF5 files. + + :return: + Dictionary mapping network IDs to parameter arrays + """ + if not self._petab_problem.extensions_config: + return {} + + array_files = self._petab_problem.extensions_config["sciml"].get( + "array_files", [] + ) + + return { + file_spec.split("_")[0]: h5py.File(file_spec, "r")["parameters"][ + file_spec.split("_")[0] + ] + for file_spec in array_files + if "parameters" in h5py.File(file_spec, "r").keys() + } + + def _load_input_arrays_from_files(self) -> dict: + """ + Load neural network input arrays from HDF5 files. + + :return: + Dictionary mapping network IDs to input arrays + """ + if not self._petab_problem.extensions_config: + return {} + + array_files = self._petab_problem.extensions_config["sciml"].get( + "array_files", [] + ) + + return { + file_spec.split("_")[0]: h5py.File(file_spec, "r")["inputs"] + for file_spec in array_files + if "inputs" in h5py.File(file_spec, "r").keys() + } + + def _parse_parameter_name( + self, pname: str, model_pars: dict + ) -> list[tuple[str, str]]: + """ + Parse parameter name to determine which layers and attributes to set. + + :param pname: + Parameter name from PEtab (format: net.layer.attribute) + :param model_pars: + Model parameters dictionary + + :return: + List of (layer_name, attribute_name) tuples to set + """ + net = pname.split("_")[0] + nn = model_pars[net] + to_set = [] + + name_parts = pname.split(".") + + if len(name_parts) > 1: + layer_name = name_parts[1] + layer = nn[layer_name] + if len(name_parts) > 2: + # Specific attribute specified + attribute_name = name_parts[2] + to_set.append((layer_name, attribute_name)) + else: + # All attributes of the layer + to_set.extend( + [(layer_name, attribute) for attribute in layer.keys()] + ) + else: + # All layers and attributes + to_set.extend( + [ + (layer_name, attribute) + for layer_name, layer in nn.items() + for attribute in layer.keys() + ] + ) + + return to_set + + def _extract_nominal_values_from_petab( + self, model: JAXModel, model_pars: dict, par_arrays: dict + ) -> None: + """ + Extract nominal parameter values from PEtab problem and populate model_pars. + + :param model: + JAX model + :param model_pars: + Model parameters dictionary to populate (modified in place) + :param par_arrays: + Parameter arrays loaded from files + """ + for pname, row in self._petab_problem.parameter_df.iterrows(): + net = pname.split("_")[0] + if net not in model.nns: + continue + + nn = model_pars[net] + scalar = True + + # Determine value source (scalar from PEtab or array from file) + if np.isnan(row[petab.NOMINAL_VALUE]): + value = par_arrays[net] + scalar = False + else: + value = float(row[petab.NOMINAL_VALUE]) + + # Parse parameter name and set values + to_set = self._parse_parameter_name(pname, model_pars) + + for layer, attribute in to_set: + if scalar: + nn[layer][attribute] = value * jnp.ones_like( + getattr(model.nns[net].layers[layer], attribute) + ) + else: + nn[layer][attribute] = jnp.array( + value[layer][attribute][:] + ) + + def _set_model_parameters( + self, model: JAXModel, model_pars: dict + ) -> JAXModel: + """ + Set parameter values in the model using equinox tree_at. + + :param model: + JAX model to update + :param model_pars: + Dictionary of parameter values to set + + :return: + Updated JAX model + """ + for net_id in model_pars: + for layer_id in model_pars[net_id]: + for attribute in model_pars[net_id][layer_id]: + logger.debug( + f"Setting {attribute} of layer {layer_id} in network " + f"{net_id} to {model_pars[net_id][layer_id][attribute]}" + ) + model = eqx.tree_at( + lambda model: getattr( + model.nns[net_id].layers[layer_id], attribute + ), + model, + model_pars[net_id][layer_id][attribute], + ) + return model + + def _set_input_arrays( + self, model: JAXModel, nn_input_arrays: dict, model_pars: dict + ) -> JAXModel: """ - Get the nominal parameter values for the model based on the nominal values in the PEtab problem. + Set input arrays in the model if provided. + + :param model: + JAX model to update + :param nn_input_arrays: + Input arrays loaded from files + :param model_pars: + Model parameters dictionary (for network IDs) :return: - jax array with nominal parameter values + Updated JAX model + """ + if len(nn_input_arrays) == 0: + return model + + for net_id in model_pars: + input_array = { + input: { + k: jnp.array( + arr[:], + dtype=jnp.float64 + if jax.config.jax_enable_x64 + else jnp.float32, + ) + for k, arr in nn_input_arrays[net_id][input].items() + } + for input in model.nns[net_id].inputs + } + model = eqx.tree_at( + lambda model: model.nns[net_id].inputs, model, input_array + ) + + return model + + def _create_scaled_parameter_array(self) -> jt.Float[jt.Array, "np"]: + """ + Create array of scaled nominal parameter values for estimation. + + :return: + JAX array of scaled parameter values """ return jnp.array( [ petab.scale( - self._petab_problem.parameter_df.loc[ - pval, petab.NOMINAL_VALUE - ], + float( + self._petab_problem.parameter_df.loc[ + pval, petab.NOMINAL_VALUE + ] + ), self._petab_problem.parameter_df.loc[ pval, petab.PARAMETER_SCALE ], @@ -517,6 +767,70 @@ def _get_nominal_parameter_values(self) -> jt.Float[jt.Array, "np"]: ] ) + def _initialize_model_with_nominal_values( + self, model: JAXModel + ) -> tuple[jt.Float[jt.Array, "np"], JAXModel]: + """ + Initialize the model with nominal parameter values and inputs from the PEtab problem. + + This method: + - Initializes model parameter structure + - Loads parameter and input arrays from HDF5 files + - Extracts nominal values from PEtab problem + - Sets parameter values in the model + - Sets input arrays in the model + - Creates scaled parameter array to initialized to nominal values + + :param model: + JAX model to initialize + + :return: + Tuple of (scaled parameter array, initialized model) + """ + # Initialize model parameters structure + model_pars = self._initialize_model_parameters(model) + + # Load arrays from files (getters) + par_arrays = self._load_parameter_arrays_from_files() + nn_input_arrays = self._load_input_arrays_from_files() + + # Extract nominal values from PEtab problem + self._extract_nominal_values_from_petab(model, model_pars, par_arrays) + + # Set values in model (setters) + model = self._set_model_parameters(model, model_pars) + model = self._set_input_arrays(model, nn_input_arrays, model_pars) + + # Create scaled parameter array + parameter_array = self._create_scaled_parameter_array() + + return parameter_array, model + + def _get_inputs(self) -> dict: + if self._petab_problem.mapping_df is None: + return {} + inputs = {net: {} for net in self.model.nns.keys()} + for petab_id, row in self._petab_problem.mapping_df.iterrows(): + if (filepath := Path(petab_id)).is_file(): + data_flat = pd.read_csv(filepath, sep="\t").sort_values( + by="ix" + ) + shape = tuple( + np.stack( + data_flat["ix"] + .astype(str) + .str.split(";") + .apply(np.array) + ) + .astype(int) + .max(axis=0) + + 1 + ) + inputs[row["netId"]][row[petab.MODEL_ENTITY_ID]] = data_flat[ + "value" + ].values.reshape(shape) + return inputs + @property def parameter_ids(self) -> list[str]: """ @@ -526,7 +840,29 @@ def parameter_ids(self) -> list[str]: PEtab parameter ids """ return self._petab_problem.parameter_df[ - self._petab_problem.parameter_df[petab.ESTIMATE] == 1 + self._petab_problem.parameter_df[petab.ESTIMATE] + == 1 + & pd.to_numeric( + self._petab_problem.parameter_df[petab.NOMINAL_VALUE], + errors="coerce", + ).notna() + ].index.tolist() + + @property + def nn_output_ids(self) -> list[str]: + """ + Parameter ids that are estimated in the PEtab problem. Same ordering as values in :attr:`parameters`. + + :return: + PEtab parameter ids + """ + if self._petab_problem.mapping_df is None: + return [] + return self._petab_problem.mapping_df[ + self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[1] + .str.startswith("output") ].index.tolist() def get_petab_parameter_by_id(self, name: str) -> jnp.float_: @@ -557,7 +893,128 @@ def _unscale( [jax_unscale(pval, scale) for pval, scale in zip(p, scales)] ) - def load_parameters( + def _eval_nn(self, output_par: str, condition_id: str): + net_id = self._petab_problem.mapping_df.loc[ + output_par, petab.MODEL_ENTITY_ID + ].split(".")[0] + nn = self.model.nns[net_id] + + def _is_net_input(model_id): + comps = model_id.split(".") + return comps[0] == net_id and comps[1].startswith("inputs") + + model_id_map = ( + self._petab_problem.mapping_df[ + self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID].apply( + _is_net_input + ) + ] + .reset_index() + .set_index(petab.MODEL_ENTITY_ID)[petab.PETAB_ENTITY_ID] + .to_dict() + ) + + condition_input_map = ( + dict( + [ + ( + petab_id, + self._petab_problem.parameter_df.loc[ + self._petab_problem.condition_df.loc[ + condition_id, petab_id + ], + petab.NOMINAL_VALUE, + ], + ) + if self._petab_problem.condition_df.loc[ + condition_id, petab_id + ] + in self._petab_problem.parameter_df.index + else ( + petab_id, + np.float64( + self._petab_problem.condition_df.loc[ + condition_id, petab_id + ] + ), + ) + for petab_id in model_id_map.values() + ] + ) + if not self._petab_problem.condition_df.empty + else {} + ) + + hybridization_parameter_map = { + petab_id: self._petab_problem.hybridization_df.loc[ + petab_id, "targetValue" + ] + for petab_id in model_id_map.values() + if petab_id in set(self._petab_problem.hybridization_df.index) + } + + # handle conditions + if len(condition_input_map) > 0: + net_input = jnp.array( + [ + condition_input_map[petab_id] + for _, petab_id in model_id_map.items() + ] + ) + return nn.forward(net_input).squeeze() + + # handle array inputs + if isinstance(self.model.nns[net_id].inputs, dict): + net_input = jnp.array( + [ + self.model.nns[net_id].inputs[petab_id][condition_id] + if condition_id in self.model.nns[net_id].inputs[petab_id] + else self.model.nns[net_id].inputs[petab_id]["0"] + for _, petab_id in model_id_map.items() + ] + ) + return nn.forward(net_input).squeeze() + + net_input = jnp.array( + [ + jax.lax.stop_gradient(self.model.nns[net_id][model_id]) + if model_id in self.model.nns[net_id].inputs + else self.get_petab_parameter_by_id(petab_id) + if petab_id in self.parameter_ids + else self._petab_problem.parameter_df.loc[ + petab_id, petab.NOMINAL_VALUE + ] + if petab_id in set(self._petab_problem.parameter_df.index) + else self._petab_problem.parameter_df.loc[ + hybridization_parameter_map[petab_id], petab.NOMINAL_VALUE + ] + for model_id, petab_id in model_id_map.items() + ] + ) + return nn.forward(net_input).squeeze() + + def _map_model_parameter_value( + self, + mapping: ParameterMappingForCondition, + pname: str, + condition_id: str, + ) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722 + pval = mapping.map_sim_var[pname] + if hasattr(self, "nn_output_ids") and pval in self.nn_output_ids: + nn_output = self._eval_nn(pval, condition_id) + if nn_output.size > 1: + entityId = self._petab_problem.mapping_df.loc[ + pval, petab.MODEL_ENTITY_ID + ] + ind = int(re.search(r"\[\d+\]\[(\d+)\]", entityId).group(1)) + return nn_output[ind] + else: + return nn_output + if isinstance(pval, Number): + return pval + return self.get_petab_parameter_by_id(pval) + + def load_model_parameters( self, simulation_condition: str ) -> jt.Float[jt.Array, "np"]: """ @@ -569,17 +1026,21 @@ def load_parameters( Parameters for the simulation condition. """ mapping = self._parameter_mappings[simulation_condition] + p = jnp.array( [ - pval - if isinstance(pval := mapping.map_sim_var[pname], Number) - else self.get_petab_parameter_by_id(pval) + self._map_model_parameter_value( + mapping, pname, simulation_condition + ) for pname in self.model.parameter_ids ] ) pscale = tuple( [ - mapping.scale_map_sim_var[pname] + petab.LIN + if self._petab_problem.mapping_df is not None + and pname in self._petab_problem.mapping_df.index + else mapping.scale_map_sim_var[pname] for pname in self.model.parameter_ids ] ) @@ -600,6 +1061,9 @@ def _state_needs_reinitialisation( :return: True if state needs reinitialisation, False otherwise """ + if state_id in self.nn_output_ids: + return True + if state_id not in self._petab_problem.condition_df: return False xval = self._petab_problem.condition_df.loc[ @@ -627,6 +1091,9 @@ def _state_reinitialisation_value( :return: reinitialisation value for the state """ + if state_id in self.nn_output_ids: + return self._eval_nn(state_id) + if state_id not in self._petab_problem.condition_df: # no reinitialisation, return dummy value return 0.0 @@ -671,6 +1138,8 @@ def load_reinitialisation( """ if not any( x_id in self._petab_problem.condition_df + or hasattr(self, "nn_output_ids") + and x_id in self.nn_output_ids for x_id in self.model.state_ids ): return jnp.array([]), jnp.array([]) @@ -737,19 +1206,25 @@ def _prepare_conditions( Tuple of parameter arrays, reinitialisation masks and reinitialisation values, observable parameters and noise parameters. """ - p_array = jnp.stack([self.load_parameters(sc) for sc in conditions]) - unscaled_parameters = jnp.stack( - [ - jax_unscale( - self.parameters[ip], - self._petab_problem.parameter_df.loc[ - p_id, petab.PARAMETER_SCALE - ], - ) - for ip, p_id in enumerate(self.parameter_ids) - ] + p_array = jnp.stack( + [self.load_model_parameters(sc) for sc in conditions] ) + if self.parameters.size: + unscaled_parameters = jnp.stack( + [ + jax_unscale( + self.parameters[ip], + self._petab_problem.parameter_df.loc[ + p_id, petab.PARAMETER_SCALE + ], + ) + for ip, p_id in enumerate(self.parameter_ids) + ] + ) + else: + unscaled_parameters = jnp.zeros((*self._ts_masks.shape[:2], 0)) + if op_numeric is not None and op_numeric.size: op_array = jnp.where( op_mask, @@ -804,6 +1279,8 @@ def run_simulation( nps: jt.Float[jt.Array, "nt *nnp"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 + init_override: jt.Float[jt.Array, "nx"], # noqa: F821, F722 + init_override_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -869,6 +1346,8 @@ def run_simulation( x_preeq=x_preeq, mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, + init_override=init_override, + init_override_mask=jax.lax.stop_gradient(init_override_mask), ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)), solver=solver, controller=controller, @@ -928,6 +1407,36 @@ def run_simulations( self._np_indices, ) ) + + init_override_mask = jnp.stack( + [ + jnp.array( + [ + p + in set(self._parameter_mappings[sc].map_sim_var.keys()) + for p in self.model.state_ids + ] + ) + for sc in simulation_conditions + ] + ) + init_override = jnp.stack( + [ + jnp.array( + [ + self._eval_nn( + self._parameter_mappings[sc].map_sim_var[p], sc + ) + if p + in set(self._parameter_mappings[sc].map_sim_var.keys()) + else 1.0 + for p in self.model.state_ids + ] + ) + for sc in simulation_conditions + ] + ) + return self.run_simulation( p_array, self._ts_dyn, @@ -939,6 +1448,8 @@ def run_simulations( np_array, mask_reinit_array, x_reinit_array, + init_override, + init_override_mask, solver, controller, root_finder, diff --git a/python/sdist/amici/petab/parameter_mapping.py b/python/sdist/amici/petab/parameter_mapping.py index 6eae4da380..aa3d57f108 100644 --- a/python/sdist/amici/petab/parameter_mapping.py +++ b/python/sdist/amici/petab/parameter_mapping.py @@ -352,7 +352,7 @@ def create_parameter_mapping( if petab_problem.model.type_id == MODEL_TYPE_SBML: import libsbml - if petab_problem.sbml_document: + if petab_problem.model.sbml_document: converter_config = ( libsbml.SBMLLocalParameterConverter().getDefaultProperties() ) @@ -374,13 +374,20 @@ def create_parameter_mapping( if parameter_mapping_kwargs is None: parameter_mapping_kwargs = {} + # TODO: Add support for conditions with sciml mappings in petab library + mapping = ( + None + if "sciml" in petab_problem.extensions_config + else petab_problem.mapping_df + ) + prelim_parameter_mapping = ( petab.get_optimization_to_simulation_parameter_mapping( condition_df=petab_problem.condition_df, measurement_df=petab_problem.measurement_df, parameter_df=petab_problem.parameter_df, observable_df=petab_problem.observable_df, - mapping_df=petab_problem.mapping_df, + mapping_df=mapping, model=petab_problem.model, simulation_conditions=simulation_conditions, fill_fixed_parameters=fill_fixed_parameters, @@ -581,6 +588,24 @@ def create_parameter_mapping_for_condition( ) logger.debug(f"Merged: {condition_map_sim_var}") + if "sciml" in petab_problem.extensions_config: + hybridizations = [ + pd.read_csv(hf, sep="\t") + for hf in petab_problem.extensions_config["sciml"][ + "hybridization_files" + ] + ] + hybridization_df = pd.concat(hybridizations) + for net_id, config in petab_problem.extensions_config["sciml"][ + "neural_nets" + ].items(): + if config["static"]: + for _, row in hybridization_df.iterrows(): + if row["targetValue"].startswith(net_id): + condition_map_sim_var[row["targetId"]] = row[ + "targetValue" + ] + parameter_mapping_for_condition = ParameterMappingForCondition( map_preeq_fix=condition_map_preeq_fix, map_sim_fix=condition_map_sim_fix, diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 88e0044740..32cefb0845 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -7,9 +7,11 @@ import logging import os +import re import shutil from pathlib import Path +import pandas as pd import petab.v1 as petab from petab.v1.models import MODEL_TYPE_PYSB, MODEL_TYPE_SBML @@ -87,12 +89,6 @@ def import_petab_problem( "Unsupported model type " + petab_problem.model.type_id ) - if petab_problem.mapping_df is not None: - # It's partially supported. Remove at your own risk... - raise NotImplementedError( - "PEtab v2.0.0 mapping tables are not yet supported." - ) - model_name = model_name or petab_problem.model.model_id if petab_problem.model.type_id == MODEL_TYPE_PYSB and model_name is None: @@ -134,6 +130,113 @@ def import_petab_problem( shutil.rmtree(model_output_dir) logger.info(f"Compiling model {model_name} to {model_output_dir}.") + + if "sciml" in petab_problem.extensions_config: + from petab_sciml.standard import NNModelStandard + + config = petab_problem.extensions_config["sciml"] + # TODO: only accept YAML format for now + hybridizations = [ + pd.read_csv(hf, sep="\t") + for hf in config["hybridization_files"] + ] + hybridization_table = pd.concat(hybridizations) + + input_mapping = dict( + zip( + hybridization_table["targetId"], + hybridization_table["targetValue"], + ) + ) + output_mapping = dict( + zip( + hybridization_table["targetValue"], + hybridization_table["targetId"], + ) + ) + observable_mapping = dict( + zip( + petab_problem.observable_df["observableFormula"], + petab_problem.observable_df.index, + ) + ) + hybridization = { + net_id: { + "model": NNModelStandard.load_data( + Path(net_config["location"]) + ), + "input_vars": [ + input_mapping[petab_id] + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("input") + and petab_id in input_mapping.keys() + ], + "output_vars": { + output_mapping[petab_id]: _get_net_index(model_id) + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("output") + and petab_id in output_mapping.keys() + }, + "observable_vars": { + observable_mapping[petab_id]: _get_net_index(model_id) + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("output") + and petab_id in observable_mapping.keys() + }, + "frozen_layers": dict( + [ + _get_frozen_layers(model_id) + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if petab_id in petab_problem.parameter_df.index + and petab_problem.parameter_df.loc[ + petab_id, petab.ESTIMATE + ] + == 0 + ] + ), + **net_config, + } + for net_id, net_config in config["neural_nets"].items() + } + if not jax or petab_problem.model.type_id != MODEL_TYPE_SBML: + raise NotImplementedError( + "petab_sciml extension is currently only supported for sbml models" + ) + else: + hybridization = None + # compile the model if petab_problem.model.type_id == MODEL_TYPE_PYSB: import_model_pysb( @@ -149,6 +252,7 @@ def import_petab_problem( model_name=model_name, model_output_dir=model_output_dir, non_estimated_parameters_as_constants=non_estimated_parameters_as_constants, + hybridization=hybridization, jax=jax, **kwargs, ) @@ -175,3 +279,17 @@ def import_petab_problem( ) return model + + +def _get_net_index(model_id: str): + matches = re.findall(r"\[(\d+)\]", model_id) + if matches: + return int(matches[-1]) + + +def _get_frozen_layers(model_id): + layers = re.findall(r"\[(.*?)\]", model_id) + array_attr = model_id.split(".")[-1] + layer_id = layers[0] if len(layers) else None + array_attr = array_attr if array_attr in ("weight", "bias") else None + return layer_id, array_attr diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index 1ca879d280..801cb6227f 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -235,6 +235,7 @@ def import_model_sbml( non_estimated_parameters_as_constants=True, output_parameter_defaults: dict[str, float] | None = None, discard_sbml_annotations: bool = False, + hybridization: dict = None, jax: bool = False, **kwargs, ) -> amici.SbmlImporter: @@ -392,10 +393,15 @@ def import_model_sbml( output_dir=model_output_dir, observation_model=observation_model, verbose=verbose, + hybridization=hybridization, **kwargs, ) return sbml_importer else: + if hybridization: + raise NotImplementedError( + "Hybridization is currently only supported for JAX models." + ) sbml_importer.sbml2amici( model_name=model_name, output_dir=model_output_dir, diff --git a/python/sdist/amici/petab/util.py b/python/sdist/amici/petab/util.py index c6ec96945e..6b8b45844e 100644 --- a/python/sdist/amici/petab/util.py +++ b/python/sdist/amici/petab/util.py @@ -28,7 +28,7 @@ def get_states_in_condition_table( species_check_fun = { MODEL_TYPE_SBML: lambda x: _element_is_sbml_state( - petab_problem.sbml_model, x + petab_problem.model.sbml_model, x ), MODEL_TYPE_PYSB: lambda x: _element_is_pysb_pattern( petab_problem.model.model, x diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index d3ef971aba..4824f38de7 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -436,6 +436,7 @@ def sbml2jax( compute_conservation_laws: bool = True, simplify: Callable | None = _default_simplify, cache_simplify: bool = False, + hybridization: dict = None, ) -> None: """ Generate and compile AMICI jax files for the model provided to the @@ -492,7 +493,11 @@ def sbml2jax( see :attr:`amici.DEModel._simplify` :param cache_simplify: - see :meth:`amici.DEModel.__init__` + see :meth:`amici.DEModel.__init__` + + :param hybridization: + dict representation of the hybridization information in the PEtab YAML file, see + https://petab-sciml.readthedocs.io/latest/format.html#problem-yaml-file """ set_log_level(logger, verbose) @@ -502,6 +507,7 @@ def sbml2jax( compute_conservation_laws=compute_conservation_laws, simplify=simplify, cache_simplify=cache_simplify, + hybridization=hybridization, ) from amici.jax.ode_export import ODEExporter @@ -511,6 +517,7 @@ def sbml2jax( model_name=model_name, outdir=output_dir, verbose=verbose, + hybridization=hybridization, ) exporter.generate_model_code() @@ -523,6 +530,7 @@ def _build_ode_model( simplify: Callable | None = _default_simplify, cache_simplify: bool = False, hardcode_symbols: Sequence[str] = None, + hybridization: dict = None, ) -> DEModel: """Generate a DEModel from this SBML model. @@ -670,6 +678,9 @@ def _build_ode_model( if compute_conservation_laws: self._process_conservation_laws(ode_model) + if hybridization: + ode_model._process_hybridization(hybridization) + # fill in 'self._sym' based on prototypes and components in ode_model ode_model.generate_basic_variables() diff --git a/python/sdist/pyproject.toml b/python/sdist/pyproject.toml index b6daa81c70..fd31c200dd 100644 --- a/python/sdist/pyproject.toml +++ b/python/sdist/pyproject.toml @@ -88,13 +88,15 @@ examples = [ "scipy", ] jax = [ - "jax>=0.4.36", - "jaxlib>=0.4.36", + "jax>=0.7.2", "diffrax>=0.7.0", "jaxtyping>=0.2.34", - "equinox>=0.11.10", + "equinox>=0.13.2", "optimistix>=0.0.9", - "interpax>=0.3.3,<=0.3.6", + "interpax>=0.3.9", +] +sciml = [ + "h5py" ] [project.scripts] diff --git a/python/tests/test_sciml.py b/python/tests/test_sciml.py new file mode 100644 index 0000000000..5756abe1e9 --- /dev/null +++ b/python/tests/test_sciml.py @@ -0,0 +1,505 @@ +"""Tests for SBML/SciML functionality, including JAX neural network code generation.""" + +from unittest.mock import Mock + +import pytest + +pytest.importorskip("jax") +pytest.importorskip("equinox") + +import pytest +from amici.jax.nn import ( + _format_function_call, + _generate_forward, + _process_activation_call, + _process_layer_call, +) + + +class TestFormatFunctionCall: + """Test the utility function for formatting function calls.""" + + def test_format_with_args_only(self): + """Test formatting with only positional arguments.""" + result = _format_function_call( + var_name="output", + fun_str="my_function", + args=["x", "y"], + kwargs=[], + indent=4, + ) + assert result == " output = my_function(x, y)" + + def test_format_with_kwargs_only(self): + """Test formatting with only keyword arguments.""" + result = _format_function_call( + var_name="output", + fun_str="my_function", + args=[], + kwargs=["a=1", "b=2"], + indent=4, + ) + assert result == " output = my_function(a=1, b=2)" + + def test_format_with_args_and_kwargs(self): + """Test formatting with both positional and keyword arguments.""" + result = _format_function_call( + var_name="result", + fun_str="jax.nn.relu", + args=["input_tensor"], + kwargs=["axis=1"], + indent=8, + ) + assert result == " result = jax.nn.relu(input_tensor, axis=1)" + + def test_format_with_no_args(self): + """Test formatting with no arguments.""" + result = _format_function_call( + var_name="output", + fun_str="get_value", + args=[], + kwargs=[], + indent=0, + ) + assert result == "output = get_value()" + + def test_format_with_zero_indent(self): + """Test formatting with zero indentation.""" + result = _format_function_call( + var_name="x", + fun_str="func", + args=["a"], + kwargs=["b=2"], + indent=0, + ) + assert result == "x = func(a, b=2)" + + +class TestProcessLayerCall: + """Test layer-specific processing logic.""" + + def test_simple_layer_no_freezing(self): + """Test processing a simple layer without freezing.""" + node = Mock() + node.target = "layer1" + node.name = "conv1" + node.args = ["input"] + + fun_str, tree_string = _process_layer_call( + node, layer_type="Conv2d", frozen_layers={} + ) + + assert fun_str.startswith("(jax.vmap(self.layers['layer1'])") + assert tree_string == "" + + def test_frozen_layer_with_attribute(self): + """Test processing a frozen layer with specific attribute.""" + node = Mock() + node.target = "layer1" + node.name = "conv1" + node.args = ["input"] + + fun_str, tree_string = _process_layer_call( + node, layer_type="Conv2d", frozen_layers={"conv1": "weight"} + ) + + assert "tree_conv1" in fun_str + assert "tree_conv1 = eqx.tree_at(" in tree_string + assert "'weight'" in tree_string + + def test_frozen_layer_full_stop_gradient(self): + """Test processing a fully frozen layer.""" + node = Mock() + node.target = "layer1" + node.name = "linear1" + node.args = ["input"] + + fun_str, tree_string = _process_layer_call( + node, layer_type="Linear", frozen_layers={"linear1": False} + ) + + assert "jax.lax.stop_gradient(self.layers['layer1'])" in fun_str + assert tree_string == "" + + def test_linear_layer_vmap(self): + """Test that Linear layer gets vmap wrapper.""" + node = Mock() + node.target = "fc1" + node.name = "fc1" + node.args = ["x"] + + fun_str, tree_string = _process_layer_call( + node, layer_type="Linear", frozen_layers={} + ) + + assert "jax.vmap" in fun_str + assert "if len(x.shape) == 2" in fun_str + + def test_conv1d_layer_vmap(self): + """Test that Conv1d layer gets vmap wrapper with correct dimensions.""" + node = Mock() + node.target = "conv" + node.name = "conv" + node.args = ["x"] + + fun_str, tree_string = _process_layer_call( + node, layer_type="Conv1d", frozen_layers={} + ) + + assert "jax.vmap" in fun_str + assert "if len(x.shape) == 3" in fun_str + + def test_conv2d_layer_vmap(self): + """Test that Conv2d layer gets vmap wrapper with correct dimensions.""" + node = Mock() + node.target = "conv" + node.name = "conv" + node.args = ["x"] + + fun_str, tree_string = _process_layer_call( + node, layer_type="Conv2d", frozen_layers={} + ) + + assert "jax.vmap" in fun_str + assert "if len(x.shape) == 4" in fun_str + + def test_layernorm_vmap(self): + """Test that LayerNorm layer gets vmap wrapper.""" + node = Mock() + node.target = "norm" + node.name = "norm" + node.args = ["x"] + + fun_str, tree_string = _process_layer_call( + node, layer_type="LayerNorm", frozen_layers={} + ) + + assert "jax.vmap" in fun_str + assert "len(self.layers['norm'].shape)+1" in fun_str + + def test_non_vmap_layer(self): + """Test layer that doesn't require vmap.""" + node = Mock() + node.target = "dropout" + node.name = "dropout" + node.args = ["x"] + + fun_str, tree_string = _process_layer_call( + node, layer_type="Dropout", frozen_layers={} + ) + + assert "jax.vmap" not in fun_str + assert fun_str == "self.layers['dropout']" + + +class TestProcessActivationCall: + """Test activation function processing logic.""" + + def test_standard_activation(self): + """Test standard JAX activation function.""" + node = Mock() + node.target = "relu" + node.kwargs = {} + + fun_str = _process_activation_call(node) + assert fun_str == "jax.nn.relu" + + def test_mapped_activation_hardtanh(self): + """Test hardtanh activation with custom mapping.""" + node = Mock() + node.target = "hardtanh" + node.kwargs = {} + + fun_str = _process_activation_call(node) + assert fun_str == "jax.nn.hard_tanh" + + def test_mapped_activation_hardsigmoid(self): + """Test hardsigmoid activation with custom mapping.""" + node = Mock() + node.target = "hardsigmoid" + node.kwargs = {} + + fun_str = _process_activation_call(node) + assert fun_str == "jax.nn.hard_sigmoid" + + def test_mapped_activation_tanhshrink(self): + """Test tanhshrink activation with custom mapping.""" + node = Mock() + node.target = "tanhshrink" + node.kwargs = {} + + fun_str = _process_activation_call(node) + assert fun_str == "amici.jax.tanhshrink" + + def test_hardtanh_valid_params(self): + """Test hardtanh with valid default parameters.""" + node = Mock() + node.target = "hardtanh" + node.kwargs = {"min_val": -1.0, "max_val": 1.0} + + fun_str = _process_activation_call(node) + assert fun_str == "jax.nn.hard_tanh" + + def test_hardtanh_invalid_min_val(self): + """Test hardtanh raises error for non-default min_val.""" + node = Mock() + node.target = "hardtanh" + node.kwargs = {"min_val": -2.0} + + with pytest.raises(NotImplementedError, match="min_val != -1.0"): + _process_activation_call(node) + + def test_hardtanh_invalid_max_val(self): + """Test hardtanh raises error for non-default max_val.""" + node = Mock() + node.target = "hardtanh" + node.kwargs = {"max_val": 2.0} + + with pytest.raises(NotImplementedError, match="max_val != 1.0"): + _process_activation_call(node) + + +class TestGenerateForward: + """Test the main forward pass generation function.""" + + def test_placeholder_node(self): + """Test generation for placeholder nodes.""" + node = Mock() + node.op = "placeholder" + node.name = "input_x" + + result = _generate_forward(node, indent=4) + assert result == "" + + def test_output_node(self): + """Test generation for output nodes.""" + node = Mock() + node.op = "output" + node.target = "output" + node.args = ["y1", "y2"] + + result = _generate_forward(node, indent=8) + assert result == " output = y1, y2" + + def test_call_module_simple(self): + """Test generation for simple module call.""" + node = Mock() + node.op = "call_module" + node.name = "x1" + node.target = "layer1" + node.args = ["input"] + node.kwargs = {} + + result = _generate_forward( + node, indent=4, frozen_layers={}, layer_type="Dropout" + ) + assert "x1 = self.layers['layer1'](input, key=key)" in result + + def test_call_function_activation(self): + """Test generation for activation function call.""" + node = Mock() + node.op = "call_function" + node.name = "act1" + node.target = "relu" + node.args = ["x"] + node.kwargs = {} + + result = _generate_forward( + node, indent=4, frozen_layers={}, layer_type="" + ) + assert result == " act1 = jax.nn.relu(x)" + + def test_call_module_with_frozen_layer(self): + """Test generation for frozen layer with tree_string.""" + node = Mock() + node.op = "call_module" + node.name = "conv1" + node.target = "layer1" + node.args = ["input"] + node.kwargs = {} + + result = _generate_forward( + node, + indent=4, + frozen_layers={"conv1": "weight"}, + layer_type="Conv2d", + ) + + assert "tree_conv1 = eqx.tree_at(" in result + assert "conv1 = " in result + assert "\n" in result # Should have tree_string on separate line + + def test_unsupported_operation(self): + """Test that unsupported operations raise NotImplementedError.""" + node = Mock() + node.op = "unknown_op" + node.kwargs = {} + + with pytest.raises( + NotImplementedError, match="Operation unknown_op not supported" + ): + _generate_forward(node, indent=4) + + def test_kwargs_filtering(self): + """Test that 'inplace' kwarg is filtered out.""" + node = Mock() + node.op = "call_function" + node.name = "act1" + node.target = "relu" + node.args = ["x"] + node.kwargs = {"inplace": True, "other": "value"} + + result = _generate_forward(node, indent=4, layer_type="") + assert "inplace" not in result + assert "other=value" in result + + def test_dropout_layer_adds_key(self): + """Test that Dropout layers get key parameter added.""" + node = Mock() + node.op = "call_module" + node.name = "drop1" + node.target = "dropout1" + node.args = ["x"] + node.kwargs = {} + + result = _generate_forward( + node, indent=4, frozen_layers={}, layer_type="Dropout1d" + ) + assert "key=key" in result + + +class TestProcessHybridizationErrors: + """Test the improved error messages in _process_hybridization.""" + + @pytest.fixture + def mock_de_model(self): + """Create a mock DEModel instance for testing.""" + import sympy as sp + from amici.de_model import DEModel + from amici.de_model_components import ( + DifferentialState, + Expression, + Observable, + Parameter, + ) + + model = DEModel() + + # Add some parameters + model._parameters = [ + Parameter(sp.Symbol("p1"), "param1", sp.Float(1.0)), + Parameter(sp.Symbol("p2"), "param2", sp.Float(2.0)), + ] + + # Add some expressions + model._expressions = [ + Expression(sp.Symbol("expr1"), "expression1", sp.Float(0.5)), + Expression(sp.Symbol("expr2"), "expression2", sp.Float(0.7)), + ] + + # Add some differential states + model._differential_states = [ + DifferentialState( + sp.Symbol("x1"), "state1", sp.Float(0.0), sp.Float(0.1) + ), + DifferentialState( + sp.Symbol("x2"), "state2", sp.Float(0.0), sp.Float(0.2) + ), + ] + + # Add some observables + model._observables = [ + Observable(sp.Symbol("obs1"), "observable1", sp.Symbol("x1")), + Observable(sp.Symbol("obs2"), "observable2", sp.Symbol("x2")), + ] + + return model + + def test_missing_input_variables(self, mock_de_model): + """Test error message for missing input variables.""" + hybridization = { + "neural_net1": { + "static": False, + "input_vars": [ + "p1", + "p2", + "p_missing", + ], # p_missing doesn't exist + "output_vars": {"expr1": 0}, + "observable_vars": {}, + } + } + + with pytest.raises(ValueError) as exc_info: + mock_de_model._process_hybridization(hybridization) + + error_msg = str(exc_info.value) + assert ( + "Could not find all input variables for neural network neural_net1" + in error_msg + ) + assert "Missing variables:" in error_msg + assert "p_missing" in error_msg + + def test_missing_output_variables(self, mock_de_model): + """Test error message for missing output variables.""" + hybridization = { + "neural_net2": { + "static": False, + "input_vars": ["p1", "p2"], + "output_vars": { + "expr1": 0, + "expr_missing": 1, + "expr_also_missing": 2, + }, + "observable_vars": {}, + } + } + + with pytest.raises(ValueError) as exc_info: + mock_de_model._process_hybridization(hybridization) + + error_msg = str(exc_info.value) + assert ( + "Could not find all output variables for neural network neural_net2" + in error_msg + ) + assert "Missing variables:" in error_msg + # Check that missing variables are in the message + assert "expr_missing" in error_msg or "expr_also_missing" in error_msg + + def test_missing_observable_variables(self, mock_de_model): + """Test error message for missing observable variables.""" + hybridization = { + "neural_net3": { + "static": False, + "input_vars": ["p1"], + "output_vars": {"expr1": 0}, + "observable_vars": {"obs1": 0, "obs_missing": 1}, + } + } + + with pytest.raises(ValueError) as exc_info: + mock_de_model._process_hybridization(hybridization) + + error_msg = str(exc_info.value) + assert ( + "Could not find all observable variables for neural network neural_net3" + in error_msg + ) + assert "Missing variables:" in error_msg + assert "obs_missing" in error_msg + + def test_valid_hybridization_no_error(self, mock_de_model): + """Test that valid hybridization doesn't raise errors.""" + hybridization = { + "valid_net": { + "static": False, + "input_vars": ["p1", "p2"], + "output_vars": {"expr1": 0}, + "observable_vars": {"obs1": 0}, + } + } + + # Should not raise any errors + mock_de_model._process_hybridization(hybridization) diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index bba279249f..43c503c6d4 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -52,6 +52,11 @@ def _test_case(case, model_type, version, jax): yaml_file = case_dir / petabtests.problem_yaml_name(case) problem = petab.Problem.from_yaml(yaml_file) + if problem.mapping_df is not None: + pytest.skip( + "PEtab test suite cases with mapping_df are not supported yet." + ) + # compile amici model if case.startswith("0006") and not jax: petab.flatten_timepoint_specific_output_overrides(problem) @@ -143,11 +148,12 @@ def _test_case(case, model_type, version, jax): "display.width", 200, ): - logger.log( - logging.DEBUG, - f"x_ss: {model.get_state_ids()} " - f"{[rdata.x_ss for rdata in rdatas]}", - ) + if not jax: + logger.log( + logging.DEBUG, + f"x_ss: {model.state_ids} " + f"{[rdata.x_ss for rdata in rdatas]}", + ) logger.log( logging.ERROR, f"Expected simulations:\n{gt_simulation_dfs}" ) diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py new file mode 100644 index 0000000000..ae541b0d36 --- /dev/null +++ b/tests/sciml/test_sciml.py @@ -0,0 +1,301 @@ +import os +from contextlib import contextmanager +from pathlib import Path + +import amici +import diffrax +import equinox as eqx +import h5py +import jax +import jax.numpy as jnp +import jax.random as jr +import numpy as np +import pandas as pd +import petab.v1 as petab +import pytest +from amici.jax import ( + JAXProblem, + generate_equinox, + petab_simulate, + run_simulations, +) +from amici.petab import import_petab_problem +from petab_sciml import NNModelStandard +from yaml import safe_load + + +@contextmanager +def change_directory(destination): + # Save the current working directory + original_directory = os.getcwd() + try: + # Change to the new directory + os.chdir(destination) + yield + finally: + # Change back to the original directory + os.chdir(original_directory) + + +jax.config.update("jax_enable_x64", True) + + +# pip install git+https://github.com/sebapersson/petab_sciml@add_standard#egg=petab_sciml\&subdirectory=src/python + +cases_dir = Path(__file__).parent / "testsuite" / "test_cases" +net_cases_dir = cases_dir / "net_import" +ude_cases_dir = cases_dir / "sciml_problem_import" +initialization_cases_dir = cases_dir / "initialization" + + +def _reshape_flat_array(array_flat): + array_flat["ix"] = array_flat["ix"].astype(str) + ix_cols = [ + f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";"))) + ] + if len(ix_cols) == 1: + array_flat[ix_cols[0]] = array_flat["ix"].apply(int) + else: + array_flat[ix_cols] = pd.DataFrame( + array_flat["ix"].str.split(";").apply(np.array).to_list(), + index=array_flat.index, + ).astype(int) + array_flat.sort_values(by=ix_cols, inplace=True) + array_shape = tuple(array_flat[ix_cols].max().astype(int) + 1) + array = np.array(array_flat["value"].values).reshape(array_shape) + return array + + +@pytest.mark.parametrize( + "test", sorted(d.stem for d in net_cases_dir.glob("[0-9]*")) +) +def test_net(test): + test_dir = net_cases_dir / test + with open(test_dir / "solutions.yaml") as f: + solutions = safe_load(f) + + if test.endswith("_alt"): + net_file = cases_dir / test.replace("_alt", "") / solutions["net_file"] + else: + net_file = test_dir / solutions["net_file"] + ml_model = NNModelStandard.load_data(net_file) + + nets = {} + outdir = Path(__file__).parent / "models" / test + module_dir = outdir / f"{ml_model.nn_model_id}.py" + if test in ( + "002", + "009", + "018", + "019", + "020", + "021", + "022", + "042", + "043", + "044", + "045", + "046", + "047", + "048", + ): + with pytest.raises(NotImplementedError): + generate_equinox(ml_model, module_dir) + return + generate_equinox(ml_model, module_dir) + nets[ml_model.nn_model_id] = amici._module_from_path( + ml_model.nn_model_id, module_dir + ).net + + for input_file, par_file, output_file in zip( + solutions["net_input"], + solutions.get("net_ps", solutions["net_input"]), + solutions["net_output"], + ): + input = h5py.File(test_dir / input_file, "r")["inputs"]["input0"][ + "data" + ][:] + output = h5py.File(test_dir / output_file, "r")["outputs"]["output0"][ + "data" + ][:] + + if "net_ps" in solutions: + par = h5py.File(test_dir / par_file, "r") + net = nets[ml_model.nn_model_id](jr.PRNGKey(0)) + for layer in net.layers.keys(): + if ( + isinstance(net.layers[layer], eqx.Module) + and hasattr(net.layers[layer], "weight") + and net.layers[layer].weight is not None + ): + w = par["parameters"][ml_model.nn_model_id][layer][ + "weight" + ][:] + if isinstance(net.layers[layer], eqx.nn.ConvTranspose): + # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose + w = np.flip(w, axis=tuple(range(2, w.ndim))).swapaxes( + 0, 1 + ) + assert w.shape == net.layers[layer].weight.shape + net = eqx.tree_at( + lambda x: x.layers[layer].weight, + net, + jnp.array(w), + ) + if ( + isinstance(net.layers[layer], eqx.Module) + and hasattr(net.layers[layer], "bias") + and net.layers[layer].bias is not None + ): + b = par["parameters"][ml_model.nn_model_id][layer]["bias"][ + : + ] + if isinstance( + net.layers[layer], + eqx.nn.Conv | eqx.nn.ConvTranspose, + ): + b = np.expand_dims( + b, + tuple( + range( + 1, + net.layers[layer].num_spatial_dims + 1, + ) + ), + ) + assert b.shape == net.layers[layer].bias.shape + net = eqx.tree_at( + lambda x: x.layers[layer].bias, + net, + jnp.array(b), + ) + net = eqx.nn.inference_mode(net) + + if test == "net_004_alt": + return # skipping, no support for non-cross-correlation in equinox + + np.testing.assert_allclose( + net.forward(input), + output, + atol=1e-3, + rtol=1e-3, + ) + + +@pytest.mark.parametrize( + "test", sorted([d.stem for d in ude_cases_dir.glob("[0-9]*")]) +) +def test_ude(test): + test_dir = ude_cases_dir / test + with open(test_dir / "petab" / "problem.yaml") as f: + petab_yaml = safe_load(f) + with open(test_dir / "solutions.yaml") as f: + solutions = safe_load(f) + + with change_directory(test_dir / "petab"): + from petab.v2 import Problem + + petab_yaml["format_version"] = "2.0.0" # TODO: fixme + petab_problem = Problem.from_yaml(petab_yaml) + jax_model = import_petab_problem( + petab_problem, + model_output_dir=Path(__file__).parent / "models" / test, + compile_=True, + jax=True, + ) + + jax_problem = JAXProblem(jax_model, petab_problem) + + # llh + llh, r = run_simulations(jax_problem) + np.testing.assert_allclose( + llh, + solutions["llh"], + atol=solutions["tol_llh"], + rtol=solutions["tol_llh"], + ) + simulations = pd.concat( + [ + pd.read_csv(test_dir / simulation, sep="\t") + for simulation in solutions["simulation_files"] + ] + ) + + # simulations + sort_by = [petab.OBSERVABLE_ID, petab.TIME, petab.SIMULATION_CONDITION_ID] + actual = petab_simulate(jax_problem).sort_values(by=sort_by) + expected = simulations.sort_values(by=sort_by) + np.testing.assert_allclose( + actual[petab.SIMULATION].values, + expected[petab.SIMULATION].values, + atol=solutions["tol_simulations"], + rtol=solutions["tol_simulations"], + ) + + # gradient + sllh, _ = eqx.filter_grad(run_simulations, has_aux=True)( + jax_problem, + solver=diffrax.Kvaerno5(), + controller=diffrax.PIDController(atol=1e-14, rtol=1e-14), + max_steps=2**16, + ) + for component, file in solutions["grad_files"].items(): + actual_dict = {} + if component == "mech": + expected = pd.read_csv(test_dir / file, sep="\t").set_index( + petab.PARAMETER_ID + ) + + for ip in expected.index: + if ip in jax_problem.parameter_ids: + actual_dict[ip] = sllh.parameters[ + jax_problem.parameter_ids.index(ip) + ].item() + actual = pd.Series(actual_dict).loc[expected.index].values + np.testing.assert_allclose( + actual, + expected["value"].values, + atol=solutions["tol_grad_llh"], + rtol=solutions["tol_grad_llh"], + ) + else: + expected = h5py.File(test_dir / file, "r") + for layer_name, layer in jax_problem.model.nns[ + component + ].layers.items(): + for attribute in dir(layer): + if not isinstance( + getattr(layer, attribute), jax.numpy.ndarray + ): + continue + actual = getattr( + sllh.model.nns[component].layers[layer_name], attribute + ) + if ( + isinstance(layer, eqx.nn.ConvTranspose) + and attribute == "weight" + ): + actual = np.flip( + actual.swapaxes(0, 1), + axis=tuple(range(2, actual.ndim)), + ) + if ( + np.squeeze( + expected["parameters"][component][layer_name][ + attribute + ][:] + ).size + == 0 + ): + assert np.all(actual == 0.0) + else: + np.testing.assert_allclose( + np.squeeze(actual), + np.squeeze( + expected["parameters"][component][layer_name][ + attribute + ][:] + ), + atol=solutions["tol_grad_llh"], + rtol=solutions["tol_grad_llh"], + )