Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ Index into all nodes in the subtree simultaneously.

DataTree.isel
DataTree.sel
DataTree.subset

.. DataTree.drop_sel
.. DataTree.drop_isel
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ v2025.05.0 (unreleased)

New Features
~~~~~~~~~~~~
- Added :py:meth:`~xarray.DataTree.subset` to index variables on all nodes of a datatree (:pull:`10400`)
By `Mathias Hauser <https://github.com/mathause>`_.
- Allow an Xarray index that uses multiple dimensions checking equality with another
index for only a subset of those dimensions (i.e., ignoring the dimensions
that are excluded from alignment).
Expand Down
30 changes: 30 additions & 0 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterable,
Iterator,
Mapping,
Sequence,
)
from html import escape
from typing import (
Expand Down Expand Up @@ -1014,6 +1015,35 @@ def __delitem__(self, key: str) -> None:
else:
raise KeyError(key)

def subset(
self, keys: str | Sequence[str], *, errors: ErrorOptions = "raise"
) -> DataTree:
"""Index DataArrays on each node

Parameters
----------
keys : str | Sequence[str]
Name of the data variables to index.
errors : "raise", "ignore"
Whether to raise a key error if a data variable is missing on a node.

Returns
-------
out : DataTree
"""

if isinstance(keys, str):
keys = [keys]

def getitem(ds):
keys_for_ds = keys
if errors == "ignore":
keys_for_ds = [key for key in keys if key in ds.data_vars]

return ds[keys_for_ds]

return map_over_datasets(getitem, self)

@overload
def update(self, other: Dataset) -> None: ...

Expand Down
28 changes: 28 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,34 @@ def test_getitem_dict_like_selection_access_to_dataset(self) -> None:
assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index]


def test_subset():
ds1 = xr.Dataset(data_vars={"var1": ("x", [1, 2]), "var2": ("x", [0, 1])})
ds2 = xr.Dataset(data_vars={"var1": ("x", [1, 2])})
dt = xr.DataTree.from_dict({"ds1": ds1, "ds2": ds2})

dt_var1 = xr.DataTree.from_dict({"ds1": ds1[["var1"]], "ds2": ds2})

# errors as map_over_datasets does not skip empty nodes
with pytest.raises(KeyError, match="var1"):
dt.subset("var1")

# will still error if map_over_datasets will ever skip empty nodes
with pytest.raises(KeyError, match="var2"):
dt.subset("var2")

result = dt.subset("var1", errors="ignore")
expected = dt_var1
xr.testing.assert_equal(result, expected)

result = dt.subset(["var1"], errors="ignore")
expected = dt_var1
xr.testing.assert_equal(result, expected)

result = dt.subset(["var1", "var2"], errors="ignore")
expected = dt
xr.testing.assert_equal(result, expected)


class TestUpdate:
def test_update(self) -> None:
dt = DataTree()
Expand Down
Loading