diff --git a/benchmarks/benchmarks/utils.py b/benchmarks/benchmarks/utils.py index 9e983498f..352388570 100644 --- a/benchmarks/benchmarks/utils.py +++ b/benchmarks/benchmarks/utils.py @@ -113,3 +113,161 @@ def gen_adata(n_obs, n_var, attr_set): index=[f"gene{i}" for i in range(n_var)], ) return adata + + +# from __future__ import annotations + +# import gc +# import sys +# from string import ascii_lowercase +# from time import sleep + +# import jax +# import numpy as np +# import pandas as pd +# from array_api_compat import get_namespace as array_api_get_namespace +# from memory_profiler import memory_usage +# from scipy import sparse + +# from anndata import AnnData + + +# def get_namespace(x=None): +# return array_api_get_namespace(x) + + +# def get_rng(xp, seed=None): +# """Return a backend-specific random number generator.""" +# # RNG isn't standardized in the Array API spec, +# # so backends like JAX, PyTorch, and NumPy each handle randomness differently. +# if xp.__name__.startswith("jax"): +# return jax.random.PRNGKey(seed or 0) +# elif xp.__name__.startswith("numpy"): +# return np.random.default_rng(seed) +# else: +# raise NotImplementedError(f"RNG not implemented for backend: {xp.__name__}") + + +# def get_actualsize(input_obj): +# """Using Python Garbage Collector to calculate the size of all elements attached to an object""" + +# memory_size = 0 +# ids = set() +# objects = [input_obj] +# while objects: +# new = [] +# for obj in objects: +# if id(obj) not in ids: +# ids.add(id(obj)) +# memory_size += sys.getsizeof(obj) +# new.append(obj) +# objects = gc.get_referents(*new) +# return memory_size + + +# def get_anndata_memsize(adata): +# recording = memory_usage( +# (sedate(adata.copy, naplength=0.005), (adata,)), interval=0.001 +# ) +# diff = recording[-1] - recording[0] +# return diff + + +# def get_peak_mem(op, interval=0.001): +# recording = memory_usage(op, interval=interval) +# xp = get_namespace() +# return xp.max(recording) - xp.min(recording) + + +# def sedate(func, naplength=0.05): +# """Make a function sleepy, so we can sample the start and end state.""" + +# def wrapped_function(*args, **kwargs): +# sleep(naplength) +# val = func(*args, **kwargs) +# sleep(naplength) +# return val + +# return wrapped_function + + +# # TODO: Factor out the time it takes to generate these + + +# def gen_indexer(adata, dim, index_kind, ratio, seed=None): +# dimnames = ("obs", "var") +# index_kinds = {"slice", "intarray", "boolarray", "strarray"} + +# if index_kind not in index_kinds: +# msg = f"Argument 'index_kind' must be one of {index_kinds}. Was {index_kind}." +# raise ValueError(msg) + +# xp = get_namespace(adata.X) +# rng = get_rng(xp, seed) +# axis = dimnames.index(dim) +# subset = [slice(None), slice(None)] +# axis_size = adata.shape[axis] +# n = int(xp.round(axis_size * ratio)) + +# if index_kind == "slice": +# subset[axis] = slice(0, n) +# elif index_kind == "intarray": +# if xp.__name__.startswith("jax"): +# subset[axis] = jax.random.choice( +# rng, xp.arange(axis_size), shape=(n,), replace=False +# ) +# elif xp.__name__.startswith("numpy"): +# subset[axis] = xp.asarray(rng.choice(axis_size, n, replace=False)) + +# elif index_kind == "boolarray": +# mask = xp.zeros(axis_size, dtype=bool) +# if xp.__name__.startswith("jax"): +# idx = jax.random.choice( +# rng, xp.arange(axis_size), shape=(n,), replace=False +# ) +# elif xp.__name__.startswith("numpy"): +# idx = rng.choice(axis_size, n, replace=False) +# mask[idx] = True +# subset[axis] = mask + +# elif index_kind == "strarray": +# subset[axis] = rng.choice(getattr(adata, dim).index, n, replace=False) +# else: +# raise ValueError() +# return tuple(subset) + + +# def gen_adata(n_obs, n_var, attr_set, seed=None): +# if "X-csr" in attr_set: +# X_sparse = sparse.random(n_obs, n_var, density=0.1, format="csr") +# xp = get_namespace(X_sparse.toarray()) +# X = X_sparse +# elif "X-dense" in attr_set: +# dense_X = sparse.random(n_obs, n_var, density=0.1, format="csr") +# xp = get_namespace(dense_X) +# X = xp.asarray(dense_X) +# else: +# # TODO: There's probably a better way to do this +# # fallback to use just numpy +# import numpy as np + +# X_dense = np.zeros((n_obs, n_var)) +# # X = sparse.random(n_obs, n_var, density=0, format="csr") +# xp = get_namespace(X_dense) +# X = xp.asarray(X_dense) +# rng = get_rng(xp, seed) +# adata = AnnData(X) +# if "obs,var" in attr_set: +# if xp.__name__.startswith("jax"): +# obs = { +# k: jax.random.randint(rng, (n_obs,), 0, 100) for k in ascii_lowercase +# } +# var = { +# k: jax.random.randint(rng, (n_var,), 0, 100) for k in ascii_lowercase +# } +# elif xp.__name__.startswith("numpy"): +# obs = {k: rng.integers(0, 100, size=n_obs) for k in ascii_lowercase} +# var = {k: rng.integers(0, 100, size=n_var) for k in ascii_lowercase} +# adata.obs = pd.DataFrame(obs, index=[f"cell{i}" for i in range(n_obs)]) +# adata.var = pd.DataFrame(var, index=[f"gene{i}" for i in range(n_var)]) +# return adata diff --git a/pyproject.toml b/pyproject.toml index 64ddb63cc..1085350f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,8 @@ dev = [ # runtime dev version generation "hatch-vcs", "anndata[dev-doc]", + "jax", + "jaxlib", ] doc = [ "sphinx>=8.2.1", @@ -105,6 +107,8 @@ test-min = [ test = [ "anndata[test-min,lazy]", "pandas>=2.1.0", + "jax", + "jaxlib", ] # pandas 2.1.0 needs to be specified for xarray to work with min-deps script gpu = [ "cupy" ] cu12 = [ "cupy-cuda12x" ] diff --git a/src/anndata/_core/index_modified.py b/src/anndata/_core/index_modified.py new file mode 100644 index 000000000..be2ea696d --- /dev/null +++ b/src/anndata/_core/index_modified.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from functools import singledispatch +from itertools import repeat +from typing import TYPE_CHECKING, Any + +import h5py +import numpy as np +import pandas as pd +from array_api_compat import get_namespace as array_api_get_namespace +from scipy.sparse import issparse + +from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray +from .xarray import Dataset2D + +if TYPE_CHECKING: + from ..compat import Index, Index1D + + +def get_xp(x): + # to fall back to numpy if no array API is available + try: + return array_api_get_namespace(x) + except Exception: + return np + + +def is_array_api_obj(x): + # check if object supports _array_namespace__, else fall back to numpy + return hasattr(x, "__array_namespace__") + + +def get_numeric_dtypes(xp): + return ( + xp.dtype(xp.int32), + xp.dtype(xp.int64), + xp.dtype(xp.float32), + xp.dtype(xp.float64), + ) + + +def get_integer_dtypes(xp): + return (xp.dtype(xp.int32), xp.dtype(xp.int64)) + + +def get_floating_dtypes(xp): + return (xp.dtype(xp.float32), xp.dtype(xp.float64)) + + +def get_boolean_dtype(xp): + return xp.dtype(xp.bool_) + + +def _normalize_indices( + index: Index | None, names0: pd.Index, names1: pd.Index +) -> tuple[slice, slice]: + # deal with tuples of length 1 + if isinstance(index, tuple) and len(index) == 1: + index = index[0] + # deal with pd.Series + if isinstance(index, pd.Series): + index: Index = index.values + if isinstance(index, tuple): + # TODO: The series should probably be aligned first + index = tuple(i.values if isinstance(i, pd.Series) else i for i in index) + ax0, ax1 = unpack_index(index) + ax0 = _normalize_index(ax0, names0) + ax1 = _normalize_index(ax1, names1) + return ax0, ax1 + + +def _normalize_index( # noqa: PLR0911, PLR0912 + indexer, + index: pd.Index, +) -> ( + slice | int | Any +): # ndarray of int or bool, switched to Any to make it compatible with array API objects + # TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough. + xp = get_xp(indexer) + if not isinstance(index, pd.RangeIndex) and index.dtype in ( + xp.float64, + xp.int64, + ): + msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}" + raise TypeError(msg) + + # the following is insanely slow for sequences, + # we replaced it using pandas below + def name_idx(i): + if isinstance(i, str): + i = index.get_loc(i) + return i + + if isinstance(indexer, slice): + start = name_idx(indexer.start) + stop = name_idx(indexer.stop) + # string slices can only be inclusive, so +1 in that case + if isinstance(indexer.stop, str): + stop = None if stop is None else stop + 1 + step = indexer.step + return slice(start, stop, step) + if isinstance(indexer, (int,)) or ( + is_array_api_obj(indexer) and isinstance(indexer, xp.integer) + ): + return indexer + elif isinstance(indexer, str): + return index.get_loc(indexer) # int + elif isinstance( + indexer, (Sequence, pd.Index, CSMatrix, CSArray) + ) or is_array_api_obj(indexer): + if hasattr(indexer, "shape") and ( + (indexer.shape == (index.shape[0], 1)) + or (indexer.shape == (1, index.shape[0])) + ): + if isinstance(indexer, CSMatrix | CSArray): + indexer = indexer.toarray() + indexer = xp.ravel(indexer) + if not isinstance(indexer, (pd.Index,)) and not is_array_api_obj(indexer): + indexer = xp.array(indexer) + if len(indexer) == 0: + indexer = indexer.astype(int) + + if get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).floating): + indexer_int = indexer.astype(int) + if xp.all((indexer - indexer_int) != 0): + msg = f"Indexer {indexer!r} has floating point values." + raise IndexError(msg) + if get_xp(indexer).issubdtype( + indexer.dtype, get_xp(indexer).integer | get_xp(indexer).floating + ): + return indexer # Might not work for range indexes + elif get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).bool_): + if indexer.shape != index.shape: + msg = ( + f"Boolean index does not match AnnData’s shape along this " + f"dimension. Boolean index has shape {indexer.shape} while " + f"AnnData index has shape {index.shape}." + ) + raise IndexError(msg) + return indexer + else: # indexer should be string array + positions = index.get_indexer(indexer) + if get_xp(positions).any(positions < 0): + not_found = indexer[positions < 0] + msg = ( + f"Values {list(not_found)}, from {list(indexer)}, " + "are not valid obs/ var names or indices." + ) + raise KeyError(msg) + return positions # np.ndarray[int] + elif isinstance(indexer, XDataArray): + if isinstance(indexer.data, DaskArray): + return indexer.data.compute() + return indexer.data + msg = f"Unknown indexer {indexer!r} of type {type(indexer)}" + raise IndexError() + + +def _fix_slice_bounds(s: slice, length: int) -> slice: + """The slice will be clipped to length, and the step won't be None. + + E.g. infer None valued attributes. + """ + step = s.step if s.step is not None else 1 + + # slice constructor would have errored if step was 0 + if step > 0: + start = s.start if s.start is not None else 0 + stop = s.stop if s.stop is not None else length + elif step < 0: + # Reverse + start = s.start if s.start is not None else length + stop = s.stop if s.stop is not None else 0 + + return slice(start, stop, step) + + +def unpack_index(index: Index) -> tuple[Index1D, Index1D]: + if not isinstance(index, tuple): + if index is Ellipsis: + index = slice(None) + return index, slice(None) + num_ellipsis = sum(i is Ellipsis for i in index) + if num_ellipsis > 1: + msg = "an index can only have a single ellipsis ('...')" + raise IndexError(msg) + # If index has Ellipsis, filter it out (and if not, error) + if len(index) > 2: + if not num_ellipsis: + msg = "Received a length 3 index without an ellipsis" + raise IndexError(msg) + index = tuple(i for i in index if i is not Ellipsis) + return index + # If index has Ellipsis, replace it with slice + if len(index) == 2: + index = tuple(slice(None) if i is Ellipsis else i for i in index) + return index + if len(index) == 1: + index = index[0] + if index is Ellipsis: + index = slice(None) + return index, slice(None) + msg = "invalid number of indices" + raise IndexError(msg) + + +@singledispatch +def _subset(a, subset_idx: Index): + # Select as combination of indexes, not coordinates + # Correcting for indexing behaviour of np.ndarray + xp = get_xp(a) + if all( + isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx + ): + subset_idx = xp.ix_(*subset_idx) + return a[subset_idx] + + +@_subset.register(DaskArray) +def _subset_dask(a: DaskArray, subset_idx: Index): + if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx): + if issparse(a._meta) and a._meta.format == "csc": + return a[:, subset_idx[1]][subset_idx[0], :] + return a[subset_idx[0], :][:, subset_idx[1]] + return a[subset_idx] + + +@_subset.register(CSMatrix) +@_subset.register(CSArray) +def _subset_sparse(a: CSMatrix | CSArray, subset_idx: Index): + # Correcting for indexing behaviour of sparse.spmatrix + xp = get_xp(a) + if len(subset_idx) > 1 and all( + isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx + ): + first_idx = subset_idx[0] + if hasattr(first_idx, "dtype") and first_idx.dtype == bool: + first_idx = xp.where(first_idx)[0] + subset_idx = (first_idx.reshape(-1, 1), *subset_idx[1:]) + return a[subset_idx] + + +@_subset.register(pd.DataFrame) +@_subset.register(Dataset2D) +def _subset_df(df: pd.DataFrame | Dataset2D, subset_idx: Index): + return df.iloc[subset_idx] + + +@_subset.register(AwkArray) +def _subset_awkarray(a: AwkArray, subset_idx: Index): + xp = get_xp(a) + if all( + isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx + ): + subset_idx = xp.ix_(*subset_idx) + return a[subset_idx] + + +# Registration for SparseDataset occurs in sparse_dataset.py +@_subset.register(h5py.Dataset) +def _subset_dataset(d, subset_idx): + if not isinstance(subset_idx, tuple): + subset_idx = (subset_idx,) + xp = get_xp(subset_idx[0]) + ordered = list(subset_idx) + rev_order = [slice(None) for _ in range(len(subset_idx))] + for axis, axis_idx in enumerate(ordered.copy()): + if hasattr(axis_idx, "dtype") and axis_idx.dtype == bool: + axis_idx = xp.where(axis_idx)[0] + order = xp.argsort(axis_idx) + ordered[axis] = axis_idx[order] + rev_order[axis] = xp.argsort(order) + # from hdf5, then to real order + return d[tuple(ordered)][tuple(rev_order)] + + +def make_slice(idx, dimidx, n=2): + mut = list(repeat(slice(None), n)) + mut[dimidx] = idx + return tuple(mut) + + +def get_vector(adata, k, coldim, idxdim, layer=None): + # adata could be self if Raw and AnnData shared a parent + dims = ("obs", "var") + col = getattr(adata, coldim).columns + idx = getattr(adata, f"{idxdim}_names") + + in_col = k in col + in_idx = k in idx + + if (in_col + in_idx) == 2: + msg = f"Key {k} could be found in both .{idxdim}_names and .{coldim}.columns" + raise ValueError(msg) + elif (in_col + in_idx) == 0: + msg = f"Could not find key {k} in .{idxdim}_names or .{coldim}.columns." + raise KeyError(msg) + elif in_col: + return getattr(adata, coldim)[k].values + elif in_idx: + selected_dim = dims.index(idxdim) + idx = adata._normalize_indices(make_slice(k, selected_dim)) + a = adata._get_X(layer=layer)[idx] + if issparse(a): + a = a.toarray() + return get_xp(a).ravel(a) diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 54906ab6b..233d0b20b 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -10,10 +10,20 @@ from string import ascii_letters from typing import TYPE_CHECKING +# trying to import jax as it is not part of the default dependencies +try: + import jax + import jax.numpy as jnp + + JAX_AVAILABLE = True +except ImportError: + JAX_AVAILABLE = False + import h5py import numpy as np import pandas as pd import pytest +from array_api_compat import get_namespace as array_api_get_namespace from pandas.api.types import is_numeric_dtype from scipy import sparse @@ -50,7 +60,6 @@ DT = TypeVar("DT") - try: from pandas.core.arrays.integer import IntegerDtype except ImportError: @@ -66,6 +75,9 @@ pd.DataFrame, sparse.csr_array, ) +# Add JAX array type to supported backends if JAX is installed +if JAX_AVAILABLE: + DEFAULT_KEY_TYPES += (jnp.ndarray,) DEFAULT_COL_TYPES = ( @@ -79,19 +91,48 @@ pd.Int32Dtype, ) - -# Give this to gen_adata when dask array support is expected. +# preset for testing with Dask arrays +# includes DaskArray in obsm/varm/layers to test lazy evaluation and chunked storage +# useful when testing backend="backed"/"memory"/"zarr" GEN_ADATA_DASK_ARGS = dict( obsm_types=(*DEFAULT_KEY_TYPES, DaskArray), varm_types=(*DEFAULT_KEY_TYPES, DaskArray), layers_types=(*DEFAULT_KEY_TYPES, DaskArray), ) +# for testing without xarray (XDataset) types +# excludes XDataset to avoid xarray dependency or reduce test complexity +# useful for minimal tests or when xarray is not available GEN_ADATA_NO_XARRAY_ARGS = dict( - obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), varm_types=(*DEFAULT_KEY_TYPES, AwkArray) + obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), + varm_types=(*DEFAULT_KEY_TYPES, AwkArray), +) + +# Optional: define args for JAX-specific backend tests +# preset for testing with JAX arrays +# useful when running tests using jax.numpy as the backend for X, obsm, varm, or layers. +GEN_ADATA_JAX_ARGS = ( + dict( + obsm_types=(*DEFAULT_KEY_TYPES,), + varm_types=(*DEFAULT_KEY_TYPES,), + layers_types=(*DEFAULT_KEY_TYPES,), + ) + if JAX_AVAILABLE + else {} ) +def get_xp(x): + try: + return array_api_get_namespace(x) + # in case if getting other errors but ImportError + except Exception: # noqa: BLE001 + # default to numpy if array_api_compat is not installed + return np + + +# Only use structured arrays like gen_vstr_recarray under NumPy. +# These are not supported under the Array API (e.g. JAX, CuPy). def gen_vstr_recarray(m, n, dtype=None): size = m * n lengths = np.random.randint(3, 5, size) @@ -285,6 +326,8 @@ def gen_adata( # noqa: PLR0913 layers_types: Collection[type] = DEFAULT_KEY_TYPES, random_state: np.random.Generator | None = None, sparse_fmt: Literal["csr", "csc"] = "csr", + xp=None, # setting default to None for now but maybe np is better (review) + rng=None, ) -> AnnData: """\ Helper function to generate a random AnnData for testing purposes. @@ -314,12 +357,34 @@ def gen_adata( # noqa: PLR0913 """ import dask.array as da + # from anndata.experimental import gen_vstr_recarray + if random_state is None: random_state = np.random.default_rng() M, N = shape obs_names = pd.Index(f"cell{i}" for i in range(shape[0])) var_names = pd.Index(f"gene{i}" for i in range(shape[1])) + + # initialize backends + if xp is None: + try: + xp = get_xp(X_type(xp.ones((1, 1), dtype=X_dtype))) + except Exception: + xp = np # default to numpy if X_type is not compatible + if rng is None: + if xp.__name__.startswith("jax"): + rng = (jax.random, jax.random.PRNGKey(42)) + else: + rng = np.random.default_rng() + + is_jax = xp.__name__.startswith("jax") + if is_jax: + jrandom, key = rng + else: + jrandom, key = None, None + + # generate obs and var dataframes obs = gen_typed_df(M, obs_names, dtypes=obs_dtypes) var = gen_typed_df(N, var_names, dtypes=var_dtypes) # For #147 @@ -332,91 +397,160 @@ def gen_adata( # noqa: PLR0913 if var_xdataset: var = XDataset.from_dataframe(var) + # generating X and including jax support if X_type is None: - X = None + X = None # if no data matrix is requested, skip creation + elif is_jax and isinstance(X_type, type) and issubclass(X_type, sparse.spmatrix): + # JAX does not support scipy sparse matrices (double check) + msg = "JAX does not support sparse matrices" + raise ValueError(msg) else: - X = X_type(random_state.binomial(100, 0.005, (M, N)).astype(X_dtype)) + if is_jax: + # split the JAX PRNG key to create a new, independent subkey + key, subkey = jrandom.split(key) + X_array = jrandom.binomial(subkey, 100.0, 0.005, shape=(M, N)).astype( + X_dtype + ) + rng = (jrandom, key) + else: + # if using numpy, generate a binomial distribution + X_array = rng.binomial(100, 0.005, (M, N)).astype(X_dtype) + X = X_type(X_array) + + # Generate obsm + if is_jax: + key, subkey = jrandom.split(key) + obsm_array = jrandom.uniform(subkey, shape=(M, 50)) + rng = (jrandom, key) + else: + obsm_array = rng.random((M, 50)) + # sparse was moved out due to JAX compatibility issues obsm = dict( - array=np.random.random((M, 50)), - sparse=sparse.random(M, 100, format=sparse_fmt, random_state=random_state), + array=obsm_array, df=gen_typed_df(M, obs_names, dtypes=obs_dtypes), awk_2d_ragged=gen_awkward((M, None)), da=da.random.random((M, 50)), ) + + if not is_jax: + obsm["sparse"] = sparse.random( + M, 100, format=sparse_fmt, random_state=random_state + ) + if has_xr: + obsm["xdataset"] = XDataset.from_dataframe( + gen_typed_df(M, obs_names, dtypes=obs_dtypes) + ) + obsm = {k: v for k, v in obsm.items() if type(v) in obsm_types} + + # generating varm + if is_jax: + key, subkey = jrandom.split(key) + varm_array = jrandom.uniform(subkey, shape=(N, 50)) + rng = (jrandom, key) + else: + varm_array = rng.random((N, 50)) + varm = dict( - array=np.random.random((N, 50)), - sparse=sparse.random(N, 100, format=sparse_fmt, random_state=random_state), + array=varm_array, df=gen_typed_df(N, var_names, dtypes=var_dtypes), awk_2d_ragged=gen_awkward((N, None)), da=da.random.random((N, 50)), ) - if has_xr: - obsm["xdataset"] = XDataset.from_dataframe( - gen_typed_df(M, obs_names, dtypes=obs_dtypes) + if not is_jax: + varm["sparse"] = sparse.random( + N, 100, format=sparse_fmt, random_state=random_state ) + if has_xr: varm["xdataset"] = XDataset.from_dataframe( gen_typed_df(N, var_names, dtypes=var_dtypes) ) + varm = {k: v for k, v in varm.items() if type(v) in varm_types} + + if has_xr: + if XDataset in obsm_types: + obsm["xdataset"] = XDataset.from_dataframe( + gen_typed_df(M, obs_names, dtypes=obs_dtypes) + ) + if XDataset in varm_types: + varm["xdataset"] = XDataset.from_dataframe( + gen_typed_df(N, var_names, dtypes=var_dtypes) + ) obsm = {k: v for k, v in obsm.items() if type(v) in obsm_types} - obsm = maybe_add_sparse_array( - mapping=obsm, - types=obsm_types, - format=sparse_fmt, - random_state=random_state, - shape=(M, 100), - ) + # JAX does not support scipy.sparse, so skip this step to avoid compatibility issues + if not is_jax: + obsm = maybe_add_sparse_array( + mapping=obsm, + types=obsm_types, + format=sparse_fmt, + random_state=random_state, + shape=(M, 100), + ) varm = {k: v for k, v in varm.items() if type(v) in varm_types} - varm = maybe_add_sparse_array( - mapping=varm, - types=varm_types, - format=sparse_fmt, - random_state=random_state, - shape=(N, 100), - ) - layers = dict( - array=np.random.random((M, N)), - sparse=sparse.random(M, N, format=sparse_fmt, random_state=random_state), - da=da.random.random((M, N)), - ) - layers = maybe_add_sparse_array( - mapping=layers, - types=layers_types, - format=sparse_fmt, - random_state=random_state, - shape=(M, N), - ) + if not is_jax: + varm = maybe_add_sparse_array( + mapping=varm, + types=varm_types, + format=sparse_fmt, + random_state=random_state, + shape=(N, 100), + ) + + if is_jax: + key, subkey = jrandom.split(key) + layer_array = jrandom.uniform(subkey, shape=(M, N)) + rng = (jrandom, key) + else: + layer_array = rng.random((M, N)) + + layers = dict(array=layer_array, da=da.random.random((M, N))) + if not is_jax: + layers["sparse"] = sparse.random( + M, N, format=sparse_fmt, random_state=random_state + ) + layers = maybe_add_sparse_array( + mapping=layers, + types=layers_types, + format=sparse_fmt, + random_state=random_state, + shape=(M, N), + ) layers = {k: v for k, v in layers.items() if type(v) in layers_types} - obsp = dict( - array=np.random.random((M, M)), - sparse=sparse.random(M, M, format=sparse_fmt, random_state=random_state), - ) - obsp["sparse_array"] = sparse.csr_array( - sparse.random(M, M, format=sparse_fmt, random_state=random_state) - ) - varp = dict( - array=np.random.random((N, N)), - sparse=sparse.random(N, N, format=sparse_fmt, random_state=random_state), - ) - varp["sparse_array"] = sparse.csr_array( - sparse.random(N, N, format=sparse_fmt, random_state=random_state) - ) - uns = dict( - O_recarray=gen_vstr_recarray(N, 5), - nested=dict( - scalar_str="str", - scalar_int=42, - scalar_float=3.0, - nested_further=dict(array=np.arange(5)), - ), - awkward_regular=gen_awkward((10, 5)), - awkward_ragged=gen_awkward((12, None, None)), - # U_recarray=gen_vstr_recarray(N, 5, "U4") - ) - # https://github.com/zarr-developers/zarr-python/issues/2134 - # zarr v3 on-disk does not write structured dtypes - if anndata.settings.zarr_write_format == 3: - del uns["O_recarray"] + + # skiping obsp/vasp and uns if using JAX + obsp, varp, uns = {}, {}, {} + if not is_jax: + obsp = dict( + array=np.random.random((M, M)), + sparse=sparse.random(M, M, format=sparse_fmt, random_state=random_state), + ) + obsp["sparse_array"] = sparse.csr_array( + sparse.random(M, M, format=sparse_fmt, random_state=random_state) + ) + varp = dict( + array=np.random.random((N, N)), + sparse=sparse.random(N, N, format=sparse_fmt, random_state=random_state), + ) + varp["sparse_array"] = sparse.csr_array( + sparse.random(N, N, format=sparse_fmt, random_state=random_state) + ) + uns = dict( + O_recarray=gen_vstr_recarray(N, 5), + nested=dict( + scalar_str="str", + scalar_int=42, + scalar_float=3.0, + nested_further=dict(array=np.arange(5)), + ), + awkward_regular=gen_awkward((10, 5)), + awkward_ragged=gen_awkward((12, None, None)), + # U_recarray=gen_vstr_recarray(N, 5, "U4") + ) + # https://github.com/zarr-developers/zarr-python/issues/2134 + # zarr v3 on-disk does not write structured dtypes + if anndata.settings.zarr_write_format == 3: + del uns["O_recarray"] + # anndata constuction with warnings.catch_warnings(): warnings.simplefilter("ignore", ExperimentalFeatureWarning) adata = AnnData( diff --git a/tests/test_backend_agnostic_utils.py b/tests/test_backend_agnostic_utils.py new file mode 100644 index 000000000..b8995466d --- /dev/null +++ b/tests/test_backend_agnostic_utils.py @@ -0,0 +1,116 @@ +# from __future__ import annotations + +# import jax +# import numpy as np +# import pytest +# from benchmarks.tests.helpers import gen_adata, gen_indexer + +# from anndata import AnnData + + +# @pytest.mark.parametrize("backend", ["numpy", "jax"]) +# def test_gen_adata_and_indexing(backend): +# # Generate AnnData using backend +# if backend == "numpy": +# pass # default backend used by gen_adata +# elif backend == "jax": +# jnp = jax.numpy +# _ = jnp.ones((1,)) +# else: +# raise ValueError(f"Unsupported backend: {backend}") + +# adata: AnnData = gen_adata(100, 50, {"X-dense", "obs,var"}, seed=42) + +# # Check structure +# assert adata.shape == (100, 50) +# assert "a" in adata.obs.columns +# assert "a" in adata.var.columns + +# # Test each index kind +# for kind in ["slice", "intarray", "boolarray", "strarray"]: +# subset = gen_indexer(adata, "obs", kind, 0.3, seed=123) +# assert isinstance(subset, tuple) +# assert len(subset) == 2 + +# index = subset[0] +# if kind == "slice": +# assert isinstance(index, slice) +# elif kind == "intarray": +# assert hasattr(index, "shape") +# assert 0 < index.shape[0] <= 100 +# elif kind == "boolarray": +# assert index.shape == (100,) +# assert index.dtype == bool +# elif kind == "strarray": +# assert isinstance(index, (list, np.ndarray)) +# assert all(isinstance(i, str) for i in index) + +from __future__ import annotations + +import numpy as np +import pytest +from scipy import sparse + +from anndata._core.anndata import AnnData +from anndata.tests.helpers import gen_adata + +# Try to import JAX if available +# flagging it as a separate import to avoid issues if JAX is not installed +try: + import jax + import jax.numpy as jnp + + jax_available = True +except ImportError: + jax_available = False + + +# testing gen_adata with NumPy backend and various X types (dense, CSR, CSC) to ensure correct shapes and AnnData validity +@pytest.mark.parametrize("X_type", [np.array, sparse.csr_matrix, sparse.csc_matrix]) +def test_gen_adata_numpy_backends(X_type): + adata = gen_adata( + shape=(20, 30), + X_type=X_type, + X_dtype=np.float32, + xp=np, + rng=np.random.default_rng(0), + ) + assert isinstance(adata, AnnData) + assert adata.X.shape == (20, 30) + assert adata.obs.shape[0] == 20 + assert adata.var.shape[0] == 30 + + +# testing that gen_adata works with JAX backend when X is omitted +@pytest.mark.skipif(not jax_available, reason="JAX is not available") +def test_gen_adata_jax_backend_no_X(): + adata = gen_adata( + shape=(20, 30), + X_type=None, # skip X + xp=jax.numpy, + rng=(jax.random, jax.random.PRNGKey(0)), + ) + assert isinstance(adata, AnnData) + assert adata.X is None + assert adata.obs.shape[0] == 20 + assert adata.var.shape[0] == 30 + + +# checking if function correctly returns an Anndata with X as None +def test_gen_adata_X_none(): + adata = gen_adata(shape=(10, 10), X_type=None) + assert isinstance(adata, AnnData) + assert adata.X is None + + +# testing that passing a sparse matrix format to JAX correctly raises a ValueError +def test_gen_adata_invalid_jax_sparse(): + if not jax_available: + pytest.skip("JAX not available") + with pytest.raises(ValueError, match="JAX does not support sparse matrices"): + gen_adata( + shape=(5, 5), + X_type=sparse.csr_matrix, + xp=jax.numpy, + rng=(jax.random, jax.random.PRNGKey(0)), + )