Skip to content

Commit c23df17

Browse files
committed
(refactor): remove read_dispatched_async
1 parent eb08bd5 commit c23df17

File tree

9 files changed

+47
-73
lines changed

9 files changed

+47
-73
lines changed

src/anndata/_io/h5ad.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
_decode_structured_array,
2525
_from_fixed_length_strings,
2626
)
27-
from ..experimental import read_dispatched_async
27+
from ..experimental import read_dispatched
2828
from .specs import write_elem
2929
from .specs.methods import sync_async_to_async
3030
from .specs.registry import IOSpec, read_elem_async, write_spec
@@ -243,9 +243,7 @@ async def callback(func, elem_name: str, elem, iospec):
243243
*(
244244
# This is covering up backwards compat in the anndata initializer
245245
# In most cases we should be able to call `func(elen[k])` instead
246-
sync_async_to_async(
247-
k, read_dispatched_async(elem[k], callback)
248-
)
246+
sync_async_to_async(k, read_dispatched(elem[k], callback))
249247
for k in elem.keys()
250248
if not k.startswith("raw.")
251249
)
@@ -263,7 +261,7 @@ async def callback(func, elem_name: str, elem, iospec):
263261
return await read_dataframe(elem)
264262
return await func(elem)
265263

266-
adata = asyncio.run(read_dispatched_async(f, callback=callback))
264+
adata = asyncio.run(read_dispatched(f, callback=callback))
267265

268266
# Backwards compat (should figure out which version)
269267
if "raw.X" in f:

src/anndata/_io/specs/methods.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,6 @@ def write_null_zarr(f, k, _v, _writer, dataset_kwargs=MappingProxyType({})):
313313
async def read_mapping(
314314
elem: GroupStorageType, *, _reader: Reader
315315
) -> dict[str, AxisStorable]:
316-
print(
317-
(
318-
sync_async_to_async(k, _reader.read_elem_async(v))
319-
for k, v in dict(elem).items()
320-
)
321-
)
322316
return dict(
323317
await asyncio.gather(
324318
*(

src/anndata/_io/zarr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .._settings import settings
1515
from .._warnings import OldFormatWarning
1616
from ..compat import _clean_uns, _from_fixed_length_strings, is_zarr_v2
17-
from ..experimental import read_dispatched_async, write_dispatched
17+
from ..experimental import read_dispatched, write_dispatched
1818
from .specs import read_elem_async
1919
from .specs.methods import sync_async_to_async
2020
from .utils import _read_legacy_raw, report_read_key_on_error
@@ -84,7 +84,7 @@ async def callback(func, elem_name: str, elem, iospec):
8484
*(
8585
# This is covering up backwards compat in the anndata initializer
8686
# In most cases we should be able to call `func(elen[k])` instead
87-
sync_async_to_async(k, read_dispatched_async(elem[k], callback))
87+
sync_async_to_async(k, read_dispatched(elem[k], callback))
8888
for k in elem.keys()
8989
if not k.startswith("raw.")
9090
)
@@ -102,7 +102,7 @@ async def callback(func, elem_name: str, elem, iospec):
102102
)
103103
return await func(elem)
104104

105-
adata = asyncio.run(read_dispatched_async(f, callback=callback))
105+
adata = asyncio.run(read_dispatched(f, callback=callback))
106106

107107
# Backwards compat (should figure out which version)
108108
if "raw.X" in f:

src/anndata/experimental/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .._io.specs import IOSpec, read_elem_as_dask
77
from .._types import Read, ReadCallback, StorageType, Write, WriteCallback
88
from ..utils import module_get_attr_redirect
9-
from ._dispatch_io import read_dispatched, read_dispatched_async, write_dispatched
9+
from ._dispatch_io import read_dispatched, write_dispatched
1010
from .merge import concat_on_disk
1111
from .multi_files import AnnCollection
1212
from .pytorch import AnnLoader
@@ -43,7 +43,6 @@ def __getattr__(attr_name: str) -> Any:
4343
"AnnLoader",
4444
"read_elem_as_dask",
4545
"read_dispatched",
46-
"read_dispatched_async",
4746
"write_dispatched",
4847
"IOSpec",
4948
"concat_on_disk",

src/anndata/experimental/_dispatch_io.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,13 @@
1010
from anndata._types import (
1111
GroupStorageType,
1212
ReadAsyncCallback,
13-
ReadCallback,
1413
StorageType,
1514
WriteCallback,
1615
)
1716
from anndata.typing import RWAble
1817

1918

20-
def read_dispatched(
21-
elem: StorageType,
22-
callback: ReadCallback,
23-
) -> RWAble:
24-
"""
25-
Read elem, calling the callback at each sub-element.
26-
27-
Params
28-
------
29-
elem
30-
Storage container (e.g. `h5py.Group`, `zarr.Group`).
31-
This must have anndata element specifications.
32-
callback
33-
Function to call at each anndata encoded element.
34-
35-
See Also
36-
--------
37-
:doc:`/tutorials/notebooks/{read,write}_dispatched`
38-
"""
39-
from anndata._io.specs import _REGISTRY, Reader
40-
41-
reader = Reader(_REGISTRY, callback=callback)
42-
43-
return reader.read_elem(elem)
44-
45-
46-
async def read_dispatched_async(
19+
async def read_dispatched(
4720
elem: StorageType,
4821
callback: ReadAsyncCallback,
4922
) -> RWAble:

src/anndata/experimental/merge.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import os
45
import shutil
56
from collections.abc import Mapping
@@ -26,6 +27,8 @@
2627
)
2728
from .._core.sparse_dataset import BaseCompressedSparseDataset, sparse_dataset
2829
from .._io.specs import read_elem, write_elem
30+
from .._io.specs.methods import sync_async_to_async
31+
from .._io.specs.registry import read_elem_async
2932
from ..compat import H5Array, H5Group, ZarrArray, ZarrGroup, _map_cat_to_str
3033
from . import read_dispatched
3134

@@ -142,19 +145,26 @@ def read_as_backed(group: ZarrGroup | H5Group):
142145
BaseCompressedSparseDataset, Array or EAGER_TYPES are encountered.
143146
"""
144147

145-
def callback(func, elem_name: str, elem, iospec):
148+
async def callback(func, elem_name: str, elem, iospec):
146149
if iospec.encoding_type in SPARSE_MATRIX:
147150
return sparse_dataset(elem)
148151
elif iospec.encoding_type in EAGER_TYPES:
149-
return read_elem(elem)
152+
return await read_elem_async(elem)
150153
elif iospec.encoding_type == "array":
151154
return elem
152155
elif iospec.encoding_type == "dict":
153-
return {k: read_as_backed(v) for k, v in dict(elem).items()}
156+
return dict(
157+
await asyncio.gather(
158+
*(
159+
sync_async_to_async(k, read_dispatched(v, callback=callback))
160+
for k, v in dict(elem).items()
161+
)
162+
)
163+
)
154164
else:
155-
return func(elem)
165+
return await func(elem)
156166

157-
return read_dispatched(group, callback=callback)
167+
return asyncio.run(read_dispatched(group, callback=callback))
158168

159169

160170
def _df_index(df: ZarrGroup | H5Group) -> pd.Index:

src/anndata/io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
read_text,
1515
read_umi_tools,
1616
)
17-
from ._io.specs import read_elem, write_elem
17+
from ._io.specs import read_elem, read_elem_async, write_elem
1818
from ._io.write import write_csvs, write_loom
1919

2020
if find_spec("zarr") or TYPE_CHECKING:
@@ -46,5 +46,6 @@ def write_zarr(*args, **kw):
4646
"write_zarr",
4747
"write_elem",
4848
"read_elem",
49+
"read_elem_async",
4950
"sparse_dataset",
5051
]

tests/test_backed_sparse.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from functools import partial
45
from itertools import product
56
from typing import TYPE_CHECKING, Literal, get_args
@@ -84,19 +85,12 @@ def read_zarr_backed(path):
8485
f = zarr.open(path, mode="r")
8586

8687
# Read with handling for backwards compat
87-
def callback(func, elem_name, elem, iospec):
88-
if iospec.encoding_type == "anndata" or elem_name.endswith("/"):
89-
return AnnData(
90-
**{
91-
k: read_dispatched(v, callback)
92-
for k, v in dict(elem).items()
93-
}
94-
)
88+
async def callback(func, elem_name, elem, iospec):
9589
if iospec.encoding_type in {"csc_matrix", "csr_matrix"}:
9690
return sparse_dataset(elem)
97-
return func(elem)
91+
return await func(elem)
9892

99-
adata = read_dispatched(f, callback=callback)
93+
adata = asyncio.run(read_dispatched(f, callback=callback))
10094

10195
return adata
10296

tests/test_io_dispatched.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import re
45
from typing import TYPE_CHECKING
56

@@ -18,11 +19,11 @@
1819

1920

2021
def test_read_dispatched_w_regex(tmp_path: Path):
21-
def read_only_axis_dfs(func, elem_name: str, elem, iospec):
22+
async def read_only_axis_dfs(func, elem_name: str, elem, iospec):
2223
if iospec.encoding_type == "anndata":
23-
return func(elem)
24+
return await func(elem)
2425
elif re.match(r"^/((obs)|(var))?(/.*)?$", elem_name):
25-
return func(elem)
26+
return await func(elem)
2627
else:
2728
return None
2829

@@ -35,27 +36,27 @@ def read_only_axis_dfs(func, elem_name: str, elem, iospec):
3536
z = zarr.open(z.store)
3637

3738
expected = ad.AnnData(obs=adata.obs, var=adata.var)
38-
actual = read_dispatched(z, read_only_axis_dfs)
39+
actual = asyncio.run(read_dispatched(z, read_only_axis_dfs))
3940

4041
assert_equal(expected, actual)
4142

4243

4344
def test_read_dispatched_dask(tmp_path: Path):
4445
import dask.array as da
4546

46-
def read_as_dask_array(func, elem_name: str, elem, iospec):
47+
async def read_as_dask_array(func, elem_name: str, elem, iospec):
4748
if iospec.encoding_type in {
4849
"dataframe",
4950
"csr_matrix",
5051
"csc_matrix",
5152
"awkward-array",
5253
}:
5354
# Preventing recursing inside of these types
54-
return func(elem)
55+
return await func(elem)
5556
elif iospec.encoding_type == "array":
5657
return da.from_zarr(elem)
5758
else:
58-
return func(elem)
59+
return await func(elem)
5960

6061
adata = gen_adata((1000, 100))
6162
z = open_write_group(tmp_path)
@@ -64,7 +65,7 @@ def read_as_dask_array(func, elem_name: str, elem, iospec):
6465
if not is_zarr_v2() and isinstance(z, ZarrGroup):
6566
z = zarr.open(z.store)
6667

67-
dask_adata = read_dispatched(z, read_as_dask_array)
68+
dask_adata = asyncio.run(read_dispatched(z, read_as_dask_array))
6869

6970
assert isinstance(dask_adata.layers["array"], da.Array)
7071
assert isinstance(dask_adata.obsm["array"], da.Array)
@@ -84,7 +85,11 @@ def test_read_dispatched_null_case(tmp_path: Path):
8485
if not is_zarr_v2() and isinstance(z, ZarrGroup):
8586
z = zarr.open(z.store)
8687
expected = ad.io.read_elem(z)
87-
actual = read_dispatched(z, lambda _, __, x, **___: ad.io.read_elem(x))
88+
89+
async def callback(_, __, x, **___):
90+
return await ad.io.read_elem_async(x)
91+
92+
actual = asyncio.run(read_dispatched(z, callback))
8893

8994
assert_equal(expected, actual)
9095

@@ -186,23 +191,23 @@ def zarr_writer(func, store, k, elem, dataset_kwargs, iospec):
186191
)
187192
func(store, k, elem, dataset_kwargs=dataset_kwargs)
188193

189-
def h5ad_reader(func, elem_name: str, elem, iospec):
194+
async def h5ad_reader(func, elem_name: str, elem, iospec):
190195
h5ad_read_keys.append(elem_name if is_zarr_v2() else elem_name.strip("/"))
191196
return func(elem)
192197

193-
def zarr_reader(func, elem_name: str, elem, iospec):
198+
async def zarr_reader(func, elem_name: str, elem, iospec):
194199
zarr_read_keys.append(elem_name if is_zarr_v2() else elem_name.strip("/"))
195200
return func(elem)
196201

197202
adata = gen_adata((50, 100))
198203

199204
with h5py.File(h5ad_path, "w") as f:
200205
write_dispatched(f, "/", adata, callback=h5ad_writer)
201-
_ = read_dispatched(f, h5ad_reader)
206+
_ = asyncio.run(read_dispatched(f, h5ad_reader))
202207

203208
f = open_write_group(zarr_path)
204209
write_dispatched(f, "/", adata, callback=zarr_writer)
205-
_ = read_dispatched(f, zarr_reader)
210+
_ = asyncio.run(read_dispatched(f, zarr_reader))
206211

207212
assert sorted(h5ad_read_keys) == sorted(zarr_read_keys)
208213
assert sorted(h5ad_write_keys) == sorted(zarr_write_keys)

0 commit comments

Comments
 (0)