-
Notifications
You must be signed in to change notification settings - Fork 2
feat/geozarr model #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d277d01
e798d98
c940bd8
553c7c7
de36bf6
105a3a5
6a88bcb
bca5f5b
38a721f
ad084c8
dfdeff2
2088e33
cfcf7e1
0faa082
f8c5722
d1a2e2d
ad585a0
3d11af4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,4 +26,4 @@ repos: | |
additional_dependencies: | ||
- types-simplejson | ||
- types-attrs | ||
- pydantic~=2.0 | ||
- pydantic>=2.11 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
"""Common utilities for GeoZarr data API.""" | ||
|
||
import io | ||
import urllib | ||
import urllib.request | ||
from typing import Annotated, Final, Literal | ||
|
||
from cf_xarray.utils import parse_cf_standard_name_table | ||
from pydantic import AfterValidator, BaseModel | ||
|
||
XARRAY_DIMS_KEY: Final = "_ARRAY_DIMENSIONS" | ||
|
||
|
||
def get_cf_standard_names(url: str) -> tuple[str, ...]: | ||
"""Retrieve the set of CF standard names and return them as a tuple.""" | ||
|
||
headers = {"User-Agent": "eopf_geozarr"} | ||
|
||
req = urllib.request.Request(url, headers=headers) | ||
|
||
try: | ||
with urllib.request.urlopen(req) as response: | ||
content = response.read() # Read the entire response body into memory | ||
content_fobj = io.BytesIO(content) | ||
except urllib.error.URLError as e: | ||
raise e | ||
|
||
_info, table, _aliases = parse_cf_standard_name_table(source=content_fobj) | ||
return tuple(table.keys()) | ||
|
||
|
||
# This is a URL to the CF standard names table. | ||
CF_STANDARD_NAME_URL = ( | ||
"https://raw.githubusercontent.com/cf-convention/cf-convention.github.io/" | ||
"master/Data/cf-standard-names/current/src/cf-standard-name-table.xml" | ||
) | ||
|
||
# this does IO against github. consider locally storing this data instead if fetching every time | ||
# is problematic. | ||
CF_STANDARD_NAMES = get_cf_standard_names(url=CF_STANDARD_NAME_URL) | ||
|
||
|
||
def check_standard_name(name: str) -> str: | ||
""" | ||
Check if the standard name is valid according to the CF conventions. | ||
|
||
Parameters | ||
---------- | ||
name : str | ||
The standard name to check. | ||
|
||
Returns | ||
------- | ||
str | ||
The validated standard name. | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If the standard name is not valid. | ||
""" | ||
|
||
if name in CF_STANDARD_NAMES: | ||
return name | ||
raise ValueError( | ||
f"Invalid standard name: {name}. This name was not found in the list of CF standard names." | ||
) | ||
|
||
|
||
CFStandardName = Annotated[str, AfterValidator(check_standard_name)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we miss the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can add these, and if they are required we should make that more clear in the spec. right now the spec says
but it isn't clear which CF attributes are required, optional, etc |
||
|
||
ResamplingMethod = Literal[ | ||
"nearest", | ||
"average", | ||
"bilinear", | ||
"cubic", | ||
"cubic_spline", | ||
"lanczos", | ||
"mode", | ||
"max", | ||
"min", | ||
"med", | ||
"sum", | ||
"q1", | ||
"q3", | ||
"rms", | ||
"gauss", | ||
] | ||
"""A string literal indicating a resampling method""" | ||
|
||
|
||
class TileMatrixLimit(BaseModel): | ||
"""""" | ||
|
||
tileMatrix: str | ||
minTileCol: int | ||
minTileRow: int | ||
maxTileCol: int | ||
maxTileRow: int | ||
|
||
|
||
class TileMatrix(BaseModel): | ||
id: str | ||
scaleDenominator: float | ||
cellSize: float | ||
pointOfOrigin: tuple[float, float] | ||
tileWidth: int | ||
tileHeight: int | ||
matrixWidth: int | ||
matrixHeight: int | ||
|
||
|
||
class TileMatrixSet(BaseModel): | ||
id: str | ||
title: str | None = None | ||
crs: str | None = None | ||
supportedCRS: str | None = None | ||
orderedAxes: tuple[str, str] | None = None | ||
tileMatrices: tuple[TileMatrix, ...] | ||
|
||
|
||
class Multiscales(BaseModel, extra="allow"): | ||
""" | ||
Multiscale metadata for a GeoZarr dataset. | ||
|
||
Attributes | ||
---------- | ||
tile_matrix_set : str | ||
The tile matrix set identifier for the multiscale dataset. | ||
resampling_method : ResamplingMethod | ||
The name of the resampling method for the multiscale dataset. | ||
tile_matrix_set_limits : dict[str, TileMatrixSetLimits] | None, optional | ||
The tile matrix set limits for the multiscale dataset. | ||
""" | ||
|
||
tile_matrix_set: TileMatrixSet | ||
resampling_method: ResamplingMethod | ||
# TODO: ensure that the keys match tile_matrix_set.tileMatrices[$index].id | ||
# TODO: ensure that the keys match the tileMatrix attribute | ||
tile_matrix_limits: dict[str, TileMatrixLimit] | None = None | ||
|
||
|
||
class DatasetAttrs(BaseModel, extra="allow"): | ||
""" | ||
Attributes for a GeoZarr dataset. | ||
|
||
Attributes | ||
---------- | ||
multiscales: MultiscaleAttrs | ||
""" | ||
|
||
multiscales: Multiscales | ||
|
||
|
||
class BaseDataArrayAttrs(BaseModel, extra="allow"): | ||
""" | ||
Base attributes for a GeoZarr DataArray. | ||
|
||
Attributes | ||
---------- | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
"""GeoZarr data API for Zarr V2.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Mapping | ||
from typing import Any, Iterable, Literal, Self, TypeVar | ||
|
||
from pydantic import BaseModel, ConfigDict, Field, model_validator | ||
from pydantic_zarr.v2 import ArraySpec, GroupSpec, auto_attributes | ||
|
||
from eopf_geozarr.data_api.geozarr.common import ( | ||
XARRAY_DIMS_KEY, | ||
BaseDataArrayAttrs, | ||
Multiscales, | ||
) | ||
|
||
|
||
class DataArrayAttrs(BaseDataArrayAttrs): | ||
""" | ||
Attributes for a GeoZarr DataArray. | ||
|
||
Attributes | ||
---------- | ||
array_dimensions : tuple[str, ...] | ||
Alias for the _ARRAY_DIMENSIONS attribute, which lists the dimension names for this array. | ||
""" | ||
|
||
# todo: validate that this names listed here are the names of zarr arrays | ||
# unless the variable is an auxiliary variable | ||
# see https://github.com/zarr-developers/geozarr-spec/blob/main/geozarr-spec.md#geozarr-coordinates | ||
array_dimensions: tuple[str, ...] = Field(alias="_ARRAY_DIMENSIONS") | ||
|
||
model_config = ConfigDict(serialize_by_alias=True) | ||
|
||
|
||
class DataArray(ArraySpec[DataArrayAttrs]): | ||
""" | ||
A GeoZarr DataArray variable. It must have attributes that contain an `"_ARRAY_DIMENSIONS"` | ||
key, with a length that matches the dimensionality of the array. | ||
|
||
References | ||
---------- | ||
https://github.com/zarr-developers/geozarr-spec/blob/main/geozarr-spec.md#geozarr-dataarray | ||
""" | ||
|
||
@classmethod | ||
def from_array( | ||
cls, | ||
array: Any, | ||
chunks: tuple[int, ...] | Literal["auto"] = "auto", | ||
attributes: Mapping[str, object] | Literal["auto"] = "auto", | ||
fill_value: object | Literal["auto"] = "auto", | ||
order: Literal["C", "F"] | Literal["auto"] = "auto", | ||
filters: tuple[Any, ...] | Literal["auto"] = "auto", | ||
dimension_separator: Literal[".", "/"] | Literal["auto"] = "auto", | ||
compressor: Any | Literal["auto"] = "auto", | ||
dimension_names: Iterable[str] | Literal["auto"] = "auto", | ||
) -> Self: | ||
if attributes == "auto": | ||
auto_attrs = dict(auto_attributes(array)) | ||
else: | ||
auto_attrs = dict(attributes) | ||
if dimension_names != "auto": | ||
auto_attrs = auto_attrs | {XARRAY_DIMS_KEY: tuple(dimension_names)} | ||
model = super().from_array( | ||
array=array, | ||
chunks=chunks, | ||
attributes=auto_attrs, | ||
fill_value=fill_value, | ||
order=order, | ||
filters=filters, | ||
dimension_separator=dimension_separator, | ||
compressor=compressor, | ||
) | ||
return model # type: ignore[no-any-return] | ||
|
||
@model_validator(mode="after") | ||
def check_array_dimensions(self) -> Self: | ||
if (len_dim := len(self.attributes.array_dimensions)) != ( | ||
ndim := len(self.shape) | ||
): | ||
msg = ( | ||
f"The {XARRAY_DIMS_KEY} attribute has length {len_dim}, which does not " | ||
f"match the number of dimensions for this array (got {ndim})." | ||
) | ||
raise ValueError(msg) | ||
return self | ||
|
||
@property | ||
def array_dimensions(self) -> tuple[str, ...]: | ||
return self.attributes.array_dimensions # type: ignore[no-any-return] | ||
|
||
|
||
T = TypeVar("T", bound=GroupSpec[Any, Any]) | ||
|
||
|
||
def check_valid_coordinates(model: T) -> T: | ||
""" | ||
Check if the coordinates of the DataArrays listed in a GeoZarr DataSet are valid. | ||
|
||
For each DataArray in the model, we check the dimensions associated with the DataArray. | ||
For each dimension associated with a data variable, an array with the name of that data variable | ||
must be present in the members of the group. | ||
|
||
Parameters | ||
---------- | ||
model : GroupSpec[Any, Any] | ||
The GeoZarr DataArray model to check. | ||
|
||
Returns | ||
------- | ||
GroupSpec[Any, Any] | ||
The validated GeoZarr DataArray model. | ||
""" | ||
if model.members is None: | ||
raise ValueError("Model members cannot be None") | ||
|
||
arrays: dict[str, DataArray] = { | ||
k: v for k, v in model.members.items() if isinstance(v, DataArray) | ||
} | ||
for key, array in arrays.items(): | ||
for idx, dim in enumerate(array.array_dimensions): | ||
if dim not in model.members: | ||
raise ValueError( | ||
f"Dimension '{dim}' for array '{key}' is not defined in the model members." | ||
) | ||
member = model.members[dim] | ||
if isinstance(member, GroupSpec): | ||
raise ValueError( | ||
f"Dimension '{dim}' for array '{key}' should be a group. Found an array instead." | ||
) | ||
if member.shape[0] != array.shape[idx]: | ||
raise ValueError( | ||
f"Dimension '{dim}' for array '{key}' has a shape mismatch: " | ||
f"{member.shape[0]} != {array.shape[idx]}." | ||
) | ||
return model | ||
|
||
|
||
class DatasetAttrs(BaseModel): | ||
""" | ||
Attributes for a GeoZarr dataset. | ||
|
||
Attributes | ||
---------- | ||
multiscales: MultiscaleAttrs | ||
""" | ||
|
||
multiscales: Multiscales | ||
|
||
|
||
class Dataset(GroupSpec[DatasetAttrs, GroupSpec[Any, Any] | DataArray]): | ||
""" | ||
A GeoZarr Dataset. | ||
""" | ||
|
||
@model_validator(mode="after") | ||
def check_valid_coordinates(self) -> Self: | ||
""" | ||
Validate the coordinates of the GeoZarr DataSet. | ||
|
||
This method checks that all DataArrays in the dataset have valid coordinates | ||
according to the GeoZarr specification. | ||
|
||
Returns | ||
------- | ||
GroupSpec[Any, Any] | ||
The validated GeoZarr DataSet. | ||
""" | ||
return check_valid_coordinates(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please bear with my limited knowledge of pydantic but how is made the link with the actual
standard_name
field name?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pydantic does most of its validation routines based on type annotations. When we annotate an attribute on a pydantic model with this type: https://github.com/d-v-b/data-model/blob/3d11af412e460993f8e603dcff0555c5342c4e8f/src/eopf_geozarr/data_api/geozarr/common.py#L70, then pydantic will run the
check_standard_name
function after checking that the input is a string.