Skip to content

Commit d0f775e

Browse files
fix: incorrect test for variable mask scaler (#649)
## Description The test was incorrect for the variable mask scaler. The scaler is setup properly. Added more description and fixed the test. ## What issue or task does this change relate to? #647 ## Additional notes ## ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent ca1b542 commit d0f775e

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

training/tests/unit/train/test_loss_scaling.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,12 @@ def test_variable_masking(
449449
)
450450
vars_to_mask = ["z", "other", "q"]
451451
indices_to_mask = [data_indices.model.output.name_to_index[v] for v in vars_to_mask]
452-
assert scalers["variable_masking"][0][0] == len(vars_to_mask)
452+
scaler = scalers["variable_masking"]
453+
assert scaler[0][0] == TensorDim.VARIABLE.value, "Expected scaler to be applied along variable dimension"
454+
# masked variables should have scaler of 0, unmasked 1
455+
assert int(scaler[1].sum().item()) == scaler[1].shape[0] - len(
456+
vars_to_mask,
457+
), "Sum of scaler values should be equal to number of unmasked variables"
453458
assert not scalers["variable_masking"][1][indices_to_mask].any(), "Expected scalers for masked variables to be zero"
454459

455460
config.training.scalers.builders["variable_masking"].update(invert=True)
@@ -462,8 +467,12 @@ def test_variable_masking(
462467
metadata_extractor=metadata_extractor,
463468
output_mask=NoOutputMask(),
464469
)
465-
assert scalers["variable_masking"][0][0] == len(vars_to_mask)
466-
assert scalers["variable_masking"][1][indices_to_mask].all(), "Expected scalers for unmasked variables to be one"
470+
inverted_scaler = scalers["variable_masking"]
471+
# dimension where scaler is applied is variable
472+
assert inverted_scaler[0][0] == TensorDim.VARIABLE.value
473+
# masked variables with inverted = True should have scaler of 1, unmasked 0
474+
assert int(inverted_scaler[1].sum().item()) == len(vars_to_mask)
475+
assert inverted_scaler[1][indices_to_mask].all(), "Expected scalers for unmasked variables to be one"
467476

468477

469478
def test_variable_loss_scaling_val_complex_variable_groups(

0 commit comments

Comments
 (0)