diff --git a/config/eval_config.yml b/config/eval_config.yml new file mode 100644 index 000000000..9060da328 --- /dev/null +++ b/config/eval_config.yml @@ -0,0 +1,14 @@ +evaluation: + metrics : ["rmse", "l1", "mse"] + +run_ids : + dw3acg5s: #uwbvdtge: + streams: + ERA5: + channels: ["2t"] #, "10u", "z_500"] + evaluation: + forecast_step: "all" + sample: "all" + label: "MTM ERA5" + epoch: 0 + rank: 0 diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index adb04aa54..bc0888e81 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -14,6 +14,11 @@ from collections import defaultdict from pathlib import Path +from dask_jobqueue import SLURMCluster +from dask.distributed import Client + + + from omegaconf import OmegaConf from weathergen.common.config import _REPO_ROOT @@ -31,12 +36,12 @@ _DEFAULT_PLOT_DIR = _REPO_ROOT / "plots" -def evaluate() -> None: +def evaluate(client=None) -> None: # By default, arguments from the command line are read. - evaluate_from_args(sys.argv[1:]) + evaluate_from_args(sys.argv[1:], client = client) -def evaluate_from_args(argl: list[str]) -> None: +def evaluate_from_args(argl: list[str], client=None) -> None: parser = argparse.ArgumentParser( description="Fast evaluation of WeatherGenerator runs." ) @@ -47,10 +52,10 @@ def evaluate_from_args(argl: list[str]) -> None: ) args = parser.parse_args(argl) - evaluate_from_config(OmegaConf.load(args.config)) + evaluate_from_config(OmegaConf.load(args.config), client = client) -def evaluate_from_config(cfg): +def evaluate_from_config(cfg, client = None): # configure logging logging.basicConfig(level=logging.INFO) @@ -101,6 +106,7 @@ def evaluate_from_config(cfg): stream, region, metric, + client=client ) available_data = reader.check_availability( @@ -123,7 +129,7 @@ def evaluate_from_config(cfg): if metrics_to_compute: all_metrics, points_per_sample = calc_scores_per_stream( - reader, stream, region, metrics_to_compute + reader, stream, region, metrics_to_compute, client = client ) metric_list_to_json( @@ -146,4 +152,8 @@ def evaluate_from_config(cfg): if __name__ == "__main__": - evaluate() + cluster = SLURMCluster( + queue='devel', account='weatherai', processes=24, cores=48, memory='96GB', interface='ib0', walltime='00:20:00') + client = Client(cluster) + cluster.scale(72) + evaluate(client = client) diff --git a/packages/evaluate/src/weathergen/evaluate/score.py b/packages/evaluate/src/weathergen/evaluate/score.py index 439a2868a..40fef2dec 100755 --- a/packages/evaluate/src/weathergen/evaluate/score.py +++ b/packages/evaluate/src/weathergen/evaluate/score.py @@ -35,7 +35,7 @@ def _get_skill_score( - score_fcst: xr.DataArray, score_ref: xr.DataArray, score_perf: float + score_fcst: xr.DataArray, score_ref: xr.DataArray, score_perf: float, client=None ) -> xr.DataArray: """ Calculate the skill score of a forecast data array w.r.t. a reference and a perfect score. @@ -104,6 +104,7 @@ def get_score( group_by_coord: str | None = None, ens_dim: str = "ens", compute: bool = False, + client=None, **kwargs, ) -> xr.DataArray: """ @@ -135,10 +136,11 @@ def get_score( sc = Scores(agg_dims=agg_dims, ens_dim=ens_dim) score_data = sc.get_score(data, score_name, group_by_coord, **kwargs) + if compute: # If compute is True, compute the score immediately return score_data.compute() - + return score_data @@ -198,6 +200,7 @@ def get_score( score_name: str, group_by_coord: str | None = None, compute: bool = False, + client=None, **kwargs, ): """ @@ -291,7 +294,7 @@ def get_score( # Return lazy evaluation result return result - def _validate_agg_dims(self, dims: str | list[str]) -> list[str] | str: + def _validate_agg_dims(self, dims: str | list[str], client=None) -> list[str] | str: if dims == "all": return dims if isinstance(dims, str): @@ -300,13 +303,13 @@ def _validate_agg_dims(self, dims: str | list[str]) -> list[str] | str: return dims raise ValueError("agg_dims must be 'all', a string, or list of strings.") - def _validate_ens_dim(self, dim: str) -> str: + def _validate_ens_dim(self, dim: str, client=None) -> str: if not isinstance(dim, str): raise ValueError("ens_dim must be a string.") return dim def _validate_groupby_coord( - self, data: VerifiedData, group_by_coord: str | None + self, data: VerifiedData, group_by_coord: str | None, client=None ) -> bool: """ Check if the group_by_coord is present in both prediction and ground truth data and compatible. @@ -350,7 +353,7 @@ def _validate_groupby_coord( ) return False - def _sum(self, data: xr.DataArray) -> xr.DataArray: + def _sum(self, data: xr.DataArray, client=None) -> xr.DataArray: """ Sum data over aggregation dimensions. @@ -366,7 +369,7 @@ def _sum(self, data: xr.DataArray) -> xr.DataArray: """ return data.sum(dim=self._agg_dims) - def _mean(self, data: xr.DataArray) -> xr.DataArray: + def _mean(self, data: xr.DataArray, client=None) -> xr.DataArray: """ Average data over aggregation dimensions. @@ -388,6 +391,7 @@ def get_2x2_event_counts( gt: xr.DataArray, thresh: float, group_by_coord: str | None = None, + client=None ) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray]: """ Get counts of 2x2 contingency tables @@ -411,6 +415,7 @@ def calc_ets( gt: xr.DataArray, group_by_coord: str | None = None, thresh: float = 0.1, + client=None ): a, b, c, d = self.get_2x2_event_counts(p, gt, thresh, group_by_coord) n = a + b + c + d @@ -429,6 +434,7 @@ def calc_fbi( gt: xr.DataArray, group_by_coord: str | None = None, thresh: float = 0.1, + client=None ): a, b, c, _ = self.get_2x2_event_counts(p, gt, thresh, group_by_coord) @@ -445,6 +451,7 @@ def calc_pss( gt: xr.DataArray, group_by_coord: str | None = None, thresh: float = 0.1, + client=None ): a, b, c, d = self.get_2x2_event_counts(p, gt, thresh, group_by_coord) @@ -461,6 +468,7 @@ def calc_l1( gt: xr.DataArray, group_by_coord: str | None = None, scale_dims: list | None = None, + client=None ): """ Calculate the L1 error norm of forecast data w.r.t. reference data. @@ -493,6 +501,7 @@ def calc_l2( group_by_coord: str | None = None, scale_dims: list | None = None, squared_l2: bool = False, + client=None ): """ Calculate the L2 error norm of forecast data w.r.t. reference data. @@ -537,7 +546,7 @@ def calc_l2( return l2 def calc_mae( - self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None + self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None, client=None ): """ Calculate mean absolute error (MAE) of forecast data w.r.t. reference data. @@ -566,7 +575,7 @@ def calc_mae( return mae def calc_mse( - self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None + self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None, client=None ): """ Calculate mean squared error (MSE) of forecast data w.r.t. reference data. @@ -595,7 +604,7 @@ def calc_mse( return mse def calc_rmse( - self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None + self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None, client=None ): """ Calculate root mean squared error (RMSE) of forecast data w.r.t. reference data @@ -622,6 +631,7 @@ def calc_change_rate( self, s0: xr.DataArray, s1: xr.DataArray, + client=None ): """ Calculate the "change rate" of a data array as the mean absolute difference between two consecutive time steps. @@ -652,6 +662,7 @@ def calc_froct( p_next: xr.DataArray, gt_next: xr.DataArray, group_by_coord: str | None = None, + client=None ): """ Calculate forecast rate of change over time @@ -691,6 +702,7 @@ def calc_troct( gt_next: xr.DataArray, p_next: xr.DataArray, group_by_coord: str | None = None, + client=None ): """ Calculate target rate of change over time @@ -730,6 +742,7 @@ def calc_acc( clim_mean: xr.DataArray, group_by_coord: str | None = None, spatial_dims: list = None, + client=None ): """ Calculate anomaly correlation coefficient (ACC). @@ -781,7 +794,8 @@ def calc_acc( return acc def calc_bias( - self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None + self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None, + client=None ): """ Calculate mean bias of forecast data w.r.t. reference data @@ -811,6 +825,7 @@ def calc_psnr( gt: xr.DataArray, group_by_coord: str | None = None, pixel_max: float = 1.0, + client=None ): """ Calculate PSNR of forecast data w.r.t. reference data @@ -844,6 +859,7 @@ def calc_spatial_variability( group_by_coord: str | None = None, order: int = 1, non_spatial_avg_dims: list[str] = None, + client=None ): """ Calculates the ratio between the spatial variability of differental operator with order 1 (higher values unsupported yest) @@ -892,6 +908,7 @@ def calc_seeps( t3: xr.DataArray, spatial_dims: list, group_by_coord: str | None = None, + client=None ): """ Calculates stable equitable error in probabiliyt space (SEEPS), see Rodwell et al., 2011 @@ -1019,7 +1036,7 @@ def calc_spread(self, p: xr.DataArray, group_by_coord: str | None = None): return self._mean(np.sqrt(ens_std**2)) def calc_ssr( - self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None + self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None, client=None ): """ Calculate the Spread-Skill Ratio (SSR) of the forecast ensemble data w.r.t. reference data @@ -1046,6 +1063,7 @@ def calc_crps( gt: xr.DataArray, group_by_coord: str | None = None, method: str = "ensemble", + client=None, **kwargs, ): """ @@ -1110,6 +1128,7 @@ def calc_rank_histogram( norm: bool = True, add_noise: bool = True, noise_fac=1.0e-03, + client=None ): """ Calculate the rank histogram of the forecast data w.r.t. reference data. @@ -1213,6 +1232,7 @@ def calc_geo_spatial_diff( order: int = 1, r_e: float = 6371.0e3, dom_avg: bool = True, + client=None ): """ Calculates the amplitude of the gradient (order=1) or the Laplacian (order=2) diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 7b482ad99..ccc4c87e0 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -10,7 +10,8 @@ import json import logging from pathlib import Path - +import time +import dask import numpy as np import omegaconf as oc import xarray as xr @@ -25,7 +26,7 @@ _logger.setLevel(logging.INFO) -def get_next_data(fstep, da_preds, da_tars, fsteps): +def get_next_data(fstep, da_preds, da_tars, fsteps, client=None): """ Get the next forecast step data for the given forecast step. """ @@ -44,7 +45,7 @@ def get_next_data(fstep, da_preds, da_tars, fsteps): def calc_scores_per_stream( - reader: Reader, stream: str, region: str, metrics: list[str] + reader: Reader, stream: str, region: str, metrics: list[str], client=None ) -> tuple[xr.DataArray, xr.DataArray]: """ Calculate scores for a given run and stream using the specified metrics. @@ -70,7 +71,7 @@ def calc_scores_per_stream( ) available_data = reader.check_availability(stream, mode="evaluation") - + start = time.time() output_data = reader.get_data( stream, region=region, @@ -107,7 +108,11 @@ def calc_scores_per_stream( "metric": metrics, }, ) - + metric_stream.chunk + end = time.time() + print(end - start) + start = time.time() + print("hello") for (fstep, tars), (_, preds) in zip( da_tars.items(), da_preds.items(), strict=False ): @@ -121,6 +126,7 @@ def calc_scores_per_stream( _logger.debug( f"Build computation graphs for metrics for stream {stream}..." ) + combined_metrics = [ get_score( @@ -128,10 +134,12 @@ def calc_scores_per_stream( metric, agg_dims="ipoint", group_by_coord="sample", + compute=True, + client=client ) for metric in metrics ] - + combined_metrics = xr.concat(combined_metrics, dim="metric") combined_metrics["metric"] = metrics @@ -154,11 +162,13 @@ def calc_scores_per_stream( metric_stream = xr.concat(metric_list, dim="forecast_step") metric_stream = metric_stream.assign_coords({"forecast_step": fsteps}) + end = time.time() + print(end - start) return metric_stream, points_per_sample -def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> list[str]: +def plot_data(reader: Reader, stream: str, global_plotting_opts: dict, client=None) -> list[str]: """ Plot the data for a given run and stream. @@ -299,6 +309,7 @@ def metric_list_to_json( npoints_sample_list: list[xr.DataArray], streams: list[str], region: str, + client=None ): """ Write the evaluation results collected in a list of xarray DataArrays for the metrics @@ -359,7 +370,7 @@ def metric_list_to_json( ) -def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str): +def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str, client=None): """ Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file. @@ -393,7 +404,7 @@ def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: raise FileNotFoundError(f"File {score_path} not found in the archive.") -def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): +def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path, client=None): """ Plot summary of the evaluation results. This function is a placeholder for future implementation. @@ -439,6 +450,7 @@ def common_ranges( data_preds: list[dict], plot_chs: list[str], maps_config: oc.dictconfig.DictConfig, + client=None ) -> oc.dictconfig.DictConfig: """ Calculate common ranges per stream and variables. @@ -483,7 +495,7 @@ def common_ranges( return maps_config -def calc_val(x: xr.DataArray, bound: str) -> list[float]: +def calc_val(x: xr.DataArray, bound: str, client=None) -> list[float]: """ Calculate the maximum or minimum value per variable for all forecasteps. Parameters @@ -509,6 +521,7 @@ def calc_bounds( data_preds, var, bound, + client=None ): """ Calculate the minimum and maximum values per variable for all forecasteps for both targets and predictions @@ -536,7 +549,7 @@ def calc_bounds( return list_bound -def scalar_coord_to_dim(da: xr.DataArray, name: str, axis: int = -1) -> xr.DataArray: +def scalar_coord_to_dim(da: xr.DataArray, name: str, axis: int = -1, client=None) -> xr.DataArray: """ Convert a scalar coordinate to a dimension in an xarray DataArray. If the coordinate is already a dimension, it is returned unchanged.