From 3bf87ae1ee2d1b4716ac0c306ea4969b4a8490b6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 25 Jul 2025 17:29:03 +0100 Subject: [PATCH 1/6] split config and settings --- pyproject.toml | 1 + src/anemoi/utils/cli.py | 4 + src/anemoi/utils/commands/config.py | 8 +- src/anemoi/utils/config.py | 344 +++++----------------------- src/anemoi/utils/settings.py | 312 +++++++++++++++++++++++++ 5 files changed, 383 insertions(+), 286 deletions(-) create mode 100644 src/anemoi/utils/settings.py diff --git a/pyproject.toml b/pyproject.toml index 39a96716..e4bfa74f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "importlib-metadata; python_version<'3.10'", "multiurl", "numpy", + "omegaconf", "pydantic>=2.9", "python-dateutil", "pyyaml", diff --git a/src/anemoi/utils/cli.py b/src/anemoi/utils/cli.py index 6711d919..4ef87580 100644 --- a/src/anemoi/utils/cli.py +++ b/src/anemoi/utils/cli.py @@ -14,6 +14,7 @@ import os import sys import traceback +import warnings from typing import Callable from typing import Optional @@ -202,6 +203,9 @@ def cli_main( test_arguments : list[str], optional The command line arguments to parse, used for testing purposes, by default None """ + + warnings.filterwarnings("default", category=DeprecationWarning) + parser = make_parser(description, commands) args, unknown = parser.parse_known_args(test_arguments) if argcomplete: diff --git a/src/anemoi/utils/commands/config.py b/src/anemoi/utils/commands/config.py index c0c80abd..ef3f1906 100644 --- a/src/anemoi/utils/commands/config.py +++ b/src/anemoi/utils/commands/config.py @@ -12,8 +12,8 @@ from argparse import ArgumentParser from argparse import Namespace -from ..config import config_path -from ..config import load_config +from ..settings import load_settings +from ..settings import settings_path from . import Command @@ -39,9 +39,9 @@ def run(self, args: Namespace) -> None: The arguments passed to the command. """ if args.path: - print(config_path()) + print(settings_path()) else: - print(json.dumps(load_config(), indent=4)) + print(json.dumps(load_settings(), indent=4)) command = Config diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index d383d8e9..93a4589c 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -10,17 +10,17 @@ from __future__ import annotations -import contextlib import json import logging import os -import threading from typing import Any -from typing import Optional from typing import Union +import deprecation import yaml +from anemoi.utils._version import __version__ + try: import tomllib # Only available since 3.11 except ImportError: @@ -62,14 +62,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for k, v in self.items(): - if isinstance(v, dict) or is_omegaconf_dict(v): + if isinstance(v, dict) or _is_omegaconf_dict(v): self[k] = DotDict(v) - if isinstance(v, list) or is_omegaconf_list(v): - self[k] = [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in v] + if isinstance(v, list) or _is_omegaconf_list(v): + self[k] = [DotDict(i) if isinstance(i, dict) or _is_omegaconf_dict(i) else i for i in v] if isinstance(v, tuple): - self[k] = [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in v] + self[k] = [DotDict(i) if isinstance(i, dict) or _is_omegaconf_dict(i) else i for i in v] @classmethod def from_file(cls, path: str) -> DotDict: @@ -194,7 +194,7 @@ def __repr__(self) -> str: return f"DotDict({super().__repr__()})" -def is_omegaconf_dict(value: Any) -> bool: +def _is_omegaconf_dict(value: Any) -> bool: """Check if a value is an OmegaConf DictConfig. Parameters @@ -215,7 +215,7 @@ def is_omegaconf_dict(value: Any) -> bool: return False -def is_omegaconf_list(value: Any) -> bool: +def _is_omegaconf_list(value: Any) -> bool: """Check if a value is an OmegaConf ListConfig. Parameters @@ -236,13 +236,6 @@ def is_omegaconf_list(value: Any) -> bool: return False -CONFIG = {} -CHECKED = {} -CONFIG_LOCK = threading.RLock() -QUIET = False -CONFIG_PATCH = None - - def _find(config: Union[dict, list], what: str, result: list = None) -> list: """Find all occurrences of a key in a nested dictionary or list. @@ -312,50 +305,6 @@ def _set_defaults(a: dict, b: dict) -> None: a.setdefault(k, v) -def config_path(name: str = "settings.toml") -> str: - """Get the path to a configuration file. - - Parameters - ---------- - name : str, optional - The name of the configuration file, by default "settings.toml". - - Returns - ------- - str - The path to the configuration file. - """ - global QUIET - - if name.startswith("/") or name.startswith("."): - return name - - if name.startswith("~"): - return os.path.expanduser(name) - - full = os.path.join(os.path.expanduser("~"), ".config", "anemoi", name) - os.makedirs(os.path.dirname(full), exist_ok=True) - - if name == "settings.toml": - old = os.path.join(os.path.expanduser("~"), ".anemoi.toml") - if not os.path.exists(full) and os.path.exists(old): - if not QUIET: - LOG.warning( - "Configuration file found at ~/.anemoi.toml. Please move it to ~/.config/anemoi/settings.toml" - ) - QUIET = True - return old - else: - if os.path.exists(old): - if not QUIET: - LOG.warning( - "Configuration file found at ~/.anemoi.toml and ~/.config/anemoi/settings.toml, ignoring the former" - ) - QUIET = True - - return full - - def load_any_dict_format(path: str) -> dict: """Load a configuration file in any supported format: JSON, YAML and TOML. @@ -406,217 +355,6 @@ def load_any_dict_format(path: str) -> dict: return open(path).read() -def _load_config( - name: str = "settings.toml", - secrets: Optional[Union[str, list[str]]] = None, - defaults: Optional[Union[str, dict]] = None, -) -> DotDict: - """Load a configuration file. - - Parameters - ---------- - name : str, optional - The name of the configuration file, by default "settings.toml". - secrets : str or list, optional - The name of the secrets file, by default None. - defaults : str or dict, optional - The name of the defaults file, by default None. - - Returns - ------- - DotDict - The loaded configuration. - """ - key = json.dumps((name, secrets, defaults), sort_keys=True, default=str) - if key in CONFIG: - return CONFIG[key] - - path = config_path(name) - if os.path.exists(path): - config = load_any_dict_format(path) - else: - config = {} - - if defaults is not None: - if isinstance(defaults, str): - defaults = load_raw_config(defaults) - _set_defaults(config, defaults) - - if secrets is not None: - if isinstance(secrets, str): - secrets = [secrets] - - base, ext = os.path.splitext(path) - secret_name = base + ".secrets" + ext - - found = set() - for secret in secrets: - if _find(config, secret): - found.add(secret) - - if found: - check_config_mode(name, secret_name, found) - - check_config_mode(secret_name, None) - secret_config = _load_config(secret_name) - _merge_dicts(config, secret_config) - - for env, value in os.environ.items(): - - if not env.startswith("ANEMOI_CONFIG_"): - continue - rest = env[len("ANEMOI_CONFIG_") :] - - package = rest.split("_")[0] - sub = rest[len(package) + 1 :] - - package = package.lower() - sub = sub.lower() - - LOG.info(f"Using environment variable {env} to override the anemoi config key '{package}.{sub}'") - - if package not in config: - config[package] = {} - config[package][sub] = value - - CONFIG[key] = DotDict(config) - return CONFIG[key] - - -def _save_config(name: str, data: Any) -> None: - """Save a configuration file. - - Parameters - ---------- - name : str - The name of the configuration file. - data : Any - The data to save. - """ - CONFIG.pop(name, None) - - conf = config_path(name) - - if conf.endswith(".json"): - with open(conf, "w") as f: - json.dump(data, f, indent=4) - return - - if conf.endswith(".yaml") or conf.endswith(".yml"): - with open(conf, "w") as f: - yaml.dump(data, f) - return - - if conf.endswith(".toml"): - raise NotImplementedError("Saving to TOML is not implemented yet") - - with open(conf, "w") as f: - f.write(data) - - -def save_config(name: str, data: Any) -> None: - """Save a configuration file. - - Parameters - ---------- - name : str - The name of the configuration file to save. - - data : Any - The data to save. - """ - with CONFIG_LOCK: - _save_config(name, data) - - -def load_config( - name: str = "settings.toml", - secrets: Optional[Union[str, list[str]]] = None, - defaults: Optional[Union[str, dict]] = None, -) -> DotDict | str: - """Read a configuration file. - - Parameters - ---------- - name : str, optional - The name of the config file to read, by default "settings.toml" - secrets : str or list, optional - The name of the secrets file, by default None - defaults : str or dict, optional - The name of the defaults file, by default None - - Returns - ------- - DotDict or str - Return DotDict if it is a dictionary, otherwise the raw data - """ - - with CONFIG_LOCK: - config = _load_config(name, secrets, defaults) - if CONFIG_PATCH is not None: - config = CONFIG_PATCH(config) - return config - - -def load_raw_config(name: str, default: Any = None) -> Union[DotDict, str]: - """Load a raw configuration file. - - Parameters - ---------- - name : str - The name of the configuration file. - default : Any, optional - The default value if the file does not exist, by default None. - - Returns - ------- - DotDict or str - The loaded configuration or the default value. - """ - path = config_path(name) - if os.path.exists(path): - return load_any_dict_format(path) - - return default - - -def check_config_mode(name: str = "settings.toml", secrets_name: str = None, secrets: list[str] = None) -> None: - """Check that a configuration file is secure. - - Parameters - ---------- - name : str, optional - The name of the configuration file, by default "settings.toml" - secrets_name : str, optional - The name of the secrets file, by default None - secrets : list, optional - The list of secrets to check, by default None - - Raises - ------ - SystemError - If the configuration file is not secure. - """ - with CONFIG_LOCK: - if name in CHECKED: - return - - conf = config_path(name) - if not os.path.exists(conf): - return - mode = os.stat(conf).st_mode - if mode & 0o777 != 0o600: - if secrets_name: - secret_path = config_path(secrets_name) - raise SystemError( - f"Configuration file {conf} should not hold entries {secrets}.\n" - f"Please move them to {secret_path}." - ) - raise SystemError(f"Configuration file {conf} is not secure.\n" f"Please run `chmod 600 {conf}`.") - - CHECKED[name] = True - - def find(metadata: Union[dict, list], what: str, result: list = None, *, select: callable = None) -> list: """Find all occurrences of a key in a nested dictionary or list with an optional selector. @@ -675,19 +413,61 @@ def merge_configs(*configs: dict) -> dict: return result -@contextlib.contextmanager -def temporary_config(tmp: dict) -> None: +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.temporary_settings instead.", +) +def temporary_config(*args, **kwargs) -> None: + from .settings import temporary_settings - global CONFIG_PATCH + return temporary_settings(*args, **kwargs) - def patch_config(config: dict) -> dict: - return merge_configs(config, tmp) - with CONFIG_LOCK: +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.load_settings instead.", +) +def load_config(*args, **kwargs) -> DotDict | str: + from .settings import load_settings - CONFIG_PATCH = patch_config + return load_settings(*args, **kwargs) - try: - yield - finally: - CONFIG_PATCH = None + +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.settings_path instead.", +) +def config_path(*args, **kwargs) -> str: + from .settings import settings_path + + return settings_path(*args, **kwargs) + + +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.save_settings instead.", +) +def save_config(*args, **kwargs) -> None: + from .settings import save_settings + + save_settings(*args, **kwargs) + + +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.load_settings instead.", +) +def check_config_mode(*args, **kwargs) -> None: + from .settings import check_settings_mode + + check_settings_mode(*args, **kwargs) diff --git a/src/anemoi/utils/settings.py b/src/anemoi/utils/settings.py new file mode 100644 index 00000000..689d8714 --- /dev/null +++ b/src/anemoi/utils/settings.py @@ -0,0 +1,312 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from __future__ import annotations + +import contextlib +import json +import logging +import os +import threading +from typing import Any +from typing import Optional +from typing import Union + +import yaml + +from .config import DotDict +from .config import _find +from .config import _merge_dicts +from .config import _set_defaults +from .config import load_any_dict_format +from .config import load_raw_config +from .config import merge_configs + +LOG = logging.getLogger(__name__) + + +CONFIG = {} +CHECKED = {} +CONFIG_LOCK = threading.RLock() +QUIET = False +CONFIG_PATCH = None + + +def settings_path(name: str = "settings.toml") -> str: + """Get the path to a configuration file. + + Parameters + ---------- + name : str, optional + The name of the configuration file, by default "settings.toml". + + Returns + ------- + str + The path to the configuration file. + """ + global QUIET + + if name.startswith("/") or name.startswith("."): + return name + + if name.startswith("~"): + return os.path.expanduser(name) + + full = os.path.join(os.path.expanduser("~"), ".config", "anemoi", name) + os.makedirs(os.path.dirname(full), exist_ok=True) + + if name == "settings.toml": + old = os.path.join(os.path.expanduser("~"), ".anemoi.toml") + if not os.path.exists(full) and os.path.exists(old): + if not QUIET: + LOG.warning( + "Configuration file found at ~/.anemoi.toml. Please move it to ~/.config/anemoi/settings.toml" + ) + QUIET = True + return old + else: + if os.path.exists(old): + if not QUIET: + LOG.warning( + "Configuration file found at ~/.anemoi.toml and ~/.config/anemoi/settings.toml, ignoring the former" + ) + QUIET = True + + return full + + +def _load_settings( + name: str = "settings.toml", + secrets: Optional[Union[str, list[str]]] = None, + defaults: Optional[Union[str, dict]] = None, +) -> DotDict: + """Load a configuration file. + + Parameters + ---------- + name : str, optional + The name of the configuration file, by default "settings.toml". + secrets : str or list, optional + The name of the secrets file, by default None. + defaults : str or dict, optional + The name of the defaults file, by default None. + + Returns + ------- + DotDict + The loaded configuration. + """ + key = json.dumps((name, secrets, defaults), sort_keys=True, default=str) + if key in CONFIG: + return CONFIG[key] + + path = settings_path(name) + if os.path.exists(path): + config = load_any_dict_format(path) + else: + config = {} + + if defaults is not None: + if isinstance(defaults, str): + defaults = load_raw_config(defaults) + _set_defaults(config, defaults) + + if secrets is not None: + if isinstance(secrets, str): + secrets = [secrets] + + base, ext = os.path.splitext(path) + secret_name = base + ".secrets" + ext + + found = set() + for secret in secrets: + if _find(config, secret): + found.add(secret) + + if found: + check_settings_mode(name, secret_name, found) + + check_settings_mode(secret_name, None) + secret_config = _load_settings(secret_name) + _merge_dicts(config, secret_config) + + for env, value in os.environ.items(): + + if not env.startswith("ANEMOI_CONFIG_"): + continue + rest = env[len("ANEMOI_CONFIG_") :] + + package = rest.split("_")[0] + sub = rest[len(package) + 1 :] + + package = package.lower() + sub = sub.lower() + + LOG.info(f"Using environment variable {env} to override the anemoi config key '{package}.{sub}'") + + if package not in config: + config[package] = {} + config[package][sub] = value + + CONFIG[key] = DotDict(config) + return CONFIG[key] + + +def _save_settings(name: str, data: Any) -> None: + """Save a configuration file. + + Parameters + ---------- + name : str + The name of the configuration file. + data : Any + The data to save. + """ + CONFIG.pop(name, None) + + conf = settings_path(name) + + if conf.endswith(".json"): + with open(conf, "w") as f: + json.dump(data, f, indent=4) + return + + if conf.endswith(".yaml") or conf.endswith(".yml"): + with open(conf, "w") as f: + yaml.dump(data, f) + return + + if conf.endswith(".toml"): + raise NotImplementedError("Saving to TOML is not implemented yet") + + with open(conf, "w") as f: + f.write(data) + + +def save_settings(name: str, data: Any) -> None: + """Save a configuration file. + + Parameters + ---------- + name : str + The name of the configuration file to save. + + data : Any + The data to save. + """ + with CONFIG_LOCK: + _save_settings(name, data) + + +def load_settings( + name: str = "settings.toml", + secrets: Optional[Union[str, list[str]]] = None, + defaults: Optional[Union[str, dict]] = None, +) -> DotDict | str: + """Read a configuration file. + + Parameters + ---------- + name : str, optional + The name of the config file to read, by default "settings.toml" + secrets : str or list, optional + The name of the secrets file, by default None + defaults : str or dict, optional + The name of the defaults file, by default None + + Returns + ------- + DotDict or str + Return DotDict if it is a dictionary, otherwise the raw data + """ + + with CONFIG_LOCK: + config = _load_settings(name, secrets, defaults) + if CONFIG_PATCH is not None: + config = CONFIG_PATCH(config) + return config + + +def _load_raw_settings(name: str, default: Any = None) -> Union[DotDict, str]: + """Load a raw configuration file. + + Parameters + ---------- + name : str + The name of the configuration file. + default : Any, optional + The default value if the file does not exist, by default None. + + Returns + ------- + DotDict or str + The loaded configuration or the default value. + """ + path = settings_path(name) + if os.path.exists(path): + return load_any_dict_format(path) + + return default + + +def check_settings_mode(name: str = "settings.toml", secrets_name: str = None, secrets: list[str] = None) -> None: + """Check that a configuration file is secure. + + Parameters + ---------- + name : str, optional + The name of the configuration file, by default "settings.toml" + secrets_name : str, optional + The name of the secrets file, by default None + secrets : list, optional + The list of secrets to check, by default None + + Raises + ------ + SystemError + If the configuration file is not secure. + """ + with CONFIG_LOCK: + if name in CHECKED: + return + + conf = settings_path(name) + if not os.path.exists(conf): + return + mode = os.stat(conf).st_mode + if mode & 0o777 != 0o600: + if secrets_name: + secret_path = settings_path(secrets_name) + raise SystemError( + f"Configuration file {conf} should not hold entries {secrets}.\n" + f"Please move them to {secret_path}." + ) + raise SystemError(f"Configuration file {conf} is not secure.\n" f"Please run `chmod 600 {conf}`.") + + CHECKED[name] = True + + +@contextlib.contextmanager +def temporary_settings(tmp: dict) -> None: + + global CONFIG_PATCH + + def patch_config(config: dict) -> dict: + return merge_configs(config, tmp) + + with CONFIG_LOCK: + + CONFIG_PATCH = patch_config + + try: + yield + finally: + CONFIG_PATCH = None From 5cc718d5e6b72827af3a71a7fbbdf0d1b36054cd Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 25 Jul 2025 17:42:13 +0100 Subject: [PATCH 2/6] add command --- src/anemoi/utils/commands/config.py | 10 +++ src/anemoi/utils/commands/settings.py | 47 ++++++++++ src/anemoi/utils/config.py | 89 ------------------- src/anemoi/utils/settings.py | 118 ++++++++++++++++++++++++-- tests/test_utils.py | 4 +- 5 files changed, 171 insertions(+), 97 deletions(-) create mode 100644 src/anemoi/utils/commands/settings.py diff --git a/src/anemoi/utils/commands/config.py b/src/anemoi/utils/commands/config.py index ef3f1906..cb196a95 100644 --- a/src/anemoi/utils/commands/config.py +++ b/src/anemoi/utils/commands/config.py @@ -12,6 +12,10 @@ from argparse import ArgumentParser from argparse import Namespace +import deprecation + +from anemoi.utils._version import __version__ + from ..settings import load_settings from ..settings import settings_path from . import Command @@ -30,6 +34,12 @@ def add_arguments(self, command_parser: ArgumentParser) -> None: """ command_parser.add_argument("--path", help="Print path to config file") + @deprecation.deprecated( + deprecated_in="0.1.0", + removed_in="0.2.0", + current_version=__version__, + details="Use `anemoi settings` instead.", + ) def run(self, args: Namespace) -> None: """Execute the command with the provided arguments. diff --git a/src/anemoi/utils/commands/settings.py b/src/anemoi/utils/commands/settings.py new file mode 100644 index 00000000..992124c5 --- /dev/null +++ b/src/anemoi/utils/commands/settings.py @@ -0,0 +1,47 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import json +from argparse import ArgumentParser +from argparse import Namespace + +from ..settings import load_settings +from ..settings import settings_path +from . import Command + + +class Settings(Command): + """Handle settings related commands.""" + + def add_arguments(self, command_parser: ArgumentParser) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : ArgumentParser + The argument parser to which the arguments will be added. + """ + command_parser.add_argument("--path", help="Print path to config file") + + def run(self, args: Namespace) -> None: + """Execute the command with the provided arguments. + + Parameters + ---------- + args : Namespace + The arguments passed to the command. + """ + if args.path: + print(settings_path()) + else: + print(json.dumps(load_settings(), indent=4)) + + +command = Settings diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index 93a4589c..c6df1f18 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -236,75 +236,6 @@ def _is_omegaconf_list(value: Any) -> bool: return False -def _find(config: Union[dict, list], what: str, result: list = None) -> list: - """Find all occurrences of a key in a nested dictionary or list. - - Parameters - ---------- - config : dict or list - The configuration to search. - what : str - The key to search for. - result : list, optional - The list to store results, by default None. - - Returns - ------- - list - The list of found values. - """ - if result is None: - result = [] - - if isinstance(config, list): - for i in config: - _find(i, what, result) - return result - - if isinstance(config, dict): - if what in config: - result.append(config[what]) - - for k, v in config.items(): - _find(v, what, result) - - return result - - -def _merge_dicts(a: dict, b: dict) -> None: - """Merge two dictionaries recursively. - - Parameters - ---------- - a : dict - The first dictionary. - b : dict - The second dictionary. - """ - for k, v in b.items(): - if k in a and isinstance(a[k], dict) and isinstance(v, dict): - _merge_dicts(a[k], v) - else: - a[k] = v - - -def _set_defaults(a: dict, b: dict) -> None: - """Set default values in a dictionary. - - Parameters - ---------- - a : dict - The dictionary to set defaults in. - b : dict - The dictionary with default values. - """ - for k, v in b.items(): - if k in a and isinstance(a[k], dict) and isinstance(v, dict): - _set_defaults(a[k], v) - else: - a.setdefault(k, v) - - def load_any_dict_format(path: str) -> dict: """Load a configuration file in any supported format: JSON, YAML and TOML. @@ -393,26 +324,6 @@ def find(metadata: Union[dict, list], what: str, result: list = None, *, select: return result -def merge_configs(*configs: dict) -> dict: - """Merge multiple configuration dictionaries. - - Parameters - ---------- - *configs : dict - The configuration dictionaries to merge. - - Returns - ------- - dict - The merged configuration dictionary. - """ - result = {} - for config in configs: - _merge_dicts(result, config) - - return result - - @deprecation.deprecated( deprecated_in="0.4.30", removed_in="0.5.0", diff --git a/src/anemoi/utils/settings.py b/src/anemoi/utils/settings.py index 689d8714..638e8685 100644 --- a/src/anemoi/utils/settings.py +++ b/src/anemoi/utils/settings.py @@ -22,12 +22,7 @@ import yaml from .config import DotDict -from .config import _find -from .config import _merge_dicts -from .config import _set_defaults from .config import load_any_dict_format -from .config import load_raw_config -from .config import merge_configs LOG = logging.getLogger(__name__) @@ -39,6 +34,95 @@ CONFIG_PATCH = None +def _find(config: Union[dict, list], what: str, result: list = None) -> list: + """Find all occurrences of a key in a nested dictionary or list. + + Parameters + ---------- + config : dict or list + The configuration to search. + what : str + The key to search for. + result : list, optional + The list to store results, by default None. + + Returns + ------- + list + The list of found values. + """ + if result is None: + result = [] + + if isinstance(config, list): + for i in config: + _find(i, what, result) + return result + + if isinstance(config, dict): + if what in config: + result.append(config[what]) + + for k, v in config.items(): + _find(v, what, result) + + return result + + +def _merge_dicts(a: dict, b: dict) -> None: + """Merge two dictionaries recursively. + + Parameters + ---------- + a : dict + The first dictionary. + b : dict + The second dictionary. + """ + for k, v in b.items(): + if k in a and isinstance(a[k], dict) and isinstance(v, dict): + _merge_dicts(a[k], v) + else: + a[k] = v + + +def _set_defaults(a: dict, b: dict) -> None: + """Set default values in a dictionary. + + Parameters + ---------- + a : dict + The dictionary to set defaults in. + b : dict + The dictionary with default values. + """ + for k, v in b.items(): + if k in a and isinstance(a[k], dict) and isinstance(v, dict): + _set_defaults(a[k], v) + else: + a.setdefault(k, v) + + +def merge_configs(*configs: dict) -> dict: + """Merge multiple configuration dictionaries. + + Parameters + ---------- + *configs : dict + The configuration dictionaries to merge. + + Returns + ------- + dict + The merged configuration dictionary. + """ + result = {} + for config in configs: + _merge_dicts(result, config) + + return result + + def settings_path(name: str = "settings.toml") -> str: """Get the path to a configuration file. @@ -116,7 +200,7 @@ def _load_settings( if defaults is not None: if isinstance(defaults, str): - defaults = load_raw_config(defaults) + defaults = load_raw_settings(defaults) _set_defaults(config, defaults) if secrets is not None: @@ -206,6 +290,28 @@ def save_settings(name: str, data: Any) -> None: _save_settings(name, data) +def load_raw_settings(name: str, default: Any = None) -> Union[DotDict, str]: + """Load a raw configuration file. + + Parameters + ---------- + name : str + The name of the configuration file. + default : Any, optional + The default value if the file does not exist, by default None. + + Returns + ------- + DotDict or str + The loaded configuration or the default value. + """ + path = settings_path(name) + if os.path.exists(path): + return load_any_dict_format(path) + + return default + + def load_settings( name: str = "settings.toml", secrets: Optional[Union[str, list[str]]] = None, diff --git a/tests/test_utils.py b/tests/test_utils.py index bc883332..24797694 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,10 +9,10 @@ from anemoi.utils.config import DotDict -from anemoi.utils.config import _merge_dicts -from anemoi.utils.config import _set_defaults from anemoi.utils.grib import paramid_to_shortname from anemoi.utils.grib import shortname_to_paramid +from anemoi.utils.settings import _merge_dicts +from anemoi.utils.settings import _set_defaults def test_dotdict() -> None: From bd26bca22aea5ce655f718459b7ca8fac4222c78 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 25 Jul 2025 18:01:38 +0100 Subject: [PATCH 3/6] update --- src/anemoi/utils/config.py | 117 +++---------------------------------- 1 file changed, 9 insertions(+), 108 deletions(-) diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index c6df1f18..e6a8f971 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -13,10 +13,10 @@ import json import logging import os -from typing import Any from typing import Union import deprecation +import omegaconf.dictconfig import yaml from anemoi.utils._version import __version__ @@ -30,7 +30,7 @@ LOG = logging.getLogger(__name__) -class DotDict(dict): +class DotDict(omegaconf.dictconfig.DictConfig): """A dictionary that allows access to its keys as attributes. >>> d = DotDict({"a": 1, "b": {"c": 2}}) @@ -49,28 +49,6 @@ class DotDict(dict): >>> d = DotDict(a=1, b=2) """ - def __init__(self, *args, **kwargs): - """Initialize a DotDict instance. - - Parameters - ---------- - *args : tuple - Positional arguments for the dict constructor. - **kwargs : dict - Keyword arguments for the dict constructor. - """ - super().__init__(*args, **kwargs) - - for k, v in self.items(): - if isinstance(v, dict) or _is_omegaconf_dict(v): - self[k] = DotDict(v) - - if isinstance(v, list) or _is_omegaconf_list(v): - self[k] = [DotDict(i) if isinstance(i, dict) or _is_omegaconf_dict(i) else i for i in v] - - if isinstance(v, tuple): - self[k] = [DotDict(i) if isinstance(i, dict) or _is_omegaconf_dict(i) else i for i in v] - @classmethod def from_file(cls, path: str) -> DotDict: """Create a DotDict from a file. @@ -151,90 +129,6 @@ def from_toml_file(cls, path: str) -> DotDict: data = tomllib.load(file) return cls(data) - def __getattr__(self, attr: str) -> Any: - """Get an attribute. - - Parameters - ---------- - attr : str - The attribute name. - - Returns - ------- - Any - The attribute value. - """ - try: - return self[attr] - except KeyError: - raise AttributeError(attr) - - def __setattr__(self, attr: str, value: Any) -> None: - """Set an attribute. - - Parameters - ---------- - attr : str - The attribute name. - value : Any - The attribute value. - """ - if isinstance(value, dict): - value = DotDict(value) - self[attr] = value - - def __repr__(self) -> str: - """Return a string representation of the DotDict. - - Returns - ------- - str - The string representation. - """ - return f"DotDict({super().__repr__()})" - - -def _is_omegaconf_dict(value: Any) -> bool: - """Check if a value is an OmegaConf DictConfig. - - Parameters - ---------- - value : Any - The value to check. - - Returns - ------- - bool - True if the value is a DictConfig, False otherwise. - """ - try: - from omegaconf import DictConfig - - return isinstance(value, DictConfig) - except ImportError: - return False - - -def _is_omegaconf_list(value: Any) -> bool: - """Check if a value is an OmegaConf ListConfig. - - Parameters - ---------- - value : Any - The value to check. - - Returns - ------- - bool - True if the value is a ListConfig, False otherwise. - """ - try: - from omegaconf import ListConfig - - return isinstance(value, ListConfig) - except ImportError: - return False - def load_any_dict_format(path: str) -> dict: """Load a configuration file in any supported format: JSON, YAML and TOML. @@ -382,3 +276,10 @@ def check_config_mode(*args, **kwargs) -> None: from .settings import check_settings_mode check_settings_mode(*args, **kwargs) + + +if __name__ == "__main__": + a = DotDict({"a": 1, "b": {"c": 2}, "user": "${oc.env:HOME}"}) + print(a) + print(a.a) + print(a.user) From 41afbcfae2cd3f17347d9d0be6593a25a3466ab1 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 2 Oct 2025 10:53:46 +0100 Subject: [PATCH 4/6] code tidy --- src/anemoi/utils/config.py | 317 +++++++++++++++++------ src/anemoi/utils/mlflow/auth.py | 20 +- src/anemoi/utils/settings.py | 10 +- tests/{test_utils.py => test_dotdict.py} | 67 ++++- tests/test_grib.py | 31 +++ tests/test_mlflow_auth.py | 16 +- 6 files changed, 344 insertions(+), 117 deletions(-) rename tests/{test_utils.py => test_dotdict.py} (65%) create mode 100644 tests/test_grib.py diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index f49143c1..3ff668ed 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -49,127 +49,219 @@ class DotDict(omegaconf.dictconfig.DictConfig): >>> d = DotDict(a=1, b=2) """ - @classmethod - def from_file(cls, path: str) -> DotDict: - """Create a DotDict from a file. + def __init__( + self, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> None: + """Initialise a DotDict instance. Parameters ---------- - path : str - The path to the file. - - Returns - ------- - DotDict - The created DotDict. + *args : Any + Arguments to construct the dictionary. + resolve_interpolations : bool, optional + Whether to resolve interpolations, by default False. + cli_arguments : list of str, optional + CLI arguments to override values, by default None. + **kwargs : Any + Keyword arguments to construct the dictionary. """ - _, ext = os.path.splitext(path) - if ext == ".yaml" or ext == ".yml": - return cls.from_yaml_file(path) - elif ext == ".json": - return cls.from_json_file(path) - elif ext == ".toml": - return cls.from_toml_file(path) - else: - raise ValueError(f"Unknown file extension {ext}") + + d = omegaconf.OmegaConf.create(dict(*args, **kwargs)) + + if cli_arguments: + d = omegaconf.OmegaConf.merge(d, omegaconf.OmegaConf.from_cli(cli_arguments)) + + if resolve_interpolations: + d = omegaconf.OmegaConf.to_container(d, resolve=True) + + return super().__init__(d) @classmethod - def from_yaml_file(cls, path: str) -> DotDict: - """Create a DotDict from a YAML file. + def from_file( + cls: type["DotDict"], + path: str, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> "DotDict": + """Create a DotDict from a file.""" + _, ext = os.path.splitext(path) - Parameters - ---------- - path : str - The path to the YAML file. + match ext: + case ".yaml" | ".yml": + return cls.from_yaml_file( + path, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) + case ".json": + return cls.from_json_file( + path, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) + case ".toml": + return cls.from_toml_file( + path, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) + case _: + raise ValueError(f"Unknown file extension {ext}") - Returns - ------- - DotDict - The created DotDict. - """ + @classmethod + def from_yaml_file( + cls: type["DotDict"], + path: str, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> "DotDict": + """Create a DotDict from a YAML file.""" with open(path) as file: data = yaml.safe_load(file) - return cls(data) + return cls( + data, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) @classmethod - def from_json_file(cls, path: str) -> DotDict: - """Create a DotDict from a JSON file. - - Parameters - ---------- - path : str - The path to the JSON file. - - Returns - ------- - DotDict - The created DotDict. - """ + def from_json_file( + cls: type["DotDict"], + path: str, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> "DotDict": + """Create a DotDict from a JSON file.""" with open(path) as file: data = json.load(file) - return cls(data) + return cls( + data, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) @classmethod - def from_toml_file(cls, path: str) -> DotDict: - """Create a DotDict from a TOML file. + def from_toml_file( + cls: type["DotDict"], + path: str, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> "DotDict": + """Create a DotDict from a TOML file.""" + with open(path) as file: + data = tomllib.load(file) + + return cls( + data, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) + + def __repr__(self) -> str: + return f"DotDict({super().__repr__()})" + + def to_dict(self, *, resolve_interpolations: bool = True) -> dict: + """Convert the DotDict to a standard dictionary. Parameters ---------- - path : str - The path to the TOML file. + resolve_interpolations : bool, optional + Whether to resolve any interpolations, by default True. Returns ------- - DotDict - The created DotDict. + dict + The converted dictionary. """ - with open(path) as file: - data = tomllib.load(file) - return cls(data) - - def __getattr__(self, attr: str) -> Any: - """Get an attribute. + """Convert the DotDict to a standard dictionary. Parameters ---------- - attr : str - The attribute name. + resolve : bool, optional + Whether to resolve any interpolations, by default False. Returns ------- - Any - The attribute value. + dict + The converted dictionary. """ - try: - return self[attr] - except KeyError: - raise AttributeError(attr) + return omegaconf.OmegaConf.to_container(self, resolve=resolve_interpolations) - def __setattr__(self, attr: str, value: Any) -> None: - """Set an attribute. - Parameters - ---------- - attr : str - The attribute name. - value : Any - The attribute value. - """ - if isinstance(value, dict): - value = DotDict(value) - self[attr] = value +def load_any_dict_format(path: str) -> dict: + """Load a configuration file in any supported format: JSON, YAML, or TOML. - def __repr__(self) -> str: - """Return a string representation of the DotDict. + Parameters + ---------- + path : str + The path to the configuration file. - Returns - ------- - str - The string representation. - """ - return f"DotDict({super().__repr__()})" + Returns + ------- + dict + The decoded configuration file. + """ + + try: + if path.endswith(".json"): + with open(path, "rb") as f: + return json.load(f) + + if path.endswith(".yaml") or path.endswith(".yml"): + with open(path, "rb") as f: + return yaml.safe_load(f) + + if path.endswith(".toml"): + with open(path, "rb") as f: + return tomllib.load(f) + + if path == "-": + import sys + + config = sys.stdin.read() + + parsers = [(yaml.safe_load, "yaml"), (json.loads, "json"), (tomllib.loads, "toml")] + + for parser, parser_type in parsers: + try: + LOG.debug(f"Trying {parser_type} parser for stdin") + return parser(config) + except Exception: + pass + + raise ValueError("Failed to parse configuration from stdin") + + except (json.JSONDecodeError, yaml.YAMLError, tomllib.TOMLDecodeError) as e: + LOG.warning(f"Failed to parse config file {path}", exc_info=e) + raise ValueError(f"Failed to parse config file {path} [{e}]") + + return open(path).read() def find(metadata: dict | list, what: str, result: list = None, *, select: callable = None) -> list: @@ -217,6 +309,15 @@ def find(metadata: dict | list, what: str, result: list = None, *, select: calla details="Use anemoi.utils.settings.temporary_settings instead.", ) def temporary_config(*args, **kwargs) -> None: + """Deprecated. Use anemoi.utils.settings.temporary_settings instead. + + Parameters + ---------- + *args : Any + Arguments to pass to temporary_settings. + **kwargs : Any + Keyword arguments to pass to temporary_settings. + """ from .settings import temporary_settings return temporary_settings(*args, **kwargs) @@ -229,6 +330,20 @@ def temporary_config(*args, **kwargs) -> None: details="Use anemoi.utils.settings.load_settings instead.", ) def load_config(*args, **kwargs) -> DotDict | str: + """Deprecated. Use anemoi.utils.settings.load_settings instead. + + Parameters + ---------- + *args : Any + Arguments to pass to load_settings. + **kwargs : Any + Keyword arguments to pass to load_settings. + + Returns + ------- + DotDict or str + The loaded configuration. + """ from .settings import load_settings return load_settings(*args, **kwargs) @@ -241,6 +356,20 @@ def load_config(*args, **kwargs) -> DotDict | str: details="Use anemoi.utils.settings.settings_path instead.", ) def config_path(*args, **kwargs) -> str: + """Deprecated. Use anemoi.utils.settings.settings_path instead. + + Parameters + ---------- + *args : Any + Arguments to pass to settings_path. + **kwargs : Any + Keyword arguments to pass to settings_path. + + Returns + ------- + str + The settings path. + """ from .settings import settings_path return settings_path(*args, **kwargs) @@ -253,6 +382,15 @@ def config_path(*args, **kwargs) -> str: details="Use anemoi.utils.settings.save_settings instead.", ) def save_config(*args, **kwargs) -> None: + """Deprecated. Use anemoi.utils.settings.save_settings instead. + + Parameters + ---------- + *args : Any + Arguments to pass to save_settings. + **kwargs : Any + Keyword arguments to pass to save_settings. + """ from .settings import save_settings save_settings(*args, **kwargs) @@ -265,6 +403,15 @@ def save_config(*args, **kwargs) -> None: details="Use anemoi.utils.settings.load_settings instead.", ) def check_config_mode(*args, **kwargs) -> None: + """Deprecated. Use anemoi.utils.settings.check_settings_mode instead. + + Parameters + ---------- + *args : Any + Arguments to pass to check_settings_mode. + **kwargs : Any + Keyword arguments to pass to check_settings_mode. + """ from .settings import check_settings_mode check_settings_mode(*args, **kwargs) diff --git a/src/anemoi/utils/mlflow/auth.py b/src/anemoi/utils/mlflow/auth.py index 371ee614..50770383 100644 --- a/src/anemoi/utils/mlflow/auth.py +++ b/src/anemoi/utils/mlflow/auth.py @@ -29,11 +29,11 @@ from pydantic import model_validator from requests.exceptions import HTTPError -from ..config import CONFIG_LOCK -from ..config import config_path -from ..config import load_raw_config -from ..config import save_config from ..remote import robust +from ..settings import SETTINGS_LOCK +from ..settings import load_raw_settings +from ..settings import save_settings +from ..settings import settings_path from ..timer import Timer REFRESH_EXPIRE_DAYS = 29 @@ -189,17 +189,17 @@ def refresh_token(self, value: str) -> None: @staticmethod def _get_store() -> ServerStore: """Read the server store from disk.""" - with CONFIG_LOCK: + with SETTINGS_LOCK: file = TokenAuth._config_file - path = config_path(file) + path = settings_path(file) if not os.path.exists(path): - save_config(file, {}) + save_settings(file, {}) if os.path.exists(path) and os.stat(path).st_mode & 0o777 != 0o600: os.chmod(path, 0o600) - return ServerStore(**load_raw_config(file)) + return ServerStore(**load_raw_settings(file)) @staticmethod def get_servers() -> list[tuple[str, int]]: @@ -337,10 +337,10 @@ def save(self, **kwargs: dict) -> None: refresh_expires=self.refresh_expires, ) - with CONFIG_LOCK: + with SETTINGS_LOCK: store = self._get_store() store.update(self.url, server_config) - save_config(self._config_file, store.model_dump()) + save_settings(self._config_file, store.model_dump()) expire_date = datetime.fromtimestamp(self.refresh_expires, tz=timezone.utc) self.log.info( diff --git a/src/anemoi/utils/settings.py b/src/anemoi/utils/settings.py index c4774a72..40ed495f 100644 --- a/src/anemoi/utils/settings.py +++ b/src/anemoi/utils/settings.py @@ -31,7 +31,7 @@ CONFIG = {} CHECKED = {} -CONFIG_LOCK = threading.RLock() +SETTINGS_LOCK = threading.RLock() QUIET = False CONFIG_PATCH = None @@ -292,7 +292,7 @@ def save_settings(name: str, data: Any) -> None: data : Any The data to save. """ - with CONFIG_LOCK: + with SETTINGS_LOCK: _save_settings(name, data) @@ -340,7 +340,7 @@ def load_settings( Return DotDict if it is a dictionary, otherwise the raw data """ - with CONFIG_LOCK: + with SETTINGS_LOCK: config = _load_settings(name, secrets, defaults) if CONFIG_PATCH is not None: config = CONFIG_PATCH(config) @@ -386,7 +386,7 @@ def check_settings_mode(name: str = "settings.toml", secrets_name: str = None, s SystemError If the configuration file is not secure. """ - with CONFIG_LOCK: + with SETTINGS_LOCK: if name in CHECKED: return @@ -414,7 +414,7 @@ def temporary_settings(tmp: dict) -> None: def patch_config(config: dict) -> dict: return merge_configs(config, tmp) - with CONFIG_LOCK: + with SETTINGS_LOCK: CONFIG_PATCH = patch_config diff --git a/tests/test_utils.py b/tests/test_dotdict.py similarity index 65% rename from tests/test_utils.py rename to tests/test_dotdict.py index a81eefc0..45ec4efa 100644 --- a/tests/test_utils.py +++ b/tests/test_dotdict.py @@ -9,8 +9,6 @@ from anemoi.utils.config import DotDict -from anemoi.utils.grib import paramid_to_shortname -from anemoi.utils.grib import shortname_to_paramid from anemoi.utils.settings import _merge_dicts from anemoi.utils.settings import _set_defaults @@ -53,6 +51,13 @@ def test_add_nested_dict_via_setitem(): def test_adding_list_of_dicts_via_setitem(): + """Test that assigning a list of dicts via item access results in recursive DotDict conversion. + + Tests + ----- + - Assigning a list of dicts to a DotDict key. + - Ensuring each dict in the list is converted to DotDict. + """ d = DotDict(a=1) d["b"] = [ { @@ -65,6 +70,13 @@ def test_adding_list_of_dicts_via_setitem(): def test_adding_list_of_dicts_via_setattr(): + """Test that assigning a list of dicts via attribute access results in recursive DotDict conversion. + + Tests + ----- + - Assigning a list of dicts to a DotDict attribute. + - Ensuring each dict in the list is converted to DotDict. + """ d = DotDict(a=1) d.b = [ { @@ -102,15 +114,52 @@ def test_set_defaults() -> None: assert a == {"a": 1, "b": 2, "c": {"d": 3, "e": 4, "a": 30}, "d": 9} -def test_grib() -> None: - """Test the GRIB utility functions. +def test_interpolation() -> None: + """Test interpolation in DotDict using OmegaConf-like syntax. - Tests: - - Converting short names to parameter IDs. - - Converting parameter IDs to short names. + Tests + ----- + - Interpolating values from other keys using nested references. + """ + # From omegaconf documentation + + d = DotDict( + { + "plans": { + "A": "plan A", + "B": "plan B", + }, + "selected_plan": "A", + "plan": "${plans[${selected_plan}]}", + } + ) + + assert d.to_dict() == {"plan": "plan A", "plans": {"A": "plan A", "B": "plan B"}, "selected_plan": "A"} + + +def test_cli_arguments() -> None: + """Test that DotDict correctly applies CLI arguments to override values. + + Tests + ----- + - Overriding top-level and nested values using CLI arguments. """ - assert shortname_to_paramid("2t") == 167 - assert paramid_to_shortname(167) == "2t" + d = DotDict( + { + "a": 1, + "b": 2, + "c": { + "d": 3, + "e": 4, + }, + }, + cli_arguments=["a=10", "c.d=30"], + ) + + assert d.a == 10 + assert d.b == 2 + assert d.c.d == 30 + assert d.c.e == 4 if __name__ == "__main__": diff --git a/tests/test_grib.py b/tests/test_grib.py new file mode 100644 index 00000000..d367a14a --- /dev/null +++ b/tests/test_grib.py @@ -0,0 +1,31 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from anemoi.utils.grib import paramid_to_shortname +from anemoi.utils.grib import shortname_to_paramid + + +def test_grib() -> None: + """Test the GRIB utility functions. + + Tests: + - Converting short names to parameter IDs. + - Converting parameter IDs to short names. + """ + assert shortname_to_paramid("2t") == 167 + assert paramid_to_shortname(167) == "2t" + + +if __name__ == "__main__": + """Run all test functions.""" + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() diff --git a/tests/test_mlflow_auth.py b/tests/test_mlflow_auth.py index 25140f00..42553c4e 100644 --- a/tests/test_mlflow_auth.py +++ b/tests/test_mlflow_auth.py @@ -49,11 +49,11 @@ def mocks( return_value=response, ) mocker.patch( - "anemoi.utils.mlflow.auth.load_raw_config", + "anemoi.utils.mlflow.auth.load_raw_settings", return_value=config, ) mocker.patch( - "anemoi.utils.mlflow.auth.save_config", + "anemoi.utils.mlflow.auth.save_settings", ) mocker.patch( "anemoi.utils.mlflow.auth.getpass", @@ -198,7 +198,7 @@ def test_legacy_format(mocker: pytest.MockerFixture) -> None: } } mocker.patch( - "anemoi.utils.mlflow.auth.load_raw_config", + "anemoi.utils.mlflow.auth.load_raw_settings", return_value=legacy_config, ) @@ -247,7 +247,7 @@ def test_multi_server_format(mocker: pytest.MockerFixture, url: str, unknown: bo mocks(mocker) mocker.patch( - "anemoi.utils.mlflow.auth.load_raw_config", + "anemoi.utils.mlflow.auth.load_raw_settings", return_value=multi_config, ) @@ -278,11 +278,11 @@ def test_server_store() -> None: def test_utils_interface(): - """TokenAuth uses the utils CONFIG_LOCK when reading and writing the server store to ensure thread safety. - Ensure that CONFIG_LOCK stays a reentrant lock, if it were a normal lock it would deadlock itself. + """TokenAuth uses the utils SETTINGS_LOCK when reading and writing the server store to ensure thread safety. + Ensure that SETTINGS_LOCK stays a reentrant lock, if it were a normal lock it would deadlock itself. """ from threading import RLock - from anemoi.utils.config import CONFIG_LOCK + from anemoi.utils.settings import SETTINGS_LOCK - assert isinstance(CONFIG_LOCK, type(RLock())) + assert isinstance(SETTINGS_LOCK, type(RLock())) From eda24d1949ac07f84255e6cf3317666349b7172c Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 4 Oct 2025 14:09:32 +0100 Subject: [PATCH 5/6] support datatime in omegaconf --- src/anemoi/utils/config.py | 4 +++- tests/test_dotdict.py | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index 3ff668ed..7cbe90c3 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -70,7 +70,9 @@ def __init__( Keyword arguments to construct the dictionary. """ - d = omegaconf.OmegaConf.create(dict(*args, **kwargs)) + # Allow non-primitive types like datetime by enabling allow_objects + + d = omegaconf.OmegaConf.create(dict(*args, **kwargs), flags={"allow_objects": True}) if cli_arguments: d = omegaconf.OmegaConf.merge(d, omegaconf.OmegaConf.from_cli(cli_arguments)) diff --git a/tests/test_dotdict.py b/tests/test_dotdict.py index 45ec4efa..b17aea8d 100644 --- a/tests/test_dotdict.py +++ b/tests/test_dotdict.py @@ -162,6 +162,27 @@ def test_cli_arguments() -> None: assert d.c.e == 4 +def test_non_primitive_types() -> None: + """Test that DotDict can handle non-primitive types like datetime. + + Tests + ----- + - Assigning and accessing datetime objects in DotDict. + """ + from datetime import datetime + + now = datetime.now() + d = DotDict(a=now) + assert d.a == now + + d.b = {"time": now} + assert d.b.time == now + + d = DotDict() + d.c = [1, 2, {"time": now}] + assert d.c[2].time == now + + if __name__ == "__main__": """Run all test functions.""" for name, obj in list(globals().items()): From 102725258ebdb7e5ae3d2ddd373807d3b60dbc3d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 4 Oct 2025 14:38:02 +0100 Subject: [PATCH 6/6] support omegaconf in registry --- src/anemoi/utils/config.py | 6 +++--- src/anemoi/utils/registry.py | 8 +++++++- tests/test_dotdict.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index 7cbe90c3..329ba9fd 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -16,7 +16,7 @@ from typing import Any import deprecation -import omegaconf.dictconfig +import omegaconf import yaml from anemoi.utils._version import __version__ @@ -30,7 +30,7 @@ LOG = logging.getLogger(__name__) -class DotDict(omegaconf.dictconfig.DictConfig): +class DotDict(omegaconf.DictConfig): """A dictionary that allows access to its keys as attributes. >>> d = DotDict({"a": 1, "b": {"c": 2}}) @@ -188,7 +188,7 @@ def from_toml_file( def __repr__(self) -> str: return f"DotDict({super().__repr__()})" - def to_dict(self, *, resolve_interpolations: bool = True) -> dict: + def as_dict(self, *, resolve_interpolations: bool = True) -> dict: """Convert the DotDict to a standard dictionary. Parameters diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py index 6a7bbf6d..d09f66ac 100644 --- a/src/anemoi/utils/registry.py +++ b/src/anemoi/utils/registry.py @@ -350,11 +350,17 @@ def from_config(self, config: str | dict[str, Any], *args: Any, **kwargs: Any) - Any The created instance. """ + import omegaconf + if isinstance(config, str): config = {config: {}} + if isinstance(config, omegaconf.DictConfig): + # Allow DotDict and OmegaConf objects + config = omegaconf.OmegaConf.to_container(config, resolve=True) + if not isinstance(config, dict): - raise ValueError(f"Invalid config: {config}") + raise ValueError(f"Invalid config: {config} (type {type(config)})") if self.key in config: config = config.copy() diff --git a/tests/test_dotdict.py b/tests/test_dotdict.py index b17aea8d..2eda0c46 100644 --- a/tests/test_dotdict.py +++ b/tests/test_dotdict.py @@ -134,7 +134,7 @@ def test_interpolation() -> None: } ) - assert d.to_dict() == {"plan": "plan A", "plans": {"A": "plan A", "B": "plan B"}, "selected_plan": "A"} + assert d.as_dict() == {"plan": "plan A", "plans": {"A": "plan A", "B": "plan B"}, "selected_plan": "A"} def test_cli_arguments() -> None: