Skip to content

Commit 4fb2c3f

Browse files
authored
keep attrs in idata/datatree conversions (#2476)
* keep attrs in idata/datatree conversions * black * update changelog
1 parent 71130f4 commit 4fb2c3f

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Fix numpyro jax incompatibility. ([2465](https://github.com/arviz-devs/arviz/pull/2465))
99
- Avoid closing unloaded files in `from_netcdf()` ([2463](https://github.com/arviz-devs/arviz/issues/2463))
1010
- Fix sign error in lp parsed in from_numpyro ([2468](https://github.com/arviz-devs/arviz/issues/2468))
11+
- Fix attrs persistance in idata-datatree conversions ([2476](https://github.com/arviz-devs/arviz/issues/2476))
1112

1213
### Deprecation
1314

arviz/data/inference_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,9 @@ def to_datatree(self):
541541
"xarray must be have DataTree in order to use InferenceData.to_datatree. "
542542
"Update to xarray>=2024.11.0"
543543
) from err
544-
return DataTree.from_dict({group: ds for group, ds in self.items()})
544+
dt = DataTree.from_dict({group: ds for group, ds in self.items()})
545+
dt.attrs = self.attrs
546+
return dt
545547

546548
@staticmethod
547549
def from_datatree(datatree):
@@ -552,7 +554,8 @@ def from_datatree(datatree):
552554
datatree : DataTree
553555
"""
554556
return InferenceData(
555-
**{group: child.to_dataset() for group, child in datatree.children.items()}
557+
attrs=datatree.attrs,
558+
**{group: child.to_dataset() for group, child in datatree.children.items()},
556559
)
557560

558561
def to_dict(self, groups=None, filter_groups=None):

arviz/tests/base_tests/test_data.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,6 +1510,15 @@ def test_datatree(self):
15101510
assert_identical(ds, idata_back[group])
15111511
assert all(group in dt.children for group in idata.groups())
15121512

1513+
def test_datatree_attrs(self):
1514+
idata = load_arviz_data("centered_eight")
1515+
idata.attrs = {"not": "empty"}
1516+
assert idata.attrs
1517+
dt = idata.to_datatree()
1518+
idata_back = from_datatree(dt)
1519+
assert dt.attrs == idata.attrs
1520+
assert idata_back.attrs == idata.attrs
1521+
15131522

15141523
class TestConversions:
15151524
def test_id_conversion_idempotent(self):

0 commit comments

Comments
 (0)