Skip to content

Commit 3c9ca8e

Browse files
VeraChristinafrazane
authored andcommitted
chore: remove duplicate plotting code (#639)
## Description Removing duplicate plotting code, see lines 936-968 ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent a5d2b41 commit 3c9ca8e

File tree

1 file changed

+0
-35
lines changed
  • training/src/anemoi/training/diagnostics/callbacks

1 file changed

+0
-35
lines changed

training/src/anemoi/training/diagnostics/callbacks/plot.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,41 +1070,6 @@ def _plot(
10701070
)
10711071

10721072

1073-
class BasePlotAdditionalMetrics(BasePerBatchPlotCallback):
1074-
"""Base processing class for additional metrics."""
1075-
1076-
def process(
1077-
self,
1078-
pl_module: pl.LightningModule,
1079-
outputs: list,
1080-
batch: torch.Tensor,
1081-
) -> tuple[np.ndarray, np.ndarray]:
1082-
if self.latlons is None:
1083-
self.latlons = np.rad2deg(pl_module.latlons_data.clone().detach().cpu().numpy())
1084-
1085-
input_tensor = (
1086-
batch[
1087-
:,
1088-
pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1,
1089-
...,
1090-
pl_module.data_indices.data.output.full,
1091-
]
1092-
.detach()
1093-
.cpu()
1094-
)
1095-
data = self.post_processors(input_tensor)[self.sample_idx]
1096-
output_tensor = torch.cat(
1097-
tuple(
1098-
self.post_processors(x[:, ...].detach().cpu(), in_place=False)[self.sample_idx : self.sample_idx + 1]
1099-
for x in outputs[1]
1100-
),
1101-
)
1102-
output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy()
1103-
data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan)
1104-
data = data.numpy()
1105-
return data, output_tensor
1106-
1107-
11081073
class PlotSpectrum(BasePlotAdditionalMetrics):
11091074
"""Plots TP related metric comparing target and prediction.
11101075

0 commit comments

Comments
 (0)