diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 06a3c2cb22d..78ef2875b31 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,6 +13,8 @@ v2025.07.2 (unreleased) New Features ~~~~~~~~~~~~ +- :py:meth:`DataTree.to_netcdf` can now write to a file-like object, or return bytes if called without a filepath. (:issue:`10570`) + By `Matthew Willson `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 328b7568cdd..2a6476ea828 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -11,7 +11,7 @@ Sequence, ) from functools import partial -from io import BytesIO +from io import IOBase from itertools import starmap from numbers import Number from typing import ( @@ -31,6 +31,8 @@ from xarray.backends.common import ( AbstractDataStore, ArrayWriter, + BytesIOProxy, + T_PathFileOrDataStore, _find_absolute_paths, _normalize_path, ) @@ -503,7 +505,7 @@ def _datatree_from_backend_datatree( def open_dataset( - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, engine: T_Engine = None, chunks: T_Chunks = None, @@ -533,12 +535,13 @@ def open_dataset( Parameters ---------- - filename_or_obj : str, Path, file-like or DataStore + filename_or_obj : str, Path, file-like, bytes, memoryview or DataStore Strings and Path objects are interpreted as a path to a netCDF file or an OpenDAP URL and opened with python-netCDF4, unless the filename ends with .gz, in which case the file is gunzipped and opened with - scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like - objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). + scipy.io.netcdf (only netCDF3 supported). Bytes, memoryview and + file-like objects are opened by scipy.io.netcdf (netCDF3) or h5netcdf + (netCDF4). engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\ , installed backend \ or subclass of xarray.backends.BackendEntrypoint, optional @@ -743,7 +746,7 @@ def open_dataset( def open_dataarray( - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, engine: T_Engine = None, chunks: T_Chunks = None, @@ -774,12 +777,13 @@ def open_dataarray( Parameters ---------- - filename_or_obj : str, Path, file-like or DataStore + filename_or_obj : str, Path, file-like, bytes, memoryview or DataStore Strings and Path objects are interpreted as a path to a netCDF file or an OpenDAP URL and opened with python-netCDF4, unless the filename ends with .gz, in which case the file is gunzipped and opened with - scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like - objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). + scipy.io.netcdf (only netCDF3 supported). Bytes, memoryview and + file-like objects are opened by scipy.io.netcdf (netCDF3) or h5netcdf + (netCDF4). engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\ , installed backend \ or subclass of xarray.backends.BackendEntrypoint, optional @@ -970,7 +974,7 @@ def open_dataarray( def open_datatree( - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, engine: T_Engine = None, chunks: T_Chunks = None, @@ -1001,8 +1005,10 @@ def open_datatree( Parameters ---------- - filename_or_obj : str, Path, file-like, or DataStore - Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. + filename_or_obj : str, Path, file-like, bytes or DataStore + Strings and Path objects are interpreted as a path to a netCDF file or + Zarr store. Bytes and memoryview objects are interpreted as file + contents. engine : {"netcdf4", "h5netcdf", "zarr", None}, \ installed backend or xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine @@ -1208,7 +1214,7 @@ def open_datatree( def open_groups( - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, engine: T_Engine = None, chunks: T_Chunks = None, @@ -1243,8 +1249,10 @@ def open_groups( Parameters ---------- - filename_or_obj : str, Path, file-like, or DataStore - Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. + filename_or_obj : str, Path, file-like, bytes, memoryview or DataStore + Strings and Path objects are interpreted as a path to a netCDF file or + Zarr store. Bytes and memoryview objects are interpreted as file + contents. engine : {"netcdf4", "h5netcdf", "zarr", None}, \ installed backend or xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine @@ -1780,7 +1788,7 @@ def to_netcdf( ) -> tuple[ArrayWriter, AbstractDataStore]: ... -# path=None writes to bytes +# path=None writes to bytes or memoryview, depending on store @overload def to_netcdf( dataset: Dataset, @@ -1795,7 +1803,7 @@ def to_netcdf( multifile: Literal[False] = False, invalid_netcdf: bool = False, auto_complex: bool | None = None, -) -> bytes: ... +) -> bytes | memoryview: ... # compute=False returns dask.Delayed @@ -1821,7 +1829,7 @@ def to_netcdf( @overload def to_netcdf( dataset: Dataset, - path_or_file: str | os.PathLike, + path_or_file: str | os.PathLike | IOBase, mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, @@ -1877,7 +1885,7 @@ def to_netcdf( @overload def to_netcdf( dataset: Dataset, - path_or_file: str | os.PathLike | None, + path_or_file: str | os.PathLike | IOBase | None, mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, @@ -1888,12 +1896,12 @@ def to_netcdf( multifile: bool = False, invalid_netcdf: bool = False, auto_complex: bool | None = None, -) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: ... +) -> tuple[ArrayWriter, AbstractDataStore] | bytes | memoryview | Delayed | None: ... def to_netcdf( dataset: Dataset, - path_or_file: str | os.PathLike | None = None, + path_or_file: str | os.PathLike | IOBase | None = None, mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, @@ -1904,7 +1912,7 @@ def to_netcdf( multifile: bool = False, invalid_netcdf: bool = False, auto_complex: bool | None = None, -) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: +) -> tuple[ArrayWriter, AbstractDataStore] | bytes | memoryview | Delayed | None: """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file @@ -1918,26 +1926,27 @@ def to_netcdf( if encoding is None: encoding = {} - if path_or_file is None: + if isinstance(path_or_file, str): + if engine is None: + engine = _get_default_engine(path_or_file) + path_or_file = _normalize_path(path_or_file) + else: + # writing to bytes/memoryview or a file-like object if engine is None: + # TODO: only use 'scipy' if format is None or a netCDF3 format engine = "scipy" - elif engine != "scipy": + elif engine not in ("scipy", "h5netcdf"): raise ValueError( - "invalid engine for creating bytes with " - f"to_netcdf: {engine!r}. Only the default engine " - "or engine='scipy' is supported" + "invalid engine for creating bytes/memoryview or writing to a " + f"file-like object with to_netcdf: {engine!r}. Only " + "engine=None, engine='scipy' and engine='h5netcdf' is " + "supported." ) if not compute: raise NotImplementedError( "to_netcdf() with compute=False is not yet implemented when " "returning bytes" ) - elif isinstance(path_or_file, str): - if engine is None: - engine = _get_default_engine(path_or_file) - path_or_file = _normalize_path(path_or_file) - else: # file-like object - engine = "scipy" # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) @@ -1962,7 +1971,11 @@ def to_netcdf( f"is not currently supported with dask's {scheduler} scheduler" ) - target = path_or_file if path_or_file is not None else BytesIO() + if path_or_file is None: + target = BytesIOProxy() + else: + target = path_or_file # type: ignore[assignment] + kwargs = dict(autoclose=True) if autoclose else {} if invalid_netcdf: if engine == "h5netcdf": @@ -2002,17 +2015,19 @@ def to_netcdf( writes = writer.sync(compute=compute) - if isinstance(target, BytesIO): - store.sync() - return target.getvalue() finally: if not multifile and compute: # type: ignore[redundant-expr] store.close() + if path_or_file is None: + assert isinstance(target, BytesIOProxy) # created in this function + return target.getvalue_or_getbuffer() + if not compute: import dask return dask.delayed(_finalize_store)(writes, store) + return None diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 8b56c8a2bf9..542ca4c897b 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -4,9 +4,18 @@ import os import time import traceback -from collections.abc import Hashable, Iterable, Mapping, Sequence +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence +from dataclasses import dataclass from glob import glob -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + TypeVar, + Union, + overload, +) import numpy as np import pandas as pd @@ -188,6 +197,24 @@ def _normalize_path_list( return _normalize_path_list(paths) +BytesOrMemory = TypeVar("BytesOrMemory", bytes, memoryview) + + +@dataclass +class BytesIOProxy(Generic[BytesOrMemory]): + """Proxy object for a write that returns either bytes or a memoryview.""" + + # TODO: remove this in favor of BytesIO when Dataset.to_netcdf() stops + # returning bytes from the scipy engine + getvalue: Callable[[], BytesOrMemory] | None = None + + def getvalue_or_getbuffer(self) -> BytesOrMemory: + """Get the value of this write as bytes or memory.""" + if self.getvalue is None: + raise ValueError("must set getvalue before fetching value") + return self.getvalue() + + def _open_remote_file(file, mode, storage_options=None): import fsspec @@ -324,6 +351,11 @@ def __exit__(self, exception_type, exception_value, traceback): self.close() +T_PathFileOrDataStore = ( + str | os.PathLike[Any] | ReadBuffer | bytes | memoryview | AbstractDataStore +) + + class ArrayWriter: __slots__ = ("lock", "regions", "sources", "targets") @@ -705,7 +737,12 @@ def __repr__(self) -> str: def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: str + | os.PathLike[Any] + | ReadBuffer + | bytes + | memoryview + | AbstractDataStore, *, drop_variables: str | Iterable[str] | None = None, ) -> Dataset: @@ -717,7 +754,12 @@ def open_dataset( def guess_can_open( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: str + | os.PathLike[Any] + | ReadBuffer + | bytes + | memoryview + | AbstractDataStore, ) -> bool: """ Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. @@ -727,7 +769,12 @@ def guess_can_open( def open_datatree( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: str + | os.PathLike[Any] + | ReadBuffer + | bytes + | memoryview + | AbstractDataStore, *, drop_variables: str | Iterable[str] | None = None, ) -> DataTree: @@ -739,7 +786,12 @@ def open_datatree( def open_groups_as_dict( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: str + | os.PathLike[Any] + | ReadBuffer + | bytes + | memoryview + | AbstractDataStore, *, drop_variables: str | Iterable[str] | None = None, ) -> dict[str, Dataset]: diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 77c6859650f..2a6f3691faf 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -339,8 +339,11 @@ def __hash__(self): class DummyFileManager(FileManager): """FileManager that simply wraps an open file in the FileManager interface.""" - def __init__(self, value): + def __init__(self, value, *, close=None): + if close is None: + close = value.close self._value = value + self._close = close def acquire(self, needs_lock=True): del needs_lock # ignored @@ -353,4 +356,4 @@ def acquire_context(self, needs_lock=True): def close(self, needs_lock=True): del needs_lock # ignored - self._value.close() + self._close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index f3e434c6e5e..24a3324bf62 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -11,13 +11,19 @@ from xarray.backends.common import ( BACKEND_ENTRYPOINTS, BackendEntrypoint, + BytesIOProxy, + T_PathFileOrDataStore, WritableCFDataStore, _normalize_path, _open_remote_file, datatree_from_dict_with_io_cleanup, find_root_and_group, ) -from xarray.backends.file_manager import CachingFileManager, DummyFileManager +from xarray.backends.file_manager import ( + CachingFileManager, + DummyFileManager, + FileManager, +) from xarray.backends.locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from xarray.backends.netCDF4_ import ( BaseNetCDF4Array, @@ -40,6 +46,8 @@ from xarray.core.variable import Variable if TYPE_CHECKING: + import h5netcdf + from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree @@ -109,7 +117,14 @@ class H5NetCDFStore(WritableCFDataStore): "lock", ) - def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False): + def __init__( + self, + manager: FileManager | h5netcdf.File | h5netcdf.Group, + group=None, + mode=None, + lock=HDF5_LOCK, + autoclose=False, + ): import h5netcdf if isinstance(manager, h5netcdf.File | h5netcdf.Group): @@ -158,12 +173,12 @@ def open( filename, mode=mode_, storage_options=storage_options ) - if isinstance(filename, bytes): - raise ValueError( - "can't open netCDF4/HDF5 as bytes " - "try passing a path or file-like object" - ) - elif isinstance(filename, io.IOBase): + if isinstance(filename, BytesIOProxy): + source = filename + filename = io.BytesIO() + source.getvalue = filename.getbuffer + + if isinstance(filename, io.IOBase) and mode == "r": magic_number = read_magic_number_from_file(filename) if not magic_number.startswith(b"\211HDF\r\n\032\n"): raise ValueError( @@ -189,7 +204,11 @@ def open( else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) - manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) + manager = ( + CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) + if isinstance(filename, str) + else h5netcdf.File(filename, mode=mode, **kwargs) + ) return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) def _acquire(self, needs_lock=True): @@ -388,6 +407,15 @@ def _emit_phony_dims_warning(): ) +def _normalize_filename_or_obj( + filename_or_obj: T_PathFileOrDataStore, +) -> str | ReadBuffer | AbstractDataStore: + if isinstance(filename_or_obj, bytes | memoryview): + return io.BytesIO(filename_or_obj) + else: + return _normalize_path(filename_or_obj) + + class H5netcdfBackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the h5netcdf package. @@ -415,10 +443,8 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint): ) url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.H5netcdfBackendEntrypoint.html" - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, - ) -> bool: + def guess_can_open(self, filename_or_obj: T_PathFileOrDataStore) -> bool: + filename_or_obj = _normalize_filename_or_obj(filename_or_obj) magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) if magic_number is not None: return magic_number.startswith(b"\211HDF\r\n\032\n") @@ -431,7 +457,7 @@ def guess_can_open( def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, @@ -454,7 +480,7 @@ def open_dataset( # remove and set phony_dims="access" above emit_phony_dims_warning, phony_dims = _check_phony_dims(phony_dims) - filename_or_obj = _normalize_path(filename_or_obj) + filename_or_obj = _normalize_filename_or_obj(filename_or_obj) store = H5NetCDFStore.open( filename_or_obj, format=format, @@ -491,7 +517,7 @@ def open_dataset( def open_datatree( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, @@ -534,7 +560,7 @@ def open_datatree( def open_groups_as_dict( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, @@ -561,7 +587,7 @@ def open_groups_as_dict( # remove and set phony_dims="access" above emit_phony_dims_warning, phony_dims = _check_phony_dims(phony_dims) - filename_or_obj = _normalize_path(filename_or_obj) + filename_or_obj = _normalize_filename_or_obj(filename_or_obj) store = H5NetCDFStore.open( filename_or_obj, format=format, diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 8c3a01eba66..ab1841461f4 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -13,6 +13,7 @@ BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, + T_PathFileOrDataStore, WritableCFDataStore, _normalize_path, datatree_from_dict_with_io_cleanup, @@ -49,10 +50,8 @@ from h5netcdf.core import EnumType as h5EnumType from netCDF4 import EnumType as ncEnumType - from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree - from xarray.core.types import ReadBuffer # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. @@ -629,10 +628,7 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): ) url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html" - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, - ) -> bool: + def guess_can_open(self, filename_or_obj: T_PathFileOrDataStore) -> bool: if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj): return True magic_number = try_read_magic_number_from_path(filename_or_obj) @@ -648,7 +644,7 @@ def guess_can_open( def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, @@ -697,7 +693,7 @@ def open_dataset( def open_datatree( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, @@ -739,7 +735,7 @@ def open_datatree( def open_groups_as_dict( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 555538c2562..76df963621e 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -138,7 +138,12 @@ def refresh_engines() -> None: def guess_engine( - store_spec: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + store_spec: str + | os.PathLike[Any] + | ReadBuffer + | bytes + | memoryview + | AbstractDataStore, ) -> str | type[BackendEntrypoint]: engines = list_engines() diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 73b719f8260..4fbfe8ee210 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -10,6 +10,7 @@ AbstractDataStore, BackendArray, BackendEntrypoint, + T_PathFileOrDataStore, _normalize_path, datatree_from_dict_with_io_cleanup, robust_getitem, @@ -207,15 +208,14 @@ class PydapBackendEntrypoint(BackendEntrypoint): description = "Open remote datasets via OPeNDAP using pydap in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.PydapBackendEntrypoint.html" - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, - ) -> bool: + def guess_can_open(self, filename_or_obj: T_PathFileOrDataStore) -> bool: return isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj) def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: ( + str | os.PathLike[Any] | ReadBuffer | bytes | memoryview | AbstractDataStore + ), *, mask_and_scale=True, decode_times=True, @@ -258,7 +258,7 @@ def open_dataset( def open_datatree( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, @@ -295,7 +295,7 @@ def open_datatree( def open_groups_as_dict( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index b98d226cac6..a93c6465d49 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -12,6 +12,8 @@ BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, + BytesIOProxy, + T_PathFileOrDataStore, WritableCFDataStore, _normalize_path, ) @@ -28,12 +30,15 @@ Frozen, FrozenDict, close_on_error, + emit_user_level_warning, module_available, try_read_magic_number_from_file_or_path, ) from xarray.core.variable import Variable if TYPE_CHECKING: + import scipy.io + from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset from xarray.core.types import ReadBuffer @@ -119,10 +124,6 @@ def _open_scipy_netcdf(filename, mode, mmap, version): else: raise - if isinstance(filename, bytes) and filename.startswith(b"CDF"): - # it's a NetCDF3 bytestring - filename = io.BytesIO(filename) - try: return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, version=version) except TypeError as e: # netcdf3 message is obscure in this case @@ -141,7 +142,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): class ScipyDataStore(WritableCFDataStore): - """Store for reading and writing data via scipy.io.netcdf. + """Store for reading and writing data via scipy.io.netcdf_file. This store has the advantage of being able to be initialized with a StringIO object, allow for serialization without writing to disk. @@ -167,7 +168,23 @@ def __init__( self.lock = ensure_lock(lock) - if isinstance(filename_or_obj, str): + if isinstance(filename_or_obj, BytesIOProxy): + emit_user_level_warning( + "return value of to_netcdf() without a target for " + "engine='scipy' is currently bytes, but will switch to " + "memoryview in a future version of Xarray. To silence this " + "warning, use the following pattern or switch to " + "to_netcdf(engine='h5netcdf'):\n" + " target = io.BytesIO()\n" + " dataset.to_netcdf(target)\n" + " result = target.getbuffer()", + FutureWarning, + ) + source = filename_or_obj + filename_or_obj = io.BytesIO() + source.getvalue = filename_or_obj.getvalue + + if isinstance(filename_or_obj, str): # path manager = CachingFileManager( _open_scipy_netcdf, filename_or_obj, @@ -175,16 +192,32 @@ def __init__( lock=lock, kwargs=dict(mmap=mmap, version=version), ) - else: + elif hasattr(filename_or_obj, "seek"): # file object + # Note: checking for .seek matches the check for file objects + # in scipy.io.netcdf_file scipy_dataset = _open_scipy_netcdf( filename_or_obj, mode=mode, mmap=mmap, version=version ) - manager = DummyFileManager(scipy_dataset) + # scipy.io.netcdf_file.close() incorrectly closes file objects that + # were passed in as constructor arguments: + # https://github.com/scipy/scipy/issues/13905 + # Instead of closing such files, only call flush(), which is + # equivalent as long as the netcdf_file object is not mmapped. + # This suffices to keep BytesIO objects open long enough to read + # their contents from to_netcdf(), but underlying files still get + # closed when the netcdf_file is garbage collected (via __del__), + # and will need to be fixed upstream in scipy. + assert not scipy_dataset.use_mmap # no mmap for file objects + manager = DummyFileManager(scipy_dataset, close=scipy_dataset.flush) + else: + raise ValueError( + f"cannot open {filename_or_obj=} with scipy.io.netcdf_file" + ) self._manager = manager @property - def ds(self): + def ds(self) -> scipy.io.netcdf_file: return self._manager.acquire() def open_store_variable(self, name, var): @@ -265,6 +298,20 @@ def close(self): self._manager.close() +def _normalize_filename_or_obj( + filename_or_obj: str + | os.PathLike[Any] + | ReadBuffer + | bytes + | memoryview + | AbstractDataStore, +) -> str | ReadBuffer | AbstractDataStore: + if isinstance(filename_or_obj, bytes | memoryview): + return io.BytesIO(filename_or_obj) + else: + return _normalize_path(filename_or_obj) # type: ignore[return-value] + + class ScipyBackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the scipy package. @@ -291,8 +338,9 @@ class ScipyBackendEntrypoint(BackendEntrypoint): def guess_can_open( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, ) -> bool: + filename_or_obj = _normalize_filename_or_obj(filename_or_obj) magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) if magic_number is not None and magic_number.startswith(b"\x1f\x8b"): with gzip.open(filename_or_obj) as f: # type: ignore[arg-type] @@ -308,7 +356,7 @@ def guess_can_open( def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, @@ -323,7 +371,7 @@ def open_dataset( mmap=None, lock=None, ) -> Dataset: - filename_or_obj = _normalize_path(filename_or_obj) + filename_or_obj = _normalize_filename_or_obj(filename_or_obj) store = ScipyDataStore( filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock ) diff --git a/xarray/backends/store.py b/xarray/backends/store.py index de52aa193ed..2c3cd42ae92 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -1,36 +1,32 @@ from __future__ import annotations from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from xarray import conventions from xarray.backends.common import ( BACKEND_ENTRYPOINTS, AbstractDataStore, BackendEntrypoint, + T_PathFileOrDataStore, ) from xarray.core.coordinates import Coordinates from xarray.core.dataset import Dataset if TYPE_CHECKING: - import os - - from xarray.core.types import ReadBuffer + pass class StoreBackendEntrypoint(BackendEntrypoint): description = "Open AbstractDataStore instances in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html" - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, - ) -> bool: + def guess_can_open(self, filename_or_obj: T_PathFileOrDataStore) -> bool: return isinstance(filename_or_obj, AbstractDataStore) def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 8b26a6b40ec..1b62a87d10c 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -17,6 +17,7 @@ AbstractWritableDataStore, BackendArray, BackendEntrypoint, + T_PathFileOrDataStore, _encode_variable_name, _normalize_path, datatree_from_dict_with_io_cleanup, @@ -39,10 +40,9 @@ from xarray.namedarray.utils import module_available if TYPE_CHECKING: - from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree - from xarray.core.types import ReadBuffer, ZarrArray, ZarrGroup + from xarray.core.types import ZarrArray, ZarrGroup def _get_mappers(*, storage_options, store, chunk_store): @@ -1548,10 +1548,7 @@ class ZarrBackendEntrypoint(BackendEntrypoint): description = "Open zarr files (.zarr) using zarr in Xarray" url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html" - def guess_can_open( - self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, - ) -> bool: + def guess_can_open(self, filename_or_obj: T_PathFileOrDataStore) -> bool: if isinstance(filename_or_obj, str | os.PathLike): _, ext = os.path.splitext(filename_or_obj) return ext == ".zarr" @@ -1560,7 +1557,7 @@ def guess_can_open( def open_dataset( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, @@ -1615,7 +1612,7 @@ def open_dataset( def open_datatree( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, @@ -1657,7 +1654,7 @@ def open_datatree( def open_groups_as_dict( self, - filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + filename_or_obj: T_PathFileOrDataStore, *, mask_and_scale=True, decode_times=True, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 73b0eb19a64..98979ce05d7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4015,7 +4015,7 @@ def to_netcdf( compute: bool = True, invalid_netcdf: bool = False, auto_complex: bool | None = None, - ) -> bytes: ... + ) -> bytes | memoryview: ... # compute=False returns dask.Delayed @overload @@ -4079,7 +4079,7 @@ def to_netcdf( compute: bool = True, invalid_netcdf: bool = False, auto_complex: bool | None = None, - ) -> bytes | Delayed | None: + ) -> bytes | memoryview | Delayed | None: """Write DataArray contents to a netCDF file. Parameters @@ -4149,8 +4149,7 @@ def to_netcdf( Returns ------- - store: bytes or Delayed or None - * ``bytes`` if path is None + * ``bytes`` or ``memoryview`` if path is None * ``dask.delayed.Delayed`` if compute is False * None otherwise diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index acc7d1f17f6..0b1d9835cf5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2,6 +2,7 @@ import copy import datetime +import io import math import sys import warnings @@ -1884,7 +1885,7 @@ def to_netcdf( compute: bool = True, invalid_netcdf: bool = False, auto_complex: bool | None = None, - ) -> bytes: ... + ) -> bytes | memoryview: ... # compute=False returns dask.Delayed @overload @@ -1907,7 +1908,7 @@ def to_netcdf( @overload def to_netcdf( self, - path: str | PathLike, + path: str | PathLike | io.IOBase, mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, @@ -1938,7 +1939,7 @@ def to_netcdf( def to_netcdf( self, - path: str | PathLike | None = None, + path: str | PathLike | io.IOBase | None = None, mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, @@ -1948,7 +1949,7 @@ def to_netcdf( compute: bool = True, invalid_netcdf: bool = False, auto_complex: bool | None = None, - ) -> bytes | Delayed | None: + ) -> bytes | memoryview | Delayed | None: """Write dataset contents to a netCDF file. Parameters @@ -2020,9 +2021,9 @@ def to_netcdf( Returns ------- - * ``bytes`` if path is None + * ``bytes`` or ``memoryview`` if path is None * ``dask.delayed.Delayed`` if compute is False - * None otherwise + * ``None`` otherwise See Also -------- diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index afef2f20094..bf82baccb31 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import io import itertools import textwrap from collections import ChainMap @@ -12,6 +13,7 @@ Mapping, ) from html import escape +from os import PathLike from typing import ( TYPE_CHECKING, Any, @@ -1659,9 +1661,11 @@ def _inplace_binary_op(self, other, f) -> Self: def __eq__(self, other: DtCompatible) -> Self: # type: ignore[override] return super().__eq__(other) + # filepath=None writes to a memoryview + @overload def to_netcdf( self, - filepath, + filepath: None = None, mode: NetcdfWriteModes = "w", encoding=None, unlimited_dims=None, @@ -1671,14 +1675,45 @@ def to_netcdf( write_inherited_coords: bool = False, compute: bool = True, **kwargs, - ): + ) -> memoryview: ... + + @overload + def to_netcdf( + self, + filepath: str | PathLike | io.IOBase, + mode: NetcdfWriteModes = "w", + encoding=None, + unlimited_dims=None, + format: T_DataTreeNetcdfTypes | None = None, + engine: T_DataTreeNetcdfEngine | None = None, + group: str | None = None, + write_inherited_coords: bool = False, + compute: bool = True, + **kwargs, + ) -> None: ... + + def to_netcdf( + self, + filepath: str | PathLike | io.IOBase | None = None, + mode: NetcdfWriteModes = "w", + encoding=None, + unlimited_dims=None, + format: T_DataTreeNetcdfTypes | None = None, + engine: T_DataTreeNetcdfEngine | None = None, + group: str | None = None, + write_inherited_coords: bool = False, + compute: bool = True, + **kwargs, + ) -> None | memoryview: """ Write datatree contents to a netCDF file. Parameters ---------- - filepath : str or Path - Path to which to save this datatree. + filepath : str or PathLike or file-like object or None + Path to which to save this datatree, or a file-like object to write + it to (which must support read and write and be seekable) or None + to return in-memory bytes as a memoryview. mode : {"w", "a"}, default: "w" Write ('w') or append ('a') mode. If mode='w', any existing file at this location will be overwritten. If mode='a', existing variables @@ -1717,6 +1752,11 @@ def to_netcdf( kwargs : Additional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` + Returns + ------- + * ``memoryview`` if path is None + * ``None`` otherwise + Note ---- Due to file format specifications the on-disk root group name @@ -1724,7 +1764,7 @@ def to_netcdf( """ from xarray.core.datatree_io import _datatree_to_netcdf - _datatree_to_netcdf( + return _datatree_to_netcdf( self, filepath, mode=mode, diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index cf3626dbb12..c586caaba89 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from collections.abc import Mapping from os import PathLike from typing import TYPE_CHECKING, Any, Literal, get_args @@ -16,7 +17,7 @@ def _datatree_to_netcdf( dt: DataTree, - filepath: str | PathLike, + filepath: str | PathLike | io.IOBase | None = None, mode: NetcdfWriteModes = "w", encoding: Mapping[str, Any] | None = None, unlimited_dims: Mapping | None = None, @@ -26,18 +27,19 @@ def _datatree_to_netcdf( write_inherited_coords: bool = False, compute: bool = True, **kwargs, -) -> None: - """This function creates an appropriate datastore for writing a datatree to - disk as a netCDF file. - - See `DataTree.to_netcdf` for full API docs. - """ +) -> None | memoryview: + """Implementation of `DataTree.to_netcdf`.""" if format not in [None, *get_args(T_DataTreeNetcdfTypes)]: - raise ValueError("to_netcdf only supports the NETCDF4 format") + raise ValueError("DataTree.to_netcdf only supports the NETCDF4 format") if engine not in [None, *get_args(T_DataTreeNetcdfEngine)]: - raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines") + raise ValueError( + "DataTree.to_netcdf only supports the netcdf4 and h5netcdf engines" + ) + + if engine is None: + engine = "h5netcdf" if group is not None: raise NotImplementedError( @@ -58,6 +60,13 @@ def _datatree_to_netcdf( f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" ) + if filepath is None: + # No need to use BytesIOProxy here because the legacy scipy backend + # cannot write netCDF files with groups + target = io.BytesIO() + else: + target = filepath # type: ignore[assignment] + if unlimited_dims is None: unlimited_dims = {} @@ -66,7 +75,7 @@ def _datatree_to_netcdf( ds = node.to_dataset(inherit=write_inherited_coords or at_root) group_path = None if at_root else "/" + node.relative_to(dt) ds.to_netcdf( - filepath, + target, group=group_path, mode=mode, encoding=encoding.get(node.path), @@ -78,6 +87,12 @@ def _datatree_to_netcdf( ) mode = "a" + if filepath is None: + assert isinstance(target, io.BytesIO) + return target.getbuffer() + + return None + def _datatree_to_zarr( dt: DataTree, @@ -90,11 +105,7 @@ def _datatree_to_zarr( compute: bool = True, **kwargs, ): - """This function creates an appropriate datastore for writing a datatree - to a zarr store. - - See `DataTree.to_zarr` for full API docs. - """ + """Implementation of `DataTree.to_zarr`.""" from zarr import consolidate_metadata diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 386f1e346de..e490fc05c2f 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -694,15 +694,12 @@ def is_remote_uri(path: str) -> bool: def read_magic_number_from_file(filename_or_obj, count=8) -> bytes: # check byte header to determine file type - if isinstance(filename_or_obj, bytes): - magic_number = filename_or_obj[:count] - elif isinstance(filename_or_obj, io.IOBase): - if filename_or_obj.tell() != 0: - filename_or_obj.seek(0) - magic_number = filename_or_obj.read(count) - filename_or_obj.seek(0) - else: + if not isinstance(filename_or_obj, io.IOBase): raise TypeError(f"cannot read the magic number from {type(filename_or_obj)}") + if filename_or_obj.tell() != 0: + filename_or_obj.seek(0) + magic_number = filename_or_obj.read(count) + filename_or_obj.seek(0) return magic_number diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 2fe57f51d65..2ff73203580 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4057,12 +4057,29 @@ def create_store(self): yield backends.ScipyDataStore(fobj, "w") def test_to_netcdf_explicit_engine(self) -> None: - # regression test for GH1321 - Dataset({"foo": 42}).to_netcdf(engine="scipy") + with pytest.warns( + FutureWarning, + match=re.escape("return value of to_netcdf() without a target"), + ): + Dataset({"foo": 42}).to_netcdf(engine="scipy") + + def test_roundtrip_via_bytes(self) -> None: + original = create_test_data() + with pytest.warns( + FutureWarning, + match=re.escape("return value of to_netcdf() without a target"), + ): + netcdf_bytes = original.to_netcdf(engine="scipy") + roundtrip = open_dataset(netcdf_bytes, engine="scipy") + assert_identical(roundtrip, original) def test_bytes_pickle(self) -> None: data = Dataset({"foo": ("x", [1, 2, 3])}) - fobj = data.to_netcdf() + with pytest.warns( + FutureWarning, + match=re.escape("return value of to_netcdf() without a target"), + ): + fobj = data.to_netcdf() with self.open(fobj) as ds: unpickled = pickle.loads(pickle.dumps(ds)) assert_identical(unpickled, data) @@ -4070,6 +4087,8 @@ def test_bytes_pickle(self) -> None: @requires_scipy class TestScipyFileObject(CFEncodedBase, NetCDF3Only): + # TODO: Consider consolidating some of these cases (e.g., + # test_file_remains_open) with TestH5NetCDFFileObject engine: T_NetcdfEngine = "scipy" @contextlib.contextmanager @@ -4092,6 +4111,20 @@ def roundtrip( with self.open(f, **open_kwargs) as ds: yield ds + @pytest.mark.xfail( + reason="scipy.io.netcdf_file closes files upon garbage collection" + ) + def test_file_remains_open(self) -> None: + data = Dataset({"foo": ("x", [1, 2, 3])}) + f = BytesIO() + data.to_netcdf(f, engine="scipy") + assert not f.closed + restored = open_dataset(f, engine="scipy") + assert not f.closed + assert_identical(restored, data) + restored.close() + assert not f.closed + @pytest.mark.skip(reason="cannot pickle file objects") def test_pickle(self) -> None: pass @@ -4216,9 +4249,10 @@ def test_engine(self) -> None: with pytest.raises(ValueError, match=r"unrecognized engine"): open_dataset(tmp_file, engine="foobar") - netcdf_bytes = data.to_netcdf() + bytes_io = BytesIO() + data.to_netcdf(bytes_io, engine="scipy") with pytest.raises(ValueError, match=r"unrecognized engine"): - open_dataset(BytesIO(netcdf_bytes), engine="foobar") + open_dataset(bytes_io, engine="foobar") def test_cross_engine_read_write_netcdf3(self) -> None: data = create_test_data() @@ -4265,6 +4299,32 @@ def test_encoding_unlimited_dims(self) -> None: assert actual.encoding["unlimited_dims"] == set("y") assert_equal(ds, actual) + @requires_scipy + def test_roundtrip_via_bytes(self) -> None: + original = create_test_data() + with pytest.warns( + FutureWarning, + match=re.escape("return value of to_netcdf() without a target"), + ): + netcdf_bytes = original.to_netcdf() + roundtrip = open_dataset(netcdf_bytes) + assert_identical(roundtrip, original) + + @pytest.mark.xfail( + reason="scipy.io.netcdf_file closes files upon garbage collection" + ) + @requires_scipy + def test_roundtrip_via_file_object(self) -> None: + original = create_test_data() + f = BytesIO() + original.to_netcdf(f) + assert not f.closed + restored = open_dataset(f) + assert not f.closed + assert_identical(restored, original) + restored.close() + assert not f.closed + @requires_h5netcdf @requires_netCDF4 @@ -4544,16 +4604,13 @@ class TestH5NetCDFFileObject(TestH5NetCDFData): engine: T_NetcdfEngine = "h5netcdf" def test_open_badbytes(self) -> None: - with pytest.raises(ValueError, match=r"HDF5 as bytes"): - with open_dataset(b"\211HDF\r\n\032\n", engine="h5netcdf"): # type: ignore[arg-type] - pass with pytest.raises( ValueError, match=r"match in any of xarray's currently installed IO" ): - with open_dataset(b"garbage"): # type: ignore[arg-type] + with open_dataset(b"garbage"): pass with pytest.raises(ValueError, match=r"can only read bytes"): - with open_dataset(b"garbage", engine="netcdf4"): # type: ignore[arg-type] + with open_dataset(b"garbage", engine="netcdf4"): pass with pytest.raises( ValueError, match=r"not the signature of a valid netCDF4 file" @@ -4604,6 +4661,32 @@ def test_open_fileobj(self) -> None: with open_dataset(f): # ensure file gets closed pass + def test_file_remains_open(self) -> None: + data = Dataset({"foo": ("x", [1, 2, 3])}) + f = BytesIO() + data.to_netcdf(f, engine="h5netcdf") + assert not f.closed + restored = open_dataset(f, engine="h5netcdf") + assert not f.closed + assert_identical(restored, data) + restored.close() + assert not f.closed + + +@requires_h5netcdf +class TestH5NetCDFInMemoryData: + def test_roundtrip_via_bytes(self) -> None: + original = create_test_data() + netcdf_bytes = original.to_netcdf(engine="h5netcdf") + roundtrip = open_dataset(netcdf_bytes, engine="h5netcdf") + assert_identical(roundtrip, original) + + def test_roundtrip_group_via_bytes(self) -> None: + original = create_test_data() + netcdf_bytes = original.to_netcdf(group="sub", engine="h5netcdf") + roundtrip = open_dataset(netcdf_bytes, group="sub", engine="h5netcdf") + assert_identical(roundtrip, original) + @requires_h5netcdf @requires_dask @@ -5973,7 +6056,11 @@ def test_open_dataarray_options(self) -> None: def test_dataarray_to_netcdf_return_bytes(self) -> None: # regression test for GH1410 data = xr.DataArray([1, 2, 3]) - output = data.to_netcdf() + with pytest.warns( + FutureWarning, + match=re.escape("return value of to_netcdf() without a target"), + ): + output = data.to_netcdf(engine="scipy") assert isinstance(output, bytes) def test_dataarray_to_netcdf_no_name_pathlib(self) -> None: @@ -6510,7 +6597,10 @@ def test_scipy_entrypoint(tmp_path: Path) -> None: with open(path, "rb") as f: _check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds) - contents = ds.to_netcdf(engine="scipy") + with pytest.warns( + FutureWarning, match=re.escape("return value of to_netcdf() without a target") + ): + contents = ds.to_netcdf(engine="scipy") _check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds) _check_guess_can_open_and_open( entrypoint, BytesIO(contents), engine="scipy", expected=ds @@ -6525,7 +6615,7 @@ def test_scipy_entrypoint(tmp_path: Path) -> None: assert entrypoint.guess_can_open("something-local.nc") assert entrypoint.guess_can_open("something-local.nc.gz") assert not entrypoint.guess_can_open("not-found-and-no-extension") - assert not entrypoint.guess_can_open(b"not-a-netcdf-file") # type: ignore[arg-type] + assert not entrypoint.guess_can_open(b"not-a-netcdf-file") @requires_h5netcdf diff --git a/xarray/tests/test_backends_common.py b/xarray/tests/test_backends_common.py index 33da027ac97..a42381882ed 100644 --- a/xarray/tests/test_backends_common.py +++ b/xarray/tests/test_backends_common.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import re import numpy as np @@ -53,10 +54,11 @@ def test_infer_dtype_error_on_mixed_types(data): def test_encoding_failure_note(): # Create an arbitrary value that cannot be encoded in netCDF3 ds = xr.Dataset({"invalid": np.array([2**63 - 1], dtype=np.int64)}) + f = io.BytesIO() with pytest.raises( ValueError, match=re.escape( "Raised while encoding variable 'invalid' with value None: "phony_dim_3": 25, } + def test_roundtrip_via_bytes(self, simple_datatree): + original_dt = simple_datatree + roundtrip_dt = open_datatree(original_dt.to_netcdf()) + assert_equal(original_dt, roundtrip_dt) + + def test_roundtrip_via_bytes_engine_specified(self, simple_datatree): + original_dt = simple_datatree + roundtrip_dt = open_datatree(original_dt.to_netcdf(engine=self.engine)) + assert_equal(original_dt, roundtrip_dt) + + def test_roundtrip_using_filelike_object(self, tmpdir, simple_datatree): + original_dt = simple_datatree + filepath = tmpdir + "/test.nc" + # h5py requires both read and write access when writing, it will + # work with file-like objects provided they support both, and are + # seekable. + with open(filepath, "wb+") as file: + original_dt.to_netcdf(file, engine=self.engine) + with open(filepath, "rb") as file: + with open_datatree(file, engine=self.engine) as roundtrip_dt: + assert_equal(original_dt, roundtrip_dt) + @requires_zarr @parametrize_zarr_format