From 6c0fe74117a6dd54990d025d1182d4448c7a74d1 Mon Sep 17 00:00:00 2001 From: amalia-k510 Date: Thu, 26 Jun 2025 13:39:02 +0200 Subject: [PATCH 1/6] utils switch test --- benchmarks/benchmarks/utils.py | 77 +++++++++++++++++++--------- tests/test_backend_agnostic_utils.py | 47 +++++++++++++++++ 2 files changed, 99 insertions(+), 25 deletions(-) create mode 100644 tests/test_backend_agnostic_utils.py diff --git a/benchmarks/benchmarks/utils.py b/benchmarks/benchmarks/utils.py index 9e983498f..23f07e785 100644 --- a/benchmarks/benchmarks/utils.py +++ b/benchmarks/benchmarks/utils.py @@ -5,14 +5,32 @@ from string import ascii_lowercase from time import sleep +import jax import numpy as np import pandas as pd +from array_api_compat import get_namespace as array_api_get_namespace from memory_profiler import memory_usage from scipy import sparse from anndata import AnnData +def get_namespace(x=None): + return array_api_get_namespace(x) + + +def get_rng(xp, seed=None): + """Return a backend-specific random number generator.""" + # RNG isn't standardized in the Array API spec, + # so backends like JAX, PyTorch, and NumPy each handle randomness differently. + if xp.__name__.startswith("jax"): + return jax.random.PRNGKey(seed or 0) + elif xp.__name__.startswith("numpy"): + return np.random.default_rng(seed) + else: + raise NotImplementedError(f"RNG not implemented for backend: {xp.__name__}") + + def get_actualsize(input_obj): """Using Python Garbage Collector to calculate the size of all elements attached to an object""" @@ -40,7 +58,8 @@ def get_anndata_memsize(adata): def get_peak_mem(op, interval=0.001): recording = memory_usage(op, interval=interval) - return np.max(recording) - np.min(recording) + xp = get_namespace() + return xp.max(recording) - xp.min(recording) def sedate(func, naplength=0.05): @@ -58,7 +77,7 @@ def wrapped_function(*args, **kwargs): # TODO: Factor out the time it takes to generate these -def gen_indexer(adata, dim, index_kind, ratio): +def gen_indexer(adata, dim, index_kind, ratio, seed=None): dimnames = ("obs", "var") index_kinds = {"slice", "intarray", "boolarray", "strarray"} @@ -66,50 +85,58 @@ def gen_indexer(adata, dim, index_kind, ratio): msg = f"Argument 'index_kind' must be one of {index_kinds}. Was {index_kind}." raise ValueError(msg) + xp = get_namespace(adata.X) + rng = get_rng(xp, seed) axis = dimnames.index(dim) subset = [slice(None), slice(None)] axis_size = adata.shape[axis] + n = int(xp.round(axis_size * ratio)) if index_kind == "slice": - subset[axis] = slice(0, int(np.round(axis_size * ratio))) + subset[axis] = slice(0, n)) elif index_kind == "intarray": - subset[axis] = np.random.choice( - np.arange(axis_size), int(np.round(axis_size * ratio)), replace=False - ) - subset[axis].sort() + if xp.__name__.startswith("jax"): + subset[axis] = jax.random.choice(rng, xp.arange(axis_size), shape=(n,), replace=False) + elif xp.__name__.startswith("numpy"): + subset[axis] = xp.asarray(rng.choice(axis_size, n, replace=False)) + elif index_kind == "boolarray": - pos = np.random.choice( - np.arange(axis_size), int(np.round(axis_size * ratio)), replace=False - ) - a = np.zeros(axis_size, dtype=bool) - a[pos] = True - subset[axis] = a + mask = xp.zeros(axis_size, dtype=bool) + if xp.__name__.startswith("jax"): + idx = jax.random.choice(rng, xp.arange(axis_size), shape=(n,), replace=False) + elif xp.__name__.startswith("numpy"): + idx = rng.choice(axis_size, n, replace=False) + mask[idx] = True + subset[axis] = mask + elif index_kind == "strarray": - subset[axis] = np.random.choice( - getattr(adata, dim).index, int(np.round(axis_size * ratio)), replace=False + subset[axis] = rng.choice( + getattr(adata, dim).index, n, replace=False ) else: raise ValueError() return tuple(subset) -def gen_adata(n_obs, n_var, attr_set): +def gen_adata(n_obs, n_var, attr_set, seed=None): + xp = get_namespace() + rng = get_rng(xp, seed) if "X-csr" in attr_set: X = sparse.random(n_obs, n_var, density=0.1, format="csr") elif "X-dense" in attr_set: X = sparse.random(n_obs, n_var, density=0.1, format="csr") - X = X.toarray() + X = xp.asarray(X.toarray()) else: # TODO: There's probably a better way to do this X = sparse.random(n_obs, n_var, density=0, format="csr") adata = AnnData(X) if "obs,var" in attr_set: - adata.obs = pd.DataFrame( - {k: np.random.randint(0, 100, n_obs) for k in ascii_lowercase}, - index=[f"cell{i}" for i in range(n_obs)], - ) - adata.var = pd.DataFrame( - {k: np.random.randint(0, 100, n_var) for k in ascii_lowercase}, - index=[f"gene{i}" for i in range(n_var)], - ) + if xp.__name__.startswith("jax"): + obs = {k: jax.random.randint(rng, (n_obs,), 0, 100) for k in ascii_lowercase} + var = {k: jax.random.randint(rng, (n_var,), 0, 100) for k in ascii_lowercase} + elif xp.__name__.startswith("numpy"): + obs = {k: rng.integers(0, 100, size=n_obs) for k in ascii_lowercase} + var = {k: rng.integers(0, 100, size=n_var) for k in ascii_lowercase} + adata.obs = pd.DataFrame(obs, index=[f"cell{i}" for i in range(n_obs)]) + adata.var = pd.DataFrame(var, index=[f"gene{i}" for i in range(n_var)]) return adata diff --git a/tests/test_backend_agnostic_utils.py b/tests/test_backend_agnostic_utils.py new file mode 100644 index 000000000..7b139b23a --- /dev/null +++ b/tests/test_backend_agnostic_utils.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import jax +import numpy as np +import pytest + +from anndata import AnnData + +from .utils import gen_adata, gen_indexer + + +@pytest.mark.parametrize("backend", ["numpy", "jax"]) +def test_gen_adata_and_indexing(backend): + # Generate AnnData using backend + if backend == "numpy": + pass # default backend used by gen_adata + elif backend == "jax": + jnp = jax.numpy + _ = jnp.ones((1,)) # ensure JAX is available and triggers JAX namespace + else: + raise ValueError(f"Unsupported backend: {backend}") + + adata: AnnData = gen_adata(100, 50, {"X-dense", "obs,var"}, seed=42) + + # Check structure + assert adata.shape == (100, 50) + assert "a" in adata.obs.columns + assert "a" in adata.var.columns + + # Test each index kind + for kind in ["slice", "intarray", "boolarray", "strarray"]: + subset = gen_indexer(adata, "obs", kind, 0.3, seed=123) + assert isinstance(subset, tuple) + assert len(subset) == 2 + + index = subset[0] + if kind == "slice": + assert isinstance(index, slice) + elif kind == "intarray": + assert hasattr(index, "shape") + assert 0 < index.shape[0] <= 100 + elif kind == "boolarray": + assert index.shape == (100,) + assert index.dtype == bool + elif kind == "strarray": + assert isinstance(index, (list, np.ndarray)) + assert all(isinstance(i, str) for i in index) From fc763d48695957714f9ea9cdbb5e1e6dce4c1653 Mon Sep 17 00:00:00 2001 From: amalia-k510 Date: Thu, 26 Jun 2025 17:34:31 +0200 Subject: [PATCH 2/6] both the testing and index.py still need debugging --- pyproject.toml | 4 + src/anndata/_core/index.py | 123 ++++++++++++++++++--------- tests/test_backend_agnostic_utils.py | 94 ++++++++++---------- 3 files changed, 136 insertions(+), 85 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 64ddb63cc..1085350f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,8 @@ dev = [ # runtime dev version generation "hatch-vcs", "anndata[dev-doc]", + "jax", + "jaxlib", ] doc = [ "sphinx>=8.2.1", @@ -105,6 +107,8 @@ test-min = [ test = [ "anndata[test-min,lazy]", "pandas>=2.1.0", + "jax", + "jaxlib", ] # pandas 2.1.0 needs to be specified for xarray to work with min-deps script gpu = [ "cupy" ] cu12 = [ "cupy-cuda12x" ] diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index 5ed271add..be2ea696d 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -3,11 +3,12 @@ from collections.abc import Iterable, Sequence from functools import singledispatch from itertools import repeat -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import h5py import numpy as np import pandas as pd +from array_api_compat import get_namespace as array_api_get_namespace from scipy.sparse import issparse from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray @@ -17,6 +18,40 @@ from ..compat import Index, Index1D +def get_xp(x): + # to fall back to numpy if no array API is available + try: + return array_api_get_namespace(x) + except Exception: + return np + + +def is_array_api_obj(x): + # check if object supports _array_namespace__, else fall back to numpy + return hasattr(x, "__array_namespace__") + + +def get_numeric_dtypes(xp): + return ( + xp.dtype(xp.int32), + xp.dtype(xp.int64), + xp.dtype(xp.float32), + xp.dtype(xp.float64), + ) + + +def get_integer_dtypes(xp): + return (xp.dtype(xp.int32), xp.dtype(xp.int64)) + + +def get_floating_dtypes(xp): + return (xp.dtype(xp.float32), xp.dtype(xp.float64)) + + +def get_boolean_dtype(xp): + return xp.dtype(xp.bool_) + + def _normalize_indices( index: Index | None, names0: pd.Index, names1: pd.Index ) -> tuple[slice, slice]: @@ -36,17 +71,17 @@ def _normalize_indices( def _normalize_index( # noqa: PLR0911, PLR0912 - indexer: slice - | np.integer - | int - | str - | Sequence[bool | int | np.integer] - | np.ndarray - | pd.Index, + indexer, index: pd.Index, -) -> slice | int | np.ndarray: # ndarray of int or bool +) -> ( + slice | int | Any +): # ndarray of int or bool, switched to Any to make it compatible with array API objects # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. - if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64): + xp = get_xp(indexer) + if not isinstance(index, pd.RangeIndex) and index.dtype in ( + xp.float64, + xp.int64, + ): msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}" raise TypeError(msg) @@ -65,34 +100,37 @@ def name_idx(i): stop = None if stop is None else stop + 1 step = indexer.step return slice(start, stop, step) - elif isinstance(indexer, np.integer | int): + if isinstance(indexer, (int,)) or ( + is_array_api_obj(indexer) and isinstance(indexer, xp.integer) + ): return indexer elif isinstance(indexer, str): return index.get_loc(indexer) # int elif isinstance( - indexer, Sequence | np.ndarray | pd.Index | CSMatrix | np.matrix | CSArray - ): + indexer, (Sequence, pd.Index, CSMatrix, CSArray) + ) or is_array_api_obj(indexer): if hasattr(indexer, "shape") and ( (indexer.shape == (index.shape[0], 1)) or (indexer.shape == (1, index.shape[0])) ): if isinstance(indexer, CSMatrix | CSArray): indexer = indexer.toarray() - indexer = np.ravel(indexer) - if not isinstance(indexer, np.ndarray | pd.Index): - indexer = np.array(indexer) + indexer = xp.ravel(indexer) + if not isinstance(indexer, (pd.Index,)) and not is_array_api_obj(indexer): + indexer = xp.array(indexer) if len(indexer) == 0: indexer = indexer.astype(int) - if isinstance(indexer, np.ndarray) and np.issubdtype( - indexer.dtype, np.floating - ): + + if get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).floating): indexer_int = indexer.astype(int) - if np.all((indexer - indexer_int) != 0): + if xp.all((indexer - indexer_int) != 0): msg = f"Indexer {indexer!r} has floating point values." raise IndexError(msg) - if issubclass(indexer.dtype.type, np.integer | np.floating): + if get_xp(indexer).issubdtype( + indexer.dtype, get_xp(indexer).integer | get_xp(indexer).floating + ): return indexer # Might not work for range indexes - elif issubclass(indexer.dtype.type, np.bool_): + elif get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).bool_): if indexer.shape != index.shape: msg = ( f"Boolean index does not match AnnData’s shape along this " @@ -103,7 +141,7 @@ def name_idx(i): return indexer else: # indexer should be string array positions = index.get_indexer(indexer) - if np.any(positions < 0): + if get_xp(positions).any(positions < 0): not_found = indexer[positions < 0] msg = ( f"Values {list(not_found)}, from {list(indexer)}, " @@ -168,11 +206,14 @@ def unpack_index(index: Index) -> tuple[Index1D, Index1D]: @singledispatch -def _subset(a: np.ndarray | pd.DataFrame, subset_idx: Index): +def _subset(a, subset_idx: Index): # Select as combination of indexes, not coordinates # Correcting for indexing behaviour of np.ndarray - if all(isinstance(x, Iterable) for x in subset_idx): - subset_idx = np.ix_(*subset_idx) + xp = get_xp(a) + if all( + isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx + ): + subset_idx = xp.ix_(*subset_idx) return a[subset_idx] @@ -189,10 +230,13 @@ def _subset_dask(a: DaskArray, subset_idx: Index): @_subset.register(CSArray) def _subset_sparse(a: CSMatrix | CSArray, subset_idx: Index): # Correcting for indexing behaviour of sparse.spmatrix - if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx): + xp = get_xp(a) + if len(subset_idx) > 1 and all( + isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx + ): first_idx = subset_idx[0] - if issubclass(first_idx.dtype.type, np.bool_): - first_idx = np.where(first_idx)[0] + if hasattr(first_idx, "dtype") and first_idx.dtype == bool: + first_idx = xp.where(first_idx)[0] subset_idx = (first_idx.reshape(-1, 1), *subset_idx[1:]) return a[subset_idx] @@ -205,8 +249,11 @@ def _subset_df(df: pd.DataFrame | Dataset2D, subset_idx: Index): @_subset.register(AwkArray) def _subset_awkarray(a: AwkArray, subset_idx: Index): - if all(isinstance(x, Iterable) for x in subset_idx): - subset_idx = np.ix_(*subset_idx) + xp = get_xp(a) + if all( + isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx + ): + subset_idx = xp.ix_(*subset_idx) return a[subset_idx] @@ -215,15 +262,15 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index): def _subset_dataset(d, subset_idx): if not isinstance(subset_idx, tuple): subset_idx = (subset_idx,) + xp = get_xp(subset_idx[0]) ordered = list(subset_idx) rev_order = [slice(None) for _ in range(len(subset_idx))] for axis, axis_idx in enumerate(ordered.copy()): - if isinstance(axis_idx, np.ndarray): - if axis_idx.dtype == bool: - axis_idx = np.where(axis_idx)[0] - order = np.argsort(axis_idx) - ordered[axis] = axis_idx[order] - rev_order[axis] = np.argsort(order) + if hasattr(axis_idx, "dtype") and axis_idx.dtype == bool: + axis_idx = xp.where(axis_idx)[0] + order = xp.argsort(axis_idx) + ordered[axis] = axis_idx[order] + rev_order[axis] = xp.argsort(order) # from hdf5, then to real order return d[tuple(ordered)][tuple(rev_order)] @@ -257,4 +304,4 @@ def get_vector(adata, k, coldim, idxdim, layer=None): a = adata._get_X(layer=layer)[idx] if issparse(a): a = a.toarray() - return np.ravel(a) + return get_xp(a).ravel(a) diff --git a/tests/test_backend_agnostic_utils.py b/tests/test_backend_agnostic_utils.py index 7b139b23a..0af18b6b7 100644 --- a/tests/test_backend_agnostic_utils.py +++ b/tests/test_backend_agnostic_utils.py @@ -1,47 +1,47 @@ -from __future__ import annotations - -import jax -import numpy as np -import pytest - -from anndata import AnnData - -from .utils import gen_adata, gen_indexer - - -@pytest.mark.parametrize("backend", ["numpy", "jax"]) -def test_gen_adata_and_indexing(backend): - # Generate AnnData using backend - if backend == "numpy": - pass # default backend used by gen_adata - elif backend == "jax": - jnp = jax.numpy - _ = jnp.ones((1,)) # ensure JAX is available and triggers JAX namespace - else: - raise ValueError(f"Unsupported backend: {backend}") - - adata: AnnData = gen_adata(100, 50, {"X-dense", "obs,var"}, seed=42) - - # Check structure - assert adata.shape == (100, 50) - assert "a" in adata.obs.columns - assert "a" in adata.var.columns - - # Test each index kind - for kind in ["slice", "intarray", "boolarray", "strarray"]: - subset = gen_indexer(adata, "obs", kind, 0.3, seed=123) - assert isinstance(subset, tuple) - assert len(subset) == 2 - - index = subset[0] - if kind == "slice": - assert isinstance(index, slice) - elif kind == "intarray": - assert hasattr(index, "shape") - assert 0 < index.shape[0] <= 100 - elif kind == "boolarray": - assert index.shape == (100,) - assert index.dtype == bool - elif kind == "strarray": - assert isinstance(index, (list, np.ndarray)) - assert all(isinstance(i, str) for i in index) +# from __future__ import annotations + +# import jax +# import numpy as np +# import pytest + +# from anndata import AnnData + +# from .utils import gen_adata, gen_indexer + + +# @pytest.mark.parametrize("backend", ["numpy", "jax"]) +# def test_gen_adata_and_indexing(backend): +# # Generate AnnData using backend +# if backend == "numpy": +# pass # default backend used by gen_adata +# elif backend == "jax": +# jnp = jax.numpy +# _ = jnp.ones((1,)) # ensure JAX is available and triggers JAX namespace +# else: +# raise ValueError(f"Unsupported backend: {backend}") + +# adata: AnnData = gen_adata(100, 50, {"X-dense", "obs,var"}, seed=42) + +# # Check structure +# assert adata.shape == (100, 50) +# assert "a" in adata.obs.columns +# assert "a" in adata.var.columns + +# # Test each index kind +# for kind in ["slice", "intarray", "boolarray", "strarray"]: +# subset = gen_indexer(adata, "obs", kind, 0.3, seed=123) +# assert isinstance(subset, tuple) +# assert len(subset) == 2 + +# index = subset[0] +# if kind == "slice": +# assert isinstance(index, slice) +# elif kind == "intarray": +# assert hasattr(index, "shape") +# assert 0 < index.shape[0] <= 100 +# elif kind == "boolarray": +# assert index.shape == (100,) +# assert index.dtype == bool +# elif kind == "strarray": +# assert isinstance(index, (list, np.ndarray)) +# assert all(isinstance(i, str) for i in index) From 77bf36223f5336b7d229357059ed23d335c21e73 Mon Sep 17 00:00:00 2001 From: amalia-k510 Date: Mon, 30 Jun 2025 14:29:11 +0200 Subject: [PATCH 3/6] test --- benchmarks/benchmarks/utils.py | 42 ++++++++----- tests/test_backend_agnostic_utils.py | 93 ++++++++++++++-------------- 2 files changed, 74 insertions(+), 61 deletions(-) diff --git a/benchmarks/benchmarks/utils.py b/benchmarks/benchmarks/utils.py index 23f07e785..1a11ccf72 100644 --- a/benchmarks/benchmarks/utils.py +++ b/benchmarks/benchmarks/utils.py @@ -93,47 +93,61 @@ def gen_indexer(adata, dim, index_kind, ratio, seed=None): n = int(xp.round(axis_size * ratio)) if index_kind == "slice": - subset[axis] = slice(0, n)) + subset[axis] = slice(0, n) elif index_kind == "intarray": if xp.__name__.startswith("jax"): - subset[axis] = jax.random.choice(rng, xp.arange(axis_size), shape=(n,), replace=False) + subset[axis] = jax.random.choice( + rng, xp.arange(axis_size), shape=(n,), replace=False + ) elif xp.__name__.startswith("numpy"): subset[axis] = xp.asarray(rng.choice(axis_size, n, replace=False)) elif index_kind == "boolarray": mask = xp.zeros(axis_size, dtype=bool) if xp.__name__.startswith("jax"): - idx = jax.random.choice(rng, xp.arange(axis_size), shape=(n,), replace=False) + idx = jax.random.choice( + rng, xp.arange(axis_size), shape=(n,), replace=False + ) elif xp.__name__.startswith("numpy"): idx = rng.choice(axis_size, n, replace=False) mask[idx] = True subset[axis] = mask elif index_kind == "strarray": - subset[axis] = rng.choice( - getattr(adata, dim).index, n, replace=False - ) + subset[axis] = rng.choice(getattr(adata, dim).index, n, replace=False) else: raise ValueError() return tuple(subset) def gen_adata(n_obs, n_var, attr_set, seed=None): - xp = get_namespace() - rng = get_rng(xp, seed) if "X-csr" in attr_set: - X = sparse.random(n_obs, n_var, density=0.1, format="csr") + X_sparse = sparse.random(n_obs, n_var, density=0.1, format="csr") + xp = get_namespace(X_sparse.toarray()) + X = X_sparse elif "X-dense" in attr_set: - X = sparse.random(n_obs, n_var, density=0.1, format="csr") - X = xp.asarray(X.toarray()) + dense_X = sparse.random(n_obs, n_var, density=0.1, format="csr") + xp = get_namespace(dense_X) + X = xp.asarray(dense_X) else: # TODO: There's probably a better way to do this - X = sparse.random(n_obs, n_var, density=0, format="csr") + # fallback to use just numpy + import numpy as np + + X_dense = np.zeros((n_obs, n_var)) + # X = sparse.random(n_obs, n_var, density=0, format="csr") + xp = get_namespace(X_dense) + X = xp.asarray(X_dense) + rng = get_rng(xp, seed) adata = AnnData(X) if "obs,var" in attr_set: if xp.__name__.startswith("jax"): - obs = {k: jax.random.randint(rng, (n_obs,), 0, 100) for k in ascii_lowercase} - var = {k: jax.random.randint(rng, (n_var,), 0, 100) for k in ascii_lowercase} + obs = { + k: jax.random.randint(rng, (n_obs,), 0, 100) for k in ascii_lowercase + } + var = { + k: jax.random.randint(rng, (n_var,), 0, 100) for k in ascii_lowercase + } elif xp.__name__.startswith("numpy"): obs = {k: rng.integers(0, 100, size=n_obs) for k in ascii_lowercase} var = {k: rng.integers(0, 100, size=n_var) for k in ascii_lowercase} diff --git a/tests/test_backend_agnostic_utils.py b/tests/test_backend_agnostic_utils.py index 0af18b6b7..12b3be7a4 100644 --- a/tests/test_backend_agnostic_utils.py +++ b/tests/test_backend_agnostic_utils.py @@ -1,47 +1,46 @@ -# from __future__ import annotations - -# import jax -# import numpy as np -# import pytest - -# from anndata import AnnData - -# from .utils import gen_adata, gen_indexer - - -# @pytest.mark.parametrize("backend", ["numpy", "jax"]) -# def test_gen_adata_and_indexing(backend): -# # Generate AnnData using backend -# if backend == "numpy": -# pass # default backend used by gen_adata -# elif backend == "jax": -# jnp = jax.numpy -# _ = jnp.ones((1,)) # ensure JAX is available and triggers JAX namespace -# else: -# raise ValueError(f"Unsupported backend: {backend}") - -# adata: AnnData = gen_adata(100, 50, {"X-dense", "obs,var"}, seed=42) - -# # Check structure -# assert adata.shape == (100, 50) -# assert "a" in adata.obs.columns -# assert "a" in adata.var.columns - -# # Test each index kind -# for kind in ["slice", "intarray", "boolarray", "strarray"]: -# subset = gen_indexer(adata, "obs", kind, 0.3, seed=123) -# assert isinstance(subset, tuple) -# assert len(subset) == 2 - -# index = subset[0] -# if kind == "slice": -# assert isinstance(index, slice) -# elif kind == "intarray": -# assert hasattr(index, "shape") -# assert 0 < index.shape[0] <= 100 -# elif kind == "boolarray": -# assert index.shape == (100,) -# assert index.dtype == bool -# elif kind == "strarray": -# assert isinstance(index, (list, np.ndarray)) -# assert all(isinstance(i, str) for i in index) +from __future__ import annotations + +import jax +import numpy as np +import pytest +from benchmarks.benchmarks.utils import gen_adata, gen_indexer + +from anndata import AnnData + + +@pytest.mark.parametrize("backend", ["numpy", "jax"]) +def test_gen_adata_and_indexing(backend): + # Generate AnnData using backend + if backend == "numpy": + pass # default backend used by gen_adata + elif backend == "jax": + jnp = jax.numpy + _ = jnp.ones((1,)) + else: + raise ValueError(f"Unsupported backend: {backend}") + + adata: AnnData = gen_adata(100, 50, {"X-dense", "obs,var"}, seed=42) + + # Check structure + assert adata.shape == (100, 50) + assert "a" in adata.obs.columns + assert "a" in adata.var.columns + + # Test each index kind + for kind in ["slice", "intarray", "boolarray", "strarray"]: + subset = gen_indexer(adata, "obs", kind, 0.3, seed=123) + assert isinstance(subset, tuple) + assert len(subset) == 2 + + index = subset[0] + if kind == "slice": + assert isinstance(index, slice) + elif kind == "intarray": + assert hasattr(index, "shape") + assert 0 < index.shape[0] <= 100 + elif kind == "boolarray": + assert index.shape == (100,) + assert index.dtype == bool + elif kind == "strarray": + assert isinstance(index, (list, np.ndarray)) + assert all(isinstance(i, str) for i in index) From 61ae62a9f619a042099cff27f66eb571b1c53a93 Mon Sep 17 00:00:00 2001 From: amalia-k510 Date: Thu, 3 Jul 2025 14:45:08 +0200 Subject: [PATCH 4/6] reverting some changes and first changes to helpers. --- benchmarks/benchmarks/utils.py | 257 ++++++++++++++++++++++++--------- src/anndata/tests/helpers.py | 45 +++++- 2 files changed, 228 insertions(+), 74 deletions(-) diff --git a/benchmarks/benchmarks/utils.py b/benchmarks/benchmarks/utils.py index 1a11ccf72..352388570 100644 --- a/benchmarks/benchmarks/utils.py +++ b/benchmarks/benchmarks/utils.py @@ -5,32 +5,14 @@ from string import ascii_lowercase from time import sleep -import jax import numpy as np import pandas as pd -from array_api_compat import get_namespace as array_api_get_namespace from memory_profiler import memory_usage from scipy import sparse from anndata import AnnData -def get_namespace(x=None): - return array_api_get_namespace(x) - - -def get_rng(xp, seed=None): - """Return a backend-specific random number generator.""" - # RNG isn't standardized in the Array API spec, - # so backends like JAX, PyTorch, and NumPy each handle randomness differently. - if xp.__name__.startswith("jax"): - return jax.random.PRNGKey(seed or 0) - elif xp.__name__.startswith("numpy"): - return np.random.default_rng(seed) - else: - raise NotImplementedError(f"RNG not implemented for backend: {xp.__name__}") - - def get_actualsize(input_obj): """Using Python Garbage Collector to calculate the size of all elements attached to an object""" @@ -58,8 +40,7 @@ def get_anndata_memsize(adata): def get_peak_mem(op, interval=0.001): recording = memory_usage(op, interval=interval) - xp = get_namespace() - return xp.max(recording) - xp.min(recording) + return np.max(recording) - np.min(recording) def sedate(func, naplength=0.05): @@ -77,7 +58,7 @@ def wrapped_function(*args, **kwargs): # TODO: Factor out the time it takes to generate these -def gen_indexer(adata, dim, index_kind, ratio, seed=None): +def gen_indexer(adata, dim, index_kind, ratio): dimnames = ("obs", "var") index_kinds = {"slice", "intarray", "boolarray", "strarray"} @@ -85,72 +66,208 @@ def gen_indexer(adata, dim, index_kind, ratio, seed=None): msg = f"Argument 'index_kind' must be one of {index_kinds}. Was {index_kind}." raise ValueError(msg) - xp = get_namespace(adata.X) - rng = get_rng(xp, seed) axis = dimnames.index(dim) subset = [slice(None), slice(None)] axis_size = adata.shape[axis] - n = int(xp.round(axis_size * ratio)) if index_kind == "slice": - subset[axis] = slice(0, n) + subset[axis] = slice(0, int(np.round(axis_size * ratio))) elif index_kind == "intarray": - if xp.__name__.startswith("jax"): - subset[axis] = jax.random.choice( - rng, xp.arange(axis_size), shape=(n,), replace=False - ) - elif xp.__name__.startswith("numpy"): - subset[axis] = xp.asarray(rng.choice(axis_size, n, replace=False)) - + subset[axis] = np.random.choice( + np.arange(axis_size), int(np.round(axis_size * ratio)), replace=False + ) + subset[axis].sort() elif index_kind == "boolarray": - mask = xp.zeros(axis_size, dtype=bool) - if xp.__name__.startswith("jax"): - idx = jax.random.choice( - rng, xp.arange(axis_size), shape=(n,), replace=False - ) - elif xp.__name__.startswith("numpy"): - idx = rng.choice(axis_size, n, replace=False) - mask[idx] = True - subset[axis] = mask - + pos = np.random.choice( + np.arange(axis_size), int(np.round(axis_size * ratio)), replace=False + ) + a = np.zeros(axis_size, dtype=bool) + a[pos] = True + subset[axis] = a elif index_kind == "strarray": - subset[axis] = rng.choice(getattr(adata, dim).index, n, replace=False) + subset[axis] = np.random.choice( + getattr(adata, dim).index, int(np.round(axis_size * ratio)), replace=False + ) else: raise ValueError() return tuple(subset) -def gen_adata(n_obs, n_var, attr_set, seed=None): +def gen_adata(n_obs, n_var, attr_set): if "X-csr" in attr_set: - X_sparse = sparse.random(n_obs, n_var, density=0.1, format="csr") - xp = get_namespace(X_sparse.toarray()) - X = X_sparse + X = sparse.random(n_obs, n_var, density=0.1, format="csr") elif "X-dense" in attr_set: - dense_X = sparse.random(n_obs, n_var, density=0.1, format="csr") - xp = get_namespace(dense_X) - X = xp.asarray(dense_X) + X = sparse.random(n_obs, n_var, density=0.1, format="csr") + X = X.toarray() else: # TODO: There's probably a better way to do this - # fallback to use just numpy - import numpy as np - - X_dense = np.zeros((n_obs, n_var)) - # X = sparse.random(n_obs, n_var, density=0, format="csr") - xp = get_namespace(X_dense) - X = xp.asarray(X_dense) - rng = get_rng(xp, seed) + X = sparse.random(n_obs, n_var, density=0, format="csr") adata = AnnData(X) if "obs,var" in attr_set: - if xp.__name__.startswith("jax"): - obs = { - k: jax.random.randint(rng, (n_obs,), 0, 100) for k in ascii_lowercase - } - var = { - k: jax.random.randint(rng, (n_var,), 0, 100) for k in ascii_lowercase - } - elif xp.__name__.startswith("numpy"): - obs = {k: rng.integers(0, 100, size=n_obs) for k in ascii_lowercase} - var = {k: rng.integers(0, 100, size=n_var) for k in ascii_lowercase} - adata.obs = pd.DataFrame(obs, index=[f"cell{i}" for i in range(n_obs)]) - adata.var = pd.DataFrame(var, index=[f"gene{i}" for i in range(n_var)]) + adata.obs = pd.DataFrame( + {k: np.random.randint(0, 100, n_obs) for k in ascii_lowercase}, + index=[f"cell{i}" for i in range(n_obs)], + ) + adata.var = pd.DataFrame( + {k: np.random.randint(0, 100, n_var) for k in ascii_lowercase}, + index=[f"gene{i}" for i in range(n_var)], + ) return adata + + +# from __future__ import annotations + +# import gc +# import sys +# from string import ascii_lowercase +# from time import sleep + +# import jax +# import numpy as np +# import pandas as pd +# from array_api_compat import get_namespace as array_api_get_namespace +# from memory_profiler import memory_usage +# from scipy import sparse + +# from anndata import AnnData + + +# def get_namespace(x=None): +# return array_api_get_namespace(x) + + +# def get_rng(xp, seed=None): +# """Return a backend-specific random number generator.""" +# # RNG isn't standardized in the Array API spec, +# # so backends like JAX, PyTorch, and NumPy each handle randomness differently. +# if xp.__name__.startswith("jax"): +# return jax.random.PRNGKey(seed or 0) +# elif xp.__name__.startswith("numpy"): +# return np.random.default_rng(seed) +# else: +# raise NotImplementedError(f"RNG not implemented for backend: {xp.__name__}") + + +# def get_actualsize(input_obj): +# """Using Python Garbage Collector to calculate the size of all elements attached to an object""" + +# memory_size = 0 +# ids = set() +# objects = [input_obj] +# while objects: +# new = [] +# for obj in objects: +# if id(obj) not in ids: +# ids.add(id(obj)) +# memory_size += sys.getsizeof(obj) +# new.append(obj) +# objects = gc.get_referents(*new) +# return memory_size + + +# def get_anndata_memsize(adata): +# recording = memory_usage( +# (sedate(adata.copy, naplength=0.005), (adata,)), interval=0.001 +# ) +# diff = recording[-1] - recording[0] +# return diff + + +# def get_peak_mem(op, interval=0.001): +# recording = memory_usage(op, interval=interval) +# xp = get_namespace() +# return xp.max(recording) - xp.min(recording) + + +# def sedate(func, naplength=0.05): +# """Make a function sleepy, so we can sample the start and end state.""" + +# def wrapped_function(*args, **kwargs): +# sleep(naplength) +# val = func(*args, **kwargs) +# sleep(naplength) +# return val + +# return wrapped_function + + +# # TODO: Factor out the time it takes to generate these + + +# def gen_indexer(adata, dim, index_kind, ratio, seed=None): +# dimnames = ("obs", "var") +# index_kinds = {"slice", "intarray", "boolarray", "strarray"} + +# if index_kind not in index_kinds: +# msg = f"Argument 'index_kind' must be one of {index_kinds}. Was {index_kind}." +# raise ValueError(msg) + +# xp = get_namespace(adata.X) +# rng = get_rng(xp, seed) +# axis = dimnames.index(dim) +# subset = [slice(None), slice(None)] +# axis_size = adata.shape[axis] +# n = int(xp.round(axis_size * ratio)) + +# if index_kind == "slice": +# subset[axis] = slice(0, n) +# elif index_kind == "intarray": +# if xp.__name__.startswith("jax"): +# subset[axis] = jax.random.choice( +# rng, xp.arange(axis_size), shape=(n,), replace=False +# ) +# elif xp.__name__.startswith("numpy"): +# subset[axis] = xp.asarray(rng.choice(axis_size, n, replace=False)) + +# elif index_kind == "boolarray": +# mask = xp.zeros(axis_size, dtype=bool) +# if xp.__name__.startswith("jax"): +# idx = jax.random.choice( +# rng, xp.arange(axis_size), shape=(n,), replace=False +# ) +# elif xp.__name__.startswith("numpy"): +# idx = rng.choice(axis_size, n, replace=False) +# mask[idx] = True +# subset[axis] = mask + +# elif index_kind == "strarray": +# subset[axis] = rng.choice(getattr(adata, dim).index, n, replace=False) +# else: +# raise ValueError() +# return tuple(subset) + + +# def gen_adata(n_obs, n_var, attr_set, seed=None): +# if "X-csr" in attr_set: +# X_sparse = sparse.random(n_obs, n_var, density=0.1, format="csr") +# xp = get_namespace(X_sparse.toarray()) +# X = X_sparse +# elif "X-dense" in attr_set: +# dense_X = sparse.random(n_obs, n_var, density=0.1, format="csr") +# xp = get_namespace(dense_X) +# X = xp.asarray(dense_X) +# else: +# # TODO: There's probably a better way to do this +# # fallback to use just numpy +# import numpy as np + +# X_dense = np.zeros((n_obs, n_var)) +# # X = sparse.random(n_obs, n_var, density=0, format="csr") +# xp = get_namespace(X_dense) +# X = xp.asarray(X_dense) +# rng = get_rng(xp, seed) +# adata = AnnData(X) +# if "obs,var" in attr_set: +# if xp.__name__.startswith("jax"): +# obs = { +# k: jax.random.randint(rng, (n_obs,), 0, 100) for k in ascii_lowercase +# } +# var = { +# k: jax.random.randint(rng, (n_var,), 0, 100) for k in ascii_lowercase +# } +# elif xp.__name__.startswith("numpy"): +# obs = {k: rng.integers(0, 100, size=n_obs) for k in ascii_lowercase} +# var = {k: rng.integers(0, 100, size=n_var) for k in ascii_lowercase} +# adata.obs = pd.DataFrame(obs, index=[f"cell{i}" for i in range(n_obs)]) +# adata.var = pd.DataFrame(var, index=[f"gene{i}" for i in range(n_var)]) +# return adata diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 54906ab6b..069d330fd 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -10,10 +10,19 @@ from string import ascii_letters from typing import TYPE_CHECKING +# trying to import jax as it is not part of the default dependencies +try: + import jax.numpy as jnp + + JAX_AVAILABLE = True +except ImportError: + JAX_AVAILABLE = False + import h5py import numpy as np import pandas as pd import pytest +from array_api_compat import get_namespace as array_api_get_namespace from pandas.api.types import is_numeric_dtype from scipy import sparse @@ -50,7 +59,6 @@ DT = TypeVar("DT") - try: from pandas.core.arrays.integer import IntegerDtype except ImportError: @@ -66,6 +74,9 @@ pd.DataFrame, sparse.csr_array, ) +# Add JAX array type to supported backends if JAX is installed +if JAX_AVAILABLE: + DEFAULT_KEY_TYPES += (jnp.ndarray,) DEFAULT_COL_TYPES = ( @@ -79,18 +90,44 @@ pd.Int32Dtype, ) - -# Give this to gen_adata when dask array support is expected. +# preset for testing with Dask arrays +# includes DaskArray in obsm/varm/layers to test lazy evaluation and chunked storage +# useful when testing backend="backed"/"memory"/"zarr" GEN_ADATA_DASK_ARGS = dict( obsm_types=(*DEFAULT_KEY_TYPES, DaskArray), varm_types=(*DEFAULT_KEY_TYPES, DaskArray), layers_types=(*DEFAULT_KEY_TYPES, DaskArray), ) +# for testing without xarray (XDataset) types +# excludes XDataset to avoid xarray dependency or reduce test complexity +# useful for minimal tests or when xarray is not available GEN_ADATA_NO_XARRAY_ARGS = dict( - obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), varm_types=(*DEFAULT_KEY_TYPES, AwkArray) + obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), + varm_types=(*DEFAULT_KEY_TYPES, AwkArray), ) +# Optional: define args for JAX-specific backend tests +# preset for testing with JAX arrays +# useful when running tests using jax.numpy as the backend for X, obsm, varm, or layers. +GEN_ADATA_JAX_ARGS = ( + dict( + obsm_types=(*DEFAULT_KEY_TYPES,), + varm_types=(*DEFAULT_KEY_TYPES,), + layers_types=(*DEFAULT_KEY_TYPES,), + ) + if JAX_AVAILABLE + else {} +) + + +def get_xp(x): + try: + return array_api_get_namespace(x) + except ImportError: + # default to numpy if array_api_compat is not installed + return np + def gen_vstr_recarray(m, n, dtype=None): size = m * n From ae9311a4f86f2addeb003de02be50d57caea6e89 Mon Sep 17 00:00:00 2001 From: amalia-k510 Date: Thu, 3 Jul 2025 17:11:38 +0200 Subject: [PATCH 5/6] in debug process --- src/anndata/tests/helpers.py | 272 ++++++++++++++++++++++++++--------- 1 file changed, 206 insertions(+), 66 deletions(-) diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 069d330fd..58b33d3d8 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -12,6 +12,7 @@ # trying to import jax as it is not part of the default dependencies try: + import jax import jax.numpy as jnp JAX_AVAILABLE = True @@ -124,7 +125,8 @@ def get_xp(x): try: return array_api_get_namespace(x) - except ImportError: + # in case if getting other errors but ImportError + except Exception: # noqa: BLE001 # default to numpy if array_api_compat is not installed return np @@ -140,6 +142,54 @@ def gen_vstr_recarray(m, n, dtype=None): ) +# def gen_vstr_recarray( +# m: int, +# n: int, +# dtype=None, +# *, +# xp=np, # Backend namespace (e.g., np or jax.numpy) +# rng=None, # NumPy: np.random.Generator | JAX: (jax.random, key) +# ): +# """ +# Supports both NumPy and JAX via `xp`. Returns a NumPy-based recarray for compatibility +# with AnnData.uns (need to double check) +# """ + +# size = m * n +# letters = xp.array(list(ascii_letters)) +# if xp.__name__.startswith("jax"): +# jrandom, key = rng +# key, subkey = jrandom.split(key) +# lengths = jrandom.randint(subkey, shape=(size,), minval=3, maxval=5) +# lengths = xp.asarray(lengths) +# else: +# if rng is None: +# rng = np.random.default_rng() +# lengths = rng.integers(3, 5, size) + +# # generating worlds in different ways depending on the backend +# def gen_word_jax(l: int, key, jrandom, letters): +# key, subkey = jrandom.split(key) +# idxs = jrandom.randint(subkey, shape=(l,), minval=0, maxval=len(letters)) +# return "".join([ascii_letters[int(i)] for i in idxs]), key + +# def gen_word_numpy(l: int, rng, letters): +# return "".join(rng.choice(letters, l)) + +# words = [] +# if xp.__name__.startswith("jax"): +# for l in lengths: +# word, key = gen_word_jax(int(l), key, jrandom, letters) +# words.append(word) +# else: +# words.extend([gen_word_numpy(l, rng, letters) for l in lengths]) + +# arr = np.array(words).reshape(m, n) # Recarray must be NumPy-compatible +# columns = [f"col_{i}" for i in range(n)] +# df = pd.DataFrame(arr, columns=columns) +# return df.to_records(index=False, column_dtypes=dtype) + + def issubdtype( a: np.dtype | pd.api.extensions.ExtensionDtype | type, b: type[DT] | tuple[type[DT], ...], @@ -322,6 +372,8 @@ def gen_adata( # noqa: PLR0913 layers_types: Collection[type] = DEFAULT_KEY_TYPES, random_state: np.random.Generator | None = None, sparse_fmt: Literal["csr", "csc"] = "csr", + xp=None, # setting default to None for now but maybe np is better (review) + rng=None, ) -> AnnData: """\ Helper function to generate a random AnnData for testing purposes. @@ -351,12 +403,31 @@ def gen_adata( # noqa: PLR0913 """ import dask.array as da + from anndata.experimental import gen_vstr_recarray + if random_state is None: random_state = np.random.default_rng() M, N = shape obs_names = pd.Index(f"cell{i}" for i in range(shape[0])) var_names = pd.Index(f"gene{i}" for i in range(shape[1])) + + # initialize backends + if xp is None: + xp = get_xp(X_type(xp.ones((1, 1), dtype=X_dtype))) + if rng is None: + if xp.__name__.startswith("jax"): + rng = (jax.random, jax.random.PRNGKey(42)) + else: + rng = np.random.default_rng() + + is_jax = xp.__name__.startswith("jax") + if is_jax: + jrandom, key = rng + else: + jrandom, key = None, None + + # generate obs and var dataframes obs = gen_typed_df(M, obs_names, dtypes=obs_dtypes) var = gen_typed_df(N, var_names, dtypes=var_dtypes) # For #147 @@ -369,91 +440,160 @@ def gen_adata( # noqa: PLR0913 if var_xdataset: var = XDataset.from_dataframe(var) + # generating X and including jax support if X_type is None: - X = None + X = None # if no data matrix is requested, skip creation + elif is_jax and issubclass(X_type, sparse.spmatrix): + # JAX does not support scipy sparse matrices (double check) + msg = "JAX does not support sparse matrices" + raise ValueError(msg) + else: + if is_jax: + # split the JAX PRNG key to create a new, independent subkey + key, subkey = jrandom.split(key) + X_array = jrandom.binomial(subkey, 100.0, 0.005, shape=(M, N)).astype( + X_dtype + ) + rng = (jrandom, key) + else: + # if using numpy, generate a binomial distribution + X_array = rng.binomial(100, 0.005, (M, N)).astype(X_dtype) + X = X_type(X_array) + + # Generate obsm + if is_jax: + key, subkey = jrandom.split(key) + obsm_array = jrandom.uniform(subkey, shape=(M, 50)) + rng = (jrandom, key) else: - X = X_type(random_state.binomial(100, 0.005, (M, N)).astype(X_dtype)) + obsm_array = rng.random((M, 50)) + # sparse was moved out due to JAX compatibility issues obsm = dict( - array=np.random.random((M, 50)), - sparse=sparse.random(M, 100, format=sparse_fmt, random_state=random_state), + array=obsm_array, df=gen_typed_df(M, obs_names, dtypes=obs_dtypes), awk_2d_ragged=gen_awkward((M, None)), da=da.random.random((M, 50)), ) + + if not is_jax: + obsm["sparse"] = sparse.random( + M, 100, format=sparse_fmt, random_state=random_state + ) + if has_xr: + obsm["xdataset"] = XDataset.from_dataframe( + gen_typed_df(M, obs_names, dtypes=obs_dtypes) + ) + obsm = {k: v for k, v in obsm.items() if type(v) in obsm_types} + + # generating varm + if is_jax: + key, subkey = jrandom.split(key) + varm_array = jrandom.uniform(subkey, shape=(N, 50)) + rng = (jrandom, key) + else: + varm_array = rng.random((N, 50)) + varm = dict( - array=np.random.random((N, 50)), - sparse=sparse.random(N, 100, format=sparse_fmt, random_state=random_state), + array=varm_array, df=gen_typed_df(N, var_names, dtypes=var_dtypes), awk_2d_ragged=gen_awkward((N, None)), da=da.random.random((N, 50)), ) - if has_xr: - obsm["xdataset"] = XDataset.from_dataframe( - gen_typed_df(M, obs_names, dtypes=obs_dtypes) + if not is_jax: + varm["sparse"] = sparse.random( + N, 100, format=sparse_fmt, random_state=random_state ) + if has_xr: varm["xdataset"] = XDataset.from_dataframe( gen_typed_df(N, var_names, dtypes=var_dtypes) ) + varm = {k: v for k, v in varm.items() if type(v) in varm_types} + + if has_xr: + if XDataset in obsm_types: + obsm["xdataset"] = XDataset.from_dataframe( + gen_typed_df(M, obs_names, dtypes=obs_dtypes) + ) + if XDataset in varm_types: + varm["xdataset"] = XDataset.from_dataframe( + gen_typed_df(N, var_names, dtypes=var_dtypes) + ) obsm = {k: v for k, v in obsm.items() if type(v) in obsm_types} - obsm = maybe_add_sparse_array( - mapping=obsm, - types=obsm_types, - format=sparse_fmt, - random_state=random_state, - shape=(M, 100), - ) + # JAX does not support scipy.sparse, so skip this step to avoid compatibility issues + if not is_jax: + obsm = maybe_add_sparse_array( + mapping=obsm, + types=obsm_types, + format=sparse_fmt, + random_state=random_state, + shape=(M, 100), + ) varm = {k: v for k, v in varm.items() if type(v) in varm_types} - varm = maybe_add_sparse_array( - mapping=varm, - types=varm_types, - format=sparse_fmt, - random_state=random_state, - shape=(N, 100), - ) - layers = dict( - array=np.random.random((M, N)), - sparse=sparse.random(M, N, format=sparse_fmt, random_state=random_state), - da=da.random.random((M, N)), - ) - layers = maybe_add_sparse_array( - mapping=layers, - types=layers_types, - format=sparse_fmt, - random_state=random_state, - shape=(M, N), - ) + if not is_jax: + varm = maybe_add_sparse_array( + mapping=varm, + types=varm_types, + format=sparse_fmt, + random_state=random_state, + shape=(N, 100), + ) + + if is_jax: + key, subkey = jrandom.split(key) + layer_array = jrandom.uniform(subkey, shape=(M, N)) + rng = (jrandom, key) + else: + layer_array = rng.random((M, N)) + + layers = dict(array=layer_array, da=da.random.random((M, N))) + if not is_jax: + layers["sparse"] = sparse.random( + M, N, format=sparse_fmt, random_state=random_state + ) + layers = maybe_add_sparse_array( + mapping=layers, + types=layers_types, + format=sparse_fmt, + random_state=random_state, + shape=(M, N), + ) layers = {k: v for k, v in layers.items() if type(v) in layers_types} - obsp = dict( - array=np.random.random((M, M)), - sparse=sparse.random(M, M, format=sparse_fmt, random_state=random_state), - ) - obsp["sparse_array"] = sparse.csr_array( - sparse.random(M, M, format=sparse_fmt, random_state=random_state) - ) - varp = dict( - array=np.random.random((N, N)), - sparse=sparse.random(N, N, format=sparse_fmt, random_state=random_state), - ) - varp["sparse_array"] = sparse.csr_array( - sparse.random(N, N, format=sparse_fmt, random_state=random_state) - ) - uns = dict( - O_recarray=gen_vstr_recarray(N, 5), - nested=dict( - scalar_str="str", - scalar_int=42, - scalar_float=3.0, - nested_further=dict(array=np.arange(5)), - ), - awkward_regular=gen_awkward((10, 5)), - awkward_ragged=gen_awkward((12, None, None)), - # U_recarray=gen_vstr_recarray(N, 5, "U4") - ) - # https://github.com/zarr-developers/zarr-python/issues/2134 - # zarr v3 on-disk does not write structured dtypes - if anndata.settings.zarr_write_format == 3: - del uns["O_recarray"] + + # skiping obsp/vasp and uns if using JAX + obsp, varp, uns = {}, {}, {} + if not is_jax: + obsp = dict( + array=np.random.random((M, M)), + sparse=sparse.random(M, M, format=sparse_fmt, random_state=random_state), + ) + obsp["sparse_array"] = sparse.csr_array( + sparse.random(M, M, format=sparse_fmt, random_state=random_state) + ) + varp = dict( + array=np.random.random((N, N)), + sparse=sparse.random(N, N, format=sparse_fmt, random_state=random_state), + ) + varp["sparse_array"] = sparse.csr_array( + sparse.random(N, N, format=sparse_fmt, random_state=random_state) + ) + uns = dict( + O_recarray=gen_vstr_recarray(N, 5), + nested=dict( + scalar_str="str", + scalar_int=42, + scalar_float=3.0, + nested_further=dict(array=np.arange(5)), + ), + awkward_regular=gen_awkward((10, 5)), + awkward_ragged=gen_awkward((12, None, None)), + # U_recarray=gen_vstr_recarray(N, 5, "U4") + ) + # https://github.com/zarr-developers/zarr-python/issues/2134 + # zarr v3 on-disk does not write structured dtypes + if anndata.settings.zarr_write_format == 3: + del uns["O_recarray"] + # anndata constuction with warnings.catch_warnings(): warnings.simplefilter("ignore", ExperimentalFeatureWarning) adata = AnnData( From f5c387c514c2b9045bc1186acf21eb10040d8929 Mon Sep 17 00:00:00 2001 From: amalia-k510 Date: Mon, 7 Jul 2025 10:50:42 +0200 Subject: [PATCH 6/6] tests for gen_adata in helpers to make sure it works plus bugs fix and reverting some files to the original --- src/anndata/_core/index.py | 123 ++++------- src/anndata/_core/index_modified.py | 307 +++++++++++++++++++++++++++ src/anndata/tests/helpers.py | 59 +---- tests/test_backend_agnostic_utils.py | 154 ++++++++++---- 4 files changed, 465 insertions(+), 178 deletions(-) create mode 100644 src/anndata/_core/index_modified.py diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index be2ea696d..5ed271add 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -3,12 +3,11 @@ from collections.abc import Iterable, Sequence from functools import singledispatch from itertools import repeat -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import h5py import numpy as np import pandas as pd -from array_api_compat import get_namespace as array_api_get_namespace from scipy.sparse import issparse from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray @@ -18,40 +17,6 @@ from ..compat import Index, Index1D -def get_xp(x): - # to fall back to numpy if no array API is available - try: - return array_api_get_namespace(x) - except Exception: - return np - - -def is_array_api_obj(x): - # check if object supports _array_namespace__, else fall back to numpy - return hasattr(x, "__array_namespace__") - - -def get_numeric_dtypes(xp): - return ( - xp.dtype(xp.int32), - xp.dtype(xp.int64), - xp.dtype(xp.float32), - xp.dtype(xp.float64), - ) - - -def get_integer_dtypes(xp): - return (xp.dtype(xp.int32), xp.dtype(xp.int64)) - - -def get_floating_dtypes(xp): - return (xp.dtype(xp.float32), xp.dtype(xp.float64)) - - -def get_boolean_dtype(xp): - return xp.dtype(xp.bool_) - - def _normalize_indices( index: Index | None, names0: pd.Index, names1: pd.Index ) -> tuple[slice, slice]: @@ -71,17 +36,17 @@ def _normalize_indices( def _normalize_index( # noqa: PLR0911, PLR0912 - indexer, + indexer: slice + | np.integer + | int + | str + | Sequence[bool | int | np.integer] + | np.ndarray + | pd.Index, index: pd.Index, -) -> ( - slice | int | Any -): # ndarray of int or bool, switched to Any to make it compatible with array API objects +) -> slice | int | np.ndarray: # ndarray of int or bool # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. - xp = get_xp(indexer) - if not isinstance(index, pd.RangeIndex) and index.dtype in ( - xp.float64, - xp.int64, - ): + if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64): msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}" raise TypeError(msg) @@ -100,37 +65,34 @@ def name_idx(i): stop = None if stop is None else stop + 1 step = indexer.step return slice(start, stop, step) - if isinstance(indexer, (int,)) or ( - is_array_api_obj(indexer) and isinstance(indexer, xp.integer) - ): + elif isinstance(indexer, np.integer | int): return indexer elif isinstance(indexer, str): return index.get_loc(indexer) # int elif isinstance( - indexer, (Sequence, pd.Index, CSMatrix, CSArray) - ) or is_array_api_obj(indexer): + indexer, Sequence | np.ndarray | pd.Index | CSMatrix | np.matrix | CSArray + ): if hasattr(indexer, "shape") and ( (indexer.shape == (index.shape[0], 1)) or (indexer.shape == (1, index.shape[0])) ): if isinstance(indexer, CSMatrix | CSArray): indexer = indexer.toarray() - indexer = xp.ravel(indexer) - if not isinstance(indexer, (pd.Index,)) and not is_array_api_obj(indexer): - indexer = xp.array(indexer) + indexer = np.ravel(indexer) + if not isinstance(indexer, np.ndarray | pd.Index): + indexer = np.array(indexer) if len(indexer) == 0: indexer = indexer.astype(int) - - if get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).floating): + if isinstance(indexer, np.ndarray) and np.issubdtype( + indexer.dtype, np.floating + ): indexer_int = indexer.astype(int) - if xp.all((indexer - indexer_int) != 0): + if np.all((indexer - indexer_int) != 0): msg = f"Indexer {indexer!r} has floating point values." raise IndexError(msg) - if get_xp(indexer).issubdtype( - indexer.dtype, get_xp(indexer).integer | get_xp(indexer).floating - ): + if issubclass(indexer.dtype.type, np.integer | np.floating): return indexer # Might not work for range indexes - elif get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).bool_): + elif issubclass(indexer.dtype.type, np.bool_): if indexer.shape != index.shape: msg = ( f"Boolean index does not match AnnData’s shape along this " @@ -141,7 +103,7 @@ def name_idx(i): return indexer else: # indexer should be string array positions = index.get_indexer(indexer) - if get_xp(positions).any(positions < 0): + if np.any(positions < 0): not_found = indexer[positions < 0] msg = ( f"Values {list(not_found)}, from {list(indexer)}, " @@ -206,14 +168,11 @@ def unpack_index(index: Index) -> tuple[Index1D, Index1D]: @singledispatch -def _subset(a, subset_idx: Index): +def _subset(a: np.ndarray | pd.DataFrame, subset_idx: Index): # Select as combination of indexes, not coordinates # Correcting for indexing behaviour of np.ndarray - xp = get_xp(a) - if all( - isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx - ): - subset_idx = xp.ix_(*subset_idx) + if all(isinstance(x, Iterable) for x in subset_idx): + subset_idx = np.ix_(*subset_idx) return a[subset_idx] @@ -230,13 +189,10 @@ def _subset_dask(a: DaskArray, subset_idx: Index): @_subset.register(CSArray) def _subset_sparse(a: CSMatrix | CSArray, subset_idx: Index): # Correcting for indexing behaviour of sparse.spmatrix - xp = get_xp(a) - if len(subset_idx) > 1 and all( - isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx - ): + if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx): first_idx = subset_idx[0] - if hasattr(first_idx, "dtype") and first_idx.dtype == bool: - first_idx = xp.where(first_idx)[0] + if issubclass(first_idx.dtype.type, np.bool_): + first_idx = np.where(first_idx)[0] subset_idx = (first_idx.reshape(-1, 1), *subset_idx[1:]) return a[subset_idx] @@ -249,11 +205,8 @@ def _subset_df(df: pd.DataFrame | Dataset2D, subset_idx: Index): @_subset.register(AwkArray) def _subset_awkarray(a: AwkArray, subset_idx: Index): - xp = get_xp(a) - if all( - isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx - ): - subset_idx = xp.ix_(*subset_idx) + if all(isinstance(x, Iterable) for x in subset_idx): + subset_idx = np.ix_(*subset_idx) return a[subset_idx] @@ -262,15 +215,15 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index): def _subset_dataset(d, subset_idx): if not isinstance(subset_idx, tuple): subset_idx = (subset_idx,) - xp = get_xp(subset_idx[0]) ordered = list(subset_idx) rev_order = [slice(None) for _ in range(len(subset_idx))] for axis, axis_idx in enumerate(ordered.copy()): - if hasattr(axis_idx, "dtype") and axis_idx.dtype == bool: - axis_idx = xp.where(axis_idx)[0] - order = xp.argsort(axis_idx) - ordered[axis] = axis_idx[order] - rev_order[axis] = xp.argsort(order) + if isinstance(axis_idx, np.ndarray): + if axis_idx.dtype == bool: + axis_idx = np.where(axis_idx)[0] + order = np.argsort(axis_idx) + ordered[axis] = axis_idx[order] + rev_order[axis] = np.argsort(order) # from hdf5, then to real order return d[tuple(ordered)][tuple(rev_order)] @@ -304,4 +257,4 @@ def get_vector(adata, k, coldim, idxdim, layer=None): a = adata._get_X(layer=layer)[idx] if issparse(a): a = a.toarray() - return get_xp(a).ravel(a) + return np.ravel(a) diff --git a/src/anndata/_core/index_modified.py b/src/anndata/_core/index_modified.py new file mode 100644 index 000000000..be2ea696d --- /dev/null +++ b/src/anndata/_core/index_modified.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from functools import singledispatch +from itertools import repeat +from typing import TYPE_CHECKING, Any + +import h5py +import numpy as np +import pandas as pd +from array_api_compat import get_namespace as array_api_get_namespace +from scipy.sparse import issparse + +from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray +from .xarray import Dataset2D + +if TYPE_CHECKING: + from ..compat import Index, Index1D + + +def get_xp(x): + # to fall back to numpy if no array API is available + try: + return array_api_get_namespace(x) + except Exception: + return np + + +def is_array_api_obj(x): + # check if object supports _array_namespace__, else fall back to numpy + return hasattr(x, "__array_namespace__") + + +def get_numeric_dtypes(xp): + return ( + xp.dtype(xp.int32), + xp.dtype(xp.int64), + xp.dtype(xp.float32), + xp.dtype(xp.float64), + ) + + +def get_integer_dtypes(xp): + return (xp.dtype(xp.int32), xp.dtype(xp.int64)) + + +def get_floating_dtypes(xp): + return (xp.dtype(xp.float32), xp.dtype(xp.float64)) + + +def get_boolean_dtype(xp): + return xp.dtype(xp.bool_) + + +def _normalize_indices( + index: Index | None, names0: pd.Index, names1: pd.Index +) -> tuple[slice, slice]: + # deal with tuples of length 1 + if isinstance(index, tuple) and len(index) == 1: + index = index[0] + # deal with pd.Series + if isinstance(index, pd.Series): + index: Index = index.values + if isinstance(index, tuple): + # TODO: The series should probably be aligned first + index = tuple(i.values if isinstance(i, pd.Series) else i for i in index) + ax0, ax1 = unpack_index(index) + ax0 = _normalize_index(ax0, names0) + ax1 = _normalize_index(ax1, names1) + return ax0, ax1 + + +def _normalize_index( # noqa: PLR0911, PLR0912 + indexer, + index: pd.Index, +) -> ( + slice | int | Any +): # ndarray of int or bool, switched to Any to make it compatible with array API objects + # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. + xp = get_xp(indexer) + if not isinstance(index, pd.RangeIndex) and index.dtype in ( + xp.float64, + xp.int64, + ): + msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}" + raise TypeError(msg) + + # the following is insanely slow for sequences, + # we replaced it using pandas below + def name_idx(i): + if isinstance(i, str): + i = index.get_loc(i) + return i + + if isinstance(indexer, slice): + start = name_idx(indexer.start) + stop = name_idx(indexer.stop) + # string slices can only be inclusive, so +1 in that case + if isinstance(indexer.stop, str): + stop = None if stop is None else stop + 1 + step = indexer.step + return slice(start, stop, step) + if isinstance(indexer, (int,)) or ( + is_array_api_obj(indexer) and isinstance(indexer, xp.integer) + ): + return indexer + elif isinstance(indexer, str): + return index.get_loc(indexer) # int + elif isinstance( + indexer, (Sequence, pd.Index, CSMatrix, CSArray) + ) or is_array_api_obj(indexer): + if hasattr(indexer, "shape") and ( + (indexer.shape == (index.shape[0], 1)) + or (indexer.shape == (1, index.shape[0])) + ): + if isinstance(indexer, CSMatrix | CSArray): + indexer = indexer.toarray() + indexer = xp.ravel(indexer) + if not isinstance(indexer, (pd.Index,)) and not is_array_api_obj(indexer): + indexer = xp.array(indexer) + if len(indexer) == 0: + indexer = indexer.astype(int) + + if get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).floating): + indexer_int = indexer.astype(int) + if xp.all((indexer - indexer_int) != 0): + msg = f"Indexer {indexer!r} has floating point values." + raise IndexError(msg) + if get_xp(indexer).issubdtype( + indexer.dtype, get_xp(indexer).integer | get_xp(indexer).floating + ): + return indexer # Might not work for range indexes + elif get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).bool_): + if indexer.shape != index.shape: + msg = ( + f"Boolean index does not match AnnData’s shape along this " + f"dimension. Boolean index has shape {indexer.shape} while " + f"AnnData index has shape {index.shape}." + ) + raise IndexError(msg) + return indexer + else: # indexer should be string array + positions = index.get_indexer(indexer) + if get_xp(positions).any(positions < 0): + not_found = indexer[positions < 0] + msg = ( + f"Values {list(not_found)}, from {list(indexer)}, " + "are not valid obs/ var names or indices." + ) + raise KeyError(msg) + return positions # np.ndarray[int] + elif isinstance(indexer, XDataArray): + if isinstance(indexer.data, DaskArray): + return indexer.data.compute() + return indexer.data + msg = f"Unknown indexer {indexer!r} of type {type(indexer)}" + raise IndexError() + + +def _fix_slice_bounds(s: slice, length: int) -> slice: + """The slice will be clipped to length, and the step won't be None. + + E.g. infer None valued attributes. + """ + step = s.step if s.step is not None else 1 + + # slice constructor would have errored if step was 0 + if step > 0: + start = s.start if s.start is not None else 0 + stop = s.stop if s.stop is not None else length + elif step < 0: + # Reverse + start = s.start if s.start is not None else length + stop = s.stop if s.stop is not None else 0 + + return slice(start, stop, step) + + +def unpack_index(index: Index) -> tuple[Index1D, Index1D]: + if not isinstance(index, tuple): + if index is Ellipsis: + index = slice(None) + return index, slice(None) + num_ellipsis = sum(i is Ellipsis for i in index) + if num_ellipsis > 1: + msg = "an index can only have a single ellipsis ('...')" + raise IndexError(msg) + # If index has Ellipsis, filter it out (and if not, error) + if len(index) > 2: + if not num_ellipsis: + msg = "Received a length 3 index without an ellipsis" + raise IndexError(msg) + index = tuple(i for i in index if i is not Ellipsis) + return index + # If index has Ellipsis, replace it with slice + if len(index) == 2: + index = tuple(slice(None) if i is Ellipsis else i for i in index) + return index + if len(index) == 1: + index = index[0] + if index is Ellipsis: + index = slice(None) + return index, slice(None) + msg = "invalid number of indices" + raise IndexError(msg) + + +@singledispatch +def _subset(a, subset_idx: Index): + # Select as combination of indexes, not coordinates + # Correcting for indexing behaviour of np.ndarray + xp = get_xp(a) + if all( + isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx + ): + subset_idx = xp.ix_(*subset_idx) + return a[subset_idx] + + +@_subset.register(DaskArray) +def _subset_dask(a: DaskArray, subset_idx: Index): + if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx): + if issparse(a._meta) and a._meta.format == "csc": + return a[:, subset_idx[1]][subset_idx[0], :] + return a[subset_idx[0], :][:, subset_idx[1]] + return a[subset_idx] + + +@_subset.register(CSMatrix) +@_subset.register(CSArray) +def _subset_sparse(a: CSMatrix | CSArray, subset_idx: Index): + # Correcting for indexing behaviour of sparse.spmatrix + xp = get_xp(a) + if len(subset_idx) > 1 and all( + isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx + ): + first_idx = subset_idx[0] + if hasattr(first_idx, "dtype") and first_idx.dtype == bool: + first_idx = xp.where(first_idx)[0] + subset_idx = (first_idx.reshape(-1, 1), *subset_idx[1:]) + return a[subset_idx] + + +@_subset.register(pd.DataFrame) +@_subset.register(Dataset2D) +def _subset_df(df: pd.DataFrame | Dataset2D, subset_idx: Index): + return df.iloc[subset_idx] + + +@_subset.register(AwkArray) +def _subset_awkarray(a: AwkArray, subset_idx: Index): + xp = get_xp(a) + if all( + isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx + ): + subset_idx = xp.ix_(*subset_idx) + return a[subset_idx] + + +# Registration for SparseDataset occurs in sparse_dataset.py +@_subset.register(h5py.Dataset) +def _subset_dataset(d, subset_idx): + if not isinstance(subset_idx, tuple): + subset_idx = (subset_idx,) + xp = get_xp(subset_idx[0]) + ordered = list(subset_idx) + rev_order = [slice(None) for _ in range(len(subset_idx))] + for axis, axis_idx in enumerate(ordered.copy()): + if hasattr(axis_idx, "dtype") and axis_idx.dtype == bool: + axis_idx = xp.where(axis_idx)[0] + order = xp.argsort(axis_idx) + ordered[axis] = axis_idx[order] + rev_order[axis] = xp.argsort(order) + # from hdf5, then to real order + return d[tuple(ordered)][tuple(rev_order)] + + +def make_slice(idx, dimidx, n=2): + mut = list(repeat(slice(None), n)) + mut[dimidx] = idx + return tuple(mut) + + +def get_vector(adata, k, coldim, idxdim, layer=None): + # adata could be self if Raw and AnnData shared a parent + dims = ("obs", "var") + col = getattr(adata, coldim).columns + idx = getattr(adata, f"{idxdim}_names") + + in_col = k in col + in_idx = k in idx + + if (in_col + in_idx) == 2: + msg = f"Key {k} could be found in both .{idxdim}_names and .{coldim}.columns" + raise ValueError(msg) + elif (in_col + in_idx) == 0: + msg = f"Could not find key {k} in .{idxdim}_names or .{coldim}.columns." + raise KeyError(msg) + elif in_col: + return getattr(adata, coldim)[k].values + elif in_idx: + selected_dim = dims.index(idxdim) + idx = adata._normalize_indices(make_slice(k, selected_dim)) + a = adata._get_X(layer=layer)[idx] + if issparse(a): + a = a.toarray() + return get_xp(a).ravel(a) diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 58b33d3d8..233d0b20b 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -131,6 +131,8 @@ def get_xp(x): return np +# Only use structured arrays like gen_vstr_recarray under NumPy. +# These are not supported under the Array API (e.g. JAX, CuPy). def gen_vstr_recarray(m, n, dtype=None): size = m * n lengths = np.random.randint(3, 5, size) @@ -142,54 +144,6 @@ def gen_vstr_recarray(m, n, dtype=None): ) -# def gen_vstr_recarray( -# m: int, -# n: int, -# dtype=None, -# *, -# xp=np, # Backend namespace (e.g., np or jax.numpy) -# rng=None, # NumPy: np.random.Generator | JAX: (jax.random, key) -# ): -# """ -# Supports both NumPy and JAX via `xp`. Returns a NumPy-based recarray for compatibility -# with AnnData.uns (need to double check) -# """ - -# size = m * n -# letters = xp.array(list(ascii_letters)) -# if xp.__name__.startswith("jax"): -# jrandom, key = rng -# key, subkey = jrandom.split(key) -# lengths = jrandom.randint(subkey, shape=(size,), minval=3, maxval=5) -# lengths = xp.asarray(lengths) -# else: -# if rng is None: -# rng = np.random.default_rng() -# lengths = rng.integers(3, 5, size) - -# # generating worlds in different ways depending on the backend -# def gen_word_jax(l: int, key, jrandom, letters): -# key, subkey = jrandom.split(key) -# idxs = jrandom.randint(subkey, shape=(l,), minval=0, maxval=len(letters)) -# return "".join([ascii_letters[int(i)] for i in idxs]), key - -# def gen_word_numpy(l: int, rng, letters): -# return "".join(rng.choice(letters, l)) - -# words = [] -# if xp.__name__.startswith("jax"): -# for l in lengths: -# word, key = gen_word_jax(int(l), key, jrandom, letters) -# words.append(word) -# else: -# words.extend([gen_word_numpy(l, rng, letters) for l in lengths]) - -# arr = np.array(words).reshape(m, n) # Recarray must be NumPy-compatible -# columns = [f"col_{i}" for i in range(n)] -# df = pd.DataFrame(arr, columns=columns) -# return df.to_records(index=False, column_dtypes=dtype) - - def issubdtype( a: np.dtype | pd.api.extensions.ExtensionDtype | type, b: type[DT] | tuple[type[DT], ...], @@ -403,7 +357,7 @@ def gen_adata( # noqa: PLR0913 """ import dask.array as da - from anndata.experimental import gen_vstr_recarray + # from anndata.experimental import gen_vstr_recarray if random_state is None: random_state = np.random.default_rng() @@ -414,7 +368,10 @@ def gen_adata( # noqa: PLR0913 # initialize backends if xp is None: - xp = get_xp(X_type(xp.ones((1, 1), dtype=X_dtype))) + try: + xp = get_xp(X_type(xp.ones((1, 1), dtype=X_dtype))) + except Exception: + xp = np # default to numpy if X_type is not compatible if rng is None: if xp.__name__.startswith("jax"): rng = (jax.random, jax.random.PRNGKey(42)) @@ -443,7 +400,7 @@ def gen_adata( # noqa: PLR0913 # generating X and including jax support if X_type is None: X = None # if no data matrix is requested, skip creation - elif is_jax and issubclass(X_type, sparse.spmatrix): + elif is_jax and isinstance(X_type, type) and issubclass(X_type, sparse.spmatrix): # JAX does not support scipy sparse matrices (double check) msg = "JAX does not support sparse matrices" raise ValueError(msg) diff --git a/tests/test_backend_agnostic_utils.py b/tests/test_backend_agnostic_utils.py index 12b3be7a4..b8995466d 100644 --- a/tests/test_backend_agnostic_utils.py +++ b/tests/test_backend_agnostic_utils.py @@ -1,46 +1,116 @@ +# from __future__ import annotations + +# import jax +# import numpy as np +# import pytest +# from benchmarks.tests.helpers import gen_adata, gen_indexer + +# from anndata import AnnData + + +# @pytest.mark.parametrize("backend", ["numpy", "jax"]) +# def test_gen_adata_and_indexing(backend): +# # Generate AnnData using backend +# if backend == "numpy": +# pass # default backend used by gen_adata +# elif backend == "jax": +# jnp = jax.numpy +# _ = jnp.ones((1,)) +# else: +# raise ValueError(f"Unsupported backend: {backend}") + +# adata: AnnData = gen_adata(100, 50, {"X-dense", "obs,var"}, seed=42) + +# # Check structure +# assert adata.shape == (100, 50) +# assert "a" in adata.obs.columns +# assert "a" in adata.var.columns + +# # Test each index kind +# for kind in ["slice", "intarray", "boolarray", "strarray"]: +# subset = gen_indexer(adata, "obs", kind, 0.3, seed=123) +# assert isinstance(subset, tuple) +# assert len(subset) == 2 + +# index = subset[0] +# if kind == "slice": +# assert isinstance(index, slice) +# elif kind == "intarray": +# assert hasattr(index, "shape") +# assert 0 < index.shape[0] <= 100 +# elif kind == "boolarray": +# assert index.shape == (100,) +# assert index.dtype == bool +# elif kind == "strarray": +# assert isinstance(index, (list, np.ndarray)) +# assert all(isinstance(i, str) for i in index) + from __future__ import annotations -import jax import numpy as np import pytest -from benchmarks.benchmarks.utils import gen_adata, gen_indexer - -from anndata import AnnData - - -@pytest.mark.parametrize("backend", ["numpy", "jax"]) -def test_gen_adata_and_indexing(backend): - # Generate AnnData using backend - if backend == "numpy": - pass # default backend used by gen_adata - elif backend == "jax": - jnp = jax.numpy - _ = jnp.ones((1,)) - else: - raise ValueError(f"Unsupported backend: {backend}") - - adata: AnnData = gen_adata(100, 50, {"X-dense", "obs,var"}, seed=42) - - # Check structure - assert adata.shape == (100, 50) - assert "a" in adata.obs.columns - assert "a" in adata.var.columns - - # Test each index kind - for kind in ["slice", "intarray", "boolarray", "strarray"]: - subset = gen_indexer(adata, "obs", kind, 0.3, seed=123) - assert isinstance(subset, tuple) - assert len(subset) == 2 - - index = subset[0] - if kind == "slice": - assert isinstance(index, slice) - elif kind == "intarray": - assert hasattr(index, "shape") - assert 0 < index.shape[0] <= 100 - elif kind == "boolarray": - assert index.shape == (100,) - assert index.dtype == bool - elif kind == "strarray": - assert isinstance(index, (list, np.ndarray)) - assert all(isinstance(i, str) for i in index) +from scipy import sparse + +from anndata._core.anndata import AnnData +from anndata.tests.helpers import gen_adata + +# Try to import JAX if available +# flagging it as a separate import to avoid issues if JAX is not installed +try: + import jax + import jax.numpy as jnp + + jax_available = True +except ImportError: + jax_available = False + + +# testing gen_adata with NumPy backend and various X types (dense, CSR, CSC) to ensure correct shapes and AnnData validity +@pytest.mark.parametrize("X_type", [np.array, sparse.csr_matrix, sparse.csc_matrix]) +def test_gen_adata_numpy_backends(X_type): + adata = gen_adata( + shape=(20, 30), + X_type=X_type, + X_dtype=np.float32, + xp=np, + rng=np.random.default_rng(0), + ) + assert isinstance(adata, AnnData) + assert adata.X.shape == (20, 30) + assert adata.obs.shape[0] == 20 + assert adata.var.shape[0] == 30 + + +# testing that gen_adata works with JAX backend when X is omitted +@pytest.mark.skipif(not jax_available, reason="JAX is not available") +def test_gen_adata_jax_backend_no_X(): + adata = gen_adata( + shape=(20, 30), + X_type=None, # skip X + xp=jax.numpy, + rng=(jax.random, jax.random.PRNGKey(0)), + ) + assert isinstance(adata, AnnData) + assert adata.X is None + assert adata.obs.shape[0] == 20 + assert adata.var.shape[0] == 30 + + +# checking if function correctly returns an Anndata with X as None +def test_gen_adata_X_none(): + adata = gen_adata(shape=(10, 10), X_type=None) + assert isinstance(adata, AnnData) + assert adata.X is None + + +# testing that passing a sparse matrix format to JAX correctly raises a ValueError +def test_gen_adata_invalid_jax_sparse(): + if not jax_available: + pytest.skip("JAX not available") + with pytest.raises(ValueError, match="JAX does not support sparse matrices"): + gen_adata( + shape=(5, 5), + X_type=sparse.csr_matrix, + xp=jax.numpy, + rng=(jax.random, jax.random.PRNGKey(0)), + )