-
-
Notifications
You must be signed in to change notification settings - Fork 366
Coalesce and parallelize partial shard reads #3004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1359aec
c726994
44d9ce4
009ce6a
c65cf82
c7ddb0e
501e7a5
12c3308
6322ca6
d9a7842
8469e9c
baf1062
50d8822
78313aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Optimizes reading more than one, but not all, chunks from a shard. Chunks are now read in parallel | ||
and reads of nearby chunks within the same shard are combined to reduce the number of calls to the store. | ||
See :ref:`user-guide-config` for more details. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,11 +38,13 @@ | |
from zarr.core.common import ( | ||
ChunkCoords, | ||
ChunkCoordsLike, | ||
concurrent_map, | ||
parse_enum, | ||
parse_named_configuration, | ||
parse_shapelike, | ||
product, | ||
) | ||
from zarr.core.config import config | ||
from zarr.core.dtype.npy.int import UInt64 | ||
from zarr.core.indexing import ( | ||
BasicIndexer, | ||
|
@@ -198,7 +200,9 @@ async def from_bytes( | |
|
||
@classmethod | ||
def create_empty( | ||
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None | ||
cls, | ||
chunks_per_shard: ChunkCoords, | ||
buffer_prototype: BufferPrototype | None = None, | ||
) -> _ShardReader: | ||
if buffer_prototype is None: | ||
buffer_prototype = default_buffer_prototype() | ||
|
@@ -248,7 +252,9 @@ def merge_with_morton_order( | |
|
||
@classmethod | ||
def create_empty( | ||
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None | ||
cls, | ||
chunks_per_shard: ChunkCoords, | ||
buffer_prototype: BufferPrototype | None = None, | ||
) -> _ShardBuilder: | ||
if buffer_prototype is None: | ||
buffer_prototype = default_buffer_prototype() | ||
|
@@ -329,9 +335,18 @@ async def finalize( | |
return await shard_builder.finalize(index_location, index_encoder) | ||
|
||
|
||
class _ChunkCoordsByteSlice(NamedTuple): | ||
"""Holds a chunk's coordinates and its byte range in a serialized shard.""" | ||
|
||
coords: ChunkCoords | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is |
||
byte_slice: slice | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ShardingCodec( | ||
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin | ||
ArrayBytesCodec, | ||
ArrayBytesCodecPartialDecodeMixin, | ||
ArrayBytesCodecPartialEncodeMixin, | ||
): | ||
"""Sharding codec""" | ||
|
||
|
@@ -510,32 +525,21 @@ async def _decode_partial_single( | |
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks} | ||
|
||
# reading bytes of all requested chunks | ||
shard_dict: ShardMapping = {} | ||
shard_dict_maybe: ShardMapping | None = {} | ||
if self._is_total_shard(all_chunk_coords, chunks_per_shard): | ||
# read entire shard | ||
shard_dict_maybe = await self._load_full_shard_maybe( | ||
byte_getter=byte_getter, | ||
prototype=chunk_spec.prototype, | ||
chunks_per_shard=chunks_per_shard, | ||
byte_getter, chunk_spec.prototype, chunks_per_shard | ||
) | ||
if shard_dict_maybe is None: | ||
return None | ||
shard_dict = shard_dict_maybe | ||
else: | ||
# read some chunks within the shard | ||
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) | ||
if shard_index is None: | ||
return None | ||
shard_dict = {} | ||
for chunk_coords in all_chunk_coords: | ||
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) | ||
if chunk_byte_slice: | ||
chunk_bytes = await byte_getter.get( | ||
prototype=chunk_spec.prototype, | ||
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]), | ||
) | ||
if chunk_bytes: | ||
shard_dict[chunk_coords] = chunk_bytes | ||
shard_dict_maybe = await self._load_partial_shard_maybe( | ||
byte_getter, chunk_spec.prototype, chunks_per_shard, all_chunk_coords | ||
) | ||
|
||
if shard_dict_maybe is None: | ||
return None | ||
shard_dict = shard_dict_maybe | ||
|
||
# decoding chunks and writing them into the output buffer | ||
await self.codec_pipeline.read( | ||
|
@@ -617,7 +621,9 @@ async def _encode_partial_single( | |
|
||
indexer = list( | ||
get_indexer( | ||
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape) | ||
selection, | ||
shape=shard_shape, | ||
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), | ||
) | ||
) | ||
|
||
|
@@ -691,7 +697,8 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int: | |
get_pipeline_class() | ||
.from_codecs(self.index_codecs) | ||
.compute_encoded_size( | ||
16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) | ||
16 * product(chunks_per_shard), | ||
self._get_index_chunk_spec(chunks_per_shard), | ||
) | ||
) | ||
|
||
|
@@ -736,7 +743,8 @@ async def _load_shard_index_maybe( | |
) | ||
else: | ||
index_bytes = await byte_getter.get( | ||
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size) | ||
prototype=numpy_buffer_prototype(), | ||
byte_range=SuffixByteRequest(shard_index_size), | ||
) | ||
if index_bytes is not None: | ||
return await self._decode_shard_index(index_bytes, chunks_per_shard) | ||
|
@@ -750,7 +758,10 @@ async def _load_shard_index( | |
) or _ShardIndex.create_empty(chunks_per_shard) | ||
|
||
async def _load_full_shard_maybe( | ||
self, byte_getter: ByteGetter, prototype: BufferPrototype, chunks_per_shard: ChunkCoords | ||
self, | ||
byte_getter: ByteGetter, | ||
prototype: BufferPrototype, | ||
chunks_per_shard: ChunkCoords, | ||
) -> _ShardReader | None: | ||
shard_bytes = await byte_getter.get(prototype=prototype) | ||
|
||
|
@@ -760,6 +771,115 @@ async def _load_full_shard_maybe( | |
else None | ||
) | ||
|
||
async def _load_partial_shard_maybe( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would happen if we got rid of |
||
self, | ||
byte_getter: ByteGetter, | ||
prototype: BufferPrototype, | ||
chunks_per_shard: ChunkCoords, | ||
all_chunk_coords: set[ChunkCoords], | ||
) -> ShardMapping | None: | ||
""" | ||
Read chunks from `byte_getter` for the case where the read is less than a full shard. | ||
Returns a mapping of chunk coordinates to bytes or None. | ||
""" | ||
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will fetch the shard index every time. Should we push reading the shard index higher up in the stack, and have this function take the content of the index as a parameter? This might be out of scope for this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (we could also build caching into the |
||
if shard_index is None: | ||
return None | ||
Comment on lines
+786
to
+787
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does this mean? when would the shard index be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think your PR is adding these semantics btw, I'm just curious what this code path means |
||
|
||
chunks = [ | ||
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice)) | ||
for chunk_coords in all_chunk_coords | ||
# Drop chunks where index lookup fails | ||
# e.g. empty chunks when write_empty_chunks = False | ||
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords)) | ||
] | ||
|
||
groups = self._coalesce_chunks(chunks) | ||
|
||
shard_dicts = await concurrent_map( | ||
[(group, byte_getter, prototype) for group in groups], | ||
self._get_group_bytes, | ||
config.get("async.concurrency"), | ||
) | ||
|
||
shard_dict: ShardMutableMapping = {} | ||
for d in shard_dicts: | ||
if d is None: | ||
return None | ||
Comment on lines
+806
to
+808
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we return |
||
shard_dict.update(d) | ||
|
||
return shard_dict | ||
|
||
def _coalesce_chunks( | ||
self, | ||
chunks: list[_ChunkCoordsByteSlice], | ||
) -> list[list[_ChunkCoordsByteSlice]]: | ||
""" | ||
Combine chunks from a single shard into groups that should be read together | ||
in a single request to the store. | ||
|
||
Respects the following configuration options: | ||
- `sharding.read.coalesce_max_gap_bytes`: The maximum gap between | ||
chunks to coalesce into a single group. | ||
- `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group. | ||
""" | ||
max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes") | ||
coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes") | ||
Comment on lines
+826
to
+827
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets not fetch parameters from a global mutable variable (the config) inside this function. better to define these as parameters of the function, and pass them in. we should get values out of the config as early as possible in this process, and after that it's just regular function parameters There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could make these keyword-only so that there's flexibility to change the configuration in the future |
||
|
||
sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start) | ||
|
||
if len(sorted_chunks) == 0: | ||
return [] | ||
|
||
groups = [] | ||
current_group = [sorted_chunks[0]] | ||
|
||
for chunk in sorted_chunks[1:]: | ||
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop | ||
size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start | ||
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes: | ||
current_group.append(chunk) | ||
else: | ||
groups.append(current_group) | ||
current_group = [chunk] | ||
|
||
groups.append(current_group) | ||
|
||
return groups | ||
|
||
async def _get_group_bytes( | ||
self, | ||
group: list[_ChunkCoordsByteSlice], | ||
byte_getter: ByteGetter, | ||
prototype: BufferPrototype, | ||
) -> ShardMapping | None: | ||
""" | ||
Reads a possibly coalesced group of one or more chunks from a shard. | ||
Returns a mapping of chunk coordinates to bytes. | ||
""" | ||
# _coalesce_chunks ensures that the group is not empty. | ||
group_start = group[0].byte_slice.start | ||
group_end = group[-1].byte_slice.stop | ||
|
||
# A single call to retrieve the bytes for the entire group. | ||
group_bytes = await byte_getter.get( | ||
prototype=prototype, | ||
byte_range=RangeByteRequest(group_start, group_end), | ||
) | ||
if group_bytes is None: | ||
return None | ||
|
||
# Extract the bytes corresponding to each chunk in group from group_bytes. | ||
shard_dict = {} | ||
for chunk in group: | ||
chunk_slice = slice( | ||
chunk.byte_slice.start - group_start, | ||
chunk.byte_slice.stop - group_start, | ||
) | ||
shard_dict[chunk.coords] = group_bytes[chunk_slice] | ||
|
||
return shard_dict | ||
|
||
def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int: | ||
chunks_per_shard = self._get_chunks_per_shard(shard_spec) | ||
return input_byte_length + self._shard_index_size(chunks_per_shard) | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm leaving these formatting changes in since they were produced by the pre-commit run