Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
import os
import subprocess
from pathlib import Path
import string
import random

import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf.omegaconf import open_dict

from weathergen.train.utils import get_run_id
def get_run_id():
s1 = string.ascii_lowercase
s2 = string.ascii_lowercase + string.digits
return "".join(random.sample(s1, 1)) + "".join(random.sample(s2, 7))


_REPO_ROOT = Path(
__file__
Expand Down Expand Up @@ -281,6 +287,7 @@ def _load_overwrite_conf(overwrite: Path | dict | DictConfig) -> DictConfig:
return overwrite_config



def _load_private_conf(private_home: Path | None = None) -> DictConfig:
"Return the private configuration."
"If none, take it from the environment variable WEATHERGEN_PRIVATE_CONF."
Expand Down
4 changes: 3 additions & 1 deletion packages/common/src/weathergen/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import typing
from copy import deepcopy

import astropy_healpix as hp
import dask.array as da
import numpy as np
import xarray as xr
Expand Down Expand Up @@ -75,6 +74,9 @@ def spoof(
Spoof an instance from data_reader_base.ReaderData instance.
other should be such an instance.
"""
# TODO: do we need it in common package?
import astropy_healpix as hp


hl = 5
dx = 0.5
Expand Down
55 changes: 55 additions & 0 deletions packages/common/src/weathergen/common/platform_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Platform environment configuration for WeatherGenerator.

These are loaded from secrets in the private repository.
"""

import importlib
from pathlib import Path
from typing import Protocol
from omegaconf import OmegaConf
from weathergen.common.config import _REPO_ROOT
from functools import lru_cache


class PlatformEnv(Protocol):
"""
Interface for platform environment configuration.
"""

def get_hpc(self) -> str | None:
...

def get_hpc_user(self) -> str | None:
...

def get_hpc_config(self) -> str | None:
...

def get_hpc_certificate(self) -> str | None:
...


# def get_private_conf() -> OmegaConf:
# """
# Loads the private configuration from the private repository.
# Excludes secrets.

# In doudbt, use this function.
# """


@lru_cache(maxsize=1)
def get_platform_env() -> PlatformEnv:
"""
Loads the platform environment module from the private repository.
"""
env_script_path = _REPO_ROOT.parent / "WeatherGenerator-private" / "hpc" / "platform-env.py"
spec = importlib.util.spec_from_file_location("platform_env", env_script_path)
platform_env = importlib.util.module_from_spec(spec)
spec.loader.exec_module(platform_env) # type: ignore
return platform_env


if __name__ == "__main__":
print(f"Loaded platform environment: {env.get_hpc()}")
3 changes: 2 additions & 1 deletion packages/evaluate/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ dependencies = [
"xhistogram",
"panel",
"omegaconf",
"weathergen-common",
"plotly>=6.2.0",
"weathergen-common",
"weathergen-metrics",
]

[dependency-groups]
Expand Down
83 changes: 81 additions & 2 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# dependencies = [
# "weathergen-evaluate",
# "weathergen-common",
# "weathergen-metrics",
# ]
# [tool.uv.sources]
# weathergen-evaluate = { path = "../../../../../packages/evaluate" }
Expand All @@ -14,22 +15,34 @@
from collections import defaultdict
from pathlib import Path

import mlflow
import numpy as np
from omegaconf import OmegaConf

from weathergen.common.config import _REPO_ROOT
from weathergen.common.platform_env import get_platform_env
from weathergen.evaluate.io_reader import WeatherGenReader
from weathergen.evaluate.plot_utils import collect_channels, collect_streams
from weathergen.evaluate.utils import (
calc_scores_per_stream,
metric_list_to_json,
plot_data,
plot_summary,
retrieve_metric_from_json,
)
from weathergen.metrics.mlflow_utils import (
MlFlowUpload,
get_or_create_mlflow_parent_run,
log_scores,
setup_mlflow,
)

_logger = logging.getLogger(__name__)

_DEFAULT_PLOT_DIR = _REPO_ROOT / "plots"

_platform_env = get_platform_env()


def evaluate() -> None:
# By default, arguments from the command line are read.
Expand All @@ -44,6 +57,12 @@ def evaluate_from_args(argl: list[str]) -> None:
default=None,
help="Path to the configuration yaml file for plotting. e.g. config/plottig_config.yaml",
)
parser.add_argument(
"--push-metrics",
required=False,
action="store_true",
help="(optional) Upload scores to MLFlow.",
)

args = parser.parse_args(argl)
if args.config:
Expand All @@ -53,10 +72,20 @@ def evaluate_from_args(argl: list[str]) -> None:
"No config file provided, using the default template config (please edit accordingly)"
)
config = Path(_REPO_ROOT / "config" / "evaluate" / "eval_config.yml")
evaluate_from_config(OmegaConf.load(config))
mlflow_client = None
if args.push_metrics:
# logging.basicConfig(level=logging.INFO)
hpc_conf = _platform_env.get_hpc_config()
assert hpc_conf is not None
private_home = Path(hpc_conf)
private_cf = OmegaConf.load(private_home)
mlflow_client = setup_mlflow(private_cf)
_logger.info(f"MLFlow client set up: {mlflow_client}")

evaluate_from_config(OmegaConf.load(config), mlflow_client)


def evaluate_from_config(cfg):
def evaluate_from_config(cfg, mlflow_client):
# configure logging
logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -148,6 +177,56 @@ def evaluate_from_config(cfg):
{"metric": metric}
)

if mlflow_client:
parent_run = get_or_create_mlflow_parent_run(mlflow_client, run_id)
_logger.info(f"MLFlow parent run: {parent_run}")
phase = "eval"

for region in regions:
for metric in metrics:
streams_set = collect_streams(runs)
channels_set = collect_channels(scores_dict, metric, region, runs)

for stream in streams_set:
for ch in channels_set:
data = scores_dict[metric][region][stream][run_id]
# skip if channel is missing or contains NaN
if ch not in np.atleast_1d(data.channel.values) or data.isnull().all():
continue
_logger.info(
f"Uploading data for {metric} - {region} - {stream} - {ch}."
)

x_dim = "forecast_step"
non_zero_dims = [
dim for dim in data.dims if dim != x_dim and data[dim].shape[0] > 1
]
if "ens" in non_zero_dims:
_logger.info("Uploading ensembles not yet imnplemented")
else:
if non_zero_dims:
_logger.info(
f"LinePlot:: Found multiple entries for dimensions: {non_zero_dims}. Averaging..."
)
averaged = data.mean(
dim=[dim for dim in data.dims if dim != x_dim], skipna=True
).sortby(x_dim)
label = f"score.{region}.{metric}.{stream}.{ch}"
with mlflow.start_run(run_id=parent_run.info.run_id):
with mlflow.start_run(
run_name=f"{phase}_{run_id}",
parent_run_id=parent_run.info.run_id,
nested=True,
) as run:
mlflow.set_tags(MlFlowUpload.run_tags(run_id, phase))
log_scores(
averaged[x_dim].values[:4],
averaged.values[:4],
label,
mlflow_client,
run.info.run_id,
)

# plot summary
if scores_dict and cfg.evaluation.get("summary_plots", True):
_logger.info("Started creating summary plots..")
Expand Down
102 changes: 102 additions & 0 deletions packages/metrics/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
[project]
name = "weathergen-metrics"
version = "0.1.0"
description = "The WeatherGenerator Machine Learning Earth System Model"
readme = "../../README.md"
requires-python = ">=3.12,<3.13"
dependencies = [
"mlflow",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe mlflow-skinny would work

"weathergen-common",
]

[dependency-groups]
dev = [
"pytest~=8.3.5",
"pytest-mock>=3.14.1",
"ruff==0.9.7",
"pyrefly==0.33.0",
]


[tool.pyrefly]
project-includes = ["src/"]
project-excludes = [
]

[tool.pyrefly.errors]
bad-argument-type = false
unsupported-operation = false
missing-attribute = false
no-matching-overload = false
bad-context-manager = false

# To do:
bad-assignment = false
bad-return = false
index-error = false
not-iterable = false
not-callable = false




# The linting configuration
[tool.ruff]

# Wide rows
line-length = 100

[tool.ruff.lint]
# All disabled until the code is formatted.
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# Banned imports
"TID",
# Naming conventions
"N",
# print
"T201"
]

# These rules are sensible and should be enabled at a later stage.
ignore = [
# "B006",
"B011",
"UP008",
"SIM117",
"SIM118",
"SIM102",
"SIM401",
# To ignore, not relevant for us
"SIM108", # in case additional norm layer supports are added in future
"N817", # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted)
"E731", # overly restrictive and less readable code
"N812", # prevents us following the convention for importing torch.nn.functional as F
]

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example"

[tool.ruff.format]
# Use Unix `\n` line endings for all files
line-ending = "lf"



[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src/weathergen"]
Empty file.
Loading