diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index aceec0869..c144ba74b 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -5,6 +5,7 @@ ######################################################################################## +import inspect from typing import Optional, Union from warnings import warn @@ -46,6 +47,17 @@ def __init__( Parameters for the learning rate scheduler. """ super().__init__() + + # simple check for MultiLoss usage. + if inspect.isclass(loss) and loss.__class__.__name__ == "MultiLoss": + warn( + "\nIMPORTANT: Multi-target forecasting (MultiLoss) is NOT supported " + "in v2 base models. For multi-target forecasting, please use " + "pytorch_forecasting.models.base.BaseModel (v1) instead. " + "Attempting to use MultiLoss with v2 models will result in runtime errors.", # noqa: E501 + UserWarning, + stacklevel=2, + ) self.loss = loss self.logging_metrics = logging_metrics if logging_metrics is not None else [] self.optimizer = optimizer