From af0d3f7c6a4c3fefecd24615b3f4eb8b6c3561ee Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 2 Jul 2025 12:46:27 +0200 Subject: [PATCH] `ReturnData` fields as `xarray.DataArray` Make relevant `ReturnData` fields available as `xarray.DataArray`. This includes the identifiers and is often more convenient than the plain arrays, allows for easy subselection and plotting of the results, and conversion to DataFrames. --- CHANGELOG.md | 11 +- .../GettingStartedExtended.ipynb | 19 +++ python/sdist/amici/numpy.py | 128 +++++++++++++++++- python/sdist/pyproject.toml | 3 +- python/tests/test_swig_interface.py | 12 +- 5 files changed, 168 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4902462052..53bb4142c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni ### v1.0.0 (unreleased) -BREAKING CHANGES +**BREAKING CHANGES** * `ReturnDataView.posteq_numsteps` and `ReturnDataView.posteq_numsteps` now return a one-dimensional array of shape `(num_timepoints,)` instead of a @@ -27,6 +27,15 @@ BREAKING CHANGES * The `force_compile` argument to `import_petab_problem` has been removed. See the `compile_` argument. +**Features** + +* Many relevant `ReturnData` fields are now available as `xarray.DataArray` + via `ReturnData.xr.{x,y,w,x0,sx,...}`. + `DataArray`s include the identifiers and are often more convenient than the + plain numpy arrays. This allows for easy subselection and plotting of the + results, and conversion to DataFrames. + + ## v0.X Series ### v0.34.1 (2025-08-25) diff --git a/doc/examples/getting_started_extended/GettingStartedExtended.ipynb b/doc/examples/getting_started_extended/GettingStartedExtended.ipynb index aba2d94f06..d81f5d30ea 100644 --- a/doc/examples/getting_started_extended/GettingStartedExtended.ipynb +++ b/doc/examples/getting_started_extended/GettingStartedExtended.ipynb @@ -877,6 +877,25 @@ "print(f\"{rdata.by_id('x2')=}\")" ] }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Alternatively, those data can be accessed through `ReturnData.xr.*` as [xarray.DataArray](https://docs.xarray.dev/en/stable/index.html) objects, that contain additional metadata such as timepoints and identifiers. This allows for more convenient indexing and plotting of the results." + }, + { + "metadata": {}, + "cell_type": "code", + "source": "rdata.xr.x", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "source": "rdata.xr.x.to_pandas()", + "outputs": [], + "execution_count": null + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index f7b85d8b96..a7b195e26c 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -4,10 +4,11 @@ This module provides views on C++ objects for efficient access. """ +from __future__ import annotations import collections import copy import itertools -from typing import Literal, Union +from typing import Literal from collections.abc import Iterator from numbers import Number import amici @@ -22,10 +23,131 @@ ReturnDataPtr, SteadyStateStatus, ) +import xarray as xr + + +__all__ = [ + "ReturnDataView", + "ExpDataView", + "evaluate", +] StrOrExpr = str | sp.Expr +class XArrayFactory: + """ + Factory class to create xarray DataArrays for fields of a + SwigPtrView instance. + + Currently, only ReturnDataView is supported. + """ + + def __init__(self, svp: SwigPtrView): + """ + Constructor + + :param svp: SwigPtrView instance to create DataArrays from. + """ + self._svp = svp + + def __getattr__(self, name: str) -> xr.DataArray: + """ + Create xarray DataArray for field name + + :param name: field name + + :returns: xarray DataArray + """ + data = getattr(self._svp, name) + if data is None: + return xr.DataArray(name=name) + + dims = None + + match name: + case "x": + coords = { + "time": self._svp.ts, + "state": list(self._svp.state_ids), + } + case "x0" | "x_ss": + coords = { + "state": list(self._svp.state_ids), + } + case "xdot": + coords = { + "state": list(self._svp.state_ids_solver), + } + case "y" | "sigmay": + coords = { + "time": self._svp.ts, + "observable": list(self._svp.observable_ids), + } + case "sy" | "ssigmay": + coords = { + "time": self._svp.ts, + "parameter": [ + self._svp.parameter_ids[i] for i in self._svp.plist + ], + "observable": list(self._svp.observable_ids), + } + case "w": + coords = { + "time": self._svp.ts, + "expression": list(self._svp.expression_ids), + } + case "sx0": + coords = { + "parameter": [ + self._svp.parameter_ids[i] for i in self._svp.plist + ], + "state": list(self._svp.state_ids), + } + case "sx": + coords = { + "time": self._svp.ts, + "parameter": [ + self._svp.parameter_ids[i] for i in self._svp.plist + ], + "state": list(self._svp.state_ids), + } + dims = ("time", "parameter", "state") + case "sllh": + coords = { + "parameter": [ + self._svp.parameter_ids[i] for i in self._svp.plist + ] + } + case "FIM": + coords = { + "parameter1": [ + self._svp.parameter_ids[i] for i in self._svp.plist + ], + "parameter2": [ + self._svp.parameter_ids[i] for i in self._svp.plist + ], + } + case "J": + coords = { + "state1": list(self._svp.state_ids_solver), + "state2": list(self._svp.state_ids_solver), + } + case _: + dims = tuple(f"dim_{i}" for i in range(data.ndim)) + coords = { + f"dim_{i}": np.arange(dim) + for i, dim in enumerate(data.shape) + } + arr = xr.DataArray( + data, + dims=dims, + coords=coords, + name=name, + ) + return arr + + class SwigPtrView(collections.abc.Mapping): """ Interface class to expose ``std::vector`` and scalar members of @@ -104,6 +226,7 @@ def __init__(self, swigptr): """ self._swigptr = swigptr self._cache = {} + super().__init__() def __len__(self) -> int: @@ -310,6 +433,7 @@ def __init__(self, rdata: ReturnDataPtr | ReturnData): "numerrtestfailsB": [rdata.nt], "numnonlinsolvconvfailsB": [rdata.nt], } + self.xr = XArrayFactory(self) super().__init__(rdata) def __getitem__( @@ -461,7 +585,7 @@ def _field_as_numpy( def _entity_type_from_id( entity_id: str, - rdata: Union[amici.ReturnData, "amici.ReturnDataView"] = None, + rdata: amici.ReturnData | amici.ReturnDataView = None, model: amici.Model = None, ) -> Literal["x", "y", "w", "p", "k"]: """Guess the type of some entity by its ID.""" diff --git a/python/sdist/pyproject.toml b/python/sdist/pyproject.toml index 228fa76ac5..76d62a8853 100644 --- a/python/sdist/pyproject.toml +++ b/python/sdist/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "toposort", "setuptools>=48", "mpmath", - "swig" + "swig", + "xarray>=2025.01.0", ] license = "BSD-3-Clause" authors = [ diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index 5e7264c722..b58b9e8b23 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -7,6 +7,7 @@ import numbers from math import nan import pytest +import xarray import amici import numpy as np @@ -531,6 +532,7 @@ def test_rdataview(sbml_example_presimulation_module): """Test some SwigPtrView functionality via ReturnDataView.""" model_module = sbml_example_presimulation_module model = model_module.getModel() + model.setTimepoints([1, 2, 3]) rdata = amici.runAmiciSimulation(model, model.getSolver()) assert isinstance(rdata, amici.ReturnDataView) @@ -547,11 +549,19 @@ def test_rdataview(sbml_example_presimulation_module): assert not hasattr(rdata, "nonexisting_attribute") assert "x" in rdata - assert rdata.x == rdata["x"] + assert (rdata.x == rdata["x"]).all() # field names are included by dir() assert "x" in dir(rdata) + # Test xarray conversion + xr_x = rdata.xr.x + assert isinstance(xr_x, xarray.DataArray) + assert (rdata.x == xr_x).all() + assert xr_x.dims == ("time", "state") + assert (xr_x.coords["time"].data == rdata.ts).all() + assert (xr_x.coords["state"].data == model.getStateIds()).all() + def test_python_exceptions(sbml_example_presimulation_module): """Test that C++ exceptions are correctly caught and re-raised in Python."""