diff --git a/doc/user-guide/hierarchical-data.rst b/doc/user-guide/hierarchical-data.rst index a350b7851de..adedbf53664 100644 --- a/doc/user-guide/hierarchical-data.rst +++ b/doc/user-guide/hierarchical-data.rst @@ -453,6 +453,8 @@ The result is a new tree, containing only the nodes matching the condition. (Yes, under the hood :py:meth:`~xarray.DataTree.filter` is just syntactic sugar for the pattern we showed you in :ref:`iterating over trees` !) +If you want to filter out empty nodes you can use :py:meth:`~xarray.DataTree.prune`. + .. _Tree Contents: Tree Contents diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 06a3c2cb22d..670b65a4e55 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -12,6 +12,9 @@ v2025.07.2 (unreleased) New Features ~~~~~~~~~~~~ +- Added :py:meth:`DataTree.prune` method to remove empty nodes while preserving tree structure. + Useful for cleaning up DataTree after time-based filtering operations (:issue:`10590`, :pull:`10598`). + By `Alfonso Ladino `_. Breaking changes diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index afef2f20094..f6732998ae4 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1448,6 +1448,73 @@ def filter_like(self, other: DataTree) -> DataTree: other_keys = {key for key, _ in other.subtree_with_keys} return self.filter(lambda node: node.relative_to(self) in other_keys) + def prune(self, drop_size_zero_vars: bool = False) -> DataTree: + """ + Remove empty nodes from the tree. + + Returns a new tree containing only nodes that contain data variables with actual data. + Intermediate nodes are kept if they are required to support non-empty children. + + Parameters + ---------- + drop_size_zero_vars : bool, default False + If True, also considers variables with zero size as empty. + If False, keeps nodes with data variables even if they have zero size. + + Returns + ------- + DataTree + A new tree with empty nodes removed. + + See Also + -------- + filter + + Examples + -------- + >>> dt = xr.DataTree.from_dict( + ... { + ... "/a": xr.Dataset({"foo": ("x", [1, 2])}), + ... "/b": xr.Dataset({"bar": ("x", [])}), + ... "/c": xr.Dataset(), + ... } + ... ) + >>> dt.prune() + + Group: / + ├── Group: /a + │ Dimensions: (x: 2) + │ Dimensions without coordinates: x + │ Data variables: + │ foo (x) int64 16B 1 2 + └── Group: /b + Dimensions: (x: 0) + Dimensions without coordinates: x + Data variables: + bar (x) float64 0B + + The ``drop_size_zero_vars`` parameter controls whether variables + with zero size are considered empty: + + >>> dt.prune(drop_size_zero_vars=True) + + Group: / + └── Group: /a + Dimensions: (x: 2) + Dimensions without coordinates: x + Data variables: + foo (x) int64 16B 1 2 + """ + non_empty_cond: Callable[[DataTree], bool] + if drop_size_zero_vars: + non_empty_cond = lambda node: len(node.data_vars) > 0 and any( + var.size > 0 for var in node.data_vars.values() + ) + else: + non_empty_cond = lambda node: len(node.data_vars) > 0 + + return self.filter(non_empty_cond) + def match(self, pattern: str) -> DataTree: """ Return nodes with paths matching pattern. diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 2bf079a7cbd..efa25386440 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1942,6 +1942,85 @@ def test_filter(self) -> None: ) assert_identical(actual, expected) + def test_prune_basic(self) -> None: + tree = DataTree.from_dict( + {"/a": xr.Dataset({"foo": ("x", [1, 2])}), "/b": xr.Dataset()} + ) + + pruned = tree.prune() + + assert "a" in pruned.children + assert "b" not in pruned.children + assert_identical( + pruned.children["a"].to_dataset(), tree.children["a"].to_dataset() + ) + + def test_prune_with_zero_size_vars(self) -> None: + tree = DataTree.from_dict( + { + "/a": xr.Dataset({"foo": ("x", [1, 2])}), + "/b": xr.Dataset({"empty": ("dim", [])}), + "/c": xr.Dataset(), + } + ) + + pruned_default = tree.prune() + expected_default = DataTree.from_dict( + { + "/a": xr.Dataset({"foo": ("x", [1, 2])}), + "/b": xr.Dataset({"empty": ("dim", [])}), + } + ) + assert_identical(pruned_default, expected_default) + + pruned_strict = tree.prune(drop_size_zero_vars=True) + expected_strict = DataTree.from_dict( + { + "/a": xr.Dataset({"foo": ("x", [1, 2])}), + } + ) + assert_identical(pruned_strict, expected_strict) + + def test_prune_with_intermediate_nodes(self) -> None: + tree = DataTree.from_dict( + { + "/": xr.Dataset(), + "/group1": xr.Dataset(), + "/group1/subA": xr.Dataset({"temp": ("x", [1, 2])}), + "/group1/subB": xr.Dataset(), + "/group2": xr.Dataset({"empty": ("dim", [])}), + } + ) + pruned = tree.prune() + expected_tree = DataTree.from_dict( + { + "/group1/subA": xr.Dataset({"temp": ("x", [1, 2])}), + "/group2": xr.Dataset({"empty": ("dim", [])}), + } + ) + assert_identical(pruned, expected_tree) + + def test_prune_after_filtering(self) -> None: + from pandas import date_range + + ds1 = xr.Dataset( + {"foo": ("time", [1, 2, 3, 4, 5])}, + coords={"time": date_range("2023-01-01", periods=5, freq="D")}, + ) + ds2 = xr.Dataset( + {"var": ("time", [1, 2, 3, 4, 5])}, + coords={"time": date_range("2023-01-04", periods=5, freq="D")}, + ) + + tree = DataTree.from_dict({"a": ds1, "b": ds2}) + filtered = tree.sel(time=slice("2023-01-01", "2023-01-03")) + + pruned = filtered.prune(drop_size_zero_vars=True) + expected_tree = DataTree.from_dict( + {"a": ds1.sel(time=slice("2023-01-01", "2023-01-03"))} + ) + assert_identical(pruned, expected_tree) + class TestIndexing: def test_isel_siblings(self) -> None: