Skip to content

Commit 124f183

Browse files
authored
Backport PR #2167 on branch 0.12.x (perf: auto sharding for zarr v3) (#2188)
1 parent cd9d708 commit 124f183

File tree

12 files changed

+188
-34
lines changed

12 files changed

+188
-34
lines changed

docs/release-notes/2167.perf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Enable automatic sharding in zarr v3 via {attr}`anndata.settings.auto_shard_zarr_v3` (via {mod}`zarr`'s own auto sharding mechanism i.e., `shards="auto"`) for all types except {class}`numpy.recarray` {user}`ilan-gold`

docs/tutorials/zarr-v3.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ There are two ways of opening remote `zarr` stores from the `zarr-python` packag
3838
Local data generally poses a different set of challenges.
3939
First, write speeds can be somewhat slow and second, the creation of many small files on a file system can slow down a filesystem.
4040
For the "many small files" problem, `zarr` has introduced {ref}`sharding <zarr:user-guide-sharding>` in the v3 file format.
41-
Sharding requires knowledge of the array element you are writing (such as shape or data type), though, and therefore you will need to use {func}`anndata.experimental.write_dispatched` to use sharding.
41+
We offer {attr}`anndata.settings.auto_shard_zarr_v3` to hook into zarr's ability to automatically compute shards, which is experimental at the moment.
42+
Manual sharding requires knowledge of the array element you are writing (such as shape or data type), though, and therefore you will need to use {func}`anndata.experimental.write_dispatched` to use custom sharding.
4243
For example, you cannot shard a 1D array with `shard` sizes `(256, 256)`.
4344
Here is a short example, although you should tune the sizes to your own use-case and also use the compression that makes the most sense for you:
4445

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ filterwarnings_when_strict = [
164164
"default:Consolidated metadata is:UserWarning",
165165
"default:.*Structured:zarr.core.dtype.common.UnstableSpecificationWarning",
166166
"default:.*FixedLengthUTF32:zarr.core.dtype.common.UnstableSpecificationWarning",
167+
"default:Automatic shard shape inference is experimental",
167168
]
168169
python_files = "test_*.py"
169170
testpaths = [

src/anndata/_io/specs/methods.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ def zarr_v3_compressor_compat(dataset_kwargs) -> dict:
102102
return dataset_kwargs
103103

104104

105+
def zarr_v3_sharding(dataset_kwargs) -> dict:
106+
if "shards" not in dataset_kwargs and ad.settings.auto_shard_zarr_v3:
107+
dataset_kwargs = {**dataset_kwargs, "shards": "auto"}
108+
return dataset_kwargs
109+
110+
105111
def _to_cpu_mem_wrapper(write_func):
106112
"""
107113
Wrapper to bring cupy types into cpu memory before writing.
@@ -432,6 +438,7 @@ def write_basic(
432438
f.create_dataset(k, data=elem, shape=elem.shape, dtype=dtype, **dataset_kwargs)
433439
else:
434440
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
441+
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
435442
f.create_array(k, shape=elem.shape, dtype=dtype, **dataset_kwargs)
436443
# see https://github.com/zarr-developers/zarr-python/discussions/2712
437444
if isinstance(elem, ZarrArray | H5Array):
@@ -511,6 +518,7 @@ def write_basic_dask_dask_dense(
511518
is_h5 = isinstance(f, H5Group)
512519
if not is_h5:
513520
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
521+
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
514522
if is_zarr_v2() or is_h5:
515523
g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
516524
else:
@@ -616,6 +624,7 @@ def write_vlen_string_array_zarr(
616624
filters, fill_value = None, None
617625
if f.metadata.zarr_format == 2:
618626
filters, fill_value = [VLenUTF8()], ""
627+
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
619628
f.create_array(
620629
k,
621630
shape=elem.shape,
@@ -684,6 +693,9 @@ def write_recarray_zarr(
684693
else:
685694
dataset_kwargs = dataset_kwargs.copy()
686695
dataset_kwargs = zarr_v3_compressor_compat(dataset_kwargs)
696+
# https://github.com/zarr-developers/zarr-python/issues/3546
697+
# if "shards" not in dataset_kwargs and ad.settings.auto_shard_zarr_v3:
698+
# dataset_kwargs = {**dataset_kwargs, "shards": "auto"}
687699
f.create_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
688700
f[k][...] = elem
689701

@@ -720,6 +732,7 @@ def write_sparse_compressed(
720732
attr_name, data=attr, shape=attr.shape, dtype=dtype, **dataset_kwargs
721733
)
722734
else:
735+
dataset_kwargs = zarr_v3_sharding(dataset_kwargs)
723736
arr = g.create_array(
724737
attr_name, shape=attr.shape, dtype=dtype, **dataset_kwargs
725738
)

src/anndata/_settings.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
from collections.abc import Callable, Sequence
20-
from typing import Any, TypeGuard
20+
from typing import Any, Self, TypeGuard
2121

2222
T = TypeVar("T")
2323

@@ -55,7 +55,7 @@ class RegisteredOption(NamedTuple, Generic[T]):
5555
option: str
5656
default_value: T
5757
description: str
58-
validate: Callable[[T], None]
58+
validate: Callable[[T, SettingsManager], None]
5959
type: object
6060

6161
describe = describe
@@ -206,7 +206,7 @@ def register(
206206
*,
207207
default_value: T,
208208
description: str,
209-
validate: Callable[[T], None],
209+
validate: Callable[[T, Self], None],
210210
option_type: object | None = None,
211211
get_from_env: Callable[[str, T], T] = lambda x, y: y,
212212
) -> None:
@@ -229,7 +229,7 @@ def register(
229229
Default behavior is to return `default_value` without checking the environment.
230230
"""
231231
try:
232-
validate(default_value)
232+
validate(default_value, self)
233233
except (ValueError, TypeError) as e:
234234
e.add_note(f"for option {option!r}")
235235
raise e
@@ -307,7 +307,7 @@ def __setattr__(self, option: str, val: object) -> None:
307307
)
308308
raise AttributeError(msg)
309309
registered_option = self._registered_options[option]
310-
registered_option.validate(val)
310+
registered_option.validate(val, self)
311311
self._config[option] = val
312312

313313
def __getattr__(self, option: str) -> object:
@@ -364,10 +364,13 @@ def override(self, **overrides):
364364
"""
365365
restore = {a: getattr(self, a) for a in overrides}
366366
try:
367-
for attr, value in overrides.items():
368-
setattr(self, attr, value)
367+
# Preserve order so that settings that depend on each other can be overridden together i.e., always override zarr version before sharding
368+
for k in self._config:
369+
if k in overrides:
370+
setattr(self, k, overrides.get(k))
369371
yield None
370372
finally:
373+
# TODO: does the order need to be preserved when restoring?
371374
for attr, value in restore.items():
372375
setattr(self, attr, value)
373376

@@ -395,7 +398,7 @@ def __doc__(self):
395398

396399

397400
def gen_validator(_type: type[V]) -> Callable[[V], None]:
398-
def validate_type(val: V) -> None:
401+
def validate_type(val: V, settings: SettingsManager) -> None:
399402
if not isinstance(val, _type):
400403
msg = f"{val} not valid {_type}"
401404
raise TypeError(msg)
@@ -434,14 +437,28 @@ def validate_type(val: V) -> None:
434437
)
435438

436439

437-
def validate_zarr_write_format(format: int):
438-
validate_int(format)
440+
def validate_zarr_write_format(format: int, settings: SettingsManager):
441+
validate_int(format, settings)
439442
if format not in {2, 3}:
440443
msg = "non-v2 zarr on-disk format not supported"
441444
raise ValueError(msg)
442445
if format == 3 and is_zarr_v2():
443446
msg = "Cannot write v3 format against v2 package"
444447
raise ValueError(msg)
448+
if format == 2 and getattr(settings, "auto_shard_zarr_v3", False):
449+
msg = "Cannot set `zarr_write_format` to 2 with autosharding on. Please set to `False` `anndata.settings.auto_shard_zarr_v3`"
450+
raise ValueError(msg)
451+
452+
453+
def validate_zarr_sharding(auto_shard: bool, settings: SettingsManager): # noqa: FBT001
454+
validate_bool(auto_shard, settings)
455+
if auto_shard:
456+
if is_zarr_v2():
457+
msg = "Cannot use sharding with `zarr-python<3`. Please upgrade package and set `anndata.settings.zarr_write_format` to 3."
458+
raise ValueError(msg)
459+
if settings.zarr_write_format == 2:
460+
msg = "Cannot shard v2 format data. Please set `anndata.settings.zarr_write_format` to 3."
461+
raise ValueError(msg)
445462

446463

447464
settings.register(
@@ -458,8 +475,8 @@ def validate_zarr_write_format(format: int):
458475
)
459476

460477

461-
def validate_sparse_settings(val: Any) -> None:
462-
validate_bool(val)
478+
def validate_sparse_settings(val: Any, settings: SettingsManager) -> None:
479+
validate_bool(val, settings)
463480

464481

465482
settings.register(
@@ -486,6 +503,14 @@ def validate_sparse_settings(val: Any) -> None:
486503
get_from_env=check_and_get_bool,
487504
)
488505

506+
settings.register(
507+
"auto_shard_zarr_v3",
508+
default_value=False,
509+
description="Whether or not to use zarr's auto computation of sharding for v3. For v2 this setting will be ignored. The setting will apply to all calls to anndata's writing mechanism (write_zarr / write_elem) and will **not** override any user-defined kwargs for shards.",
510+
validate=validate_zarr_sharding,
511+
get_from_env=check_and_get_bool,
512+
)
513+
489514

490515
##################################################################################
491516
##################################################################################

src/anndata/_settings.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from collections.abc import Callable as Callable
22
from collections.abc import Generator, Iterable
33
from contextlib import contextmanager
44
from dataclasses import dataclass
5-
from typing import Literal, TypeVar
5+
from typing import Literal, Self, TypeVar
66

77
_T = TypeVar("_T")
88

@@ -25,7 +25,7 @@ class SettingsManager:
2525
*,
2626
default_value: _T,
2727
description: str,
28-
validate: Callable[[_T], None],
28+
validate: Callable[[_T, Self], None],
2929
option_type: object | None = None,
3030
get_from_env: Callable[[str, _T], _T] = ...,
3131
) -> None: ...
@@ -46,5 +46,6 @@ class _AnnDataSettingsManager(SettingsManager):
4646
use_sparse_array_on_read: bool = False
4747
min_rows_for_chunked_h5_copy: int = 1000
4848
disallow_forward_slash_in_h5ad: bool = False
49+
auto_shard_zarr_v3: bool = False
4950

5051
settings: _AnnDataSettingsManager

src/anndata/tests/helpers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import pandas as pd
1616
import pytest
17+
import zarr
1718
from pandas.api.types import is_numeric_dtype
1819
from scipy import sparse
1920

@@ -34,6 +35,7 @@
3435
XDataArray,
3536
XDataset,
3637
ZarrArray,
38+
ZarrGroup,
3739
is_zarr_v2,
3840
)
3941
from anndata.utils import asarray
@@ -1187,3 +1189,23 @@ def get_multiindex_columns_df(shape: tuple[int, int]) -> pd.DataFrame:
11871189
+ list(itertools.product(["b"], range(shape[1] // 2)))
11881190
),
11891191
)
1192+
1193+
1194+
def visititems_zarr(
1195+
z: ZarrGroup, visitor: Callable[[str, ZarrGroup | zarr.Array], None]
1196+
) -> None:
1197+
for key in z:
1198+
maybe_group = z[key]
1199+
if isinstance(maybe_group, ZarrGroup):
1200+
visititems_zarr(maybe_group, visitor)
1201+
else:
1202+
visitor(key, maybe_group)
1203+
1204+
1205+
def check_all_sharded(g: ZarrGroup):
1206+
def visit(key: str, arr: zarr.Array | zarr.Group):
1207+
# Check for recarray via https://numpy.org/doc/stable/user/basics.rec.html#manipulating-and-displaying-structured-datatypes
1208+
if isinstance(arr, zarr.Array) and arr.shape != () and arr.dtype.names is None:
1209+
assert arr.shards is not None
1210+
1211+
visititems_zarr(g, visitor=visit)

tests/test_concatenate_disk.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import pytest
99
from scipy import sparse
1010

11-
from anndata import AnnData, concat
11+
from anndata import AnnData, concat, settings
1212
from anndata._core import merge
1313
from anndata._core.merge import _resolve_axis
14+
from anndata.compat import is_zarr_v2
1415
from anndata.experimental.merge import as_group, concat_on_disk
1516
from anndata.io import read_elem, write_elem
16-
from anndata.tests.helpers import assert_equal, gen_adata
17+
from anndata.tests.helpers import assert_equal, check_all_sharded, gen_adata
1718
from anndata.utils import asarray
1819

1920
if TYPE_CHECKING:
@@ -230,7 +231,7 @@ def gen_index(n):
230231
X=sparse.csr_matrix((2, 100)),
231232
obs=pd.DataFrame(index=gen_index(2)),
232233
obsm={
233-
"sparse": np.arange(8).reshape(2, 4),
234+
"sparse": sparse.csr_matrix(np.arange(8).reshape(2, 4)),
234235
"dense": np.arange(4, 8).reshape(2, 2),
235236
"df": pd.DataFrame(
236237
{
@@ -253,6 +254,22 @@ def test_concatenate_xxxm(xxxm_adatas, tmp_path, file_format, join_type):
253254
assert_eq_concat_on_disk(xxxm_adatas, tmp_path, file_format, join=join_type)
254255

255256

257+
@pytest.mark.skipif(is_zarr_v2(), reason="auto sharding is allowed only for zarr v3.")
258+
def test_concatenate_zarr_v3_shard(xxxm_adatas, tmp_path):
259+
import zarr
260+
261+
with settings.override(auto_shard_zarr_v3=True, zarr_write_format=3):
262+
assert_eq_concat_on_disk(xxxm_adatas, tmp_path, file_format="zarr")
263+
g = zarr.open(tmp_path)
264+
assert g.metadata.zarr_format == 3
265+
266+
def visit(key: str, arr: zarr.Array | zarr.Group):
267+
if isinstance(arr, zarr.Array) and arr.shape != ():
268+
assert arr.shards is not None
269+
270+
check_all_sharded(g)
271+
272+
256273
def test_output_dir_exists(tmp_path):
257274
in_pth = tmp_path / "in.h5ad"
258275
out_pth = tmp_path / "does_not_exist" / "out.h5ad"

tests/test_dask.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
as_dense_dask_array,
2424
as_sparse_dask_array,
2525
assert_equal,
26+
check_all_sharded,
2627
gen_adata,
2728
)
2829

@@ -109,12 +110,20 @@ def test_dask_write(adata, tmp_path, diskfmt):
109110

110111
@pytest.mark.xdist_group("dask")
111112
@pytest.mark.dask_distributed
113+
@pytest.mark.parametrize(
114+
"auto_shard_zarr_v3",
115+
[pytest.param(True, id="shard"), pytest.param(False, id="no-shard")],
116+
)
112117
def test_dask_distributed_write(
113118
adata: AnnData,
114119
tmp_path: Path,
115120
diskfmt: Literal["h5ad", "zarr"],
116121
local_cluster_addr: str,
122+
*,
123+
auto_shard_zarr_v3: bool,
117124
) -> None:
125+
if auto_shard_zarr_v3 and ad.settings.zarr_write_format == 2:
126+
pytest.skip(reason="Cannot shard v2 data")
118127
import dask.array as da
119128
import dask.distributed as dd
120129
import numpy as np
@@ -128,9 +137,12 @@ def test_dask_distributed_write(
128137
adata.obsm["b"] = da.random.random((M, 10))
129138
adata.varm["a"] = da.random.random((N, 10))
130139
orig = adata
131-
ad.io.write_elem(g, "", orig)
140+
with ad.settings.override(auto_shard_zarr_v3=auto_shard_zarr_v3):
141+
ad.io.write_elem(g, "", orig)
132142
# TODO: See https://github.com/zarr-developers/zarr-python/issues/2716
133143
g = as_group(pth, mode="r")
144+
if auto_shard_zarr_v3:
145+
check_all_sharded(g)
134146
curr = ad.io.read_elem(g)
135147

136148
with pytest.raises(AssertionError):

tests/test_io_dispatched.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
from anndata._io.zarr import open_write_group
1313
from anndata.compat import CSArray, CSMatrix, ZarrGroup, is_zarr_v2
1414
from anndata.experimental import read_dispatched, write_dispatched
15-
from anndata.tests.helpers import GEN_ADATA_NO_XARRAY_ARGS, assert_equal, gen_adata
15+
from anndata.tests.helpers import (
16+
GEN_ADATA_NO_XARRAY_ARGS,
17+
assert_equal,
18+
gen_adata,
19+
visititems_zarr,
20+
)
1621

1722
if TYPE_CHECKING:
18-
from collections.abc import Callable
1923
from pathlib import Path
2024
from typing import Literal
2125

@@ -180,18 +184,7 @@ def check_chunking(k: str, v: ZarrGroup | zarr.Array):
180184
if is_zarr_v2():
181185
z.visititems(check_chunking)
182186
else:
183-
184-
def visititems(
185-
z: ZarrGroup, visitor: Callable[[str, ZarrGroup | zarr.Array], None]
186-
) -> None:
187-
for key in z:
188-
maybe_group = z[key]
189-
if isinstance(maybe_group, ZarrGroup):
190-
visititems(maybe_group, visitor)
191-
else:
192-
visitor(key, maybe_group)
193-
194-
visititems(z, check_chunking)
187+
visititems_zarr(z, check_chunking)
195188

196189

197190
@pytest.mark.zarr_io

0 commit comments

Comments
 (0)