diff --git a/config/default_config.yml b/config/default_config.yml index 679f58dd3..2197e75f4 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -146,7 +146,7 @@ val_initial: False loader_num_workers: 8 log_validation: 0 -analysis_streams_output: ["ERA5"] +analysis_streams_output: ["ERA5", "tokens"] istep: 0 run_history: [] diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index c24441e4c..4d43fa2cb 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -67,45 +67,6 @@ def create(cls, other: typing.Any) -> "IOReaderData": return cls(**dataclasses.asdict(other)) - @classmethod - def spoof( - cls, other: typing.Any, nchannels, datetime, geoinfo_size, mean_of_data - ) -> typing.Self: - """ - Spoof an instance from data_reader_base.ReaderData instance. - other should be such an instance. - """ - - hl = 5 - dx = 0.5 - dy = 0.5 - other_copy = deepcopy(other) - num_healpix_cells = 12 * 4**hl - lons, lats = hp.healpix_to_lonlat( - np.arange(0, num_healpix_cells), 2**hl, dx=dx, dy=dy, order="nested" - ) - other_copy.coords = np.stack([lats.deg, lons.deg], axis=-1, dtype=np.float32) - other_copy.geoinfos = np.zeros((other_copy.coords.shape[0], geoinfo_size), dtype=np.float32) - - other_copy.data = np.expand_dims(mean_of_data.astype(np.float32), axis=0).repeat( - other_copy.coords.shape[0], axis=0 - ) - other_copy.datetimes = np.array(datetime).repeat(other_copy.coords.shape[0]) - - n_datapoints = len(other_copy.data) - - assert other_copy.coords.shape == (n_datapoints, 2), ( - "number of datapoints do not match data" - ) - assert other_copy.geoinfos.shape[0] == n_datapoints, ( - "number of datapoints do not match data" - ) - assert other_copy.datetimes.shape[0] == n_datapoints, ( - "number of datapoints do not match data" - ) - - return cls(**dataclasses.asdict(other_copy)) - @dataclasses.dataclass class ItemKey: @@ -418,10 +379,10 @@ def extract(self, key: ItemKey) -> OutputItem: data_coords = self._extract_coordinates(stream_idx, offset_key, datapoints) assert len(data_coords.channels) == target_data.shape[1], ( - "Number of channel names does not align with target data." + "Number of channel names does not align with target data.", len(data_coords.channels), target_data.shape ) assert len(data_coords.channels) == preds_data.shape[1], ( - "Number of channel names does not align with prediction data." + "Number of channel names does not align with prediction data.", len(data_coords.channels), preds_data.shape ) if key.with_source: diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 83adb8128..a6fe21b60 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -12,6 +12,7 @@ import numpy as np import torch +from torch import Tensor from weathergen.common.io import IOReaderData from weathergen.datasets.data_reader_anemoi import DataReaderAnemoi @@ -284,7 +285,7 @@ def denormalize_source_channels(self, obs_id, data): return self.streams_datasets[obs_id][0].denormalize_source_channels(data) ################################################### - def denormalize_target_channels(self, obs_id, data): + def denormalize_target_channels(self, obs_id, data) -> Tensor: # TODO: with multiple ds per stream we need to distinguish these here return self.streams_datasets[obs_id][0].denormalize_target_channels(data) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 7be504aa5..8a5d4c7ac 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -10,6 +10,7 @@ import astropy_healpix as hp import numpy as np import torch +from numpy.typing import NDArray from weathergen.common.io import IOReaderData @@ -373,13 +374,7 @@ def spoof(healpix_level: int, datetime, geoinfo_size, mean_of_data) -> IOReaderD other should be such an instance. """ - dx = 0.5 - dy = 0.5 - num_healpix_cells = 12 * 4**healpix_level - lons, lats = hp.healpix_to_lonlat( - np.arange(0, num_healpix_cells), 2**healpix_level, dx=dx, dy=dy, order="nested" - ) - coords = np.stack([lats.deg, lons.deg], axis=-1, dtype=np.float32) + coords = healpix_coords(healpix_level) geoinfos = np.zeros((coords.shape[0], geoinfo_size), dtype=np.float32) data = np.expand_dims(mean_of_data.astype(np.float32), axis=0).repeat(coords.shape[0], axis=0) @@ -404,3 +399,13 @@ def spoof(healpix_level: int, datetime, geoinfo_size, mean_of_data) -> IOReaderD ) return IOReaderData(coords, geoinfos, data, datetimes) + +def healpix_coords(healpix_level: int) -> NDArray[np.float32]: + dx = 0.5 + dy = 0.5 + num_healpix_cells = 12 * 4**healpix_level + lons, lats = hp.healpix_to_lonlat( + np.arange(0, num_healpix_cells), 2**healpix_level, dx=dx, dy=dy, order="nested" + ) + coords = np.stack([lats.deg, lons.deg], axis=-1, dtype=np.float32) + return coords diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 803c0312b..832a70e74 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -13,14 +13,15 @@ import math import warnings from pathlib import Path +from dataclasses import dataclass import astropy_healpix as hp -import astropy_healpix.healpy import numpy as np import torch import torch.nn as nn from astropy_healpix import healpy from torch.utils.checkpoint import checkpoint +from torch import Tensor from weathergen.common.config import Config from weathergen.model.engines import ( @@ -196,6 +197,15 @@ def reset_parameters(self, cf: Config) -> "ModelParams": return +@dataclass +class ForwardOutput: + """ + Output of the model forward pass. + """ + tokens: list[Tensor] + posteriors: Tensor + predictions: list[list[Tensor]] + #################################################################################################### class Model(torch.nn.Module): """WeatherGenerator model architecture @@ -545,10 +555,10 @@ def forward_jac(self, *args): preds_all = self.forward(sources, sources_lens) - return tuple(preds_all[0]) + return tuple(preds_all.targets) ######################################### - def forward(self, model_params: ModelParams, batch, forecast_offset: int, forecast_steps: int): + def forward(self, model_params: ModelParams, batch, forecast_offset: int, forecast_steps: int) -> ForwardOutput: """Performs the forward pass of the model to generate forecasts Tokens are processed through the model components, which were defined in the create method. @@ -572,12 +582,13 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = self.embed_cells(model_params, streams_data) # local assimilation engine and adapter - tokens, posteriors = self.assimilate_local(model_params, tokens, source_cell_lens) + local_tokens, posteriors = self.assimilate_local(model_params, tokens, source_cell_lens) - tokens = self.assimilate_global(model_params, tokens) + tokens = self.assimilate_global(model_params, local_tokens) # roll-out in latent space preds_all = [] + tokens_all = [tokens] for fstep in range(forecast_offset, forecast_offset + forecast_steps): # prediction preds_all += [ @@ -590,13 +601,15 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ) ] - if self.training: - # Impute noise to the latent state - noise_std = self.cf.get("impute_latent_noise_std", 0.0) - if noise_std > 0.0: - tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std + noise_std = self.cf.get("impute_latent_noise_std", 0.0) + if self.training and noise_std > 0.0: + noisy_tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std + else: + noisy_tokens = tokens + - tokens = self.forecast(model_params, tokens) + tokens = self.forecast(model_params, noisy_tokens) + tokens_all.append(tokens) # prediction for final step preds_all += [ @@ -609,7 +622,8 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ) ] - return preds_all, posteriors + + return ForwardOutput(tokens=tokens_all, posteriors=posteriors, predictions=preds_all) ######################################### def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index eb2cab895..a21b75298 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -195,9 +195,10 @@ def train_with_args(argl: list[str], stream_dir: str | None): if __name__ == "__main__": - # Entry point for slurm script. - # Check whether --from_run_id passed as argument. - if next((True for arg in sys.argv if "--from_run_id" in arg), False): - train_continue() - else: - train() + inference() + # # Entry point for slurm script. + # # Check whether --from_run_id passed as argument. + # if next((True for arg in sys.argv if "--from_run_id" in arg), False): + # train_continue() + # else: + # train() diff --git a/src/weathergen/train/structures.py b/src/weathergen/train/structures.py new file mode 100644 index 000000000..3582e0279 --- /dev/null +++ b/src/weathergen/train/structures.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +from numpy.typing import NDArray +from torch import Tensor + +@dataclass +class TrainerPredictions: + # Denormalized predictions + preds_all: list[list[list[Tensor]]] + # Denormalized targets + targets_all: list[list[list[Tensor]]] + # Raw target coordinates + targets_coords_raw: list[list[Tensor]] + # Raw target timestamps + targets_times_raw: list[list[NDArray]] + # Target lengths + targets_lens: list[list[list[int]]] + tokens_all: list[Tensor] + tokens_coords_raw: list[NDArray] + tokens_times_raw: list[NDArray] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 8f26da14d..0c204037d 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -14,6 +14,8 @@ import time from pathlib import Path from typing import Any +from dataclasses import dataclass +from numpy.typing import NDArray import numpy as np import torch @@ -39,9 +41,10 @@ MultiSelfAttentionHeadLocal, MultiSelfAttentionHeadVarlen, ) +from weathergen.datasets.stream_data import StreamData, healpix_coords from weathergen.model.ema import EMAModel from weathergen.model.layers import MLP -from weathergen.model.model import Model, ModelParams +from weathergen.model.model import Model, ModelParams, ForwardOutput from weathergen.model.utils import freeze_weights from weathergen.train.loss_calculator import LossCalculator from weathergen.train.lr_scheduler import LearningRateScheduler @@ -50,10 +53,11 @@ from weathergen.utils.train_logger import TRAIN, VAL, Stage, TrainLogger from weathergen.utils.utils import get_dtype from weathergen.utils.validation_io import write_output - +from weathergen.train.structures import TrainerPredictions logger = logging.getLogger(__name__) + class Trainer(TrainerBase): def __init__(self, train_log_freq: Config): TrainerBase.__init__(self) @@ -243,6 +247,8 @@ def init_model_and_shard(self, cf, devices): def run(self, cf, devices, run_id_contd=None, epoch_contd=None): # general initalization + assert cf is not None + assert self.cf is not None self.init(cf, devices) cf = self.cf @@ -429,10 +435,11 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): def _prepare_logging( self, preds: list[list[Tensor]], + tokens: list[Tensor], forecast_offset: int, forecast_steps: int, - streams_data: list[list[Any]], - ): + streams_data: list[list[StreamData]], + ) -> TrainerPredictions: """Collects and denormalizes prediction and target data for logging. This function processes target and prediction tensors, extracts relevant @@ -472,6 +479,9 @@ def _prepare_logging( inner list contains the original lengths (shape[0]) of the target tensors before any filtering. """ + assert self.cf is not None + sd = streams_data[0][0] + sd.source_centroids #''' # TODO: Remove this function and port functionality to write_validation(), which then @@ -535,15 +545,25 @@ def _prepare_logging( preds_all[fstep][i_strm] += [dn_data(i_strm, pred.to(f32)).detach().cpu()] targets_all[fstep][i_strm] += [dn_data(i_strm, target.to(f32)).detach().cpu()] - return ( - preds_all, - targets_all, - targets_coords_raw, - targets_times_raw, - targets_lens, + x = targets_all[0][0][0] + x.shape + + tokens_cpu = [t.detach().cpu() for t in tokens] + tokens_coords_raw = healpix_coords(self.cf.healpix_level) + + return TrainerPredictions( + preds_all=preds_all, + targets_all=targets_all, + targets_coords_raw=targets_coords_raw, + targets_times_raw=targets_times_raw, + targets_lens=targets_lens, + tokens_all=tokens_cpu, + tokens_coords_raw=[tokens_coords_raw], + tokens_times_raw=[], # TODO ) def train(self, epoch): + assert self.cf is not None cf = self.cf self.model.train() # torch.autograd.set_detect_anomaly(True) @@ -623,6 +643,7 @@ def train(self, epoch): self.dataset.advance() def validate(self, epoch): + assert self.cf is not None cf = self.cf self.model.eval() @@ -649,9 +670,10 @@ def validate(self, epoch): if self.ema_model is None else self.ema_model.forward_eval ) - preds, _ = model_forward( + forward_out: ForwardOutput = model_forward( self.model_params, batch, cf.forecast_offset, forecast_steps ) + preds = forward_out.predictions # compute loss and log output if bidx < cf.log_validation: @@ -661,14 +683,9 @@ def validate(self, epoch): ) # TODO: Move _prepare_logging into write_validation by passing streams_data - ( - preds_all, - targets_all, - targets_coords_all, - targets_times_all, - targets_lens, - ) = self._prepare_logging( + denorm_predictions = self._prepare_logging( preds=preds, + tokens=forward_out.tokens, forecast_offset=cf.forecast_offset, forecast_steps=cf.forecast_steps, streams_data=batch[0], @@ -679,11 +696,12 @@ def validate(self, epoch): epoch, bidx, sources, - preds_all, - targets_all, - targets_coords_all, - targets_times_all, - targets_lens, + denorm_predictions.preds_all, + denorm_predictions.targets_all, + denorm_predictions.targets_coords_raw, + denorm_predictions.targets_times_raw, + denorm_predictions.targets_lens, + denorm_predictions, ) else: @@ -721,6 +739,7 @@ def load_model(self, run_id: str, epoch=-1): run_id : model_id of the trained model epoch : The epoch to load. Default (-1) is the latest epoch """ + assert self.cf is not None path_run = Path(self.cf.model_path) / run_id epoch_id = f"epoch{epoch:05d}" if epoch != -1 and epoch is not None else "latest" diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index e28563132..16e74564d 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -8,45 +8,86 @@ # nor does it submit to any jurisdiction. import logging +import torch import weathergen.common.config as config import weathergen.common.io as io - +import numpy as np +from weathergen.train.structures import TrainerPredictions _logger = logging.getLogger(__name__) def write_output( cf, - epoch, + epoch: int, batch_idx, - sources, + sources: list[list[io.IOReaderData]], preds_all, targets_all, targets_coords_all, targets_times_all, targets_lens, + preds: TrainerPredictions ): - stream_names = [stream.name for stream in cf.streams] output_stream_names = cf.analysis_streams_output + output_stream_names = ["ERA5", "tokens"] # TEMPORARY OVERRIDE FOR TESTING + export_tokens = "tokens" in output_stream_names + stream_names = [stream.name for stream in cf.streams] + (["tokens"] if export_tokens else []) if output_stream_names is None: output_stream_names = stream_names output_streams = {name: stream_names.index(name) for name in output_stream_names} - _logger.debug(f"Using output streams: {output_streams} from streams: {stream_names}") + _logger.info(f"Using output streams: {output_streams} from streams: {stream_names}") target_channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams] + if export_tokens: + token_channels = [f"tokens_{i}" for i in range(preds.tokens_all[0].shape[-1])] + target_channels.append(token_channels) source_channels: list[list[str]] = [list(stream.val_source_channels) for stream in cf.streams] + # No source channels for tokens geoinfo_channels = [[] for _ in cf.streams] # TODO obtain channels + if export_tokens: + geoinfo_channels.append([]) # assume: is batch size guarnteed and constant: # => calculate global sample indices for this batch by offsetting by sample_start sample_start = batch_idx * cf.batch_size_validation_per_gpu - assert len(stream_names) == len(targets_all[0]), "data does not match number of streams" - assert len(stream_names) == len(preds_all[0]), "data does not match number of streams" - assert len(stream_names) == len(sources[0]), "data does not match number of streams" + if export_tokens: + num_tokens = preds.tokens_all[0].shape[1] + # Append the tokens to the output + # Add a source for the tokens + # TODO: is this correct? It should not + iord = io.IOReaderData( + coords=preds.tokens_coords_raw[0], + geoinfos=np.array([]), # No geoinfo for tokens + data=preds.tokens_all[0].numpy(), + datetimes=np.array([]), # No times for tokens + ) + sources = [fc_sources + [iord] for fc_sources in sources] + # TODO: not sure if needed, it is not a target. + for fc_target, fc_token in zip(targets_all, preds.tokens_all): + fc_target.append([fc_token.squeeze().unsqueeze(-1)]) + for fc_preds, fc_token in zip(preds_all, preds.tokens_all): + fc_preds.append([fc_token.squeeze().unsqueeze(0)]) + for fc_coords_target, fc_coords_token in zip(targets_coords_all, preds.tokens_coords_raw): + fc_coords_target.append(torch.tensor(fc_coords_token)) + for fc_lens in targets_lens: + fc_lens.append([num_tokens]) + # TODO: correct target times + for fc_times_target in targets_times_all: + fc_times_target.append(fc_times_target[0]) # Just repeating the existing ones + + + + + extra_dim = 1 if export_tokens else 0 + assert len(stream_names) == len(targets_all[0]), f"data does not match number of streams, {len(stream_names),extra_dim,len(targets_all[0])}" + assert len(stream_names) == len(preds_all[0]), f"data does not match number of streams, {len(stream_names),extra_dim,len(preds_all[0])}" + assert len(stream_names) == len(sources[0]), f"data does not match number of streams, {len(stream_names),extra_dim,len(sources[0])}" + data = io.OutputBatchData( sources,