11from __future__ import annotations
22
3+ import asyncio
34import re
45from typing import TYPE_CHECKING
56
1819
1920
2021def test_read_dispatched_w_regex (tmp_path : Path ):
21- def read_only_axis_dfs (func , elem_name : str , elem , iospec ):
22+ async def read_only_axis_dfs (func , elem_name : str , elem , iospec ):
2223 if iospec .encoding_type == "anndata" :
23- return func (elem )
24+ return await func (elem )
2425 elif re .match (r"^/((obs)|(var))?(/.*)?$" , elem_name ):
25- return func (elem )
26+ return await func (elem )
2627 else :
2728 return None
2829
@@ -35,27 +36,27 @@ def read_only_axis_dfs(func, elem_name: str, elem, iospec):
3536 z = zarr .open (z .store )
3637
3738 expected = ad .AnnData (obs = adata .obs , var = adata .var )
38- actual = read_dispatched (z , read_only_axis_dfs )
39+ actual = asyncio . run ( read_dispatched (z , read_only_axis_dfs ) )
3940
4041 assert_equal (expected , actual )
4142
4243
4344def test_read_dispatched_dask (tmp_path : Path ):
4445 import dask .array as da
4546
46- def read_as_dask_array (func , elem_name : str , elem , iospec ):
47+ async def read_as_dask_array (func , elem_name : str , elem , iospec ):
4748 if iospec .encoding_type in {
4849 "dataframe" ,
4950 "csr_matrix" ,
5051 "csc_matrix" ,
5152 "awkward-array" ,
5253 }:
5354 # Preventing recursing inside of these types
54- return func (elem )
55+ return await func (elem )
5556 elif iospec .encoding_type == "array" :
5657 return da .from_zarr (elem )
5758 else :
58- return func (elem )
59+ return await func (elem )
5960
6061 adata = gen_adata ((1000 , 100 ))
6162 z = open_write_group (tmp_path )
@@ -64,7 +65,7 @@ def read_as_dask_array(func, elem_name: str, elem, iospec):
6465 if not is_zarr_v2 () and isinstance (z , ZarrGroup ):
6566 z = zarr .open (z .store )
6667
67- dask_adata = read_dispatched (z , read_as_dask_array )
68+ dask_adata = asyncio . run ( read_dispatched (z , read_as_dask_array ) )
6869
6970 assert isinstance (dask_adata .layers ["array" ], da .Array )
7071 assert isinstance (dask_adata .obsm ["array" ], da .Array )
@@ -84,7 +85,11 @@ def test_read_dispatched_null_case(tmp_path: Path):
8485 if not is_zarr_v2 () and isinstance (z , ZarrGroup ):
8586 z = zarr .open (z .store )
8687 expected = ad .io .read_elem (z )
87- actual = read_dispatched (z , lambda _ , __ , x , ** ___ : ad .io .read_elem (x ))
88+
89+ async def callback (_ , __ , x , ** ___ ):
90+ return await ad .io .read_elem_async (x )
91+
92+ actual = asyncio .run (read_dispatched (z , callback ))
8893
8994 assert_equal (expected , actual )
9095
@@ -186,23 +191,23 @@ def zarr_writer(func, store, k, elem, dataset_kwargs, iospec):
186191 )
187192 func (store , k , elem , dataset_kwargs = dataset_kwargs )
188193
189- def h5ad_reader (func , elem_name : str , elem , iospec ):
194+ async def h5ad_reader (func , elem_name : str , elem , iospec ):
190195 h5ad_read_keys .append (elem_name if is_zarr_v2 () else elem_name .strip ("/" ))
191196 return func (elem )
192197
193- def zarr_reader (func , elem_name : str , elem , iospec ):
198+ async def zarr_reader (func , elem_name : str , elem , iospec ):
194199 zarr_read_keys .append (elem_name if is_zarr_v2 () else elem_name .strip ("/" ))
195200 return func (elem )
196201
197202 adata = gen_adata ((50 , 100 ))
198203
199204 with h5py .File (h5ad_path , "w" ) as f :
200205 write_dispatched (f , "/" , adata , callback = h5ad_writer )
201- _ = read_dispatched (f , h5ad_reader )
206+ _ = asyncio . run ( read_dispatched (f , h5ad_reader ) )
202207
203208 f = open_write_group (zarr_path )
204209 write_dispatched (f , "/" , adata , callback = zarr_writer )
205- _ = read_dispatched (f , zarr_reader )
210+ _ = asyncio . run ( read_dispatched (f , zarr_reader ) )
206211
207212 assert sorted (h5ad_read_keys ) == sorted (zarr_read_keys )
208213 assert sorted (h5ad_write_keys ) == sorted (zarr_write_keys )
0 commit comments