Skip to content

Commit ecaf28e

Browse files
tjhunterclessig
authored andcommitted
[1144] Extra fixes (#1148)
* Fixed problem in inferecne * more fixes * fixes * lint * lint --------- Co-authored-by: Christian Lessig <[email protected]>
1 parent 8862303 commit ecaf28e

File tree

5 files changed

+49
-39
lines changed

5 files changed

+49
-39
lines changed

integration_tests/small1_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,9 @@ def assert_train_loss_below_threshold(run_id):
172172
)
173173
# Check that the loss does not explode in a single epoch
174174
# This is meant to be a quick test, not a convergence test
175-
assert loss_metric < 1.25, (
176-
f"'stream.ERA5.loss_mse.loss_avg' is {loss_metric}, expected to be below 0.25"
175+
target = 1.5
176+
assert loss_metric < target, (
177+
f"'stream.ERA5.loss_mse.loss_avg' is {loss_metric}, expected to be below {target}"
177178
)
178179

179180

packages/common/src/weathergen/common/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) ->
8686
path = Path(model_path)
8787
fname = path / run_id / _get_model_config_file_name(run_id, epoch)
8888
assert fname.exists(), (
89-
"The fallback path to the model does not exist. Please provide a `model_path`."
89+
"The fallback path to the model does not exist. Please provide a `model_path`.",
90+
fname,
9091
)
9192

9293
_logger.info(f"Loading config from specified run_id and epoch: {fname}")

packages/common/src/weathergen/common/io.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
_logger = logging.getLogger(__name__)
3131

3232

33+
def is_ndarray(obj: typing.Any) -> bool:
34+
"""Check if object is an ndarray (wraps the linter warning)."""
35+
return isinstance(obj, (np.ndarray)) # noqa: TID251
36+
37+
3338
@dataclasses.dataclass
3439
class IOReaderData:
3540
"""
@@ -58,10 +63,10 @@ def create(cls, other: typing.Any) -> "IOReaderData":
5863
5964
other should be such an instance.
6065
"""
61-
coords = other.coords
62-
geoinfos = other.geoinfos
63-
data = other.data
64-
datetimes = other.datetimes
66+
coords = np.asarray(other.coords)
67+
geoinfos = np.asarray(other.geoinfos)
68+
data = np.asarray(other.data)
69+
datetimes = np.asarray(other.datetimes)
6570

6671
n_datapoints = len(data)
6772

@@ -130,22 +135,22 @@ class OutputDataset:
130135
item_key: ItemKey
131136

132137
# (datapoints, channels, ens)
133-
data: zarr.Array # wrong type => array like
138+
data: zarr.Array | NDArray # wrong type => array like
134139

135140
# (datapoints,)
136-
times: zarr.Array
141+
times: zarr.Array | NDArray
137142

138143
# (datapoints, 2)
139-
coords: zarr.Array
144+
coords: zarr.Array | NDArray
140145

141146
# (datapoints, geoinfos) geoinfos are stream dependent => 0 for most gridded data
142-
geoinfo: zarr.Array
147+
geoinfo: zarr.Array | NDArray
143148

144149
channels: list[str]
145150
geoinfo_channels: list[str]
146151

147152
@functools.cached_property
148-
def arrays(self) -> dict[str, zarr.Array]:
153+
def arrays(self) -> dict[str, zarr.Array | NDArray]:
149154
"""Iterate over the arrays and their names."""
150155
return {
151156
"data": self.data,
@@ -236,7 +241,8 @@ def write_zarr(self, item: OutputItem):
236241
"""Write one output item to the zarr store."""
237242
group = self._get_group(item.key, create=True)
238243
for dataset in item.datasets:
239-
self._write_dataset(group, dataset)
244+
if dataset is not None:
245+
self._write_dataset(group, dataset)
240246

241247
def get_data(self, sample: int, stream: str, forecast_step: int) -> OutputItem:
242248
"""Get datasets for the output item matching the arguments."""
@@ -285,6 +291,7 @@ def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset):
285291
self._create_dataset(dataset_group, array_name, array)
286292

287293
def _create_dataset(self, group: zarr.Group, name: str, array: NDArray):
294+
assert is_ndarray(array), f"Expected ndarray but got: {type(array)}"
288295
if array.size == 0: # sometimes for geoinfo
289296
chunks = None
290297
else:
@@ -394,20 +401,10 @@ def extract(self, key: ItemKey) -> OutputItem:
394401
target_data = np.zeros((0, len(self.target_channels[stream_idx])), dtype=np.float32)
395402
preds_data = np.zeros((0, len(self.target_channels[stream_idx])), dtype=np.float32)
396403
else:
397-
target_data = (
398-
self.targets[offset_key.forecast_step][stream_idx][0][datapoints]
399-
.cpu()
400-
.detach()
401-
.numpy()
402-
)
403-
preds_data = (
404-
self.predictions[offset_key.forecast_step][stream_idx][0]
405-
.transpose(1, 0)
406-
.transpose(1, 2)[datapoints]
407-
.cpu()
408-
.detach()
409-
.numpy()
410-
)
404+
target_data = self.targets[offset_key.forecast_step][stream_idx][0][datapoints]
405+
preds_data = self.predictions[offset_key.forecast_step][stream_idx][0].transpose(
406+
1, 2, 0
407+
)[datapoints]
411408

412409
data_coords = self._extract_coordinates(stream_idx, offset_key, datapoints)
413410

@@ -423,6 +420,8 @@ def extract(self, key: ItemKey) -> OutputItem:
423420
else:
424421
source_dataset = None
425422

423+
assert is_ndarray(target_data), f"Expected ndarray but got: {type(target_data)}"
424+
assert is_ndarray(preds_data), f"Expected ndarray but got: {type(preds_data)}"
426425
return OutputItem(
427426
key=key,
428427
source=source_dataset,
@@ -501,10 +500,10 @@ def _extract_sources(self, sample, stream_idx, key):
501500
source_dataset = OutputDataset(
502501
"source",
503502
key,
504-
source.data,
505-
source.datetimes,
506-
source.coords,
507-
source.geoinfos,
503+
np.asarray(source.data),
504+
np.asarray(source.datetimes),
505+
np.asarray(source.coords),
506+
np.asarray(source.geoinfos),
508507
channels,
509508
geoinfo_channels,
510509
)

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,12 @@ def reset(self):
313313
self.tokenizer.reset_rng(self.rng)
314314

315315
###################################################
316-
def denormalize_source_channels(self, stream_id, data):
316+
def denormalize_source_channels(self, stream_id, data) -> torch.Tensor:
317317
# TODO: with multiple ds per stream we need to distinguish these here
318318
return self.streams_datasets[stream_id][0].denormalize_source_channels(data)
319319

320320
###################################################
321-
def denormalize_target_channels(self, stream_id, data):
321+
def denormalize_target_channels(self, stream_id, data) -> torch.Tensor:
322322
# TODO: with multiple ds per stream we need to distinguish these here
323323
return self.streams_datasets[stream_id][0].denormalize_target_channels(data)
324324

src/weathergen/train/trainer.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn as nn
2121
import tqdm
22+
from numpy.typing import NDArray
2223
from omegaconf import OmegaConf
2324
from torch import Tensor
2425

@@ -240,7 +241,7 @@ def init_model_and_shard(self, cf, devices):
240241
for tensor in itertools.chain(model.parameters(), model.buffers()):
241242
assert tensor.device == torch.device("meta")
242243

243-
# For reasons we do not yet fully understand, when using train continue in some
244+
# For reasons we do not yet fully understand, when using train continue in some
244245
# instances, FSDP2 does not register the forward_channels and forward_columns
245246
# functions in the embedding engine as forward functions. Thus, yielding a crash
246247
# because the input tensors are not converted to DTensors. This seems to primarily
@@ -518,9 +519,13 @@ def _prepare_logging(
518519

519520
# assert len(targets_rt) == len(preds) and len(preds) == len(self.cf.streams)
520521
fsteps = len(targets_rt)
521-
preds_all = [[[] for _ in self.cf.streams] for _ in range(fsteps)]
522-
targets_all = [[[] for _ in self.cf.streams] for _ in range(fsteps)]
523-
targets_lens = [[[] for _ in self.cf.streams] for _ in range(fsteps)]
522+
preds_all: list[list[list[NDArray]]] = [
523+
[[] for _ in self.cf.streams] for _ in range(fsteps)
524+
]
525+
targets_all: list[list[list[NDArray]]] = [
526+
[[] for _ in self.cf.streams] for _ in range(fsteps)
527+
]
528+
targets_lens: list[list[list[int]]] = [[[] for _ in self.cf.streams] for _ in range(fsteps)]
524529

525530
# TODO: iterate over batches here in future, and change loop order to batch, stream, fstep
526531
for fstep in range(len(targets_rt)):
@@ -542,8 +547,12 @@ def _prepare_logging(
542547
dn_data = self.dataset_val.denormalize_target_channels
543548

544549
f32 = torch.float32
545-
preds_all[fstep][i_strm] += [dn_data(i_strm, pred.to(f32)).detach().cpu()]
546-
targets_all[fstep][i_strm] += [dn_data(i_strm, target.to(f32)).detach().cpu()]
550+
preds_all[fstep][i_strm] += [
551+
np.asarray(dn_data(i_strm, pred.to(f32)).detach().cpu())
552+
]
553+
targets_all[fstep][i_strm] += [
554+
np.asarray(dn_data(i_strm, target.to(f32)).detach().cpu())
555+
]
547556

548557
return (
549558
preds_all,

0 commit comments

Comments
 (0)