Skip to content

Commit 5346640

Browse files
committed
added some comments
1 parent 23e8a5f commit 5346640

File tree

3 files changed

+103
-11
lines changed

3 files changed

+103
-11
lines changed

src/anemoi/datasets/data/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,9 @@ def _select_to_columns(self, vars: str | list[str] | tuple[str] | set) -> list[i
410410
if not isinstance(vars, (list, tuple)):
411411
vars = [vars]
412412

413+
for v in vars:
414+
if v not in self.name_to_index:
415+
raise ValueError(f"select: unknown variable: {v}, available: {list(self.name_to_index)}")
413416
return [self.name_to_index[v] for v in vars]
414417

415418
def _drop_to_columns(self, vars: str | Sequence[str]) -> list[int]:

src/anemoi/datasets/data/records/__init__.py

Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,20 @@ def _to_numpy_dates(d):
7878

7979

8080
class BaseRecordsDataset:
81+
"""This is the base class for all datasets based on records.
82+
Records datasets are datasets that can be indexed by time (int) or by group (str).
83+
A record dataset is designed for observations, where multiple array of difference shapes need to be stored for each date.
84+
They have the same concept or start_date, end_date, frequency as fields datasets, but each date correspond to a window.
85+
All windows have the same size (the window span can be different from the dataset frequency)
8186
82-
def __getitem__(self, i):
87+
variables in a record datasets are identified by a group and a name.
88+
"""
89+
90+
# Depending on the context, a variable is identified by "group.name",
91+
# or using a dict with keys as groups and values as list of names.
92+
# most of the code should be agnostic and transform one format to the other when needed.
93+
94+
def __getitem__(self, i: int | str):
8395
if isinstance(i, str):
8496
return self._getgroup(i)
8597

@@ -90,15 +102,31 @@ def __getitem__(self, i):
90102

91103
@cached_property
92104
def window(self):
105+
"""Returns a string representation of the relative window of the dataset, such as '(-3h, 3h]'."""
93106
return str(self._window)
94107

95-
def _getgroup(self, i):
96-
return Tabular(self, i)
108+
def _getgroup(self, group: str):
109+
"""Returns a Tabular object for the group. As a partial function when argument group is given but i is not."""
110+
return Tabular(self, group)
97111

98-
def _getrecord(self, i):
112+
def _getrecord(self, i: int):
113+
"""Returns a Record object for the time step i. As a partial function when argument i is given but group is not."""
99114
return Record(self, i)
100115

101-
def _load_data(self, i):
116+
def _load_data(self, i: int) -> dict:
117+
"""
118+
Load the data for a specific time step or window (i).
119+
It is expected to return a dict containing keys of the form:
120+
121+
- "data:group1" : numpy array
122+
- "latitudes:group1" : numpy array
123+
- "longitudes:group1" : numpy array
124+
- "metadata:group1" :
125+
- ...
126+
- "data:group2" : numpy array
127+
- "latitudes:group2" : numpy array
128+
- ...
129+
"""
102130
raise NotImplementedError("Must be implemented in subclass")
103131

104132
@property
@@ -221,6 +249,13 @@ class FieldsRecords(RecordsForward):
221249
"""A wrapper around a FieldsDataset to provide a consistent interface for records datasets."""
222250

223251
def __init__(self, fields_dataset, name):
252+
"""wrapper around a fields dataset to provide a consistent interface for records datasets.
253+
A FieldsRecords appears as a RecordsDataset with a single group.
254+
This allows merging fields datasets with other records datasets.
255+
Parameters:
256+
fields_dataset: must be a regular fields dataset
257+
name: the name of the group
258+
."""
224259
self.forward = fields_dataset
225260
from anemoi.datasets.data.dataset import Dataset
226261

@@ -293,7 +328,9 @@ def __len__(self):
293328
return len(self.forward.dates)
294329

295330

296-
class GenericRename(RecordsForward):
331+
class BaseRename(RecordsForward):
332+
"""Renames variables in a records dataset."""
333+
297334
def __init__(self, dataset, rename):
298335
self.forward = dataset
299336
assert isinstance(rename, dict)
@@ -320,16 +357,16 @@ def groups(self):
320357
return [self.rename.get(k, k) for k in self.forward.groups]
321358

322359

323-
class Rename(GenericRename):
360+
class Rename(BaseRename):
324361
pass
325362

326363

327-
class SetGroup(GenericRename):
364+
class SetGroup(BaseRename):
328365
def __init__(self, dataset, set_group):
329366
if len(dataset.groups) != 1:
330367
raise ValueError(f"{self.__class__.__name__} can only be used with datasets containing a single group.")
331368

332-
super.__init__(dataset, {dataset.groups[0]: set_group})
369+
super().__init__(dataset, {dataset.groups[0]: set_group})
333370

334371
def _load_data(self, i):
335372
return self.dataset._load_data(i)
@@ -411,6 +448,7 @@ def _to_timedelta(t):
411448

412449

413450
class AbsoluteWindow:
451+
# not used but expected to be useful when building datasets. And used in tests
414452
def __init__(self, start, end, include_start=True, include_end=True):
415453
assert isinstance(start, datetime.datetime), f"start must be a datetime.datetime, got {type(start)}"
416454
assert isinstance(end, datetime.datetime), f"end must be a datetime.datetime, got {type(end)}"
@@ -428,6 +466,14 @@ def __repr__(self):
428466

429467

430468
class WindowsSpec:
469+
# A window specified by relative timedeltas, such as (-6h, 0h]
470+
#
471+
# the term "WindowSpec" is used here to avoid confusion between
472+
# - a relative window, such as (-6h, 0h] which this class represents (WindowsSpec)
473+
# - an actual time interval, such as [2023-01-01 00:00, 2023-01-01 06:00] which is an (AbsoluteWindow)
474+
#
475+
# but is is more confusing, it should be renamed as Window.
476+
431477
def __init__(self, *, start, end, include_start=False, include_end=True):
432478
assert isinstance(start, (str, datetime.timedelta)), f"start must be a str or timedelta, got {type(start)}"
433479
assert isinstance(end, (str, datetime.timedelta)), f"end must be a str or timedelta, got {type(end)}"
@@ -447,6 +493,7 @@ def __init__(self, *, start, end, include_start=False, include_end=True):
447493

448494
def to_absolute_window(self, date):
449495
"""Convert the window to an absolute window based on a date."""
496+
# not used but expected to be useful when building datasets. And used in tests
450497
assert isinstance(date, datetime.datetime), f"date must be a datetime.datetime, got {type(date)}"
451498
start = date + self.start
452499
end = date + self.end
@@ -466,6 +513,8 @@ def _frequency_to_string(t):
466513
return f"{first}{_frequency_to_string(self.start)},{_frequency_to_string(self.end)}{last}"
467514

468515
def compute_mask(self, timedeltas):
516+
"""Returns a boolean numpy array of the same shape as timedeltas."""
517+
469518
assert timedeltas.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {timedeltas.dtype}"
470519
if self.include_start:
471520
lower_mask = timedeltas >= self._start_np
@@ -480,6 +529,9 @@ def compute_mask(self, timedeltas):
480529
return lower_mask & upper_mask
481530

482531
def starts_before(self, my_dates, other_dates, other_window):
532+
# apply this window to my_dates[0] and the other_window to other_dates[0]
533+
# return True if this window starts before the other window
534+
483535
assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}"
484536
assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}"
485537
assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}"
@@ -492,6 +544,7 @@ def starts_before(self, my_dates, other_dates, other_window):
492544
return my_start <= other_start
493545

494546
def ends_after(self, my_dates, other_dates, other_window):
547+
# same as starts_before
495548
assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}"
496549
assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}"
497550
assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}"
@@ -507,13 +560,15 @@ def ends_after(self, my_dates, other_dates, other_window):
507560

508561

509562
class Rewindowed(RecordsForward):
563+
# change the window of a records dataset
564+
# similar to changing the frequency of a dataset
565+
510566
def __init__(self, dataset, window):
511567
super().__init__(dataset)
512568
self.dataset = dataset
513569

514570
# in this class anything with 1 refers to the original window/dataset
515571
# and anything with 2 refers to the new window/dataset
516-
# and we use _Δ for timedeltas
517572

518573
self._window1 = self.forward._window
519574
self._window2 = window_from_str(window)
@@ -602,6 +657,13 @@ def _load_data(self, i):
602657

603658

604659
class Select(RecordsForward):
660+
# Select a subset of variables from a records dataset
661+
# select can be a list of strings with dots (or a dict with keys as groups and values as list of strings)
662+
#
663+
# the selection is a filter, not a reordering, which is different from fields datasets and should be documented/fixed
664+
#
665+
# Drop should be implemented
666+
605667
def __init__(self, dataset, select):
606668
super().__init__(dataset)
607669

@@ -693,6 +755,8 @@ def statistics(self):
693755

694756

695757
class RecordsSubset(RecordsForward):
758+
"""Subset of a records dataset based on a list of integer indices."""
759+
696760
def __init__(self, dataset, indices, reason):
697761
super().__init__(dataset)
698762
self.dataset = dataset
@@ -711,6 +775,7 @@ def __len__(self):
711775

712776

713777
class RecordsDataset(BaseRecordsDataset):
778+
"""This is the base class for all datasets based on records stored on disk."""
714779

715780
def __init__(self, path, backend=None, **kwargs):
716781
if kwargs:
@@ -806,7 +871,13 @@ def tree(self):
806871

807872

808873
class Record:
809-
def __init__(self, dataset, n):
874+
"""A record corresponds to a single time step in a record dataset."""
875+
876+
def __init__(self, dataset: RecordsDataset, n: int):
877+
"""A record corresponds to a single time step in a record dataset.
878+
n : int, the index of the time step in the dataset.
879+
dataset : RecordsDataset, the dataset this record belongs to.
880+
"""
810881
self.dataset = dataset
811882
self.n = n
812883

@@ -867,6 +938,8 @@ def as_dict(self):
867938

868939

869940
class Tabular:
941+
"""A RecordsDataset for a single group, similar to a fields dataset, but allowing different shapes for each date."""
942+
870943
def __init__(self, dataset, name):
871944
self.dataset = dataset
872945
self.name = name

src/anemoi/datasets/data/records/backends/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,19 @@ def __init__(self, path, **kwargs):
2727
self.kwargs = kwargs
2828

2929
def read(self, i, **kwargs):
30+
""" Read the i-th record and return a dictionary of numpy arrays."""
3031
raise NotImplementedError("Must be implemented in subclass")
3132

3233
def read_metadata(self):
34+
""" Read the metadata of a record dataset. The metadata does not depend on the record index."""
3335
raise NotImplementedError("Must be implemented in subclass")
3436

3537
def read_statistics(self):
38+
"""Read the statistics of a record dataset. The statistics does not depend on the record index."""
3639
raise NotImplementedError("Must be implemented in subclass")
3740

3841
def _check_data(self, data):
42+
# Check that all keys are normalised
3943
for k in list(data.keys()):
4044
k = k.split(":")[-1]
4145
if k != normalise_key(k):
@@ -139,16 +143,22 @@ def backend_factory(name, *args, **kwargs):
139143

140144

141145
class WriteBackend(Backend):
146+
# Write backend base class, not used for reading
147+
# provides implementation to write data
142148
def __init__(self, *, target, **kwargs):
143149
super().__init__(target, **kwargs)
144150

145151
def write(self, i, data, **kwargs):
152+
# expects data to be a dict of numpy arrays
146153
raise NotImplementedError("Must be implemented in subclass")
147154

148155
def write_metadata(self, metadata):
156+
# expects metadata to be a dict
149157
raise NotImplementedError("Must be implemented in subclass")
150158

151159
def write_statistics(self, statistics):
160+
# expects statistics to be a dict of dicts with the right keys:
161+
# {group: {mean:..., std:..., min:..., max:...}}
152162
raise NotImplementedError("Must be implemented in subclass")
153163

154164
def _check_data(self, data):
@@ -158,6 +168,8 @@ def _check_data(self, data):
158168
raise ValueError(f"{k} must be alphanumerical and '_' only.")
159169

160170
def _dataframes_to_record(self, i, data, variables, **kwargs):
171+
# Convert data from pandas DataFrames to a record format
172+
# will be used for writing, building obs datasets
161173

162174
assert isinstance(data, (dict)), type(data)
163175
if not data:
@@ -174,6 +186,8 @@ def _dataframes_to_record(self, i, data, variables, **kwargs):
174186
return data
175187

176188
def _dataframe_to_dict(self, name, df, **kwargs):
189+
# will be used for writing, building obs datasets
190+
177191
d = {}
178192
d["timedeltas:" + name] = df["timedeltas"]
179193
d["latitudes:" + name] = df["latitudes"]
@@ -304,6 +318,8 @@ def write_statistics(self, statistics):
304318

305319

306320
def writer_backend_factory(name, **kwargs):
321+
# choose the right backend for writing
322+
# this is intended to make benchmarking easier
307323
WRITE_BACKENDS = dict(
308324
npz1=Npz1WriteBackend,
309325
npz2=Npz2WriteBackend,

0 commit comments

Comments
 (0)