Skip to content

Commit a70af48

Browse files
committed
feat(multiscales): add encoding utilities, helpers and filesystem abstractions
1 parent 779e812 commit a70af48

File tree

3 files changed

+273
-34
lines changed

3 files changed

+273
-34
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
from typing import Any, Dict, Tuple
3+
4+
import xarray as xr
5+
6+
from . import utils
7+
from .helpers import iter_str
8+
9+
10+
def create_geozarr_encoding(
11+
ds: xr.Dataset, compressor: Any, spatial_chunk: int
12+
) -> Dict[str, Dict[str, Any]]:
13+
"""Create encoding for GeoZarr dataset variables (spec-aware)."""
14+
encoding: Dict[str, Dict[str, Any]] = {}
15+
# Optional safety cap for chunk bytes to avoid OOM. Default ~8 MiB unless overridden.
16+
try:
17+
max_chunk_bytes = int(os.environ.get("EOPF_MAX_CHUNK_BYTES", str(8 * 1024 * 1024)))
18+
except Exception:
19+
max_chunk_bytes = 8 * 1024 * 1024
20+
for var in iter_str(ds.data_vars):
21+
if utils.is_grid_mapping_variable(ds, var):
22+
encoding[var] = {"compressors": None}
23+
else:
24+
data_shape = ds[var].shape
25+
dtype_size = getattr(ds[var].dtype, "itemsize", 1) or 1
26+
if len(data_shape) >= 2:
27+
height, width = data_shape[-2:]
28+
spatial_chunk_aligned = min(
29+
spatial_chunk,
30+
utils.calculate_aligned_chunk_size(width, spatial_chunk),
31+
utils.calculate_aligned_chunk_size(height, spatial_chunk),
32+
)
33+
else:
34+
spatial_chunk_aligned = spatial_chunk
35+
36+
# Build chunk tuple matching variable dimensionality.
37+
# Use 1 for all leading (non-spatial) dims, and spatial_chunk_aligned for the last two.
38+
if len(data_shape) == 1:
39+
chunks: Tuple[int, ...] = (min(spatial_chunk_aligned, data_shape[0]),)
40+
elif len(data_shape) == 2:
41+
chunks = (spatial_chunk_aligned, spatial_chunk_aligned)
42+
else:
43+
leading: Tuple[int, ...] = tuple(1 for _ in range(len(data_shape) - 2))
44+
chunks = leading + (spatial_chunk_aligned, spatial_chunk_aligned)
45+
46+
# Enforce max_chunk_bytes by reducing spatial chunk if needed
47+
# Estimate total bytes per chunk as product(chunks) * dtype_size
48+
from math import prod as _prod
49+
50+
est_bytes = _prod(chunks) * dtype_size
51+
if max_chunk_bytes and est_bytes > max_chunk_bytes and len(chunks) >= 1:
52+
# Reduce spatial chunks proportionally (keep leading dims as-is)
53+
lead = chunks[:-2] if len(chunks) > 2 else tuple()
54+
yc, xc = (chunks[-2], chunks[-1]) if len(chunks) >= 2 else (chunks[-1], 1)
55+
factor = (est_bytes / max_chunk_bytes) ** 0.5
56+
new_y = max(1, int(yc / factor))
57+
new_x = max(1, int(xc / factor))
58+
chunks = lead + (new_y, new_x) if len(chunks) >= 2 else (new_y,)
59+
60+
encoding[var] = {
61+
"chunks": chunks,
62+
"compressors": ([compressor] if compressor is not None else None),
63+
}
64+
65+
for coord in iter_str(ds.coords):
66+
encoding[coord] = {"compressors": None}
67+
68+
return encoding

src/eopf_geozarr/conversion/fs_utils.py

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
"""S3 utilities for GeoZarr conversion."""
1+
"""S3 utilities for GeoZarr conversion.
2+
3+
Note: Optional dependencies may lack type stubs; suppress their missing-import
4+
noise locally to keep global mypy strictness intact.
5+
"""
6+
7+
# mypy: disable-error-code=import-not-found
28

39
import json
410
import os
5-
from typing import Any, Dict, Optional
11+
from typing import Any, Dict, Literal, Optional
612
from urllib.parse import urlparse
713

814
import s3fs
@@ -108,9 +114,7 @@ def get_s3_storage_options(s3_path: str, **s3_kwargs: Any) -> Dict[str, Any]:
108114
default_s3_kwargs = {
109115
"anon": False, # Use credentials
110116
"use_ssl": True,
111-
"client_kwargs": {
112-
"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
113-
},
117+
"client_kwargs": {"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")},
114118
}
115119

116120
# Add custom endpoint support (e.g., for OVH Cloud)
@@ -147,6 +151,36 @@ def get_storage_options(path: str, **kwargs: Any) -> Optional[Dict[str, Any]]:
147151
"""
148152
if is_s3_path(path):
149153
return get_s3_storage_options(path, **kwargs)
154+
# For HTTP(S) paths, ensure servers don't apply content-encoding (e.g., gzip)
155+
# to chunk responses which would corrupt codec bytes (e.g., Blosc) and
156+
# trigger decompression errors. Force identity encoding and set a sane
157+
# default block size for ranged requests.
158+
if path.startswith(("http://", "https://")):
159+
headers = {"Accept-Encoding": "identity"}
160+
# Merge user headers if provided
161+
user_headers = kwargs.get("headers")
162+
if isinstance(user_headers, dict):
163+
headers.update(user_headers)
164+
http_opts: Dict[str, Any] = {
165+
"headers": headers,
166+
"block_size": kwargs.get("block_size", 0),
167+
"simple_links": kwargs.get("simple_links", True),
168+
}
169+
# Add conservative aiohttp client settings to mitigate disconnects
170+
try:
171+
import aiohttp
172+
173+
timeout = kwargs.get("timeout") or aiohttp.ClientTimeout(total=120)
174+
connector = kwargs.get("connector") or aiohttp.TCPConnector(limit=8)
175+
client_kwargs = kwargs.get("client_kwargs", {}) or {}
176+
if not isinstance(client_kwargs, dict):
177+
client_kwargs = {}
178+
client_kwargs.setdefault("timeout", timeout)
179+
client_kwargs.setdefault("connector", connector)
180+
http_opts["client_kwargs"] = client_kwargs
181+
except Exception:
182+
pass
183+
return http_opts
150184
# For local paths, return None (no storage options needed)
151185
# Future protocols (gcs://, azure://, etc.) can be added here
152186
return None
@@ -201,9 +235,7 @@ def create_s3_store(s3_path: str, **s3_kwargs: Any) -> str:
201235
return s3_path
202236

203237

204-
def write_s3_json_metadata(
205-
s3_path: str, metadata: Dict[str, Any], **s3_kwargs: Any
206-
) -> None:
238+
def write_s3_json_metadata(s3_path: str, metadata: Dict[str, Any], **s3_kwargs: Any) -> None:
207239
"""
208240
Write JSON metadata directly to S3.
209241
@@ -224,9 +256,7 @@ def write_s3_json_metadata(
224256
"anon": False,
225257
"use_ssl": True,
226258
"asynchronous": False, # Force synchronous mode
227-
"client_kwargs": {
228-
"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
229-
},
259+
"client_kwargs": {"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")},
230260
}
231261

232262
# Add custom endpoint support (e.g., for OVH Cloud)
@@ -266,9 +296,7 @@ def read_s3_json_metadata(s3_path: str, **s3_kwargs: Any) -> Dict[str, Any]:
266296
"anon": False,
267297
"use_ssl": True,
268298
"asynchronous": False, # Force synchronous mode
269-
"client_kwargs": {
270-
"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
271-
},
299+
"client_kwargs": {"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")},
272300
}
273301

274302
# Add custom endpoint support (e.g., for OVH Cloud)
@@ -308,9 +336,7 @@ def s3_path_exists(s3_path: str, **s3_kwargs: Any) -> bool:
308336
"anon": False,
309337
"use_ssl": True,
310338
"asynchronous": False, # Force synchronous mode
311-
"client_kwargs": {
312-
"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
313-
},
339+
"client_kwargs": {"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")},
314340
}
315341

316342
# Add custom endpoint support (e.g., for OVH Cloud)
@@ -327,7 +353,9 @@ def s3_path_exists(s3_path: str, **s3_kwargs: Any) -> bool:
327353
return result
328354

329355

330-
def open_s3_zarr_group(s3_path: str, mode: str = "r", **s3_kwargs: Any) -> zarr.Group:
356+
def open_s3_zarr_group(
357+
s3_path: str, mode: Literal["r", "r+", "w", "a", "w-"] = "r", **s3_kwargs: Any
358+
) -> zarr.Group:
331359
"""
332360
Open a Zarr group from S3 using storage_options.
333361
@@ -346,9 +374,7 @@ def open_s3_zarr_group(s3_path: str, mode: str = "r", **s3_kwargs: Any) -> zarr.
346374
Zarr group
347375
"""
348376
storage_options = get_s3_storage_options(s3_path, **s3_kwargs)
349-
return zarr.open_group(
350-
s3_path, mode=mode, zarr_format=3, storage_options=storage_options
351-
)
377+
return zarr.open_group(s3_path, mode=mode, zarr_format=3, storage_options=storage_options)
352378

353379

354380
def get_s3_credentials_info() -> Dict[str, Optional[str]]:
@@ -362,9 +388,7 @@ def get_s3_credentials_info() -> Dict[str, Optional[str]]:
362388
"""
363389
return {
364390
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
365-
"aws_secret_access_key": "***"
366-
if os.environ.get("AWS_SECRET_ACCESS_KEY")
367-
else None,
391+
"aws_secret_access_key": "***" if os.environ.get("AWS_SECRET_ACCESS_KEY") else None,
368392
"aws_session_token": "***" if os.environ.get("AWS_SESSION_TOKEN") else None,
369393
"aws_default_region": os.environ.get("AWS_DEFAULT_REGION", "us-east-1"),
370394
"aws_profile": os.environ.get("AWS_PROFILE"),
@@ -395,9 +419,7 @@ def validate_s3_access(s3_path: str, **s3_kwargs: Any) -> tuple[bool, Optional[s
395419
"anon": False,
396420
"use_ssl": True,
397421
"asynchronous": False, # Force synchronous mode
398-
"client_kwargs": {
399-
"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
400-
},
422+
"client_kwargs": {"region_name": os.environ.get("AWS_DEFAULT_REGION", "us-east-1")},
401423
}
402424

403425
# Add custom endpoint support (e.g., for OVH Cloud)
@@ -441,9 +463,34 @@ def get_filesystem(path: str, **kwargs: Any) -> Any:
441463
# Get S3 storage options and use them for fsspec
442464
storage_options = get_s3_storage_options(path, **kwargs)
443465
return fsspec.filesystem("s3", **storage_options)
444-
else:
445-
# For local paths, use the local filesystem
446-
return fsspec.filesystem("file")
466+
if path.startswith(("http://", "https://")):
467+
# Ensure identity encoding for raw chunk bytes over HTTP(S)
468+
headers = {"Accept-Encoding": "identity"}
469+
user_headers = kwargs.get("headers")
470+
if isinstance(user_headers, dict):
471+
headers.update(user_headers)
472+
http_opts: Dict[str, Any] = {
473+
"headers": headers,
474+
"block_size": kwargs.get("block_size", 0),
475+
"simple_links": kwargs.get("simple_links", True),
476+
}
477+
# Add conservative aiohttp client settings to mitigate disconnects
478+
try:
479+
import aiohttp
480+
481+
timeout = kwargs.get("timeout") or aiohttp.ClientTimeout(total=120)
482+
connector = kwargs.get("connector") or aiohttp.TCPConnector(limit=8)
483+
client_kwargs = kwargs.get("client_kwargs", {}) or {}
484+
if not isinstance(client_kwargs, dict):
485+
client_kwargs = {}
486+
client_kwargs.setdefault("timeout", timeout)
487+
client_kwargs.setdefault("connector", connector)
488+
http_opts["client_kwargs"] = client_kwargs
489+
except Exception:
490+
pass
491+
return fsspec.filesystem("http", **http_opts)
492+
# For local paths, use the local filesystem
493+
return fsspec.filesystem("file")
447494

448495

449496
def write_json_metadata(path: str, metadata: Dict[str, Any], **kwargs: Any) -> None:
@@ -519,7 +566,9 @@ def path_exists(path: str, **kwargs: Any) -> bool:
519566
return result
520567

521568

522-
def open_zarr_group(path: str, mode: str = "r", **kwargs: Any) -> zarr.Group:
569+
def open_zarr_group(
570+
path: str, mode: Literal["r", "r+", "w", "a", "w-"] = "r", **kwargs: Any
571+
) -> zarr.Group:
523572
"""
524573
Open a Zarr group from any path type using unified storage options.
525574
@@ -538,6 +587,4 @@ def open_zarr_group(path: str, mode: str = "r", **kwargs: Any) -> zarr.Group:
538587
Zarr group
539588
"""
540589
storage_options = get_storage_options(path, **kwargs)
541-
return zarr.open_group(
542-
path, mode=mode, zarr_format=3, storage_options=storage_options
543-
)
590+
return zarr.open_group(path, mode=mode, zarr_format=3, storage_options=storage_options)

0 commit comments

Comments
 (0)