diff --git a/pyproject.toml b/pyproject.toml index a56edc99..0234811e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "importlib-metadata; python_version<'3.10'", "multiurl", "numpy", + "omegaconf", "pydantic>=2.9", "python-dateutil", "pyyaml", diff --git a/src/anemoi/utils/cli.py b/src/anemoi/utils/cli.py index 58698c7f..c2f822ee 100644 --- a/src/anemoi/utils/cli.py +++ b/src/anemoi/utils/cli.py @@ -14,6 +14,7 @@ import os import sys import traceback +import warnings from collections.abc import Callable from anemoi.utils import ENV @@ -225,6 +226,9 @@ def cli_main( test_arguments : list[str], optional The command line arguments to parse, used for testing purposes, by default None """ + + warnings.filterwarnings("default", category=DeprecationWarning) + parser = make_parser(description, commands) args, unknown = parser.parse_known_args(test_arguments) if argcomplete: diff --git a/src/anemoi/utils/commands/config.py b/src/anemoi/utils/commands/config.py index c0c80abd..cb196a95 100644 --- a/src/anemoi/utils/commands/config.py +++ b/src/anemoi/utils/commands/config.py @@ -12,8 +12,12 @@ from argparse import ArgumentParser from argparse import Namespace -from ..config import config_path -from ..config import load_config +import deprecation + +from anemoi.utils._version import __version__ + +from ..settings import load_settings +from ..settings import settings_path from . import Command @@ -30,6 +34,12 @@ def add_arguments(self, command_parser: ArgumentParser) -> None: """ command_parser.add_argument("--path", help="Print path to config file") + @deprecation.deprecated( + deprecated_in="0.1.0", + removed_in="0.2.0", + current_version=__version__, + details="Use `anemoi settings` instead.", + ) def run(self, args: Namespace) -> None: """Execute the command with the provided arguments. @@ -39,9 +49,9 @@ def run(self, args: Namespace) -> None: The arguments passed to the command. """ if args.path: - print(config_path()) + print(settings_path()) else: - print(json.dumps(load_config(), indent=4)) + print(json.dumps(load_settings(), indent=4)) command = Config diff --git a/src/anemoi/utils/commands/settings.py b/src/anemoi/utils/commands/settings.py new file mode 100644 index 00000000..992124c5 --- /dev/null +++ b/src/anemoi/utils/commands/settings.py @@ -0,0 +1,47 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import json +from argparse import ArgumentParser +from argparse import Namespace + +from ..settings import load_settings +from ..settings import settings_path +from . import Command + + +class Settings(Command): + """Handle settings related commands.""" + + def add_arguments(self, command_parser: ArgumentParser) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : ArgumentParser + The argument parser to which the arguments will be added. + """ + command_parser.add_argument("--path", help="Print path to config file") + + def run(self, args: Namespace) -> None: + """Execute the command with the provided arguments. + + Parameters + ---------- + args : Namespace + The arguments passed to the command. + """ + if args.path: + print(settings_path()) + else: + print(json.dumps(load_settings(), indent=4)) + + +command = Settings diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index 221eb4b8..329ba9fd 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -10,16 +10,16 @@ from __future__ import annotations -import contextlib import json import logging import os -import threading from typing import Any +import deprecation +import omegaconf import yaml -from anemoi.utils import ENV +from anemoi.utils._version import __version__ try: import tomllib # Only available since 3.11 @@ -30,7 +30,7 @@ LOG = logging.getLogger(__name__) -class DotDict(dict): +class DotDict(omegaconf.DictConfig): """A dictionary that allows access to its keys as attributes. >>> d = DotDict({"a": 1, "b": {"c": 2}}) @@ -49,334 +49,175 @@ class DotDict(dict): >>> d = DotDict(a=1, b=2) """ - def __init__(self, *args, **kwargs): - """Initialize a DotDict instance. + def __init__( + self, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> None: + """Initialise a DotDict instance. Parameters ---------- - *args : tuple - Positional arguments for the dict constructor. - **kwargs : dict - Keyword arguments for the dict constructor. + *args : Any + Arguments to construct the dictionary. + resolve_interpolations : bool, optional + Whether to resolve interpolations, by default False. + cli_arguments : list of str, optional + CLI arguments to override values, by default None. + **kwargs : Any + Keyword arguments to construct the dictionary. """ - super().__init__(*args, **kwargs) - for k, v in self.items(): - super().__setitem__(k, self.convert_to_nested_dot_dict(v)) + # Allow non-primitive types like datetime by enabling allow_objects - @staticmethod - def convert_to_nested_dot_dict(value): - if isinstance(value, dict) or is_omegaconf_dict(value): - return DotDict(value) + d = omegaconf.OmegaConf.create(dict(*args, **kwargs), flags={"allow_objects": True}) - if isinstance(value, list) or is_omegaconf_list(value): - return [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in value] + if cli_arguments: + d = omegaconf.OmegaConf.merge(d, omegaconf.OmegaConf.from_cli(cli_arguments)) - if isinstance(value, tuple): - return [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in value] + if resolve_interpolations: + d = omegaconf.OmegaConf.to_container(d, resolve=True) - return value + return super().__init__(d) @classmethod - def from_file(cls, path: str) -> DotDict: - """Create a DotDict from a file. - - Parameters - ---------- - path : str - The path to the file. - - Returns - ------- - DotDict - The created DotDict. - """ + def from_file( + cls: type["DotDict"], + path: str, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> "DotDict": + """Create a DotDict from a file.""" _, ext = os.path.splitext(path) - if ext == ".yaml" or ext == ".yml": - return cls.from_yaml_file(path) - elif ext == ".json": - return cls.from_json_file(path) - elif ext == ".toml": - return cls.from_toml_file(path) - else: - raise ValueError(f"Unknown file extension {ext}") - @classmethod - def from_yaml_file(cls, path: str) -> DotDict: - """Create a DotDict from a YAML file. - - Parameters - ---------- - path : str - The path to the YAML file. + match ext: + case ".yaml" | ".yml": + return cls.from_yaml_file( + path, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) + case ".json": + return cls.from_json_file( + path, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) + case ".toml": + return cls.from_toml_file( + path, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) + case _: + raise ValueError(f"Unknown file extension {ext}") - Returns - ------- - DotDict - The created DotDict. - """ + @classmethod + def from_yaml_file( + cls: type["DotDict"], + path: str, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> "DotDict": + """Create a DotDict from a YAML file.""" with open(path) as file: data = yaml.safe_load(file) - return cls(data) + return cls( + data, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) @classmethod - def from_json_file(cls, path: str) -> DotDict: - """Create a DotDict from a JSON file. - - Parameters - ---------- - path : str - The path to the JSON file. - - Returns - ------- - DotDict - The created DotDict. - """ + def from_json_file( + cls: type["DotDict"], + path: str, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> "DotDict": + """Create a DotDict from a JSON file.""" with open(path) as file: data = json.load(file) - return cls(data) + return cls( + data, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) @classmethod - def from_toml_file(cls, path: str) -> DotDict: - """Create a DotDict from a TOML file. - - Parameters - ---------- - path : str - The path to the TOML file. - - Returns - ------- - DotDict - The created DotDict. - """ + def from_toml_file( + cls: type["DotDict"], + path: str, + *args: Any, + resolve_interpolations: bool = False, + cli_arguments: list[str] | None = None, + **kwargs: Any, + ) -> "DotDict": + """Create a DotDict from a TOML file.""" with open(path) as file: data = tomllib.load(file) - return cls(data) - def __getattr__(self, attr: str) -> Any: - """Get an attribute. + return cls( + data, + *args, + resolve_interpolations=resolve_interpolations, + cli_arguments=cli_arguments, + **kwargs, + ) + + def __repr__(self) -> str: + return f"DotDict({super().__repr__()})" + + def as_dict(self, *, resolve_interpolations: bool = True) -> dict: + """Convert the DotDict to a standard dictionary. Parameters ---------- - attr : str - The attribute name. + resolve_interpolations : bool, optional + Whether to resolve any interpolations, by default True. Returns ------- - Any - The attribute value. + dict + The converted dictionary. """ - try: - return self[attr] - except KeyError: - raise AttributeError(attr) - - def __setattr__(self, attr: str, value: Any) -> None: - """Set an attribute. - - Parameters - ---------- - attr : str - The attribute name. - value : Any - The attribute value. - """ - - value = self.convert_to_nested_dot_dict(value) - super().__setitem__(attr, value) - - def __setitem__(self, key: str, value: Any) -> None: - """Set an item in the dictionary. + """Convert the DotDict to a standard dictionary. Parameters ---------- - key : str - The key to set. - value : Any - The value to set. - """ - value = self.convert_to_nested_dot_dict(value) - super().__setitem__(key, value) - - def __repr__(self) -> str: - """Return a string representation of the DotDict. + resolve : bool, optional + Whether to resolve any interpolations, by default False. Returns ------- - str - The string representation. + dict + The converted dictionary. """ - return f"DotDict({super().__repr__()})" - - -def is_omegaconf_dict(value: Any) -> bool: - """Check if a value is an OmegaConf DictConfig. - - Parameters - ---------- - value : Any - The value to check. - - Returns - ------- - bool - True if the value is a DictConfig, False otherwise. - """ - try: - from omegaconf import DictConfig - - return isinstance(value, DictConfig) - except ImportError: - return False - - -def is_omegaconf_list(value: Any) -> bool: - """Check if a value is an OmegaConf ListConfig. - - Parameters - ---------- - value : Any - The value to check. - - Returns - ------- - bool - True if the value is a ListConfig, False otherwise. - """ - try: - from omegaconf import ListConfig - - return isinstance(value, ListConfig) - except ImportError: - return False - - -CONFIG = {} -CHECKED = {} -CONFIG_LOCK = threading.RLock() -QUIET = False -CONFIG_PATCH = None - - -def _find(config: dict | list, what: str, result: list = None) -> list: - """Find all occurrences of a key in a nested dictionary or list. - - Parameters - ---------- - config : dict or list - The configuration to search. - what : str - The key to search for. - result : list, optional - The list to store results, by default None. - - Returns - ------- - list - The list of found values. - """ - if result is None: - result = [] - - if isinstance(config, list): - for i in config: - _find(i, what, result) - return result - - if isinstance(config, dict): - if what in config: - result.append(config[what]) - - for k, v in config.items(): - _find(v, what, result) - - return result - - -def _merge_dicts(a: dict, b: dict) -> None: - """Merge two dictionaries recursively. - - Parameters - ---------- - a : dict - The first dictionary. - b : dict - The second dictionary. - """ - for k, v in b.items(): - if k in a and isinstance(a[k], dict) and isinstance(v, dict): - _merge_dicts(a[k], v) - else: - a[k] = v - - -def _set_defaults(a: dict, b: dict) -> None: - """Set default values in a dictionary. - - Parameters - ---------- - a : dict - The dictionary to set defaults in. - b : dict - The dictionary with default values. - """ - for k, v in b.items(): - if k in a and isinstance(a[k], dict) and isinstance(v, dict): - _set_defaults(a[k], v) - else: - a.setdefault(k, v) - - -def config_path(name: str = "settings.toml") -> str: - """Get the path to a configuration file. - - Parameters - ---------- - name : str, optional - The name of the configuration file, by default "settings.toml". - - Returns - ------- - str - The path to the configuration file. - """ - global QUIET - - if name.startswith("/") or name.startswith("."): - return name - - if name.startswith("~"): - return os.path.expanduser(name) - - full = os.path.join(os.path.expanduser("~"), ".config", "anemoi", name) - os.makedirs(os.path.dirname(full), exist_ok=True) - - if name == "settings.toml": - old = os.path.join(os.path.expanduser("~"), ".anemoi.toml") - if not os.path.exists(full) and os.path.exists(old): - if not QUIET: - LOG.warning( - "Configuration file found at ~/.anemoi.toml. Please move it to ~/.config/anemoi/settings.toml" - ) - QUIET = True - return old - else: - if os.path.exists(old): - if not QUIET: - LOG.warning( - "Configuration file found at ~/.anemoi.toml and ~/.config/anemoi/settings.toml, ignoring the former" - ) - QUIET = True - - return full + return omegaconf.OmegaConf.to_container(self, resolve=resolve_interpolations) def load_any_dict_format(path: str) -> dict: - """Load a configuration file in any supported format: JSON, YAML and TOML. + """Load a configuration file in any supported format: JSON, YAML, or TOML. Parameters ---------- @@ -425,292 +266,161 @@ def load_any_dict_format(path: str) -> dict: return open(path).read() -def _load_config( - name: str = "settings.toml", - secrets: str | list[str] | None = None, - defaults: str | dict | None = None, -) -> DotDict: - """Load a configuration file. +def find(metadata: dict | list, what: str, result: list = None, *, select: callable = None) -> list: + """Find all occurrences of a key in a nested dictionary or list with an optional selector. Parameters ---------- - name : str, optional - The name of the configuration file, by default "settings.toml". - secrets : str or list, optional - The name of the secrets file, by default None. - defaults : str or dict, optional - The name of the defaults file, by default None. + metadata : dict or list + The metadata to search. + what : str + The key to search for. + result : list, optional + The list to store results, by default None. + select : callable, optional + A function to filter the results, by default None. Returns ------- - DotDict - The loaded configuration. + list + The list of found values. """ - key = json.dumps((name, secrets, defaults), sort_keys=True, default=str) - if key in CONFIG: - return CONFIG[key] - - path = config_path(name) - if os.path.exists(path): - config = load_any_dict_format(path) - else: - config = {} - - if defaults is not None: - if isinstance(defaults, str): - defaults = load_raw_config(defaults) - _set_defaults(config, defaults) - - if secrets is not None: - if isinstance(secrets, str): - secrets = [secrets] - - base, ext = os.path.splitext(path) - secret_name = base + ".secrets" + ext - - found = set() - for secret in secrets: - if _find(config, secret): - found.add(secret) - - if found: - check_config_mode(name, secret_name, found) - - check_config_mode(secret_name, None) - secret_config = _load_config(secret_name) - _merge_dicts(config, secret_config) - - if ENV.ANEMOI_CONFIG_OVERRIDE_PATH is not None: - override_config = load_any_dict_format(os.path.abspath(ENV.ANEMOI_CONFIG_OVERRIDE_PATH)) - config = merge_configs(config, override_config) - - for env, value in os.environ.items(): - - if not env.startswith("ANEMOI_CONFIG_"): - continue - rest = env[len("ANEMOI_CONFIG_") :] - - package = rest.split("_")[0] - sub = rest[len(package) + 1 :] + if result is None: + result = [] - package = package.lower() - sub = sub.lower() + if isinstance(metadata, list): + for i in metadata: + find(i, what, result) + return result - LOG.info(f"Using environment variable {env} to override the anemoi config key '{package}.{sub}'") + if isinstance(metadata, dict): + if what in metadata: + if select is None or select(metadata[what]): + result.append(metadata[what]) - if package not in config: - config[package] = {} - config[package][sub] = value + for k, v in metadata.items(): + find(v, what, result) - CONFIG[key] = DotDict(config) - return CONFIG[key] + return result -def _save_config(name: str, data: Any) -> None: - """Save a configuration file. +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.temporary_settings instead.", +) +def temporary_config(*args, **kwargs) -> None: + """Deprecated. Use anemoi.utils.settings.temporary_settings instead. Parameters ---------- - name : str - The name of the configuration file. - data : Any - The data to save. + *args : Any + Arguments to pass to temporary_settings. + **kwargs : Any + Keyword arguments to pass to temporary_settings. """ - CONFIG.pop(name, None) - - conf = config_path(name) + from .settings import temporary_settings - if conf.endswith(".json"): - with open(conf, "w") as f: - json.dump(data, f, indent=4) - return + return temporary_settings(*args, **kwargs) - if conf.endswith(".yaml") or conf.endswith(".yml"): - with open(conf, "w") as f: - yaml.dump(data, f) - return - if conf.endswith(".toml"): - raise NotImplementedError("Saving to TOML is not implemented yet") - - with open(conf, "w") as f: - f.write(data) - - -def save_config(name: str, data: Any) -> None: - """Save a configuration file. +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.load_settings instead.", +) +def load_config(*args, **kwargs) -> DotDict | str: + """Deprecated. Use anemoi.utils.settings.load_settings instead. Parameters ---------- - name : str - The name of the configuration file to save. - - data : Any - The data to save. - """ - with CONFIG_LOCK: - _save_config(name, data) - - -def load_config( - name: str = "settings.toml", - secrets: str | list[str] | None = None, - defaults: str | dict | None = None, -) -> DotDict | str: - """Read a configuration file. - - Parameters - ---------- - name : str, optional - The name of the config file to read, by default "settings.toml" - secrets : str or list, optional - The name of the secrets file, by default None - defaults : str or dict, optional - The name of the defaults file, by default None + *args : Any + Arguments to pass to load_settings. + **kwargs : Any + Keyword arguments to pass to load_settings. Returns ------- DotDict or str - Return DotDict if it is a dictionary, otherwise the raw data + The loaded configuration. """ + from .settings import load_settings - with CONFIG_LOCK: - config = _load_config(name, secrets, defaults) - if CONFIG_PATCH is not None: - config = CONFIG_PATCH(config) - return config + return load_settings(*args, **kwargs) -def load_raw_config(name: str, default: Any = None) -> DotDict | str: - """Load a raw configuration file. +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.settings_path instead.", +) +def config_path(*args, **kwargs) -> str: + """Deprecated. Use anemoi.utils.settings.settings_path instead. Parameters ---------- - name : str - The name of the configuration file. - default : Any, optional - The default value if the file does not exist, by default None. + *args : Any + Arguments to pass to settings_path. + **kwargs : Any + Keyword arguments to pass to settings_path. Returns ------- - DotDict or str - The loaded configuration or the default value. + str + The settings path. """ - path = config_path(name) - if os.path.exists(path): - return load_any_dict_format(path) + from .settings import settings_path - return default + return settings_path(*args, **kwargs) -def check_config_mode(name: str = "settings.toml", secrets_name: str = None, secrets: list[str] = None) -> None: - """Check that a configuration file is secure. +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.save_settings instead.", +) +def save_config(*args, **kwargs) -> None: + """Deprecated. Use anemoi.utils.settings.save_settings instead. Parameters ---------- - name : str, optional - The name of the configuration file, by default "settings.toml" - secrets_name : str, optional - The name of the secrets file, by default None - secrets : list, optional - The list of secrets to check, by default None - - Raises - ------ - SystemError - If the configuration file is not secure. + *args : Any + Arguments to pass to save_settings. + **kwargs : Any + Keyword arguments to pass to save_settings. """ - with CONFIG_LOCK: - if name in CHECKED: - return - - conf = config_path(name) - if not os.path.exists(conf): - return - mode = os.stat(conf).st_mode - if mode & 0o777 != 0o600: - if secrets_name: - secret_path = config_path(secrets_name) - raise SystemError( - f"Configuration file {conf} should not hold entries {secrets}.\n" - f"Please move them to {secret_path}." - ) - raise SystemError(f"Configuration file {conf} is not secure.\n" f"Please run `chmod 600 {conf}`.") + from .settings import save_settings - CHECKED[name] = True + save_settings(*args, **kwargs) -def find(metadata: dict | list, what: str, result: list = None, *, select: callable = None) -> list: - """Find all occurrences of a key in a nested dictionary or list with an optional selector. +@deprecation.deprecated( + deprecated_in="0.4.30", + removed_in="0.5.0", + current_version=__version__, + details="Use anemoi.utils.settings.load_settings instead.", +) +def check_config_mode(*args, **kwargs) -> None: + """Deprecated. Use anemoi.utils.settings.check_settings_mode instead. Parameters ---------- - metadata : dict or list - The metadata to search. - what : str - The key to search for. - result : list, optional - The list to store results, by default None. - select : callable, optional - A function to filter the results, by default None. - - Returns - ------- - list - The list of found values. + *args : Any + Arguments to pass to check_settings_mode. + **kwargs : Any + Keyword arguments to pass to check_settings_mode. """ - if result is None: - result = [] - - if isinstance(metadata, list): - for i in metadata: - find(i, what, result) - return result - - if isinstance(metadata, dict): - if what in metadata: - if select is None or select(metadata[what]): - result.append(metadata[what]) - - for k, v in metadata.items(): - find(v, what, result) - - return result - - -def merge_configs(*configs: dict) -> dict: - """Merge multiple configuration dictionaries. - - Parameters - ---------- - *configs : dict - The configuration dictionaries to merge. - - Returns - ------- - dict - The merged configuration dictionary. - """ - result = {} - for config in configs: - _merge_dicts(result, config) - - return result - - -@contextlib.contextmanager -def temporary_config(tmp: dict) -> None: - - global CONFIG_PATCH - - def patch_config(config: dict) -> dict: - return merge_configs(config, tmp) + from .settings import check_settings_mode - with CONFIG_LOCK: + check_settings_mode(*args, **kwargs) - CONFIG_PATCH = patch_config - try: - yield - finally: - CONFIG_PATCH = None +if __name__ == "__main__": + a = DotDict({"a": 1, "b": {"c": 2}, "user": "${oc.env:HOME}"}) + print(a) + print(a.a) + print(a.user) diff --git a/src/anemoi/utils/mlflow/auth.py b/src/anemoi/utils/mlflow/auth.py index 371ee614..50770383 100644 --- a/src/anemoi/utils/mlflow/auth.py +++ b/src/anemoi/utils/mlflow/auth.py @@ -29,11 +29,11 @@ from pydantic import model_validator from requests.exceptions import HTTPError -from ..config import CONFIG_LOCK -from ..config import config_path -from ..config import load_raw_config -from ..config import save_config from ..remote import robust +from ..settings import SETTINGS_LOCK +from ..settings import load_raw_settings +from ..settings import save_settings +from ..settings import settings_path from ..timer import Timer REFRESH_EXPIRE_DAYS = 29 @@ -189,17 +189,17 @@ def refresh_token(self, value: str) -> None: @staticmethod def _get_store() -> ServerStore: """Read the server store from disk.""" - with CONFIG_LOCK: + with SETTINGS_LOCK: file = TokenAuth._config_file - path = config_path(file) + path = settings_path(file) if not os.path.exists(path): - save_config(file, {}) + save_settings(file, {}) if os.path.exists(path) and os.stat(path).st_mode & 0o777 != 0o600: os.chmod(path, 0o600) - return ServerStore(**load_raw_config(file)) + return ServerStore(**load_raw_settings(file)) @staticmethod def get_servers() -> list[tuple[str, int]]: @@ -337,10 +337,10 @@ def save(self, **kwargs: dict) -> None: refresh_expires=self.refresh_expires, ) - with CONFIG_LOCK: + with SETTINGS_LOCK: store = self._get_store() store.update(self.url, server_config) - save_config(self._config_file, store.model_dump()) + save_settings(self._config_file, store.model_dump()) expire_date = datetime.fromtimestamp(self.refresh_expires, tz=timezone.utc) self.log.info( diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py index 6a7bbf6d..d09f66ac 100644 --- a/src/anemoi/utils/registry.py +++ b/src/anemoi/utils/registry.py @@ -350,11 +350,17 @@ def from_config(self, config: str | dict[str, Any], *args: Any, **kwargs: Any) - Any The created instance. """ + import omegaconf + if isinstance(config, str): config = {config: {}} + if isinstance(config, omegaconf.DictConfig): + # Allow DotDict and OmegaConf objects + config = omegaconf.OmegaConf.to_container(config, resolve=True) + if not isinstance(config, dict): - raise ValueError(f"Invalid config: {config}") + raise ValueError(f"Invalid config: {config} (type {type(config)})") if self.key in config: config = config.copy() diff --git a/src/anemoi/utils/settings.py b/src/anemoi/utils/settings.py new file mode 100644 index 00000000..40ed495f --- /dev/null +++ b/src/anemoi/utils/settings.py @@ -0,0 +1,424 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from __future__ import annotations + +import contextlib +import json +import logging +import os +import threading +from typing import Any +from typing import Optional +from typing import Union + +import yaml + +from anemoi.utils import ENV + +from .config import DotDict +from .config import load_any_dict_format + +LOG = logging.getLogger(__name__) + + +CONFIG = {} +CHECKED = {} +SETTINGS_LOCK = threading.RLock() +QUIET = False +CONFIG_PATCH = None + + +def _find(config: Union[dict, list], what: str, result: list = None) -> list: + """Find all occurrences of a key in a nested dictionary or list. + + Parameters + ---------- + config : dict or list + The configuration to search. + what : str + The key to search for. + result : list, optional + The list to store results, by default None. + + Returns + ------- + list + The list of found values. + """ + if result is None: + result = [] + + if isinstance(config, list): + for i in config: + _find(i, what, result) + return result + + if isinstance(config, dict): + if what in config: + result.append(config[what]) + + for k, v in config.items(): + _find(v, what, result) + + return result + + +def _merge_dicts(a: dict, b: dict) -> None: + """Merge two dictionaries recursively. + + Parameters + ---------- + a : dict + The first dictionary. + b : dict + The second dictionary. + """ + for k, v in b.items(): + if k in a and isinstance(a[k], dict) and isinstance(v, dict): + _merge_dicts(a[k], v) + else: + a[k] = v + + +def _set_defaults(a: dict, b: dict) -> None: + """Set default values in a dictionary. + + Parameters + ---------- + a : dict + The dictionary to set defaults in. + b : dict + The dictionary with default values. + """ + for k, v in b.items(): + if k in a and isinstance(a[k], dict) and isinstance(v, dict): + _set_defaults(a[k], v) + else: + a.setdefault(k, v) + + +def merge_configs(*configs: dict) -> dict: + """Merge multiple configuration dictionaries. + + Parameters + ---------- + *configs : dict + The configuration dictionaries to merge. + + Returns + ------- + dict + The merged configuration dictionary. + """ + result = {} + for config in configs: + _merge_dicts(result, config) + + return result + + +def settings_path(name: str = "settings.toml") -> str: + """Get the path to a configuration file. + + Parameters + ---------- + name : str, optional + The name of the configuration file, by default "settings.toml". + + Returns + ------- + str + The path to the configuration file. + """ + global QUIET + + if name.startswith("/") or name.startswith("."): + return name + + if name.startswith("~"): + return os.path.expanduser(name) + + full = os.path.join(os.path.expanduser("~"), ".config", "anemoi", name) + os.makedirs(os.path.dirname(full), exist_ok=True) + + if name == "settings.toml": + old = os.path.join(os.path.expanduser("~"), ".anemoi.toml") + if not os.path.exists(full) and os.path.exists(old): + if not QUIET: + LOG.warning( + "Configuration file found at ~/.anemoi.toml. Please move it to ~/.config/anemoi/settings.toml" + ) + QUIET = True + return old + else: + if os.path.exists(old): + if not QUIET: + LOG.warning( + "Configuration file found at ~/.anemoi.toml and ~/.config/anemoi/settings.toml, ignoring the former" + ) + QUIET = True + + return full + + +def _load_settings( + name: str = "settings.toml", + secrets: Optional[Union[str, list[str]]] = None, + defaults: Optional[Union[str, dict]] = None, +) -> DotDict: + """Load a configuration file. + + Parameters + ---------- + name : str, optional + The name of the configuration file, by default "settings.toml". + secrets : str or list, optional + The name of the secrets file, by default None. + defaults : str or dict, optional + The name of the defaults file, by default None. + + Returns + ------- + DotDict + The loaded configuration. + """ + key = json.dumps((name, secrets, defaults), sort_keys=True, default=str) + if key in CONFIG: + return CONFIG[key] + + path = settings_path(name) + if os.path.exists(path): + config = load_any_dict_format(path) + else: + config = {} + + if defaults is not None: + if isinstance(defaults, str): + defaults = load_raw_settings(defaults) + _set_defaults(config, defaults) + + if secrets is not None: + if isinstance(secrets, str): + secrets = [secrets] + + base, ext = os.path.splitext(path) + secret_name = base + ".secrets" + ext + + found = set() + for secret in secrets: + if _find(config, secret): + found.add(secret) + + if found: + check_settings_mode(name, secret_name, found) + + check_settings_mode(secret_name, None) + secret_config = _load_settings(secret_name) + _merge_dicts(config, secret_config) + + if ENV.ANEMOI_CONFIG_OVERRIDE_PATH is not None: + override_config = load_any_dict_format(os.path.abspath(ENV.ANEMOI_CONFIG_OVERRIDE_PATH)) + config = merge_configs(config, override_config) + + for env, value in os.environ.items(): + + if not env.startswith("ANEMOI_CONFIG_"): + continue + rest = env[len("ANEMOI_CONFIG_") :] + + package = rest.split("_")[0] + sub = rest[len(package) + 1 :] + + package = package.lower() + sub = sub.lower() + + LOG.info(f"Using environment variable {env} to override the anemoi config key '{package}.{sub}'") + + if package not in config: + config[package] = {} + config[package][sub] = value + + CONFIG[key] = DotDict(config) + return CONFIG[key] + + +def _save_settings(name: str, data: Any) -> None: + """Save a configuration file. + + Parameters + ---------- + name : str + The name of the configuration file. + data : Any + The data to save. + """ + CONFIG.pop(name, None) + + conf = settings_path(name) + + if conf.endswith(".json"): + with open(conf, "w") as f: + json.dump(data, f, indent=4) + return + + if conf.endswith(".yaml") or conf.endswith(".yml"): + with open(conf, "w") as f: + yaml.dump(data, f) + return + + if conf.endswith(".toml"): + raise NotImplementedError("Saving to TOML is not implemented yet") + + with open(conf, "w") as f: + f.write(data) + + +def save_settings(name: str, data: Any) -> None: + """Save a configuration file. + + Parameters + ---------- + name : str + The name of the configuration file to save. + + data : Any + The data to save. + """ + with SETTINGS_LOCK: + _save_settings(name, data) + + +def load_raw_settings(name: str, default: Any = None) -> Union[DotDict, str]: + """Load a raw configuration file. + + Parameters + ---------- + name : str + The name of the configuration file. + default : Any, optional + The default value if the file does not exist, by default None. + + Returns + ------- + DotDict or str + The loaded configuration or the default value. + """ + path = settings_path(name) + if os.path.exists(path): + return load_any_dict_format(path) + + return default + + +def load_settings( + name: str = "settings.toml", + secrets: Optional[Union[str, list[str]]] = None, + defaults: Optional[Union[str, dict]] = None, +) -> DotDict | str: + """Read a configuration file. + + Parameters + ---------- + name : str, optional + The name of the config file to read, by default "settings.toml" + secrets : str or list, optional + The name of the secrets file, by default None + defaults : str or dict, optional + The name of the defaults file, by default None + + Returns + ------- + DotDict or str + Return DotDict if it is a dictionary, otherwise the raw data + """ + + with SETTINGS_LOCK: + config = _load_settings(name, secrets, defaults) + if CONFIG_PATCH is not None: + config = CONFIG_PATCH(config) + return config + + +def _load_raw_settings(name: str, default: Any = None) -> Union[DotDict, str]: + """Load a raw configuration file. + + Parameters + ---------- + name : str + The name of the configuration file. + default : Any, optional + The default value if the file does not exist, by default None. + + Returns + ------- + DotDict or str + The loaded configuration or the default value. + """ + path = settings_path(name) + if os.path.exists(path): + return load_any_dict_format(path) + + return default + + +def check_settings_mode(name: str = "settings.toml", secrets_name: str = None, secrets: list[str] = None) -> None: + """Check that a configuration file is secure. + + Parameters + ---------- + name : str, optional + The name of the configuration file, by default "settings.toml" + secrets_name : str, optional + The name of the secrets file, by default None + secrets : list, optional + The list of secrets to check, by default None + + Raises + ------ + SystemError + If the configuration file is not secure. + """ + with SETTINGS_LOCK: + if name in CHECKED: + return + + conf = settings_path(name) + if not os.path.exists(conf): + return + mode = os.stat(conf).st_mode + if mode & 0o777 != 0o600: + if secrets_name: + secret_path = settings_path(secrets_name) + raise SystemError( + f"Configuration file {conf} should not hold entries {secrets}.\n" + f"Please move them to {secret_path}." + ) + raise SystemError(f"Configuration file {conf} is not secure.\n" f"Please run `chmod 600 {conf}`.") + + CHECKED[name] = True + + +@contextlib.contextmanager +def temporary_settings(tmp: dict) -> None: + + global CONFIG_PATCH + + def patch_config(config: dict) -> dict: + return merge_configs(config, tmp) + + with SETTINGS_LOCK: + + CONFIG_PATCH = patch_config + + try: + yield + finally: + CONFIG_PATCH = None diff --git a/tests/test_utils.py b/tests/test_dotdict.py similarity index 57% rename from tests/test_utils.py rename to tests/test_dotdict.py index d77db2c7..2eda0c46 100644 --- a/tests/test_utils.py +++ b/tests/test_dotdict.py @@ -9,10 +9,8 @@ from anemoi.utils.config import DotDict -from anemoi.utils.config import _merge_dicts -from anemoi.utils.config import _set_defaults -from anemoi.utils.grib import paramid_to_shortname -from anemoi.utils.grib import shortname_to_paramid +from anemoi.utils.settings import _merge_dicts +from anemoi.utils.settings import _set_defaults def test_dotdict() -> None: @@ -53,6 +51,13 @@ def test_add_nested_dict_via_setitem(): def test_adding_list_of_dicts_via_setitem(): + """Test that assigning a list of dicts via item access results in recursive DotDict conversion. + + Tests + ----- + - Assigning a list of dicts to a DotDict key. + - Ensuring each dict in the list is converted to DotDict. + """ d = DotDict(a=1) d["b"] = [ { @@ -65,6 +70,13 @@ def test_adding_list_of_dicts_via_setitem(): def test_adding_list_of_dicts_via_setattr(): + """Test that assigning a list of dicts via attribute access results in recursive DotDict conversion. + + Tests + ----- + - Assigning a list of dicts to a DotDict attribute. + - Ensuring each dict in the list is converted to DotDict. + """ d = DotDict(a=1) d.b = [ { @@ -102,15 +114,73 @@ def test_set_defaults() -> None: assert a == {"a": 1, "b": 2, "c": {"d": 3, "e": 4, "a": 30}, "d": 9} -def test_grib() -> None: - """Test the GRIB utility functions. +def test_interpolation() -> None: + """Test interpolation in DotDict using OmegaConf-like syntax. - Tests: - - Converting short names to parameter IDs. - - Converting parameter IDs to short names. + Tests + ----- + - Interpolating values from other keys using nested references. """ - assert shortname_to_paramid("2t") == 167 - assert paramid_to_shortname(167) == "2t" + # From omegaconf documentation + + d = DotDict( + { + "plans": { + "A": "plan A", + "B": "plan B", + }, + "selected_plan": "A", + "plan": "${plans[${selected_plan}]}", + } + ) + + assert d.as_dict() == {"plan": "plan A", "plans": {"A": "plan A", "B": "plan B"}, "selected_plan": "A"} + + +def test_cli_arguments() -> None: + """Test that DotDict correctly applies CLI arguments to override values. + + Tests + ----- + - Overriding top-level and nested values using CLI arguments. + """ + d = DotDict( + { + "a": 1, + "b": 2, + "c": { + "d": 3, + "e": 4, + }, + }, + cli_arguments=["a=10", "c.d=30"], + ) + + assert d.a == 10 + assert d.b == 2 + assert d.c.d == 30 + assert d.c.e == 4 + + +def test_non_primitive_types() -> None: + """Test that DotDict can handle non-primitive types like datetime. + + Tests + ----- + - Assigning and accessing datetime objects in DotDict. + """ + from datetime import datetime + + now = datetime.now() + d = DotDict(a=now) + assert d.a == now + + d.b = {"time": now} + assert d.b.time == now + + d = DotDict() + d.c = [1, 2, {"time": now}] + assert d.c[2].time == now if __name__ == "__main__": diff --git a/tests/test_grib.py b/tests/test_grib.py new file mode 100644 index 00000000..d367a14a --- /dev/null +++ b/tests/test_grib.py @@ -0,0 +1,31 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from anemoi.utils.grib import paramid_to_shortname +from anemoi.utils.grib import shortname_to_paramid + + +def test_grib() -> None: + """Test the GRIB utility functions. + + Tests: + - Converting short names to parameter IDs. + - Converting parameter IDs to short names. + """ + assert shortname_to_paramid("2t") == 167 + assert paramid_to_shortname(167) == "2t" + + +if __name__ == "__main__": + """Run all test functions.""" + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() diff --git a/tests/test_mlflow_auth.py b/tests/test_mlflow_auth.py index 25140f00..42553c4e 100644 --- a/tests/test_mlflow_auth.py +++ b/tests/test_mlflow_auth.py @@ -49,11 +49,11 @@ def mocks( return_value=response, ) mocker.patch( - "anemoi.utils.mlflow.auth.load_raw_config", + "anemoi.utils.mlflow.auth.load_raw_settings", return_value=config, ) mocker.patch( - "anemoi.utils.mlflow.auth.save_config", + "anemoi.utils.mlflow.auth.save_settings", ) mocker.patch( "anemoi.utils.mlflow.auth.getpass", @@ -198,7 +198,7 @@ def test_legacy_format(mocker: pytest.MockerFixture) -> None: } } mocker.patch( - "anemoi.utils.mlflow.auth.load_raw_config", + "anemoi.utils.mlflow.auth.load_raw_settings", return_value=legacy_config, ) @@ -247,7 +247,7 @@ def test_multi_server_format(mocker: pytest.MockerFixture, url: str, unknown: bo mocks(mocker) mocker.patch( - "anemoi.utils.mlflow.auth.load_raw_config", + "anemoi.utils.mlflow.auth.load_raw_settings", return_value=multi_config, ) @@ -278,11 +278,11 @@ def test_server_store() -> None: def test_utils_interface(): - """TokenAuth uses the utils CONFIG_LOCK when reading and writing the server store to ensure thread safety. - Ensure that CONFIG_LOCK stays a reentrant lock, if it were a normal lock it would deadlock itself. + """TokenAuth uses the utils SETTINGS_LOCK when reading and writing the server store to ensure thread safety. + Ensure that SETTINGS_LOCK stays a reentrant lock, if it were a normal lock it would deadlock itself. """ from threading import RLock - from anemoi.utils.config import CONFIG_LOCK + from anemoi.utils.settings import SETTINGS_LOCK - assert isinstance(CONFIG_LOCK, type(RLock())) + assert isinstance(SETTINGS_LOCK, type(RLock()))