Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 1 addition & 154 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import anndata as ad
from anndata import AnnData, Raw
from anndata._core import views
from anndata._core.index import _normalize_indices
from anndata._core.merge import intersect_keys
from anndata._core.sparse_dataset import _CSCDataset, _CSRDataset, sparse_dataset
from anndata._io.utils import H5PY_V3, check_key, zero_dim_array_as_scalar
from anndata._warnings import OldFormatWarning
Expand All @@ -42,11 +40,10 @@

from ..._settings import settings
from ...compat import is_zarr_v2
from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial
from .registry import _REGISTRY, IOSpec, read_elem

if TYPE_CHECKING:
from collections.abc import Callable, Iterator
from os import PathLike
from typing import Any, Literal

from numpy import typing as npt
Expand Down Expand Up @@ -173,16 +170,6 @@ def read_basic_zarr(
return zarr.read_dataset(elem) # TODO: Handle legacy


# @_REGISTRY.register_read_partial(IOSpec("", ""))
# def read_basic_partial(elem, *, items=None, indices=(slice(None), slice(None))):
# if isinstance(elem, Mapping):
# return _read_partial(elem, items=items, indices=indices)
# elif indices != (slice(None), slice(None)):
# return elem[indices]
# else:
# return elem[()]


###########
# AnnData #
###########
Expand All @@ -198,77 +185,6 @@ def read_indices(group):
return obs_idx, var_idx


def read_partial(
pth: PathLike[str] | str,
*,
obs_idx=slice(None),
var_idx=slice(None),
X=True,
obs=None,
var=None,
obsm=None,
varm=None,
obsp=None,
varp=None,
layers=None,
uns=None,
) -> ad.AnnData:
result = {}
with h5py.File(pth, "r") as f:
obs_idx, var_idx = _normalize_indices((obs_idx, var_idx), *read_indices(f))
result["obs"] = read_elem_partial(
f["obs"], items=obs, indices=(obs_idx, slice(None))
)
result["var"] = read_elem_partial(
f["var"], items=var, indices=(var_idx, slice(None))
)
if X:
result["X"] = read_elem_partial(f["X"], indices=(obs_idx, var_idx))
else:
result["X"] = sparse.csr_matrix((len(result["obs"]), len(result["var"])))
if "obsm" in f:
result["obsm"] = _read_partial(
f["obsm"], items=obsm, indices=(obs_idx, slice(None))
)
if "varm" in f:
result["varm"] = _read_partial(
f["varm"], items=varm, indices=(var_idx, slice(None))
)
if "obsp" in f:
result["obsp"] = _read_partial(
f["obsp"], items=obsp, indices=(obs_idx, obs_idx)
)
if "varp" in f:
result["varp"] = _read_partial(
f["varp"], items=varp, indices=(var_idx, var_idx)
)
if "layers" in f:
result["layers"] = _read_partial(
f["layers"], items=layers, indices=(obs_idx, var_idx)
)
if "uns" in f:
result["uns"] = _read_partial(f["uns"], items=uns)

return ad.AnnData(**result)


def _read_partial(group, *, items=None, indices=(slice(None), slice(None))):
if group is None:
return None
if items is None:
keys = intersect_keys((group,))
else:
keys = intersect_keys((group, items))
result = {}
for k in keys:
if isinstance(items, Mapping):
next_items = items.get(k, None)
else:
next_items = None
result[k] = read_elem_partial(group[k], items=next_items, indices=indices)
return result


@_REGISTRY.register_write(ZarrGroup, AnnData, IOSpec("anndata", "0.1.0"))
@_REGISTRY.register_write(H5Group, AnnData, IOSpec("anndata", "0.1.0"))
def write_anndata(
Expand Down Expand Up @@ -543,28 +459,12 @@ def read_array(elem: ArrayStorageType, *, _reader: Reader) -> npt.NDArray:
return elem[()]


@_REGISTRY.register_read_partial(H5Array, IOSpec("array", "0.2.0"))
@_REGISTRY.register_read_partial(ZarrArray, IOSpec("string-array", "0.2.0"))
def read_array_partial(elem, *, items=None, indices=(slice(None, None))):
return elem[indices]


@_REGISTRY.register_read_partial(ZarrArray, IOSpec("array", "0.2.0"))
def read_zarr_array_partial(elem, *, items=None, indices=(slice(None, None))):
return elem.oindex[indices]


# arrays of strings
@_REGISTRY.register_read(H5Array, IOSpec("string-array", "0.2.0"))
def read_string_array(d: H5Array, *, _reader: Reader):
return read_array(d.asstr(), _reader=_reader)


@_REGISTRY.register_read_partial(H5Array, IOSpec("string-array", "0.2.0"))
def read_string_array_partial(d, items=None, indices=slice(None)):
return read_array_partial(d.asstr(), items=items, indices=indices)


@_REGISTRY.register_write(
H5Group, (views.ArrayView, "U"), IOSpec("string-array", "0.2.0")
)
Expand Down Expand Up @@ -896,14 +796,6 @@ def read_sparse(elem: GroupStorageType, *, _reader: Reader) -> CSMatrix | CSArra
return sparse_dataset(elem).to_memory()


@_REGISTRY.register_read_partial(H5Group, IOSpec("csc_matrix", "0.1.0"))
@_REGISTRY.register_read_partial(H5Group, IOSpec("csr_matrix", "0.1.0"))
@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("csc_matrix", "0.1.0"))
@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("csr_matrix", "0.1.0"))
def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None))):
return sparse_dataset(elem)[indices]


#################
# Awkward array #
#################
Expand Down Expand Up @@ -1023,29 +915,6 @@ def read_dataframe(elem: GroupStorageType, *, _reader: Reader) -> pd.DataFrame:
return df


# TODO: Figure out what indices is allowed to be at each element
@_REGISTRY.register_read_partial(H5Group, IOSpec("dataframe", "0.2.0"))
@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("dataframe", "0.2.0"))
def read_dataframe_partial(
elem, *, items=None, indices=(slice(None, None), slice(None, None))
):
if items is not None:
columns = [
col for col in _read_attr(elem.attrs, "column-order") if col in items
]
else:
columns = list(_read_attr(elem.attrs, "column-order"))
idx_key = _read_attr(elem.attrs, "_index")
df = pd.DataFrame(
{k: read_elem_partial(elem[k], indices=indices[0]) for k in columns},
index=read_elem_partial(elem[idx_key], indices=indices[0]),
columns=columns if len(columns) else None,
)
if idx_key != "_index":
df.index.name = idx_key
return df


# Backwards compat dataframe reading


Expand Down Expand Up @@ -1084,18 +953,6 @@ def read_series(dataset: h5py.Dataset) -> np.ndarray | pd.Categorical:
return read_elem(dataset)


@_REGISTRY.register_read_partial(H5Group, IOSpec("dataframe", "0.1.0"))
@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("dataframe", "0.1.0"))
def read_partial_dataframe_0_1_0(
elem, *, items=None, indices=(slice(None), slice(None))
):
if items is None:
items = slice(None)
else:
items = list(items)
return read_elem(elem)[items].iloc[indices[0]]


###############
# Categorical #
###############
Expand Down Expand Up @@ -1130,16 +987,6 @@ def read_categorical(elem: GroupStorageType, *, _reader: Reader) -> pd.Categoric
)


@_REGISTRY.register_read_partial(H5Group, IOSpec("categorical", "0.2.0"))
@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("categorical", "0.2.0"))
def read_partial_categorical(elem, *, items=None, indices=(slice(None),)):
return pd.Categorical.from_codes(
codes=read_elem_partial(elem["codes"], indices=indices),
categories=read_elem(elem["categories"]),
ordered=bool(_read_attr(elem.attrs, "ordered")),
)


####################
# Pandas nullables #
####################
Expand Down
39 changes: 0 additions & 39 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def wrapper(g: GroupStorageType, k: str, *args, **kwargs):
class IORegistry(Generic[_R, R]):
def __init__(self):
self.read: dict[tuple[type, IOSpec, frozenset[str]], _R] = {}
self.read_partial: dict[tuple[type, IOSpec, frozenset[str]], Callable] = {}
self.write: dict[
tuple[type, type | tuple[type, str], frozenset[str]], _WriteInternal
] = {}
Expand Down Expand Up @@ -181,29 +180,6 @@ def has_read(
) -> bool:
return (src_type, spec, modifiers) in self.read

def register_read_partial(
self,
src_type: type,
spec: IOSpec | Mapping[str, str],
modifiers: Iterable[str] = frozenset(),
):
spec = proc_spec(spec)
modifiers = frozenset(modifiers)

def _register(func):
self.read_partial[(src_type, spec, modifiers)] = func
return func

return _register

def get_partial_read(
self, src_type: type, spec: IOSpec, modifiers: frozenset[str] = frozenset()
):
if (src_type, spec, modifiers) in self.read_partial:
return self.read_partial[(src_type, spec, modifiers)]
name = "read_partial"
raise IORegistryError._from_read_parts(name, self.read_partial, src_type, spec)

def get_spec(self, elem: Any) -> IOSpec:
if isinstance(elem, DaskArray):
if (typ_meta := (DaskArray, type(elem._meta))) in self.write_specs:
Expand Down Expand Up @@ -507,18 +483,3 @@ def write_elem(
E.g. for zarr this would be `chunks`, `compressor`.
"""
Writer(_REGISTRY).write_elem(store, k, elem, dataset_kwargs=dataset_kwargs)


# TODO: If all items would be read, just call normal read method
def read_elem_partial(
elem,
*,
items=None,
indices=(slice(None), slice(None)),
modifiers: frozenset[str] = frozenset(),
):
"""Read part of an element from an on disk store."""
read_partial = _REGISTRY.get_partial_read(
type(elem), get_spec(elem), frozenset(modifiers)
)
return read_partial(elem, items=items, indices=indices)
100 changes: 0 additions & 100 deletions tests/test_io_partial.py

This file was deleted.