Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions benchmarks/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,161 @@ def gen_adata(n_obs, n_var, attr_set):
index=[f"gene{i}" for i in range(n_var)],
)
return adata


# from __future__ import annotations

# import gc
# import sys
# from string import ascii_lowercase
# from time import sleep

# import jax
# import numpy as np
# import pandas as pd
# from array_api_compat import get_namespace as array_api_get_namespace
# from memory_profiler import memory_usage
# from scipy import sparse

# from anndata import AnnData


# def get_namespace(x=None):
# return array_api_get_namespace(x)


# def get_rng(xp, seed=None):
# """Return a backend-specific random number generator."""
# # RNG isn't standardized in the Array API spec,
# # so backends like JAX, PyTorch, and NumPy each handle randomness differently.
# if xp.__name__.startswith("jax"):
# return jax.random.PRNGKey(seed or 0)
# elif xp.__name__.startswith("numpy"):
# return np.random.default_rng(seed)
# else:
# raise NotImplementedError(f"RNG not implemented for backend: {xp.__name__}")


# def get_actualsize(input_obj):
# """Using Python Garbage Collector to calculate the size of all elements attached to an object"""

# memory_size = 0
# ids = set()
# objects = [input_obj]
# while objects:
# new = []
# for obj in objects:
# if id(obj) not in ids:
# ids.add(id(obj))
# memory_size += sys.getsizeof(obj)
# new.append(obj)
# objects = gc.get_referents(*new)
# return memory_size


# def get_anndata_memsize(adata):
# recording = memory_usage(
# (sedate(adata.copy, naplength=0.005), (adata,)), interval=0.001
# )
# diff = recording[-1] - recording[0]
# return diff


# def get_peak_mem(op, interval=0.001):
# recording = memory_usage(op, interval=interval)
# xp = get_namespace()
# return xp.max(recording) - xp.min(recording)


# def sedate(func, naplength=0.05):
# """Make a function sleepy, so we can sample the start and end state."""

# def wrapped_function(*args, **kwargs):
# sleep(naplength)
# val = func(*args, **kwargs)
# sleep(naplength)
# return val

# return wrapped_function


# # TODO: Factor out the time it takes to generate these


# def gen_indexer(adata, dim, index_kind, ratio, seed=None):
# dimnames = ("obs", "var")
# index_kinds = {"slice", "intarray", "boolarray", "strarray"}

# if index_kind not in index_kinds:
# msg = f"Argument 'index_kind' must be one of {index_kinds}. Was {index_kind}."
# raise ValueError(msg)

# xp = get_namespace(adata.X)
# rng = get_rng(xp, seed)
# axis = dimnames.index(dim)
# subset = [slice(None), slice(None)]
# axis_size = adata.shape[axis]
# n = int(xp.round(axis_size * ratio))

# if index_kind == "slice":
# subset[axis] = slice(0, n)
# elif index_kind == "intarray":
# if xp.__name__.startswith("jax"):
# subset[axis] = jax.random.choice(
# rng, xp.arange(axis_size), shape=(n,), replace=False
# )
# elif xp.__name__.startswith("numpy"):
# subset[axis] = xp.asarray(rng.choice(axis_size, n, replace=False))

# elif index_kind == "boolarray":
# mask = xp.zeros(axis_size, dtype=bool)
# if xp.__name__.startswith("jax"):
# idx = jax.random.choice(
# rng, xp.arange(axis_size), shape=(n,), replace=False
# )
# elif xp.__name__.startswith("numpy"):
# idx = rng.choice(axis_size, n, replace=False)
# mask[idx] = True
# subset[axis] = mask

# elif index_kind == "strarray":
# subset[axis] = rng.choice(getattr(adata, dim).index, n, replace=False)
# else:
# raise ValueError()
# return tuple(subset)


# def gen_adata(n_obs, n_var, attr_set, seed=None):
# if "X-csr" in attr_set:
# X_sparse = sparse.random(n_obs, n_var, density=0.1, format="csr")
# xp = get_namespace(X_sparse.toarray())
# X = X_sparse
# elif "X-dense" in attr_set:
# dense_X = sparse.random(n_obs, n_var, density=0.1, format="csr")
# xp = get_namespace(dense_X)
# X = xp.asarray(dense_X)
# else:
# # TODO: There's probably a better way to do this
# # fallback to use just numpy
# import numpy as np

# X_dense = np.zeros((n_obs, n_var))
# # X = sparse.random(n_obs, n_var, density=0, format="csr")
# xp = get_namespace(X_dense)
# X = xp.asarray(X_dense)
# rng = get_rng(xp, seed)
# adata = AnnData(X)
# if "obs,var" in attr_set:
# if xp.__name__.startswith("jax"):
# obs = {
# k: jax.random.randint(rng, (n_obs,), 0, 100) for k in ascii_lowercase
# }
# var = {
# k: jax.random.randint(rng, (n_var,), 0, 100) for k in ascii_lowercase
# }
# elif xp.__name__.startswith("numpy"):
# obs = {k: rng.integers(0, 100, size=n_obs) for k in ascii_lowercase}
# var = {k: rng.integers(0, 100, size=n_var) for k in ascii_lowercase}
# adata.obs = pd.DataFrame(obs, index=[f"cell{i}" for i in range(n_obs)])
# adata.var = pd.DataFrame(var, index=[f"gene{i}" for i in range(n_var)])
# return adata
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ dev = [
# runtime dev version generation
"hatch-vcs",
"anndata[dev-doc]",
"jax",
"jaxlib",
]
doc = [
"sphinx>=8.2.1",
Expand Down Expand Up @@ -105,6 +107,8 @@ test-min = [
test = [
"anndata[test-min,lazy]",
"pandas>=2.1.0",
"jax",
"jaxlib",
] # pandas 2.1.0 needs to be specified for xarray to work with min-deps script
gpu = [ "cupy" ]
cu12 = [ "cupy-cuda12x" ]
Expand Down
123 changes: 85 additions & 38 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from collections.abc import Iterable, Sequence
from functools import singledispatch
from itertools import repeat
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import h5py
import numpy as np
import pandas as pd
from array_api_compat import get_namespace as array_api_get_namespace
from scipy.sparse import issparse

from ..compat import AwkArray, CSArray, CSMatrix, DaskArray, XDataArray
Expand All @@ -17,6 +18,40 @@
from ..compat import Index, Index1D


def get_xp(x):
# to fall back to numpy if no array API is available
try:
return array_api_get_namespace(x)
except Exception:
return np


def is_array_api_obj(x):
# check if object supports _array_namespace__, else fall back to numpy
return hasattr(x, "__array_namespace__")


def get_numeric_dtypes(xp):
return (
xp.dtype(xp.int32),
xp.dtype(xp.int64),
xp.dtype(xp.float32),
xp.dtype(xp.float64),
)


def get_integer_dtypes(xp):
return (xp.dtype(xp.int32), xp.dtype(xp.int64))


def get_floating_dtypes(xp):
return (xp.dtype(xp.float32), xp.dtype(xp.float64))


def get_boolean_dtype(xp):
return xp.dtype(xp.bool_)


def _normalize_indices(
index: Index | None, names0: pd.Index, names1: pd.Index
) -> tuple[slice, slice]:
Expand All @@ -36,17 +71,17 @@ def _normalize_indices(


def _normalize_index( # noqa: PLR0911, PLR0912
indexer: slice
| np.integer
| int
| str
| Sequence[bool | int | np.integer]
| np.ndarray
| pd.Index,
indexer,
index: pd.Index,
) -> slice | int | np.ndarray: # ndarray of int or bool
) -> (
slice | int | Any
): # ndarray of int or bool, switched to Any to make it compatible with array API objects
# TODO: why is this here? All tests pass without it and it seems at the minimum not strict enough.
if not isinstance(index, pd.RangeIndex) and index.dtype in (np.float64, np.int64):
xp = get_xp(indexer)
if not isinstance(index, pd.RangeIndex) and index.dtype in (
xp.float64,
xp.int64,
):
msg = f"Don’t call _normalize_index with non-categorical/string names and non-range index {index}"
raise TypeError(msg)

Expand All @@ -65,34 +100,37 @@ def name_idx(i):
stop = None if stop is None else stop + 1
step = indexer.step
return slice(start, stop, step)
elif isinstance(indexer, np.integer | int):
if isinstance(indexer, (int,)) or (
is_array_api_obj(indexer) and isinstance(indexer, xp.integer)
):
return indexer
elif isinstance(indexer, str):
return index.get_loc(indexer) # int
elif isinstance(
indexer, Sequence | np.ndarray | pd.Index | CSMatrix | np.matrix | CSArray
):
indexer, (Sequence, pd.Index, CSMatrix, CSArray)
) or is_array_api_obj(indexer):
if hasattr(indexer, "shape") and (
(indexer.shape == (index.shape[0], 1))
or (indexer.shape == (1, index.shape[0]))
):
if isinstance(indexer, CSMatrix | CSArray):
indexer = indexer.toarray()
indexer = np.ravel(indexer)
if not isinstance(indexer, np.ndarray | pd.Index):
indexer = np.array(indexer)
indexer = xp.ravel(indexer)
if not isinstance(indexer, (pd.Index,)) and not is_array_api_obj(indexer):
indexer = xp.array(indexer)
if len(indexer) == 0:
indexer = indexer.astype(int)
if isinstance(indexer, np.ndarray) and np.issubdtype(
indexer.dtype, np.floating
):

if get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).floating):
indexer_int = indexer.astype(int)
if np.all((indexer - indexer_int) != 0):
if xp.all((indexer - indexer_int) != 0):
msg = f"Indexer {indexer!r} has floating point values."
raise IndexError(msg)
if issubclass(indexer.dtype.type, np.integer | np.floating):
if get_xp(indexer).issubdtype(
indexer.dtype, get_xp(indexer).integer | get_xp(indexer).floating
):
return indexer # Might not work for range indexes
elif issubclass(indexer.dtype.type, np.bool_):
elif get_xp(indexer).issubdtype(indexer.dtype, get_xp(indexer).bool_):
if indexer.shape != index.shape:
msg = (
f"Boolean index does not match AnnData’s shape along this "
Expand All @@ -103,7 +141,7 @@ def name_idx(i):
return indexer
else: # indexer should be string array
positions = index.get_indexer(indexer)
if np.any(positions < 0):
if get_xp(positions).any(positions < 0):
not_found = indexer[positions < 0]
msg = (
f"Values {list(not_found)}, from {list(indexer)}, "
Expand Down Expand Up @@ -168,11 +206,14 @@ def unpack_index(index: Index) -> tuple[Index1D, Index1D]:


@singledispatch
def _subset(a: np.ndarray | pd.DataFrame, subset_idx: Index):
def _subset(a, subset_idx: Index):
# Select as combination of indexes, not coordinates
# Correcting for indexing behaviour of np.ndarray
if all(isinstance(x, Iterable) for x in subset_idx):
subset_idx = np.ix_(*subset_idx)
xp = get_xp(a)
if all(
isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx
):
subset_idx = xp.ix_(*subset_idx)
return a[subset_idx]


Expand All @@ -189,10 +230,13 @@ def _subset_dask(a: DaskArray, subset_idx: Index):
@_subset.register(CSArray)
def _subset_sparse(a: CSMatrix | CSArray, subset_idx: Index):
# Correcting for indexing behaviour of sparse.spmatrix
if len(subset_idx) > 1 and all(isinstance(x, Iterable) for x in subset_idx):
xp = get_xp(a)
if len(subset_idx) > 1 and all(
isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx
):
first_idx = subset_idx[0]
if issubclass(first_idx.dtype.type, np.bool_):
first_idx = np.where(first_idx)[0]
if hasattr(first_idx, "dtype") and first_idx.dtype == bool:
first_idx = xp.where(first_idx)[0]
subset_idx = (first_idx.reshape(-1, 1), *subset_idx[1:])
return a[subset_idx]

Expand All @@ -205,8 +249,11 @@ def _subset_df(df: pd.DataFrame | Dataset2D, subset_idx: Index):

@_subset.register(AwkArray)
def _subset_awkarray(a: AwkArray, subset_idx: Index):
if all(isinstance(x, Iterable) for x in subset_idx):
subset_idx = np.ix_(*subset_idx)
xp = get_xp(a)
if all(
isinstance(x, Iterable) and not isinstance(x, (str, bytes)) for x in subset_idx
):
subset_idx = xp.ix_(*subset_idx)
return a[subset_idx]


Expand All @@ -215,15 +262,15 @@ def _subset_awkarray(a: AwkArray, subset_idx: Index):
def _subset_dataset(d, subset_idx):
if not isinstance(subset_idx, tuple):
subset_idx = (subset_idx,)
xp = get_xp(subset_idx[0])
ordered = list(subset_idx)
rev_order = [slice(None) for _ in range(len(subset_idx))]
for axis, axis_idx in enumerate(ordered.copy()):
if isinstance(axis_idx, np.ndarray):
if axis_idx.dtype == bool:
axis_idx = np.where(axis_idx)[0]
order = np.argsort(axis_idx)
ordered[axis] = axis_idx[order]
rev_order[axis] = np.argsort(order)
if hasattr(axis_idx, "dtype") and axis_idx.dtype == bool:
axis_idx = xp.where(axis_idx)[0]
order = xp.argsort(axis_idx)
ordered[axis] = axis_idx[order]
rev_order[axis] = xp.argsort(order)
# from hdf5, then to real order
return d[tuple(ordered)][tuple(rev_order)]

Expand Down Expand Up @@ -257,4 +304,4 @@ def get_vector(adata, k, coldim, idxdim, layer=None):
a = adata._get_X(layer=layer)[idx]
if issparse(a):
a = a.toarray()
return np.ravel(a)
return get_xp(a).ravel(a)
Loading
Loading