Skip to content

Commit 9c2e659

Browse files
committed
properly version-gate xfailing test
1 parent 9235625 commit 9c2e659

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

src/anndata/_core/merge.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,11 @@ def equal(a, b) -> bool:
122122
b = asarray(b)
123123
if a.ndim == b.ndim == 0:
124124
return bool(a == b)
125-
return np.array_equal(a, b)
125+
a_na = (
126+
pd.isna(a) if a.dtype.names is None else np.False_
127+
) # pd.isna doesn't work for record arrays
128+
b_na = pd.isna(b) if b.dtype.names is None else np.False_
129+
return np.array_equal(a_na, b_na) and np.array_equal(a[~a_na], b[~b_na])
126130

127131

128132
@equal.register(pd.DataFrame)

tests/test_concatenate.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pandas as pd
1313
import pytest
1414
import scipy
15+
import xarray as xr
1516
from boltons.iterutils import default_exit, remap, research
1617
from numpy import ma
1718
from packaging.version import Version
@@ -146,6 +147,14 @@ def fix_known_differences(
146147
orig = orig.copy()
147148
result = result.copy()
148149

150+
if backwards_compat:
151+
del orig.varm
152+
del orig.varp
153+
if isinstance(result.obs, XDataset):
154+
result.obs = result.obs.drop_vars(["batch"])
155+
else:
156+
result.obs.drop(columns=["batch"], inplace=True)
157+
149158
for attrname in ("obs", "var"):
150159
if isinstance(getattr(result, attrname), XDataset):
151160
for adata in (orig, result):
@@ -171,11 +180,6 @@ def fix_known_differences(
171180
# * merge obsp, but some information should be lost
172181
del orig.obsp # TODO
173182

174-
if backwards_compat:
175-
del orig.varm
176-
del orig.varp
177-
result.obs.drop(columns=["batch"], inplace=True)
178-
179183
# Possibly need to fix this, ordered categoricals lose orderedness
180184
for get_df in [lambda k: k.obs, lambda k: k.obsm["df"]]:
181185
str_to_df_converted = get_df(result)
@@ -234,9 +238,6 @@ def test_concatenate_roundtrip(
234238
**GEN_ADATA_DASK_ARGS,
235239
)
236240

237-
if backwards_compat and (obs_xdataset or var_xdataset):
238-
pytest.xfail("https://github.com/pydata/xarray/issues/10218")
239-
240241
remaining = adata.obs_names
241242
subsets = []
242243
while len(remaining) > 0:
@@ -246,6 +247,16 @@ def test_concatenate_roundtrip(
246247
remaining = remaining.difference(subset_idx)
247248

248249
result = concat_func(subsets, join=join_type, uns_merge="same", index_unique=None)
250+
if (
251+
backwards_compat
252+
and (obs_xdataset or var_xdataset)
253+
and Version(xr.__version__) < Version("2025.4.0")
254+
):
255+
pytest.xfail("https://github.com/pydata/xarray/issues/10218")
256+
if backwards_compat and var_xdataset:
257+
result.var = xr.Dataset.from_dataframe(
258+
result.var
259+
) # backwards compat always returns a dataframe
249260

250261
# Correcting for known differences
251262
orig, result = fix_known_differences(

0 commit comments

Comments
 (0)