Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni

### v1.0.0 (unreleased)

BREAKING CHANGES
**BREAKING CHANGES**

* `ReturnDataView.posteq_numsteps` and `ReturnDataView.posteq_numsteps` now
return a one-dimensional array of shape `(num_timepoints,)` instead of a
Expand All @@ -27,6 +27,15 @@ BREAKING CHANGES
* The `force_compile` argument to `import_petab_problem` has been removed.
See the `compile_` argument.

**Features**

* Many relevant `ReturnData` fields are now available as `xarray.DataArray`
via `ReturnData.xr.{x,y,w,x0,sx,...}`.
`DataArray`s include the identifiers and are often more convenient than the
plain numpy arrays. This allows for easy subselection and plotting of the
results, and conversion to DataFrames.


## v0.X Series

### v0.34.1 (2025-08-25)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,25 @@
"print(f\"{rdata.by_id('x2')=}\")"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Alternatively, those data can be accessed through `ReturnData.xr.*` as [xarray.DataArray](https://docs.xarray.dev/en/stable/index.html) objects, that contain additional metadata such as timepoints and identifiers. This allows for more convenient indexing and plotting of the results."
},
{
"metadata": {},
"cell_type": "code",
"source": "rdata.xr.x",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "rdata.xr.x.to_pandas()",
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
128 changes: 126 additions & 2 deletions python/sdist/amici/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
This module provides views on C++ objects for efficient access.
"""

from __future__ import annotations
import collections
import copy
import itertools
from typing import Literal, Union
from typing import Literal
from collections.abc import Iterator
from numbers import Number
import amici
Expand All @@ -22,10 +23,131 @@
ReturnDataPtr,
SteadyStateStatus,
)
import xarray as xr


__all__ = [
"ReturnDataView",
"ExpDataView",
"evaluate",
]

StrOrExpr = str | sp.Expr


class XArrayFactory:
"""
Factory class to create xarray DataArrays for fields of a
SwigPtrView instance.

Currently, only ReturnDataView is supported.
"""

def __init__(self, svp: SwigPtrView):
"""
Constructor

:param svp: SwigPtrView instance to create DataArrays from.
"""
self._svp = svp

def __getattr__(self, name: str) -> xr.DataArray:
"""
Create xarray DataArray for field name

:param name: field name

:returns: xarray DataArray
"""
data = getattr(self._svp, name)
if data is None:
return xr.DataArray(name=name)

dims = None

match name:
case "x":
coords = {
"time": self._svp.ts,
"state": list(self._svp.state_ids),
}
case "x0" | "x_ss":
coords = {
"state": list(self._svp.state_ids),
}
case "xdot":
coords = {
"state": list(self._svp.state_ids_solver),
}
case "y" | "sigmay":
coords = {
"time": self._svp.ts,
"observable": list(self._svp.observable_ids),
}
case "sy" | "ssigmay":
coords = {
"time": self._svp.ts,
"parameter": [
self._svp.parameter_ids[i] for i in self._svp.plist
],
"observable": list(self._svp.observable_ids),
}
case "w":
coords = {
"time": self._svp.ts,
"expression": list(self._svp.expression_ids),
}
case "sx0":
coords = {
"parameter": [
self._svp.parameter_ids[i] for i in self._svp.plist
],
"state": list(self._svp.state_ids),
}
case "sx":
coords = {
"time": self._svp.ts,
"parameter": [
self._svp.parameter_ids[i] for i in self._svp.plist
],
"state": list(self._svp.state_ids),
}
dims = ("time", "parameter", "state")
case "sllh":
coords = {
"parameter": [
self._svp.parameter_ids[i] for i in self._svp.plist
]
}
case "FIM":
coords = {
"parameter1": [
self._svp.parameter_ids[i] for i in self._svp.plist
],
"parameter2": [
self._svp.parameter_ids[i] for i in self._svp.plist
],
}
case "J":
coords = {
"state1": list(self._svp.state_ids_solver),
"state2": list(self._svp.state_ids_solver),
}
case _:
dims = tuple(f"dim_{i}" for i in range(data.ndim))
coords = {
f"dim_{i}": np.arange(dim)
for i, dim in enumerate(data.shape)
}
arr = xr.DataArray(
data,
dims=dims,
coords=coords,
name=name,
)
return arr


class SwigPtrView(collections.abc.Mapping):
"""
Interface class to expose ``std::vector<double>`` and scalar members of
Expand Down Expand Up @@ -104,6 +226,7 @@ def __init__(self, swigptr):
"""
self._swigptr = swigptr
self._cache = {}

super().__init__()

def __len__(self) -> int:
Expand Down Expand Up @@ -310,6 +433,7 @@ def __init__(self, rdata: ReturnDataPtr | ReturnData):
"numerrtestfailsB": [rdata.nt],
"numnonlinsolvconvfailsB": [rdata.nt],
}
self.xr = XArrayFactory(self)
super().__init__(rdata)

def __getitem__(
Expand Down Expand Up @@ -461,7 +585,7 @@ def _field_as_numpy(

def _entity_type_from_id(
entity_id: str,
rdata: Union[amici.ReturnData, "amici.ReturnDataView"] = None,
rdata: amici.ReturnData | amici.ReturnDataView = None,
model: amici.Model = None,
) -> Literal["x", "y", "w", "p", "k"]:
"""Guess the type of some entity by its ID."""
Expand Down
3 changes: 2 additions & 1 deletion python/sdist/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"toposort",
"setuptools>=48",
"mpmath",
"swig"
"swig",
"xarray>=2025.01.0",
]
license = "BSD-3-Clause"
authors = [
Expand Down
12 changes: 11 additions & 1 deletion python/tests/test_swig_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numbers
from math import nan
import pytest
import xarray

import amici
import numpy as np
Expand Down Expand Up @@ -531,6 +532,7 @@ def test_rdataview(sbml_example_presimulation_module):
"""Test some SwigPtrView functionality via ReturnDataView."""
model_module = sbml_example_presimulation_module
model = model_module.getModel()
model.setTimepoints([1, 2, 3])
rdata = amici.runAmiciSimulation(model, model.getSolver())
assert isinstance(rdata, amici.ReturnDataView)

Expand All @@ -547,11 +549,19 @@ def test_rdataview(sbml_example_presimulation_module):

assert not hasattr(rdata, "nonexisting_attribute")
assert "x" in rdata
assert rdata.x == rdata["x"]
assert (rdata.x == rdata["x"]).all()

# field names are included by dir()
assert "x" in dir(rdata)

# Test xarray conversion
xr_x = rdata.xr.x
assert isinstance(xr_x, xarray.DataArray)
assert (rdata.x == xr_x).all()
assert xr_x.dims == ("time", "state")
assert (xr_x.coords["time"].data == rdata.ts).all()
assert (xr_x.coords["state"].data == model.getStateIds()).all()


def test_python_exceptions(sbml_example_presimulation_module):
"""Test that C++ exceptions are correctly caught and re-raised in Python."""
Expand Down
Loading