-
Notifications
You must be signed in to change notification settings - Fork 45
feat: track origins of variable from dataset creation to inference #437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9cb8c52
01774dd
33793a6
920d523
ed5d190
2562baf
f048a4b
8ad5eb0
5db7a97
b33f3ac
71d8180
6d23027
6933574
9179dae
79a391b
203e09b
b4433bd
6f3fdb0
45365a1
d7cc82c
f381b00
da93ad6
3082edf
ef4a5c9
93410d5
1df0ef7
18df4eb
3341d4c
58dc8a2
3e180f9
255c22d
83936f7
5209f26
c7a0e5d
b78a098
d641ea7
3754eb2
39ebc13
5d32745
6c1f146
37de369
24f2c2a
db4d895
55f740d
014dbbc
8ad9396
9756618
a493a96
1cde9f8
cdb1a9a
e69eb10
92165b4
a044e14
cb9c576
99a5fb7
ce027f4
3bf7c35
96dfe3d
f68a11e
3d5f0ef
70272f6
38ced18
cb3847e
00477c9
b0508a9
7d494b9
dd62e77
21208ad
53f915c
b0348bd
28d6ffa
1a6a3e4
7b332b5
c40025a
81e355b
06850d8
e09ed7e
cd06c98
f9fd3a0
c9259c6
8df9cc5
4c06588
a917c49
39bedac
56ac3ca
9309cfe
437e4aa
79f1706
ef431f8
48c9d07
caf5b90
9a7a8b7
57901c5
9343e66
095d57d
5148b61
0cd3bf6
ae1bb2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
.. _filters: | ||
|
||
######### | ||
Filters | ||
######### | ||
|
||
.. warning:: | ||
|
||
This is still a work-in-progress. Some of the filters may be renamed | ||
later. | ||
|
||
Filters are used to modify the data or metadata in a dataset. | ||
|
||
See :ref:`install <anemoi-transform:filters>` for more information. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,20 +8,15 @@ | |
# nor does it submit to any jurisdiction. | ||
|
||
import logging | ||
from abc import ABC | ||
from abc import abstractmethod | ||
|
||
from anemoi.datasets.dates import DatesProvider | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
|
||
class Action: | ||
"""An "Action" represents a single operation described in the yaml configuration, e.g. a source, a filter, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought this docstring was actually quite useful when I first went through the code. Unless there's a good reason, I think we should keep it (even if we don't want to expose it in the docs) |
||
pipe, join, etc. | ||
|
||
See :ref:`operations` for more details. | ||
|
||
""" | ||
|
||
class Action(ABC): | ||
def __init__(self, config, *path): | ||
self.config = config | ||
self.path = path | ||
|
@@ -30,6 +25,13 @@ def __init__(self, config, *path): | |
"data_sources", | ||
), f"{self.__class__.__name__}: path must start with 'input' or 'data_sources': {path}" | ||
|
||
@abstractmethod | ||
def __call__(self, context, argument): | ||
pass | ||
|
||
def __repr__(self): | ||
return f"{self.__class__.__name__}({'.'.join(str(x) for x in self.path)}, {self.config})" | ||
|
||
|
||
class Concat(Action): | ||
"""The Concat contruct is used to concat different actions that are responsible | ||
|
@@ -65,6 +67,7 @@ def __init__(self, config, *path): | |
|
||
for i, item in enumerate(config): | ||
|
||
assert "dates" in item, f"Value must contain the key 'dates' {item}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated to tracking origin work? |
||
dates = item["dates"] | ||
filtering_dates = DatesProvider.from_config(**dates) | ||
action = action_factory({k: v for k, v in item.items() if k != "dates"}, *self.path, str(i)) | ||
|
@@ -186,7 +189,13 @@ def create_object(self, context, config): | |
return create_datasets_source(context, config) | ||
|
||
def call_object(self, context, source, argument): | ||
return source.execute(context.source_argument(argument)) | ||
result = source.execute(context.source_argument(argument)) | ||
return context.origin(result, self, argument) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on the parameter names, the method call looks correct - but I have no idea what |
||
|
||
def origin(self): | ||
from .origin import Source | ||
|
||
return Source(self.path[-1], self.config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's better to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, looking in the In other words, I think the link between |
||
|
||
|
||
class TransformSourceMixin: | ||
|
@@ -197,6 +206,15 @@ def create_object(self, context, config): | |
|
||
return create_transform_source(context, config) | ||
|
||
def combine_origins(self, current, previous): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both the transform mixins have |
||
assert previous is None, f"Cannot combine origins, previous already exists: {previous}" | ||
return current | ||
|
||
def origin(self): | ||
from .origin import Source | ||
|
||
return Source(self.path[-1], self.config) | ||
|
||
|
||
class TransformFilterMixin: | ||
"""Mixin class for filters defined in anemoi-transform""" | ||
|
@@ -207,14 +225,16 @@ def create_object(self, context, config): | |
return create_transform_filter(context, config) | ||
|
||
def call_object(self, context, filter, argument): | ||
return filter.forward(context.filter_argument(argument)) | ||
result = filter.forward(context.filter_argument(argument)) | ||
return context.origin(result, self, argument) | ||
|
||
def origin(self): | ||
from .origin import Filter | ||
|
||
class FilterFunction(Function): | ||
"""Action to call a filter on the argument (e.g. rename, regrid, etc.).""" | ||
return Filter(self.path[-1], self.config) | ||
|
||
def __call__(self, context, argument): | ||
return self.call(context, argument, context.filter_argument) | ||
def combine_origins(self, current, previous): | ||
return {"_apply": current, **(previous or {})} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand this one |
||
|
||
|
||
def _make_name(name, what): | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -10,6 +10,8 @@ | |||||
|
||||||
from typing import Any | ||||||
|
||||||
from anemoi.transform.fields import new_field_with_metadata | ||||||
from anemoi.transform.fields import new_fieldlist_from_list | ||||||
from earthkit.data.core.order import build_remapping | ||||||
|
||||||
from ..result.field import FieldResult | ||||||
|
@@ -52,3 +54,20 @@ def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: | |||||
from anemoi.datasets.dates.groups import GroupOfDates | ||||||
|
||||||
return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) | ||||||
|
||||||
def origin(self, data: Any, action: Any, action_arguments: Any) -> Any: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're going to have type hints, can they be more specific? e.g. |
||||||
|
||||||
origin = action.origin() | ||||||
|
||||||
result = [] | ||||||
for fs in data: | ||||||
previous = fs.metadata("anemoi_origin", default=None) | ||||||
fall_through = fs.metadata("anemoi_fall_through", default=False) | ||||||
if fall_through: | ||||||
# The field has pass unchanges in a filter | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
result.append(fs) | ||||||
else: | ||||||
anemoi_origin = origin.combine(previous, action, action_arguments) | ||||||
result.append(new_field_with_metadata(fs, anemoi_origin=anemoi_origin)) | ||||||
|
||||||
return new_fieldlist_from_list(result) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# (C) Copyright 2025 Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
import logging | ||
from abc import ABC | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
|
||
class Origin(ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're creating an abstract base class, are there any abstract methods or properties we should be requiring? |
||
|
||
def __init__(self, when="dataset-create"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it anticipated that other packages will create |
||
self.when = when | ||
|
||
def __eq__(self, other): | ||
if not isinstance(other, Origin): | ||
return False | ||
return self is other | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's just doing object equality, isn't the type check above redundant? |
||
|
||
def __hash__(self): | ||
return id(self) | ||
|
||
|
||
def _un_dotdict(x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this already in anemoi-utils? |
||
if isinstance(x, dict): | ||
return {k: _un_dotdict(v) for k, v in x.items()} | ||
|
||
if isinstance(x, (list, tuple, set)): | ||
return [_un_dotdict(a) for a in x] | ||
|
||
return x | ||
|
||
|
||
class Pipe(Origin): | ||
def __init__(self, s1, s2, when="dataset-create"): | ||
super().__init__(when) | ||
self.steps = [s1, s2] | ||
|
||
assert s1 is not None, (s1, s2) | ||
assert s2 is not None, (s1, s2) | ||
|
||
if isinstance(s1, Pipe): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is |
||
assert not isinstance(s2, Pipe), (s1, s2) | ||
self.steps = s1.steps + [s2] | ||
|
||
def combine(self, previous, action, action_arguments): | ||
assert False, (self, previous) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (This is true in many places in the code) |
||
|
||
def as_dict(self): | ||
return { | ||
"type": "pipe", | ||
"steps": [s.as_dict() for s in self.steps], | ||
"when": self.when, | ||
} | ||
|
||
def __repr__(self): | ||
return " | ".join(repr(s) for s in self.steps) | ||
|
||
|
||
class Join(Origin): | ||
def __init__(self, origins, when="dataset-create"): | ||
assert isinstance(origins, (list, tuple, set)), origins | ||
super().__init__(when) | ||
self.steps = list(origins) | ||
|
||
assert all(o is not None for o in origins), origins | ||
|
||
def combine(self, previous, action, action_arguments): | ||
assert False, (self, previous) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above – better to raise a more specific exception |
||
|
||
def as_dict(self): | ||
return { | ||
"type": "join", | ||
"steps": [s.as_dict() for s in self.steps], | ||
"when": self.when, | ||
} | ||
|
||
def __repr__(self): | ||
return " & ".join(repr(s) for s in self.steps) | ||
|
||
|
||
class Source(Origin): | ||
def __init__(self, name, config, when="dataset-create"): | ||
super().__init__(when) | ||
assert isinstance(config, dict), f"Config must be a dictionary {config}" | ||
self.name = name | ||
self.config = _un_dotdict(config) | ||
|
||
def combine(self, previous, action, action_arguments): | ||
assert previous is None, f"Cannot combine origins, previous already exists: {previous}" | ||
return self | ||
|
||
def as_dict(self): | ||
return { | ||
"type": "source", | ||
"name": self.name, | ||
"config": self.config, | ||
"when": self.when, | ||
} | ||
|
||
def __repr__(self): | ||
return f"{self.name}({id(self)})" | ||
|
||
|
||
class Filter(Origin): | ||
def __init__(self, name, config, when="dataset-create"): | ||
super().__init__(when) | ||
assert isinstance(config, dict), f"Config must be a dictionary {config}" | ||
self.name = name | ||
self.config = _un_dotdict(config) | ||
self._cache = {} | ||
|
||
def combine(self, previous, action, action_arguments): | ||
|
||
if previous is None: | ||
# This can happen if the filter does not tag its output with an origin | ||
# (e.g. a user plugin). In that case we try to get the origin from the action arguments | ||
key = (id(action), id(action_arguments)) | ||
if key not in self._cache: | ||
|
||
LOG.warning(f"No previous origin to combine with: {self}. Action: {action}") | ||
LOG.warning(f"Connecting to action arguments {action_arguments}") | ||
origins = set() | ||
for k in action_arguments: | ||
o = k.metadata("anemoi_origin", default=None) | ||
if o is None: | ||
raise ValueError( | ||
f"Cannot combine origins, previous is None and action_arguments {action_arguments} has no origin" | ||
) | ||
origins.add(o) | ||
if len(origins) == 1: | ||
self._cache[key] = origins.pop() | ||
else: | ||
self._cache[key] = Join(origins) | ||
previous = self._cache[key] | ||
|
||
if previous in self._cache: | ||
# We use a cache to avoid recomputing the same combination | ||
return self._cache[previous] | ||
|
||
self._cache[previous] = Pipe(previous, self) | ||
return self._cache[previous] | ||
|
||
def as_dict(self): | ||
return { | ||
"type": "filter", | ||
"name": self.name, | ||
"config": self.config, | ||
"when": self.when, | ||
} | ||
|
||
def __repr__(self): | ||
return f"{self.name}({id(self)})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this being removed?