Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions src/anemoi/datasets/create/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,19 @@ def set_to_test_mode(cfg: dict) -> None:
group_by=NUMBER_OF_DATES,
)

num_ensembles = count_ensembles(cfg)

def set_element_to_test(obj):
if isinstance(obj, (list, tuple)):
for v in obj:
set_element_to_test(v)
return
if isinstance(obj, (dict, DotDict)):
if "grid" in obj:
if "grid" in obj and num_ensembles > 1:
previous = obj["grid"]
obj["grid"] = "20./20."
LOG.warning(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}")
if "number" in obj:
if "number" in obj and num_ensembles > 1:
if isinstance(obj["number"], (list, tuple)):
previous = obj["number"]
obj["number"] = previous[0:3]
Expand Down Expand Up @@ -444,3 +446,49 @@ def build_output(*args, **kwargs) -> OutputSpecs:
The output specifications object.
"""
return OutputSpecs(*args, **kwargs)


def flatten_list_of_sets(list_of_sets: list[set]) -> set:
return {element for subset in list_of_sets for element in subset}


def mars_str_to_set(s: str) -> set[str]:
"""Mars strings are like 1/to/2 or 1/to/2/by/1

Returns a set of strings, e.g. {'1', '2'}
"""
assert "/" in s, "mars_str_to_set expects a string with '/'"
lst = s.split("/")
assert len(lst) in (3, 5), f"mars_str_to_set expects a string like 1/to/2 or 1/to/4/by/1, got {s}"
if len(lst) == 3:
assert "to" in lst
start, _, end = lst
step = 1
elif len(lst) == 5:
assert "by" in lst and "to" in lst
start, _, end, _, step = lst
return {str(i) for i in range(int(start), int(end) + 1, int(step))}


def get_ensembles_set(obj):
"""Counts the number of ensembles in the configuration."""
if isinstance(obj, dict):
if "number" in obj:
if isinstance(obj["number"], (list, tuple)):
return set([str(element) for element in obj["number"]])
if isinstance(obj["number"], (str, int)):
if "/" in str(obj["number"]):
return mars_str_to_set(obj["number"])
else:
return {str(obj["number"])}
if isinstance(obj, (dict)):
return flatten_list_of_sets([get_ensembles_set(v) for v in obj.values()])
if isinstance(obj, (list, tuple)):
return flatten_list_of_sets([get_ensembles_set(v) for v in obj])
return {}


def count_ensembles(config: Config) -> int:
"""Counts the number of ensembles in the configuration."""
ensembles = get_ensembles_set(config.input)
return len(ensembles) if ensembles else 1
Loading