Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions config/eval_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
evaluation:
metrics : ["rmse", "l1", "mse"]

run_ids :
dw3acg5s: #uwbvdtge:
streams:
ERA5:
channels: ["2t"] #, "10u", "z_500"]
evaluation:
forecast_step: "all"
sample: "all"
label: "MTM ERA5"
epoch: 0
rank: 0
24 changes: 17 additions & 7 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from collections import defaultdict
from pathlib import Path

from dask_jobqueue import SLURMCluster
from dask.distributed import Client



from omegaconf import OmegaConf

from weathergen.common.config import _REPO_ROOT
Expand All @@ -31,12 +36,12 @@
_DEFAULT_PLOT_DIR = _REPO_ROOT / "plots"


def evaluate() -> None:
def evaluate(client=None) -> None:
# By default, arguments from the command line are read.
evaluate_from_args(sys.argv[1:])
evaluate_from_args(sys.argv[1:], client = client)


def evaluate_from_args(argl: list[str]) -> None:
def evaluate_from_args(argl: list[str], client=None) -> None:
parser = argparse.ArgumentParser(
description="Fast evaluation of WeatherGenerator runs."
)
Expand All @@ -47,10 +52,10 @@ def evaluate_from_args(argl: list[str]) -> None:
)

args = parser.parse_args(argl)
evaluate_from_config(OmegaConf.load(args.config))
evaluate_from_config(OmegaConf.load(args.config), client = client)


def evaluate_from_config(cfg):
def evaluate_from_config(cfg, client = None):
# configure logging
logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -101,6 +106,7 @@ def evaluate_from_config(cfg):
stream,
region,
metric,
client=client
)

available_data = reader.check_availability(
Expand All @@ -123,7 +129,7 @@ def evaluate_from_config(cfg):

if metrics_to_compute:
all_metrics, points_per_sample = calc_scores_per_stream(
reader, stream, region, metrics_to_compute
reader, stream, region, metrics_to_compute, client = client
)

metric_list_to_json(
Expand All @@ -146,4 +152,8 @@ def evaluate_from_config(cfg):


if __name__ == "__main__":
evaluate()
cluster = SLURMCluster(
queue='devel', account='weatherai', processes=24, cores=48, memory='96GB', interface='ib0', walltime='00:20:00')
client = Client(cluster)
cluster.scale(72)
evaluate(client = client)
44 changes: 32 additions & 12 deletions packages/evaluate/src/weathergen/evaluate/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def _get_skill_score(
score_fcst: xr.DataArray, score_ref: xr.DataArray, score_perf: float
score_fcst: xr.DataArray, score_ref: xr.DataArray, score_perf: float, client=None
) -> xr.DataArray:
"""
Calculate the skill score of a forecast data array w.r.t. a reference and a perfect score.
Expand Down Expand Up @@ -104,6 +104,7 @@ def get_score(
group_by_coord: str | None = None,
ens_dim: str = "ens",
compute: bool = False,
client=None,
**kwargs,
) -> xr.DataArray:
"""
Expand Down Expand Up @@ -135,10 +136,11 @@ def get_score(
sc = Scores(agg_dims=agg_dims, ens_dim=ens_dim)

score_data = sc.get_score(data, score_name, group_by_coord, **kwargs)

if compute:
# If compute is True, compute the score immediately
return score_data.compute()

return score_data


Expand Down Expand Up @@ -198,6 +200,7 @@ def get_score(
score_name: str,
group_by_coord: str | None = None,
compute: bool = False,
client=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -291,7 +294,7 @@ def get_score(
# Return lazy evaluation result
return result

def _validate_agg_dims(self, dims: str | list[str]) -> list[str] | str:
def _validate_agg_dims(self, dims: str | list[str], client=None) -> list[str] | str:
if dims == "all":
return dims
if isinstance(dims, str):
Expand All @@ -300,13 +303,13 @@ def _validate_agg_dims(self, dims: str | list[str]) -> list[str] | str:
return dims
raise ValueError("agg_dims must be 'all', a string, or list of strings.")

def _validate_ens_dim(self, dim: str) -> str:
def _validate_ens_dim(self, dim: str, client=None) -> str:
if not isinstance(dim, str):
raise ValueError("ens_dim must be a string.")
return dim

def _validate_groupby_coord(
self, data: VerifiedData, group_by_coord: str | None
self, data: VerifiedData, group_by_coord: str | None, client=None
) -> bool:
"""
Check if the group_by_coord is present in both prediction and ground truth data and compatible.
Expand Down Expand Up @@ -350,7 +353,7 @@ def _validate_groupby_coord(
)
return False

def _sum(self, data: xr.DataArray) -> xr.DataArray:
def _sum(self, data: xr.DataArray, client=None) -> xr.DataArray:
"""
Sum data over aggregation dimensions.

Expand All @@ -366,7 +369,7 @@ def _sum(self, data: xr.DataArray) -> xr.DataArray:
"""
return data.sum(dim=self._agg_dims)

def _mean(self, data: xr.DataArray) -> xr.DataArray:
def _mean(self, data: xr.DataArray, client=None) -> xr.DataArray:
"""
Average data over aggregation dimensions.

Expand All @@ -388,6 +391,7 @@ def get_2x2_event_counts(
gt: xr.DataArray,
thresh: float,
group_by_coord: str | None = None,
client=None
) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray]:
"""
Get counts of 2x2 contingency tables
Expand All @@ -411,6 +415,7 @@ def calc_ets(
gt: xr.DataArray,
group_by_coord: str | None = None,
thresh: float = 0.1,
client=None
):
a, b, c, d = self.get_2x2_event_counts(p, gt, thresh, group_by_coord)
n = a + b + c + d
Expand All @@ -429,6 +434,7 @@ def calc_fbi(
gt: xr.DataArray,
group_by_coord: str | None = None,
thresh: float = 0.1,
client=None
):
a, b, c, _ = self.get_2x2_event_counts(p, gt, thresh, group_by_coord)

Expand All @@ -445,6 +451,7 @@ def calc_pss(
gt: xr.DataArray,
group_by_coord: str | None = None,
thresh: float = 0.1,
client=None
):
a, b, c, d = self.get_2x2_event_counts(p, gt, thresh, group_by_coord)

Expand All @@ -461,6 +468,7 @@ def calc_l1(
gt: xr.DataArray,
group_by_coord: str | None = None,
scale_dims: list | None = None,
client=None
):
"""
Calculate the L1 error norm of forecast data w.r.t. reference data.
Expand Down Expand Up @@ -493,6 +501,7 @@ def calc_l2(
group_by_coord: str | None = None,
scale_dims: list | None = None,
squared_l2: bool = False,
client=None
):
"""
Calculate the L2 error norm of forecast data w.r.t. reference data.
Expand Down Expand Up @@ -537,7 +546,7 @@ def calc_l2(
return l2

def calc_mae(
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None, client=None
):
"""
Calculate mean absolute error (MAE) of forecast data w.r.t. reference data.
Expand Down Expand Up @@ -566,7 +575,7 @@ def calc_mae(
return mae

def calc_mse(
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None, client=None
):
"""
Calculate mean squared error (MSE) of forecast data w.r.t. reference data.
Expand Down Expand Up @@ -595,7 +604,7 @@ def calc_mse(
return mse

def calc_rmse(
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None, client=None
):
"""
Calculate root mean squared error (RMSE) of forecast data w.r.t. reference data
Expand All @@ -622,6 +631,7 @@ def calc_change_rate(
self,
s0: xr.DataArray,
s1: xr.DataArray,
client=None
):
"""
Calculate the "change rate" of a data array as the mean absolute difference between two consecutive time steps.
Expand Down Expand Up @@ -652,6 +662,7 @@ def calc_froct(
p_next: xr.DataArray,
gt_next: xr.DataArray,
group_by_coord: str | None = None,
client=None
):
"""
Calculate forecast rate of change over time
Expand Down Expand Up @@ -691,6 +702,7 @@ def calc_troct(
gt_next: xr.DataArray,
p_next: xr.DataArray,
group_by_coord: str | None = None,
client=None
):
"""
Calculate target rate of change over time
Expand Down Expand Up @@ -730,6 +742,7 @@ def calc_acc(
clim_mean: xr.DataArray,
group_by_coord: str | None = None,
spatial_dims: list = None,
client=None
):
"""
Calculate anomaly correlation coefficient (ACC).
Expand Down Expand Up @@ -781,7 +794,8 @@ def calc_acc(
return acc

def calc_bias(
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None,
client=None
):
"""
Calculate mean bias of forecast data w.r.t. reference data
Expand Down Expand Up @@ -811,6 +825,7 @@ def calc_psnr(
gt: xr.DataArray,
group_by_coord: str | None = None,
pixel_max: float = 1.0,
client=None
):
"""
Calculate PSNR of forecast data w.r.t. reference data
Expand Down Expand Up @@ -844,6 +859,7 @@ def calc_spatial_variability(
group_by_coord: str | None = None,
order: int = 1,
non_spatial_avg_dims: list[str] = None,
client=None
):
"""
Calculates the ratio between the spatial variability of differental operator with order 1 (higher values unsupported yest)
Expand Down Expand Up @@ -892,6 +908,7 @@ def calc_seeps(
t3: xr.DataArray,
spatial_dims: list,
group_by_coord: str | None = None,
client=None
):
"""
Calculates stable equitable error in probabiliyt space (SEEPS), see Rodwell et al., 2011
Expand Down Expand Up @@ -1019,7 +1036,7 @@ def calc_spread(self, p: xr.DataArray, group_by_coord: str | None = None):
return self._mean(np.sqrt(ens_std**2))

def calc_ssr(
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None
self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None, client=None
):
"""
Calculate the Spread-Skill Ratio (SSR) of the forecast ensemble data w.r.t. reference data
Expand All @@ -1046,6 +1063,7 @@ def calc_crps(
gt: xr.DataArray,
group_by_coord: str | None = None,
method: str = "ensemble",
client=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -1110,6 +1128,7 @@ def calc_rank_histogram(
norm: bool = True,
add_noise: bool = True,
noise_fac=1.0e-03,
client=None
):
"""
Calculate the rank histogram of the forecast data w.r.t. reference data.
Expand Down Expand Up @@ -1213,6 +1232,7 @@ def calc_geo_spatial_diff(
order: int = 1,
r_e: float = 6371.0e3,
dom_avg: bool = True,
client=None
):
"""
Calculates the amplitude of the gradient (order=1) or the Laplacian (order=2)
Expand Down
Loading
Loading