Skip to content

feat/geozarr model #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ repos:
additional_dependencies:
- types-simplejson
- types-attrs
- pydantic~=2.0
- pydantic>=2.11
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ classifiers = [
requires-python = ">=3.11"
dependencies = [
"pydantic-zarr>=0.8.0",
"pydantic>=2.11",
"zarr>=3.1.1",
"xarray>=2025.7.1",
"dask[array,distributed]>=2025.5.1",
Expand Down Expand Up @@ -111,7 +112,7 @@ use_parentheses = true
ensure_newline_before_comments = true

[tool.mypy]
python_version = "3.10"
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
Expand Down
Empty file.
161 changes: 161 additions & 0 deletions src/eopf_geozarr/data_api/geozarr/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Common utilities for GeoZarr data API."""

import io
import urllib
import urllib.request
from typing import Annotated, Final, Literal

from cf_xarray.utils import parse_cf_standard_name_table
from pydantic import AfterValidator, BaseModel

XARRAY_DIMS_KEY: Final = "_ARRAY_DIMENSIONS"


def get_cf_standard_names(url: str) -> tuple[str, ...]:
"""Retrieve the set of CF standard names and return them as a tuple."""

headers = {"User-Agent": "eopf_geozarr"}

req = urllib.request.Request(url, headers=headers)

try:
with urllib.request.urlopen(req) as response:
content = response.read() # Read the entire response body into memory
content_fobj = io.BytesIO(content)
except urllib.error.URLError as e:
raise e

_info, table, _aliases = parse_cf_standard_name_table(source=content_fobj)
return tuple(table.keys())


# This is a URL to the CF standard names table.
CF_STANDARD_NAME_URL = (
"https://raw.githubusercontent.com/cf-convention/cf-convention.github.io/"
"master/Data/cf-standard-names/current/src/cf-standard-name-table.xml"
)

# this does IO against github. consider locally storing this data instead if fetching every time
# is problematic.
CF_STANDARD_NAMES = get_cf_standard_names(url=CF_STANDARD_NAME_URL)


def check_standard_name(name: str) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please bear with my limited knowledge of pydantic but how is made the link with the actual standard_name field name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pydantic does most of its validation routines based on type annotations. When we annotate an attribute on a pydantic model with this type: https://github.com/d-v-b/data-model/blob/3d11af412e460993f8e603dcff0555c5342c4e8f/src/eopf_geozarr/data_api/geozarr/common.py#L70, then pydantic will run the check_standard_name function after checking that the input is a string.

"""
Check if the standard name is valid according to the CF conventions.

Parameters
----------
name : str
The standard name to check.

Returns
-------
str
The validated standard name.

Raises
------
ValueError
If the standard name is not valid.
"""

if name in CF_STANDARD_NAMES:
return name
raise ValueError(
f"Invalid standard name: {name}. This name was not found in the list of CF standard names."
)


CFStandardName = Annotated[str, AfterValidator(check_standard_name)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we miss the grid_mapping attribute verification defaulted to spatial_ref scalar with the EPSG code. https://zarr.dev/geozarr-spec/documents/standard/template/geozarr-spec.html#_e15d59bd-f2ec-28e8-8016-4e541c95e10f

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add these, and if they are required we should make that more clear in the spec. right now the spec says

CF Conventions – Including attributes such as standard_name, units, axis, and grid_mapping to express spatiotemporal semantics and coordinate system properties.

but it isn't clear which CF attributes are required, optional, etc


ResamplingMethod = Literal[
"nearest",
"average",
"bilinear",
"cubic",
"cubic_spline",
"lanczos",
"mode",
"max",
"min",
"med",
"sum",
"q1",
"q3",
"rms",
"gauss",
]
"""A string literal indicating a resampling method"""


class TileMatrixLimit(BaseModel):
""""""

tileMatrix: str
minTileCol: int
minTileRow: int
maxTileCol: int
maxTileRow: int


class TileMatrix(BaseModel):
id: str
scaleDenominator: float
cellSize: float
pointOfOrigin: tuple[float, float]
tileWidth: int
tileHeight: int
matrixWidth: int
matrixHeight: int


class TileMatrixSet(BaseModel):
id: str
title: str | None = None
crs: str | None = None
supportedCRS: str | None = None
orderedAxes: tuple[str, str] | None = None
tileMatrices: tuple[TileMatrix, ...]


class Multiscales(BaseModel, extra="allow"):
"""
Multiscale metadata for a GeoZarr dataset.

Attributes
----------
tile_matrix_set : str
The tile matrix set identifier for the multiscale dataset.
resampling_method : ResamplingMethod
The name of the resampling method for the multiscale dataset.
tile_matrix_set_limits : dict[str, TileMatrixSetLimits] | None, optional
The tile matrix set limits for the multiscale dataset.
"""

tile_matrix_set: TileMatrixSet
resampling_method: ResamplingMethod
# TODO: ensure that the keys match tile_matrix_set.tileMatrices[$index].id
# TODO: ensure that the keys match the tileMatrix attribute
tile_matrix_limits: dict[str, TileMatrixLimit] | None = None


class DatasetAttrs(BaseModel, extra="allow"):
"""
Attributes for a GeoZarr dataset.

Attributes
----------
multiscales: MultiscaleAttrs
"""

multiscales: Multiscales


class BaseDataArrayAttrs(BaseModel, extra="allow"):
"""
Base attributes for a GeoZarr DataArray.

Attributes
----------
"""
170 changes: 170 additions & 0 deletions src/eopf_geozarr/data_api/geozarr/v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""GeoZarr data API for Zarr V2."""

from __future__ import annotations

from collections.abc import Mapping
from typing import Any, Iterable, Literal, Self, TypeVar

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic_zarr.v2 import ArraySpec, GroupSpec, auto_attributes

from eopf_geozarr.data_api.geozarr.common import (
XARRAY_DIMS_KEY,
BaseDataArrayAttrs,
Multiscales,
)


class DataArrayAttrs(BaseDataArrayAttrs):
"""
Attributes for a GeoZarr DataArray.

Attributes
----------
array_dimensions : tuple[str, ...]
Alias for the _ARRAY_DIMENSIONS attribute, which lists the dimension names for this array.
"""

# todo: validate that this names listed here are the names of zarr arrays
# unless the variable is an auxiliary variable
# see https://github.com/zarr-developers/geozarr-spec/blob/main/geozarr-spec.md#geozarr-coordinates
array_dimensions: tuple[str, ...] = Field(alias="_ARRAY_DIMENSIONS")

model_config = ConfigDict(serialize_by_alias=True)


class DataArray(ArraySpec[DataArrayAttrs]):
"""
A GeoZarr DataArray variable. It must have attributes that contain an `"_ARRAY_DIMENSIONS"`
key, with a length that matches the dimensionality of the array.

References
----------
https://github.com/zarr-developers/geozarr-spec/blob/main/geozarr-spec.md#geozarr-dataarray
"""

@classmethod
def from_array(
cls,
array: Any,
chunks: tuple[int, ...] | Literal["auto"] = "auto",
attributes: Mapping[str, object] | Literal["auto"] = "auto",
fill_value: object | Literal["auto"] = "auto",
order: Literal["C", "F"] | Literal["auto"] = "auto",
filters: tuple[Any, ...] | Literal["auto"] = "auto",
dimension_separator: Literal[".", "/"] | Literal["auto"] = "auto",
compressor: Any | Literal["auto"] = "auto",
dimension_names: Iterable[str] | Literal["auto"] = "auto",
) -> Self:
if attributes == "auto":
auto_attrs = dict(auto_attributes(array))
else:
auto_attrs = dict(attributes)
if dimension_names != "auto":
auto_attrs = auto_attrs | {XARRAY_DIMS_KEY: tuple(dimension_names)}
model = super().from_array(
array=array,
chunks=chunks,
attributes=auto_attrs,
fill_value=fill_value,
order=order,
filters=filters,
dimension_separator=dimension_separator,
compressor=compressor,
)
return model # type: ignore[no-any-return]

@model_validator(mode="after")
def check_array_dimensions(self) -> Self:
if (len_dim := len(self.attributes.array_dimensions)) != (
ndim := len(self.shape)
):
msg = (
f"The {XARRAY_DIMS_KEY} attribute has length {len_dim}, which does not "
f"match the number of dimensions for this array (got {ndim})."
)
raise ValueError(msg)
return self

@property
def array_dimensions(self) -> tuple[str, ...]:
return self.attributes.array_dimensions # type: ignore[no-any-return]


T = TypeVar("T", bound=GroupSpec[Any, Any])


def check_valid_coordinates(model: T) -> T:
"""
Check if the coordinates of the DataArrays listed in a GeoZarr DataSet are valid.

For each DataArray in the model, we check the dimensions associated with the DataArray.
For each dimension associated with a data variable, an array with the name of that data variable
must be present in the members of the group.

Parameters
----------
model : GroupSpec[Any, Any]
The GeoZarr DataArray model to check.

Returns
-------
GroupSpec[Any, Any]
The validated GeoZarr DataArray model.
"""
if model.members is None:
raise ValueError("Model members cannot be None")

arrays: dict[str, DataArray] = {
k: v for k, v in model.members.items() if isinstance(v, DataArray)
}
for key, array in arrays.items():
for idx, dim in enumerate(array.array_dimensions):
if dim not in model.members:
raise ValueError(
f"Dimension '{dim}' for array '{key}' is not defined in the model members."
)
member = model.members[dim]
if isinstance(member, GroupSpec):
raise ValueError(
f"Dimension '{dim}' for array '{key}' should be a group. Found an array instead."
)
if member.shape[0] != array.shape[idx]:
raise ValueError(
f"Dimension '{dim}' for array '{key}' has a shape mismatch: "
f"{member.shape[0]} != {array.shape[idx]}."
)
return model


class DatasetAttrs(BaseModel):
"""
Attributes for a GeoZarr dataset.

Attributes
----------
multiscales: MultiscaleAttrs
"""

multiscales: Multiscales


class Dataset(GroupSpec[DatasetAttrs, GroupSpec[Any, Any] | DataArray]):
"""
A GeoZarr Dataset.
"""

@model_validator(mode="after")
def check_valid_coordinates(self) -> Self:
"""
Validate the coordinates of the GeoZarr DataSet.

This method checks that all DataArrays in the dataset have valid coordinates
according to the GeoZarr specification.

Returns
-------
GroupSpec[Any, Any]
The validated GeoZarr DataSet.
"""
return check_valid_coordinates(self)
Loading
Loading