Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 122 additions & 12 deletions tests/test_models/test_timexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@
from pytorch_forecasting.models import TimeXer


def _expected_fwd_shape(batch_size, prediction_length, loss):
"""
Return the expected output shape for the forward pass of the model.
"""

if isinstance(loss, QuantileLoss):
n_quantiles = len(loss.quantiles)
return (batch_size, prediction_length, n_quantiles)
elif isinstance(loss, MultiLoss):
shapes = []
for single_loss in loss.losses:
if isinstance(single_loss, QuantileLoss):
n_quantiles = len(single_loss.quantiles)
shapes.append((batch_size, prediction_length, n_quantiles))
else:
shapes.append((batch_size, prediction_length, 1))
return shapes
else:
return (batch_size, prediction_length, 1)


def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs):
"""
Integration test for the TimeXer model.
Expand Down Expand Up @@ -96,7 +117,6 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
net = TimeXer.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
)

predictions = net.predict(
val_dataloader,
return_index=True,
Expand All @@ -117,17 +137,6 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
f"but got {predictions.output.shape}"
)

# raw prediction if debugging the model

net.predict(
val_dataloader,
return_index=True,
return_x=True,
fast_dev_run=True,
mode="raw",
trainer_kwargs=trainer_kwargs,
)

finally:
# remove the temporary directory created for the test
shutil.rmtree(tmp_path, ignore_errors=True)
Expand Down Expand Up @@ -471,3 +480,104 @@ def test_with_exogenous_variables(tmp_path):

finally:
shutil.rmtree(tmp_path, ignore_errors=True)


@pytest.mark.skipif(
True, reason="Skipping due to incompatibility with current model outputs."
) # noqa: E501
def test_model_forward_output(dataloaders_with_covariates):
"""
Test the model's forward output shapes.
This test checks that the model's forward pass returns outputs
of expected shapes based on the loss function used.
Args:
dataloaders_with_covariates: The dataloaders to use for training and validation
"""

train_dataloader = dataloaders_with_covariates["train"]
val_dataloader = dataloaders_with_covariates["val"]

dataset = train_dataloader.dataset
batch = next(iter(val_dataloader))
x, y = batch

batch_size = x["encoder_cont"].shape[0]
prediction_length = dataset.max_prediction_length

loss = MAE()
model = TimeXer.from_dataset(
dataset,
hidden_size=16,
n_heads=2,
e_layers=1,
patch_length=2,
dropout=0.1,
loss=loss,
)

with torch.no_grad():
output = model(x)

prediction = output["prediction"]
expected_shape = _expected_fwd_shape(
batch_size=batch_size,
prediction_length=prediction_length,
loss=loss,
)

assert (
prediction.shape == expected_shape
), f"Expected output shape {expected_shape}, but got {prediction.shape}"

quantile_loss = QuantileLoss(quantiles=[0.1, 0.5, 0.9])
model_quantile = TimeXer.from_dataset(
dataset,
hidden_size=16,
n_heads=2,
e_layers=1,
patch_length=2,
dropout=0.1,
loss=quantile_loss,
)

with torch.no_grad():
output_quantile = model_quantile(x)
prediction_quantile = output_quantile["prediction"]
expected_shape_quantile = _expected_fwd_shape(
batch_size=batch_size,
prediction_length=prediction_length,
loss=quantile_loss,
)
assert prediction_quantile.shape == expected_shape_quantile, (
f"Expected output shape {expected_shape_quantile}, but got {prediction_quantile.shape}" # noqa: E501
)

multi_loss = MultiLoss([MAE(), MAE()])
model_multi = TimeXer.from_dataset(
dataset,
hidden_size=16,
n_heads=2,
e_layers=1,
d_ff=32,
patch_length=2,
dropout=0.1,
loss=multi_loss,
)

with torch.no_grad():
output_multi = model_multi(x)

prediction_multi = output_multi["prediction"]
expected_shapes_multi = _expected_fwd_shape(
batch_size, prediction_length, multi_loss
)

assert isinstance(prediction_multi, list)
assert len(prediction_multi) == len(expected_shapes_multi)

for i, (pred_tensor, expected_shape) in enumerate(
zip(prediction_multi, expected_shapes_multi)
): # noqa: E501
assert (
pred_tensor.shape == expected_shape
), f"MultiLoss target {i}: Expected {expected_shape}, got {pred_tensor.shape}" # noqa: E501
Loading