Skip to content

Commit 41bc3b5

Browse files
authored
perf: use name in map_blocks to bypass tokenization (#2121)
1 parent 3c90d3c commit 41bc3b5

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

benchmarks/benchmarks/sparse_dataset.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dask.array.core import Array as DaskArray
88
from scipy import sparse
99

10-
from anndata import AnnData
10+
from anndata import AnnData, concat
1111
from anndata._core.sparse_dataset import sparse_dataset
1212
from anndata._io.specs import write_elem
1313
from anndata.experimental import read_elem_lazy
@@ -77,3 +77,34 @@ def peakmem_getitem_adata(self, *_):
7777
res = self.adata[self.index]
7878
if isinstance(res, DaskArray):
7979
res.compute()
80+
81+
82+
class SparseCSRDask:
83+
filepath = "data.zarr"
84+
85+
def setup_cache(self):
86+
X = sparse.random(
87+
10_000,
88+
10_000,
89+
density=0.01,
90+
format="csr",
91+
random_state=np.random.default_rng(42),
92+
)
93+
g = zarr.group(self.filepath)
94+
write_elem(g, "X", X)
95+
96+
def setup(self):
97+
self.group = zarr.group(self.filepath)
98+
self.adata = AnnData(X=read_elem_lazy(self.group["X"]))
99+
100+
def time_concat(self):
101+
concat([self.adata for i in range(100)])
102+
103+
def peakmem_concat(self):
104+
concat([self.adata for i in range(100)])
105+
106+
def time_read(self):
107+
AnnData(X=read_elem_lazy(self.group["X"]))
108+
109+
def peakmem_read(self):
110+
AnnData(X=read_elem_lazy(self.group["X"]))

docs/release-notes/2121.perf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Use `name` on {func}`dask.array.map_blocks` internally when concatenating {class}`anndata.experimental.backed.Dataset2D` objects whose categoricals/nullable types must be converted to dask arrays {user}`ilan-gold`

src/anndata/_core/merge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
import uuid
78
from collections import OrderedDict
89
from collections.abc import Callable, Mapping, MutableSet
910
from functools import partial, reduce, singledispatch
@@ -1250,6 +1251,7 @@ def get_chunk(block_info=None):
12501251
chunks=chunk_size,
12511252
meta=np.array([], dtype=dtype),
12521253
dtype=dtype,
1254+
name=f"{uuid.uuid4()}/{base_path_or_zarr_group}/{elem_name}-{dtype}",
12531255
)
12541256

12551257
return da.from_array(col.values, chunks=-1) # in-memory

0 commit comments

Comments
 (0)