Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,12 @@ def __len__(self) -> int:
def __repr__(self) -> str:
return f"<NDBuffer shape={self.shape} dtype={self.dtype} {self._data!r}>"

def all_equal(self, other: Any) -> bool:
return bool((self._data == other).all())
def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
"""Compare to `other` using np.array_equal."""
# use array_equal to obtain equal_nan=True functionality
data, other = np.broadcast_arrays(self._data, other)
result = np.array_equal(self._data, other, equal_nan=equal_nan)
return result

def fill(self, value: Any) -> None:
self._data.fill(value)
Expand Down
19 changes: 19 additions & 0 deletions tests/v3/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,25 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str
assert arr.fill_value.dtype == arr.dtype


@pytest.mark.parametrize("store", ["memory"], indirect=True)
async def test_array_v3_nan_fill_value(store: MemoryStore) -> None:
shape = (10,)
arr = Array.create(
store=store,
shape=shape,
dtype=np.float64,
zarr_format=3,
chunk_shape=shape,
fill_value=np.nan,
)
arr[:] = np.nan

assert np.isnan(arr.fill_value)
assert arr.fill_value.dtype == arr.dtype
# all fill value chunk is an empty chunk, and should not be written
assert not [a async for a in store.list_prefix("/")]


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
async def test_serializable_async_array(
Expand Down
10 changes: 0 additions & 10 deletions tests/v3/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ def test_roundtrip(data: st.DataObject) -> None:


@given(data=st.data())
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
# Uncomment the next line to reproduce the original failure.
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/ndR2z7nkDZEDADWpBL4=')
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_basic_indexing(data: st.DataObject) -> None:
zarray = data.draw(arrays())
nparray = zarray[:]
Expand All @@ -37,11 +32,6 @@ def test_basic_indexing(data: st.DataObject) -> None:


@given(data=st.data())
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
# Uncomment the next line to reproduce the original failure.
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/eLmF7qr/C5EDADZUBRM=')
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_vindex(data: st.DataObject) -> None:
zarray = data.draw(arrays())
nparray = zarray[:]
Expand Down