Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ val_initial: False

loader_num_workers: 8
log_validation: 0
analysis_streams_output: ["ERA5"]
analysis_streams_output: ["ERA5", "tokens"]

istep: 0
run_history: []
Expand Down
43 changes: 2 additions & 41 deletions packages/common/src/weathergen/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,45 +67,6 @@ def create(cls, other: typing.Any) -> "IOReaderData":

return cls(**dataclasses.asdict(other))

@classmethod
def spoof(
cls, other: typing.Any, nchannels, datetime, geoinfo_size, mean_of_data
) -> typing.Self:
"""
Spoof an instance from data_reader_base.ReaderData instance.
other should be such an instance.
"""

hl = 5
dx = 0.5
dy = 0.5
other_copy = deepcopy(other)
num_healpix_cells = 12 * 4**hl
lons, lats = hp.healpix_to_lonlat(
np.arange(0, num_healpix_cells), 2**hl, dx=dx, dy=dy, order="nested"
)
other_copy.coords = np.stack([lats.deg, lons.deg], axis=-1, dtype=np.float32)
other_copy.geoinfos = np.zeros((other_copy.coords.shape[0], geoinfo_size), dtype=np.float32)

other_copy.data = np.expand_dims(mean_of_data.astype(np.float32), axis=0).repeat(
other_copy.coords.shape[0], axis=0
)
other_copy.datetimes = np.array(datetime).repeat(other_copy.coords.shape[0])

n_datapoints = len(other_copy.data)

assert other_copy.coords.shape == (n_datapoints, 2), (
"number of datapoints do not match data"
)
assert other_copy.geoinfos.shape[0] == n_datapoints, (
"number of datapoints do not match data"
)
assert other_copy.datetimes.shape[0] == n_datapoints, (
"number of datapoints do not match data"
)

return cls(**dataclasses.asdict(other_copy))


@dataclasses.dataclass
class ItemKey:
Expand Down Expand Up @@ -418,10 +379,10 @@ def extract(self, key: ItemKey) -> OutputItem:
data_coords = self._extract_coordinates(stream_idx, offset_key, datapoints)

assert len(data_coords.channels) == target_data.shape[1], (
"Number of channel names does not align with target data."
"Number of channel names does not align with target data.", len(data_coords.channels), target_data.shape
)
assert len(data_coords.channels) == preds_data.shape[1], (
"Number of channel names does not align with prediction data."
"Number of channel names does not align with prediction data.", len(data_coords.channels), preds_data.shape
)

if key.with_source:
Expand Down
3 changes: 2 additions & 1 deletion src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import torch
from torch import Tensor

from weathergen.common.io import IOReaderData
from weathergen.datasets.data_reader_anemoi import DataReaderAnemoi
Expand Down Expand Up @@ -284,7 +285,7 @@ def denormalize_source_channels(self, obs_id, data):
return self.streams_datasets[obs_id][0].denormalize_source_channels(data)

###################################################
def denormalize_target_channels(self, obs_id, data):
def denormalize_target_channels(self, obs_id, data) -> Tensor:
# TODO: with multiple ds per stream we need to distinguish these here
return self.streams_datasets[obs_id][0].denormalize_target_channels(data)

Expand Down
19 changes: 12 additions & 7 deletions src/weathergen/datasets/stream_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import astropy_healpix as hp
import numpy as np
import torch
from numpy.typing import NDArray

from weathergen.common.io import IOReaderData

Expand Down Expand Up @@ -373,13 +374,7 @@ def spoof(healpix_level: int, datetime, geoinfo_size, mean_of_data) -> IOReaderD
other should be such an instance.
"""

dx = 0.5
dy = 0.5
num_healpix_cells = 12 * 4**healpix_level
lons, lats = hp.healpix_to_lonlat(
np.arange(0, num_healpix_cells), 2**healpix_level, dx=dx, dy=dy, order="nested"
)
coords = np.stack([lats.deg, lons.deg], axis=-1, dtype=np.float32)
coords = healpix_coords(healpix_level)
geoinfos = np.zeros((coords.shape[0], geoinfo_size), dtype=np.float32)

data = np.expand_dims(mean_of_data.astype(np.float32), axis=0).repeat(coords.shape[0], axis=0)
Expand All @@ -404,3 +399,13 @@ def spoof(healpix_level: int, datetime, geoinfo_size, mean_of_data) -> IOReaderD
)

return IOReaderData(coords, geoinfos, data, datetimes)

def healpix_coords(healpix_level: int) -> NDArray[np.float32]:
dx = 0.5
dy = 0.5
num_healpix_cells = 12 * 4**healpix_level
lons, lats = hp.healpix_to_lonlat(
np.arange(0, num_healpix_cells), 2**healpix_level, dx=dx, dy=dy, order="nested"
)
coords = np.stack([lats.deg, lons.deg], axis=-1, dtype=np.float32)
return coords
38 changes: 26 additions & 12 deletions src/weathergen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
import math
import warnings
from pathlib import Path
from dataclasses import dataclass

import astropy_healpix as hp
import astropy_healpix.healpy
import numpy as np
import torch
import torch.nn as nn
from astropy_healpix import healpy
from torch.utils.checkpoint import checkpoint
from torch import Tensor

from weathergen.common.config import Config
from weathergen.model.engines import (
Expand Down Expand Up @@ -196,6 +197,15 @@ def reset_parameters(self, cf: Config) -> "ModelParams":
return


@dataclass
class ForwardOutput:
"""
Output of the model forward pass.
"""
tokens: list[Tensor]
posteriors: Tensor
predictions: list[list[Tensor]]

####################################################################################################
class Model(torch.nn.Module):
"""WeatherGenerator model architecture
Expand Down Expand Up @@ -545,10 +555,10 @@ def forward_jac(self, *args):

preds_all = self.forward(sources, sources_lens)

return tuple(preds_all[0])
return tuple(preds_all.targets)

#########################################
def forward(self, model_params: ModelParams, batch, forecast_offset: int, forecast_steps: int):
def forward(self, model_params: ModelParams, batch, forecast_offset: int, forecast_steps: int) -> ForwardOutput:
"""Performs the forward pass of the model to generate forecasts

Tokens are processed through the model components, which were defined in the create method.
Expand All @@ -572,12 +582,13 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca
tokens = self.embed_cells(model_params, streams_data)

# local assimilation engine and adapter
tokens, posteriors = self.assimilate_local(model_params, tokens, source_cell_lens)
local_tokens, posteriors = self.assimilate_local(model_params, tokens, source_cell_lens)

tokens = self.assimilate_global(model_params, tokens)
tokens = self.assimilate_global(model_params, local_tokens)

# roll-out in latent space
preds_all = []
tokens_all = [tokens]
for fstep in range(forecast_offset, forecast_offset + forecast_steps):
# prediction
preds_all += [
Expand All @@ -590,13 +601,15 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca
)
]

if self.training:
# Impute noise to the latent state
noise_std = self.cf.get("impute_latent_noise_std", 0.0)
if noise_std > 0.0:
tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std
noise_std = self.cf.get("impute_latent_noise_std", 0.0)
if self.training and noise_std > 0.0:
noisy_tokens = tokens + torch.randn_like(tokens) * torch.norm(tokens) * noise_std
else:
noisy_tokens = tokens


tokens = self.forecast(model_params, tokens)
tokens = self.forecast(model_params, noisy_tokens)
tokens_all.append(tokens)

# prediction for final step
preds_all += [
Expand All @@ -609,7 +622,8 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca
)
]

return preds_all, posteriors

return ForwardOutput(tokens=tokens_all, posteriors=posteriors, predictions=preds_all)

#########################################
def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor:
Expand Down
13 changes: 7 additions & 6 deletions src/weathergen/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ def train_with_args(argl: list[str], stream_dir: str | None):


if __name__ == "__main__":
# Entry point for slurm script.
# Check whether --from_run_id passed as argument.
if next((True for arg in sys.argv if "--from_run_id" in arg), False):
train_continue()
else:
train()
inference()
# # Entry point for slurm script.
# # Check whether --from_run_id passed as argument.
# if next((True for arg in sys.argv if "--from_run_id" in arg), False):
# train_continue()
# else:
# train()
20 changes: 20 additions & 0 deletions src/weathergen/train/structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass

from numpy.typing import NDArray
from torch import Tensor

@dataclass
class TrainerPredictions:
# Denormalized predictions
preds_all: list[list[list[Tensor]]]
# Denormalized targets
targets_all: list[list[list[Tensor]]]
# Raw target coordinates
targets_coords_raw: list[list[Tensor]]
# Raw target timestamps
targets_times_raw: list[list[NDArray]]
# Target lengths
targets_lens: list[list[list[int]]]
tokens_all: list[Tensor]
tokens_coords_raw: list[NDArray]
tokens_times_raw: list[NDArray]
Loading
Loading