diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb
new file mode 100644
index 000000000..2036a4784
--- /dev/null
+++ b/docs/source/tutorials/ptf_V2_example.ipynb
@@ -0,0 +1,1159 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rzVbXsEBxnF-"
+ },
+ "source": [
+ "# `pytorch-forecasting v2` Model Training and Inference - Beta API"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yt0uZV7Px-40"
+ },
+ "source": [
+ "
\n",
+ ":warning: The vignette showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice.\n",
+ "\n",
+ "Feedback and suggestions are highly encouraged — please share them in
issue 1736.\n",
+ "
\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6D9ARyp05R0t"
+ },
+ "source": [
+ "In this vignette, we demonstrate how to train and evaluate the **Temporal Fusion Transformer (TFT)** using the new `TimeSeries` and `DataModule` API from the v2 pipeline.\n",
+ "\n",
+ "\n",
+ "## Steps\n",
+ "\n",
+ "1. **Load Data** \n",
+ "2. **Create Dataset & DataModule** \n",
+ "3. **Initialize, Train & Run Inference with the Model**\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Load Data\n",
+ "\n",
+ "We generate a synthetic dataset using `load_toydata` which returns a `pandas` DataFrame with purely numerical values. \n",
+ "*(Note: The current pipeline assumes all inputs are numerical only.)*\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Create Dataset & DataModule\n",
+ "\n",
+ "- `TimeSeries` returns the raw data in terms of tensors .\n",
+ "- `DataModule` wraps the dataset, handles splits, preprocessing, batching, and exposes `metadata` for the model initialisation.\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Initialize the Model\n",
+ "\n",
+ "We initialize the TFT model using the `metadata` provided by the `DataModule`. This metadata includes all required dimensional info for the encoder, decoder, and static inputs.\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Train the Model\n",
+ "\n",
+ "We use a `Trainer` from PyTorch Lightning to train the model\n",
+ "\n",
+ "### Run Inference\n",
+ "\n",
+ "After training, we can make predictions using the trained model\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QyMFNk4MyY_b"
+ },
+ "source": [
+ "# 1. Load Data\n",
+ "We generate a synthetic dataset using `load_toydata` that creates a `pandas` DataFrame with just numerical values as for now **the pipeline assumes the data to be numerical only**."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "RkgOT4kiy_RU"
+ },
+ "outputs": [],
+ "source": [
+ "from pytorch_forecasting.data.examples import load_toydata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 206
+ },
+ "id": "WX-FRdusJSVN",
+ "outputId": "e481484c-b0c3-4026-c933-a9dc047617c5"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6718152112102416,\n \"min\": -1.26886060400029,\n \"max\": 1.3107688985416461,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.1974946533075752,\n 0.8954011563960967,\n -0.802070871866599\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6763475728129391,\n \"min\": -1.26886060400029,\n \"max\": 1.3107688985416461,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.4987517118113735,\n 0.6086435548017073,\n -1.0256970040347706\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2888495447542638,\n \"min\": 0.0028672051720661784,\n \"max\": 0.990604862265208,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.5971646697158197,\n 0.12749395651151985,\n 0.32838971618312873\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}",
+ "type": "dataframe",
+ "variable_name": "data_df"
+ },
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " series_id | \n",
+ " time_idx | \n",
+ " x | \n",
+ " y | \n",
+ " category | \n",
+ " future_known_feature | \n",
+ " static_feature | \n",
+ " static_feature_cat | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " -0.137988 | \n",
+ " 0.016824 | \n",
+ " 0 | \n",
+ " 1.000000 | \n",
+ " 0.927557 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0.016824 | \n",
+ " 0.287291 | \n",
+ " 0 | \n",
+ " 0.995004 | \n",
+ " 0.927557 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0 | \n",
+ " 2 | \n",
+ " 0.287291 | \n",
+ " 0.599587 | \n",
+ " 0 | \n",
+ " 0.980067 | \n",
+ " 0.927557 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0 | \n",
+ " 3 | \n",
+ " 0.599587 | \n",
+ " 0.779352 | \n",
+ " 0 | \n",
+ " 0.955336 | \n",
+ " 0.927557 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0 | \n",
+ " 4 | \n",
+ " 0.779352 | \n",
+ " 0.876148 | \n",
+ " 0 | \n",
+ " 0.921061 | \n",
+ " 0.927557 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n"
+ ],
+ "text/plain": [
+ " series_id time_idx x y category future_known_feature \\\n",
+ "0 0 0 -0.137988 0.016824 0 1.000000 \n",
+ "1 0 1 0.016824 0.287291 0 0.995004 \n",
+ "2 0 2 0.287291 0.599587 0 0.980067 \n",
+ "3 0 3 0.599587 0.779352 0 0.955336 \n",
+ "4 0 4 0.779352 0.876148 0 0.921061 \n",
+ "\n",
+ " static_feature static_feature_cat \n",
+ "0 0.927557 0 \n",
+ "1 0.927557 0 \n",
+ "2 0.927557 0 \n",
+ "3 0.927557 0 \n",
+ "4 0.927557 0 "
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "num_series = 100 # Number of individual time series to generate\n",
+ "seq_length = 50 # Length of each time series\n",
+ "data_df = load_toydata(num_series, seq_length)\n",
+ "data_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RYQ5CdNUyc2q"
+ },
+ "source": [
+ "# 2. Create the dataset and datamodule\n",
+ "We create a `TimeSeries` dataset instance that returns the raw data in terms of tensors, then this \"raw data\" is sent to the `data_module`that will internally handle the dataloaders and preprocessing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ONe8Eo1zzvCH"
+ },
+ "source": [
+ "`TimeSeries` dataset's Key arguments:\n",
+ "- `data`: DataFrame with sequence data.\n",
+ "- `time`: integer typed column denoting the time index within `data`.\n",
+ "- `target`: Column(s) in `data` denoting the forecasting target.\n",
+ "- `group`: List of column names identifying a time series instance within `data`.\n",
+ "- `num`: List of numerical features.\n",
+ "- `cat`: List of categorical features.\n",
+ "- `known`: Features known in future\n",
+ "- `unknown`: Features not known in the future\n",
+ "- `static`: List of variables that do not change over time,"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "JPD3y3qny5Dx"
+ },
+ "outputs": [],
+ "source": [
+ "from pytorch_forecasting.data.timeseries import TimeSeries"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "id": "AxxPHK6AKSD2"
+ },
+ "outputs": [],
+ "source": [
+ "# create `TimeSeries` dataset that returns the raw data in terms of tensors\n",
+ "dataset = TimeSeries(\n",
+ " data=data_df,\n",
+ " time=\"time_idx\",\n",
+ " target=\"y\",\n",
+ " group=[\"series_id\"],\n",
+ " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n",
+ " cat=[\"category\", \"static_feature_cat\"],\n",
+ " known=[\"future_known_feature\"],\n",
+ " unknown=[\"x\", \"category\"],\n",
+ " static=[\"static_feature\", \"static_feature_cat\"],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "S-yHU46v1MhN"
+ },
+ "source": [
+ "`EncoderDecoderTimeSeriesDataModule` key arguments:\n",
+ "- `time_series_dataset`: `TimeSeries` dataset instance\n",
+ "- `max_encoder_length` : Maximum length of the encoder input sequence.\n",
+ "- `max_prediction_length` : Maximum length of the decoder output sequence.\n",
+ "- `batch_size` : Batch size for DataLoader.\n",
+ "- `categorical_encoders` : Dictionary of categorical encoders.\n",
+ "- `scalers` : Dictionary of feature scalers.\n",
+ "- `target_normalizer`: Normalizer for the target variable."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DUWB4LrGyxrL"
+ },
+ "outputs": [],
+ "source": [
+ "from sklearn.preprocessing import StandardScaler\n",
+ "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n",
+ "from pytorch_forecasting.data.encoders import (\n",
+ " NaNLabelEncoder,\n",
+ " TorchNormalizer,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "id": "5U5Lr_ZFKX0s"
+ },
+ "outputs": [],
+ "source": [
+ "# create the `data_module` that handles the dataloaders and preprocessing\n",
+ "data_module = EncoderDecoderTimeSeriesDataModule(\n",
+ " time_series_dataset=dataset,\n",
+ " max_encoder_length=30,\n",
+ " max_prediction_length=1,\n",
+ " batch_size=32,\n",
+ " categorical_encoders={\n",
+ " \"category\": NaNLabelEncoder(add_nan=True),\n",
+ " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n",
+ " },\n",
+ " scalers={\n",
+ " \"x\": StandardScaler(),\n",
+ " \"future_known_feature\": StandardScaler(),\n",
+ " \"static_feature\": StandardScaler(),\n",
+ " },\n",
+ " target_normalizer=TorchNormalizer(),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qykX7vQ7zWnC"
+ },
+ "source": [
+ "# 3. Initialise and train the model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_kz3MO362Tlo"
+ },
+ "source": [
+ "To initialise the model you now don't have to pass arguments like `encoder_cont`, `decoder_cont` etc as they are calculated internally using the `metadata` property [[source]](https://github.com/sktime/pytorch-forecasting/blob/4a34931e499c2b59de3939fcffcaabd75204b045/pytorch_forecasting/data/data_module.py#L264-L292) of `EncoderDecoderTimeSeriesDataModule`. But you still have to pass other params like `loss`, `optimizer` etc"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XvwIuzD34Ytk"
+ },
+ "source": [
+ "\n",
+ "```python\n",
+ "model = TFT(\n",
+ " loss=nn.MSELoss(),\n",
+ " logging_metrics=[MAE(), SMAPE()],\n",
+ " metadata=data_module.metadata, # <-- crucial for model setup\n",
+ " ...\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "The `metadata` includes:\n",
+ "- `max_encoder_length`, `max_prediction_length`\n",
+ "- number of continuous/categorical variables in encoder/decoder\n",
+ "- number of static features\n",
+ "\n",
+ "These are used to configure internal layers like `encoder_cont`, `decoder_cat`, etc.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "id": "xOsEucZnzCkN"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from pytorch_forecasting.metrics import MAE, SMAPE\n",
+ "from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "9qbjnTxnyh4H",
+ "outputId": "f59bf985-ffaa-4980-c890-39a80dfcc598"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/content/pytorch-forecasting/pytorch_forecasting/models/base/_base_model_v2.py:58: UserWarning: The Model 'TFT' is part of an experimental reworkof the pytorch-forecasting model layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. This class is intended for beta testing and as a basic skeleton, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n",
+ " warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Initialise the Model\n",
+ "model = TFT(\n",
+ " loss=nn.MSELoss(),\n",
+ " logging_metrics=[MAE(), SMAPE()],\n",
+ " optimizer=\"adam\",\n",
+ " optimizer_params={\"lr\": 1e-3},\n",
+ " lr_scheduler=\"reduce_lr_on_plateau\",\n",
+ " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n",
+ " hidden_size=64,\n",
+ " num_layers=2,\n",
+ " attention_head_size=4,\n",
+ " dropout=0.1,\n",
+ " metadata=data_module.metadata, # pass the metadata from the datamodule to the model\n",
+ " # to initialise important params like `encoder_cont` etc\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "svdoye-d8F-z"
+ },
+ "source": [
+ "We use a `Trainer` from PyTorch Lightning to train the model:\n",
+ "\n",
+ "```python\n",
+ "trainer = Trainer(max_epochs=5, ...)\n",
+ "trainer.fit(model, data_module)\n",
+ "```\n",
+ "\n",
+ "The `Trainer`:\n",
+ "- Pulls data from `data_module`\n",
+ "- Handles device placement\n",
+ "- Logs training progress and metrics\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "id": "RTSmUu9RytS8"
+ },
+ "outputs": [],
+ "source": [
+ "from lightning.pytorch import Trainer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 930,
+ "referenced_widgets": [
+ "edc196b34a0b49fa992fd89909a8414f",
+ "76e801c32da348788b0b92fd2c2849c1",
+ "cd6d925c2dcf4ae995f62d8f20219fde",
+ "283f67f9ef6a47f8b3d4a389a18b8ccc",
+ "8e19d683c674403b87cbb7c162a0ab33",
+ "525fe2cef0444d558d870c394cf67e81",
+ "640318d6256a4451957fce6481625f6b",
+ "0820bd0fd60f4e0180899f18615d9ee3",
+ "90cb2b96e18948679eda809a9123565c",
+ "5ffb5772ae4f4764b23643cfa547a615",
+ "04f140582780402a94e3391bfcbffa91",
+ "4abfcb0ade1b47e1b015e7efd75cd563",
+ "483c0919c7c740af89e6c7a7470591f7",
+ "18849509bcfc41dd89011efb5690e49b",
+ "098eb27809314e7ca15c5b9bce21b46d",
+ "53f7fc6124fb4dda8bfb147078212ada",
+ "8c2821dc533c4a1ba81d68836047dc54",
+ "cb303a9d13084d47b0714d717f482a80",
+ "491aa926179b41dba21a008dc6e9a4fc",
+ "363967a782db4627a3e16a726b4993c6",
+ "0d28163142aa465784871ec647efa504",
+ "a6d821b9d9c9453b87740c3e4d86ca77",
+ "2301684175b0454897af0fc5e534f52f",
+ "58fe6653d5e944be8fdd072cd3541b8c",
+ "00fb395e4cb5472faf221affc0244780",
+ "1cd0c86474a443b8aae6466405dd3807",
+ "cfcb344ba8cb49a78a8f934657252855",
+ "4f455332437a465bb300ccf50405358d",
+ "fde9c4616cd84e629998819c18adfd1f",
+ "405e6f4ab49b445ba7e942e0d35419eb",
+ "2ee2fe3ea3c04592851fef7e6dda5115",
+ "713bcea77e1744179502f6a187239ce1",
+ "531c73851dcd469791d6ebdc9aa9da52",
+ "b0ed53cd20bd47b5aac5fecaacb9bfed",
+ "da87d95793f84907819ab4c506f6936b",
+ "c8a64619bf574d43a01595e3c8690759",
+ "8d0c66b5cc984b8db7d151a970897266",
+ "a723c1055f4f4c0d9c7b2c8c3469fa9d",
+ "99dcc8dd7e88463a95ddcf86c67ec153",
+ "94c1a818e78a47ee9bb46f1c8b1dad34",
+ "903b0891a3014701b117a391d777e012",
+ "f6053af158304fc1b307a296cdbe7290",
+ "f6c93805c03545458f1af0fffaf77fb0",
+ "52bbe8a40baf4a8f989281bcc225a067",
+ "c9eb21c70ac243d1a21f50f287b4731f",
+ "100a61bf42b0498ba241ddd9ef7a96f3",
+ "1cce85eee5bc4700aa966c05e35ea042",
+ "2dcaeafc8bc84db786f9fa99aeda39db",
+ "7a9cf1446cc14d51a0a729822eb1cabc",
+ "1c6f0e1be6f8475485ba6354ccc68597",
+ "7ef67c160b8740e9932486f857470542",
+ "282e2d3a169a4e18b46ca85dbfe82371",
+ "84a96c68feba48b6880de08722de8743",
+ "c111af2e297641908cb1835d9a4d06ff",
+ "d94d97f0852b450c8ea56e6ef4e44ac5",
+ "1dd910df34e44cdc98586e3094dfddd4",
+ "86c6678f86d944eeba1fe982197cb7d3",
+ "b609e7acf93944a98f11b450e01222f8",
+ "0baf540b09454f61873cd6a3088a3c9d",
+ "7bc8c547554c481481d0653193dcd917",
+ "2d92b66364844b5491dabee7a0b90686",
+ "99a8fec0a22f4949a3d16fc7f48aeea6",
+ "0fb99264aa0644cda962905dfe1d6997",
+ "f3d688e29f4b4432ac370d6a1e76ec5c",
+ "d4cd0f3144784991ba9ec0b178a6c2e7",
+ "83a1e407cf60401b932c6ee6e72d90db",
+ "99443f56f20a4c6da81ddb107ab84490",
+ "2da428c9a3b541bc8f667d5715fdcb0d",
+ "f483a26167424f66b8b3c537537afd77",
+ "bc01c61ee2a14f3bbea23d04bb7e9dc9",
+ "b59e64ed8d9d46dd8803bfe25602c391",
+ "cc1d7b4924784a08999f41e1e3c989ac",
+ "8c6ad939d53942e48d8d6fb4c30b666d",
+ "49d44a5db92b4b3e979e88d14072a16c",
+ "36d0bfd0973c43e4958af452d2c7025c",
+ "8a90ddfc1fd446a6894c2771195a4419",
+ "aa94917a2f364a4ea9033000730d48b7"
+ ]
+ },
+ "id": "aB_ayE_eykXp",
+ "outputId": "02c49d3e-2124-4b0b-8ca4-b2d886662613"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n",
+ "INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n",
+ "INFO: GPU available: False, used: False\n",
+ "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n",
+ "INFO: TPU available: False, using: 0 TPU cores\n",
+ "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n",
+ "INFO: HPU available: False, using: 0 HPUs\n",
+ "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n",
+ "INFO: \n",
+ " | Name | Type | Params | Mode\n",
+ "--------------------------------------------------------------------\n",
+ "0 | loss | MSELoss | 0 | eval\n",
+ "1 | encoder_var_selection | Sequential | 709 | eval\n",
+ "2 | decoder_var_selection | Sequential | 193 | eval\n",
+ "3 | static_context_linear | Linear | 192 | eval\n",
+ "4 | lstm_encoder | LSTM | 51.5 K | eval\n",
+ "5 | lstm_decoder | LSTM | 50.4 K | eval\n",
+ "6 | self_attention | MultiheadAttention | 16.6 K | eval\n",
+ "7 | pre_output | Linear | 4.2 K | eval\n",
+ "8 | output_layer | Linear | 65 | eval\n",
+ "--------------------------------------------------------------------\n",
+ "123 K Trainable params\n",
+ "0 Non-trainable params\n",
+ "123 K Total params\n",
+ "0.495 Total estimated model params size (MB)\n",
+ "0 Modules in train mode\n",
+ "18 Modules in eval mode\n",
+ "INFO:lightning.pytorch.callbacks.model_summary:\n",
+ " | Name | Type | Params | Mode\n",
+ "--------------------------------------------------------------------\n",
+ "0 | loss | MSELoss | 0 | eval\n",
+ "1 | encoder_var_selection | Sequential | 709 | eval\n",
+ "2 | decoder_var_selection | Sequential | 193 | eval\n",
+ "3 | static_context_linear | Linear | 192 | eval\n",
+ "4 | lstm_encoder | LSTM | 51.5 K | eval\n",
+ "5 | lstm_decoder | LSTM | 50.4 K | eval\n",
+ "6 | self_attention | MultiheadAttention | 16.6 K | eval\n",
+ "7 | pre_output | Linear | 4.2 K | eval\n",
+ "8 | output_layer | Linear | 65 | eval\n",
+ "--------------------------------------------------------------------\n",
+ "123 K Trainable params\n",
+ "0 Non-trainable params\n",
+ "123 K Total params\n",
+ "0.495 Total estimated model params size (MB)\n",
+ "0 Modules in train mode\n",
+ "18 Modules in eval mode\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Training model...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "edc196b34a0b49fa992fd89909a8414f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Sanity Checking: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
+ }
+ }
+ }
+ },
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a6d821b9d9c9453b87740c3e4d86ca77",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Training: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
+ }
+ }
+ }
+ },
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "531c73851dcd469791d6ebdc9aa9da52",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
+ }
+ }
+ }
+ },
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "52bbe8a40baf4a8f989281bcc225a067",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
+ }
+ }
+ }
+ },
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d94d97f0852b450c8ea56e6ef4e44ac5",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
+ }
+ }
+ }
+ },
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "83a1e407cf60401b932c6ee6e72d90db",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
+ }
+ }
+ }
+ },
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "aa94917a2f364a4ea9033000730d48b7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
+ }
+ }
+ }
+ },
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO: `Trainer.fit` stopped: `max_epochs=5` reached.\n",
+ "INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Train the model\n",
+ "print(\"\\nTraining model...\")\n",
+ "trainer = Trainer(\n",
+ " max_epochs=5,\n",
+ " accelerator=\"auto\",\n",
+ " devices=1,\n",
+ " enable_progress_bar=True,\n",
+ " log_every_n_steps=10,\n",
+ ")\n",
+ "\n",
+ "trainer.fit(model, data_module)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "p3mI-QVJ8TZF"
+ },
+ "source": [
+ "After training, we can make predictions using the trained model:\n",
+ "\n",
+ "```python\n",
+ "model.eval()\n",
+ "with torch.no_grad():\n",
+ " batch = next(iter(data_module.test_dataloader()))\n",
+ " x, y = batch\n",
+ " y_pred = model(x)\n",
+ "```\n",
+ "\n",
+ "#### Output\n",
+ "Output of TFT model is a `dict` with key `prediction`:\n",
+ "\n",
+ "- `y_pred[\"prediction\"]`: Tensor of shape `(batch_size, prediction_length, output_size)`\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 278,
+ "referenced_widgets": [
+ "a081e60ec1604f6abf87dc3f108d66b6",
+ "565d40dbd4b0467396949cc2adc030d5",
+ "f41ef269f3754fffa1296278be09be3b",
+ "2f336cafec17479f80f897042286233a",
+ "ccd55c64fc884543a5831f23cf213642",
+ "e1ade849ec7341e9a026a7d2dcac24be",
+ "105949ce9b8844a78481da2f7b7406b2",
+ "ebc84674fde647aeb1bf579a3c890f15",
+ "f82e1c40a89c4d14b28d91ed7d8a9e28",
+ "a8fc76e0722a4dea8f00f5cd412399c6",
+ "a84df45d2ed644f1baf72bd22f924224"
+ ]
+ },
+ "id": "Si7bbZIULBZz",
+ "outputId": "ff3fb499-14e2-48e5-e4a1-ec8e18650c36"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Evaluating model...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a081e60ec1604f6abf87dc3f108d66b6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Testing: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
+ }
+ }
+ }
+ },
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Test metric ┃ DataLoader 0 ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ test_MAE │ 0.48676350712776184 │\n",
+ "│ test_SMAPE │ 1.031250238418579 │\n",
+ "│ test_loss │ 0.012420947663486004 │\n",
+ "└───────────────────────────┴───────────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.48676350712776184 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.031250238418579 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.012420947663486004 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "└───────────────────────────┴───────────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Prediction shape: torch.Size([32, 1, 1])\n",
+ "First prediction values: [[-0.11366826]]\n",
+ "First true values: [[-0.12978955]]\n",
+ "\n",
+ "TFT model test complete!\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Evaluate the model\n",
+ "print(\"\\nEvaluating model...\")\n",
+ "test_metrics = trainer.test(model, data_module)\n",
+ "\n",
+ "model.eval()\n",
+ "with torch.no_grad():\n",
+ " test_batch = next(iter(data_module.test_dataloader()))\n",
+ " x_test, y_test = test_batch\n",
+ " y_pred = model(x_test)\n",
+ "\n",
+ " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n",
+ " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n",
+ " print(\"First true values:\", y_test[0].cpu().numpy())\n",
+ "print(\"\\nTFT model test complete!\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/pytorch_forecasting/data/examples.py b/pytorch_forecasting/data/examples.py
index 1adab65d6..dc24b349f 100644
--- a/pytorch_forecasting/data/examples.py
+++ b/pytorch_forecasting/data/examples.py
@@ -109,3 +109,27 @@ def generate_ar_data(
)
return data
+
+
+def load_toydata(num_series, seq_length):
+ data_list = []
+ for i in range(num_series):
+ x = np.arange(seq_length)
+ y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)
+ category = i % 5
+ static_value = np.random.rand()
+ for t in range(seq_length - 1):
+ data_list.append(
+ {
+ "series_id": i,
+ "time_idx": t,
+ "x": y[t],
+ "y": y[t + 1],
+ "category": category,
+ "future_known_feature": np.cos(t / 10),
+ "static_feature": static_value,
+ "static_feature_cat": i % 3,
+ }
+ )
+ data_df = pd.DataFrame(data_list)
+ return data_df