From b8f77fec04e5eac61e435b98ff3fd862f957bd7e Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Thu, 21 Aug 2025 01:30:00 +0530 Subject: [PATCH 1/3] add test for model_forward_output to check the modl's output shape this is skipped for now since the version of timexer on main isn't compatible with the contract outlined by the test --- tests/test_models/test_timexer.py | 145 +++++++++++++++++++++++++----- 1 file changed, 122 insertions(+), 23 deletions(-) diff --git a/tests/test_models/test_timexer.py b/tests/test_models/test_timexer.py index 2f5518026..d52cb358e 100644 --- a/tests/test_models/test_timexer.py +++ b/tests/test_models/test_timexer.py @@ -17,6 +17,27 @@ from pytorch_forecasting.models import TimeXer +def _expected_fwd_shape(batch_size, prediction_length, loss): + """ + Return the expected output shape for the forward pass of the model. + """ + + if isinstance(loss, QuantileLoss): + n_quantiles = len(loss.quantiles) + return (batch_size, prediction_length, n_quantiles) + elif isinstance(loss, MultiLoss): + shapes = [] + for single_loss in loss.losses: + if isinstance(single_loss, QuantileLoss): + n_quantiles = len(single_loss.quantiles) + shapes.append((batch_size, prediction_length, n_quantiles)) + else: + shapes.append((batch_size, prediction_length, 1)) + return shapes + else: + return (batch_size, prediction_length, 1) + + def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs): """ Integration test for the TimeXer model. @@ -97,34 +118,11 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) trainer.checkpoint_callback.best_model_path, ) - predictions = net.predict( - val_dataloader, - return_index=True, - return_x=True, - return_y=True, - fast_dev_run=True, - trainer_kwargs=trainer_kwargs, - ) - - if isinstance(predictions.output, torch.Tensor): - assert predictions.output.ndim == 2, ( - f"shapes of the output should be [batch_size, n_targets], " - f"but got {predictions.output.shape}" - ) - else: - assert all(p.ndim for p in predictions.output), ( - f"shapes of the output should be [batch_size, n_targets], " - f"but got {predictions.output.shape}" - ) - - # raw prediction if debugging the model - net.predict( val_dataloader, return_index=True, return_x=True, fast_dev_run=True, - mode="raw", trainer_kwargs=trainer_kwargs, ) @@ -471,3 +469,104 @@ def test_with_exogenous_variables(tmp_path): finally: shutil.rmtree(tmp_path, ignore_errors=True) + + +@pytest.mark.skipif( + True, reason="Skipping due to incompatibility with current model outputs." +) # noqa: E501 +def test_model_forward_output(dataloaders_with_covariates): + """ + Test the model's forward output shapes. + This test checks that the model's forward pass returns outputs + of expected shapes based on the loss function used. + Args: + dataloaders_with_covariates: The dataloaders to use for training and validation + """ + + train_dataloader = dataloaders_with_covariates["train"] + val_dataloader = dataloaders_with_covariates["val"] + + dataset = train_dataloader.dataset + batch = next(iter(val_dataloader)) + x, y = batch + + batch_size = x["encoder_cont"].shape[0] + prediction_length = dataset.max_prediction_length + + loss = MAE() + model = TimeXer.from_dataset( + dataset, + hidden_size=16, + n_heads=2, + e_layers=1, + patch_length=2, + dropout=0.1, + loss=loss, + ) + + with torch.no_grad(): + output = model(x) + + prediction = output["prediction"] + expected_shape = _expected_fwd_shape( + batch_size=batch_size, + prediction_length=prediction_length, + loss=loss, + ) + + assert ( + prediction.shape == expected_shape + ), f"Expected output shape {expected_shape}, but got {prediction.shape}" + + quantile_loss = QuantileLoss(quantiles=[0.1, 0.5, 0.9]) + model_quantile = TimeXer.from_dataset( + dataset, + hidden_size=16, + n_heads=2, + e_layers=1, + patch_length=2, + dropout=0.1, + loss=quantile_loss, + ) + + with torch.no_grad(): + output_quantile = model_quantile(x) + prediction_quantile = output_quantile["prediction"] + expected_shape_quantile = _expected_fwd_shape( + batch_size=batch_size, + prediction_length=prediction_length, + loss=quantile_loss, + ) + assert prediction_quantile.shape == expected_shape_quantile, ( + f"Expected output shape {expected_shape_quantile}, but got {prediction_quantile.shape}" # noqa: E501 + ) + + multi_loss = MultiLoss([MAE(), MAE()]) + model_multi = TimeXer.from_dataset( + dataset, + hidden_size=16, + n_heads=2, + e_layers=1, + d_ff=32, + patch_length=2, + dropout=0.1, + loss=multi_loss, + ) + + with torch.no_grad(): + output_multi = model_multi(x) + + prediction_multi = output_multi["prediction"] + expected_shapes_multi = _expected_fwd_shape( + batch_size, prediction_length, multi_loss + ) + + assert isinstance(prediction_multi, list) + assert len(prediction_multi) == len(expected_shapes_multi) + + for i, (pred_tensor, expected_shape) in enumerate( + zip(prediction_multi, expected_shapes_multi) + ): # noqa: E501 + assert ( + pred_tensor.shape == expected_shape + ), f"MultiLoss target {i}: Expected {expected_shape}, got {pred_tensor.shape}" # noqa: E501 From ff761477be2ace32bd38ca1f167470b397b8a80a Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Thu, 21 Aug 2025 01:31:38 +0530 Subject: [PATCH 2/3] restore few deleted lines from original _integration --- tests/test_models/test_timexer.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_models/test_timexer.py b/tests/test_models/test_timexer.py index d52cb358e..f7970f42a 100644 --- a/tests/test_models/test_timexer.py +++ b/tests/test_models/test_timexer.py @@ -126,6 +126,26 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) trainer_kwargs=trainer_kwargs, ) + predictions = net.predict( + val_dataloader, + return_index=True, + return_x=True, + return_y=True, + fast_dev_run=True, + trainer_kwargs=trainer_kwargs, + ) + + if isinstance(predictions.output, torch.Tensor): + assert predictions.output.ndim == 2, ( + f"shapes of the output should be [batch_size, n_targets], " + f"but got {predictions.output.shape}" + ) + else: + assert all(p.ndim for p in predictions.output), ( + f"shapes of the output should be [batch_size, n_targets], " + f"but got {predictions.output.shape}" + ) + finally: # remove the temporary directory created for the test shutil.rmtree(tmp_path, ignore_errors=True) From 0fd21d7ec0acb95c828f881ea44efb6d85049960 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Thu, 21 Aug 2025 12:42:18 +0530 Subject: [PATCH 3/3] remove unecessary predict call in _integration --- tests/test_models/test_timexer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/test_models/test_timexer.py b/tests/test_models/test_timexer.py index f7970f42a..abf42ec9a 100644 --- a/tests/test_models/test_timexer.py +++ b/tests/test_models/test_timexer.py @@ -117,15 +117,6 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) net = TimeXer.load_from_checkpoint( trainer.checkpoint_callback.best_model_path, ) - - net.predict( - val_dataloader, - return_index=True, - return_x=True, - fast_dev_run=True, - trainer_kwargs=trainer_kwargs, - ) - predictions = net.predict( val_dataloader, return_index=True,