Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3169.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix race condition when passing array data in ``create_array(data=..)`` for an array that has a set shard size.
65 changes: 55 additions & 10 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,18 @@ def cdata_shape(self) -> ChunkCoords:
"""
return tuple(starmap(ceildiv, zip(self.shape, self.chunks, strict=False)))

@property
def _shard_data_shape(self) -> ChunkCoords:
"""
The shape of the shard grid for this array.

Returns
-------
Tuple[int]
The shape of the chunk grid for this array.
"""
return tuple(starmap(ceildiv, zip(self.shape, self.shards or self.chunks, strict=False)))

@property
def nchunks(self) -> int:
"""
Expand Down Expand Up @@ -1216,7 +1228,11 @@ async def nbytes_stored(self) -> int:
return await self.store_path.store.getsize_prefix(self.store_path.path)

def _iter_chunk_coords(
self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
self,
*,
iter_shards: bool = False,
origin: Sequence[int] | None = None,
selection_shape: Sequence[int] | None = None,
) -> Iterator[ChunkCoords]:
"""
Create an iterator over the coordinates of chunks in chunk grid space. If the `origin`
Expand All @@ -1228,6 +1244,8 @@ def _iter_chunk_coords(

Parameters
----------
iter_shards : bool, default=False
Whether to iterate by shard (if True) or by chunk (if False).
origin : Sequence[int] | None, default=None
The origin of the selection relative to the array's chunk grid.
selection_shape : Sequence[int] | None, default=None
Expand All @@ -1238,7 +1256,11 @@ def _iter_chunk_coords(
chunk_coords: ChunkCoords
The coordinates of each chunk in the selection.
"""
return _iter_grid(self.cdata_shape, origin=origin, selection_shape=selection_shape)
return _iter_grid(
self._shard_data_shape if iter_shards else self.cdata_shape,
origin=origin,
selection_shape=selection_shape,
)

def _iter_chunk_keys(
self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
Expand All @@ -1265,13 +1287,19 @@ def _iter_chunk_keys(
yield self.metadata.encode_chunk_key(k)

def _iter_chunk_regions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like _iter_chunk_regions should only iterate over the regions spanned by each chunk. Otherwise the name doesn't fit. So adding a flag to this function that makes it do something different (iterate over the regions spanned by each shard) seems worse than implementing a new _iter_shard_regions method, that does exactly what its name suggests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my general POV is that several well-defined functions is better than a smaller number of functions that try to do a lot. since this is private API, adding functions is cheap, so lets create new functions instead of adding functionality to existing ones in this case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. "private API" might be the magic word there 😂 There are no other users of _iter_chunk_regions (ignoring the a few unit tests); so- may I directly rename the function to _iter_shard_regions in this case? 😁

self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
self,
*,
iter_shards: bool = False,
origin: Sequence[int] | None = None,
selection_shape: Sequence[int] | None = None,
) -> Iterator[tuple[slice, ...]]:
"""
Iterate over the regions spanned by each chunk.

Parameters
----------
iter_shards : bool, default=False
Whether to iterate by shard (if True) or by chunk (if False).
origin : Sequence[int] | None, default=None
The origin of the selection relative to the array's chunk grid.
selection_shape : Sequence[int] | None, default=None
Expand All @@ -1282,11 +1310,12 @@ def _iter_chunk_regions(
region: tuple[slice, ...]
A tuple of slice objects representing the region spanned by each chunk in the selection.
"""
region_size = (self.shards or self.chunks) if iter_shards else self.chunks
for cgrid_position in self._iter_chunk_coords(
origin=origin, selection_shape=selection_shape
iter_shards=iter_shards, origin=origin, selection_shape=selection_shape
):
out: tuple[slice, ...] = ()
for c_pos, c_shape in zip(cgrid_position, self.chunks, strict=False):
for c_pos, c_shape in zip(cgrid_position, region_size, strict=False):
start = c_pos * c_shape
stop = start + c_shape
out += (slice(start, stop, 1),)
Expand Down Expand Up @@ -2184,6 +2213,13 @@ def cdata_shape(self) -> ChunkCoords:
"""
return tuple(starmap(ceildiv, zip(self.shape, self.chunks, strict=False)))

@property
def _shard_data_shape(self) -> ChunkCoords:
"""
The shape of the shard grid for this array.
"""
return tuple(starmap(ceildiv, zip(self.shape, self.shards or self.chunks, strict=False)))

@property
def nchunks(self) -> int:
"""
Expand Down Expand Up @@ -2271,7 +2307,10 @@ def nbytes_stored(self) -> int:
return sync(self._async_array.nbytes_stored())

def _iter_chunk_keys(
self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
self,
*,
origin: Sequence[int] | None = None,
selection_shape: Sequence[int] | None = None,
) -> Iterator[str]:
"""
Iterate over the storage keys of each chunk, relative to an optional origin, and optionally
Expand All @@ -2294,13 +2333,19 @@ def _iter_chunk_keys(
)

def _iter_chunk_regions(
self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
self,
*,
iter_shards: bool = False,
origin: Sequence[int] | None = None,
selection_shape: Sequence[int] | None = None,
) -> Iterator[tuple[slice, ...]]:
"""
Iterate over the regions spanned by each chunk.

Parameters
----------
iter_shards : bool, default=False
Whether to iterate by shard (if True) or by chunk (if False).
origin : Sequence[int] | None, default=None
The origin of the selection relative to the array's chunk grid.
selection_shape : Sequence[int] | None, default=None
Expand All @@ -2312,7 +2357,7 @@ def _iter_chunk_regions(
A tuple of slice objects representing the region spanned by each chunk in the selection.
"""
yield from self._async_array._iter_chunk_regions(
origin=origin, selection_shape=selection_shape
iter_shards=iter_shards, origin=origin, selection_shape=selection_shape
)

def __array__(
Expand Down Expand Up @@ -4100,7 +4145,7 @@ async def _copy_array_region(chunk_coords: ChunkCoords | slice, _data: Array) ->

# Stream data from the source array to the new array
await concurrent_map(
[(region, data) for region in result._iter_chunk_regions()],
[(region, data) for region in result._iter_chunk_regions(iter_shards=True)],
_copy_array_region,
zarr.core.config.config.get("async.concurrency"),
)
Expand All @@ -4111,7 +4156,7 @@ async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> Non

# Stream data from the source array to the new array
await concurrent_map(
[(region, data) for region in result._iter_chunk_regions()],
[(region, data) for region in result._iter_chunk_regions(iter_shards=True)],
_copy_arraylike_region,
zarr.core.config.config.get("async.concurrency"),
)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,30 @@ async def test_from_array(
assert result.chunks == new_chunks


@pytest.mark.parametrize("store", ["memory"], indirect=True)
@pytest.mark.parametrize("chunks", [(10, 10)])
@pytest.mark.parametrize("shards", [(60, 60)])
async def test_from_array_shards(
store: Store,
zarr_format: ZarrFormat,
chunks: tuple[int, ...],
shards: tuple[int, ...],
) -> None:
# Regression test for https://github.com/zarr-developers/zarr-python/issues/3169
source_data = np.arange(3600).reshape((60, 60))

zarr.create_array(
store=store,
data=source_data,
chunks=chunks,
shards=shards,
)

array = zarr.open_array(store=store)

assert np.array_equal(array[:], source_data)


@pytest.mark.parametrize("store", ["local"], indirect=True)
@pytest.mark.parametrize("chunks", ["keep", "auto"])
@pytest.mark.parametrize("write_data", [True, False])
Expand Down