diff --git a/docs/source/tutorials/ptf_V2_example.ipynb b/docs/source/tutorials/ptf_V2_example.ipynb new file mode 100644 index 000000000..2036a4784 --- /dev/null +++ b/docs/source/tutorials/ptf_V2_example.ipynb @@ -0,0 +1,1159 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "rzVbXsEBxnF-" + }, + "source": [ + "# `pytorch-forecasting v2` Model Training and Inference - Beta API" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yt0uZV7Px-40" + }, + "source": [ + "
\n", + ":warning: The vignette showcased here is part of an experimental rework of the `pytorch-forecasting` data layer, planned for release in v2.0.0. The API is currently unstable and subject to change without prior notice.\n", + "\n", + "Feedback and suggestions are highly encouraged — please share them in issue 1736.\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6D9ARyp05R0t" + }, + "source": [ + "In this vignette, we demonstrate how to train and evaluate the **Temporal Fusion Transformer (TFT)** using the new `TimeSeries` and `DataModule` API from the v2 pipeline.\n", + "\n", + "\n", + "## Steps\n", + "\n", + "1. **Load Data** \n", + "2. **Create Dataset & DataModule** \n", + "3. **Initialize, Train & Run Inference with the Model**\n", + "\n", + "\n", + "\n", + "### Load Data\n", + "\n", + "We generate a synthetic dataset using `load_toydata` which returns a `pandas` DataFrame with purely numerical values. \n", + "*(Note: The current pipeline assumes all inputs are numerical only.)*\n", + "\n", + "\n", + "\n", + "\n", + "### Create Dataset & DataModule\n", + "\n", + "- `TimeSeries` returns the raw data in terms of tensors .\n", + "- `DataModule` wraps the dataset, handles splits, preprocessing, batching, and exposes `metadata` for the model initialisation.\n", + "\n", + "\n", + "\n", + "### Initialize the Model\n", + "\n", + "We initialize the TFT model using the `metadata` provided by the `DataModule`. This metadata includes all required dimensional info for the encoder, decoder, and static inputs.\n", + "\n", + "\n", + "\n", + "### Train the Model\n", + "\n", + "We use a `Trainer` from PyTorch Lightning to train the model\n", + "\n", + "### Run Inference\n", + "\n", + "After training, we can make predictions using the trained model\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QyMFNk4MyY_b" + }, + "source": [ + "# 1. Load Data\n", + "We generate a synthetic dataset using `load_toydata` that creates a `pandas` DataFrame with just numerical values as for now **the pipeline assumes the data to be numerical only**." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "RkgOT4kiy_RU" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.examples import load_toydata" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "e481484c-b0c3-4026-c933-a9dc047617c5" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6718152112102416,\n \"min\": -1.26886060400029,\n \"max\": 1.3107688985416461,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.1974946533075752,\n 0.8954011563960967,\n -0.802070871866599\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6763475728129391,\n \"min\": -1.26886060400029,\n \"max\": 1.3107688985416461,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.4987517118113735,\n 0.6086435548017073,\n -1.0256970040347706\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2888495447542638,\n \"min\": 0.0028672051720661784,\n \"max\": 0.990604862265208,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.5971646697158197,\n 0.12749395651151985,\n 0.32838971618312873\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" + }, + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
000-0.1379880.01682401.0000000.9275570
1010.0168240.28729100.9950040.9275570
2020.2872910.59958700.9800670.9275570
3030.5995870.77935200.9553360.9275570
4040.7793520.87614800.9210610.9275570
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 -0.137988 0.016824 0 1.000000 \n", + "1 0 1 0.016824 0.287291 0 0.995004 \n", + "2 0 2 0.287291 0.599587 0 0.980067 \n", + "3 0 3 0.599587 0.779352 0 0.955336 \n", + "4 0 4 0.779352 0.876148 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.927557 0 \n", + "1 0.927557 0 \n", + "2 0.927557 0 \n", + "3 0.927557 0 \n", + "4 0.927557 0 " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "num_series = 100 # Number of individual time series to generate\n", + "seq_length = 50 # Length of each time series\n", + "data_df = load_toydata(num_series, seq_length)\n", + "data_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RYQ5CdNUyc2q" + }, + "source": [ + "# 2. Create the dataset and datamodule\n", + "We create a `TimeSeries` dataset instance that returns the raw data in terms of tensors, then this \"raw data\" is sent to the `data_module`that will internally handle the dataloaders and preprocessing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ONe8Eo1zzvCH" + }, + "source": [ + "`TimeSeries` dataset's Key arguments:\n", + "- `data`: DataFrame with sequence data.\n", + "- `time`: integer typed column denoting the time index within `data`.\n", + "- `target`: Column(s) in `data` denoting the forecasting target.\n", + "- `group`: List of column names identifying a time series instance within `data`.\n", + "- `num`: List of numerical features.\n", + "- `cat`: List of categorical features.\n", + "- `known`: Features known in future\n", + "- `unknown`: Features not known in the future\n", + "- `static`: List of variables that do not change over time," + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "JPD3y3qny5Dx" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.timeseries import TimeSeries" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "AxxPHK6AKSD2" + }, + "outputs": [], + "source": [ + "# create `TimeSeries` dataset that returns the raw data in terms of tensors\n", + "dataset = TimeSeries(\n", + " data=data_df,\n", + " time=\"time_idx\",\n", + " target=\"y\",\n", + " group=[\"series_id\"],\n", + " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", + " cat=[\"category\", \"static_feature_cat\"],\n", + " known=[\"future_known_feature\"],\n", + " unknown=[\"x\", \"category\"],\n", + " static=[\"static_feature\", \"static_feature_cat\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S-yHU46v1MhN" + }, + "source": [ + "`EncoderDecoderTimeSeriesDataModule` key arguments:\n", + "- `time_series_dataset`: `TimeSeries` dataset instance\n", + "- `max_encoder_length` : Maximum length of the encoder input sequence.\n", + "- `max_prediction_length` : Maximum length of the decoder output sequence.\n", + "- `batch_size` : Batch size for DataLoader.\n", + "- `categorical_encoders` : Dictionary of categorical encoders.\n", + "- `scalers` : Dictionary of feature scalers.\n", + "- `target_normalizer`: Normalizer for the target variable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DUWB4LrGyxrL" + }, + "outputs": [], + "source": [ + "from sklearn.preprocessing import StandardScaler\n", + "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n", + "from pytorch_forecasting.data.encoders import (\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "5U5Lr_ZFKX0s" + }, + "outputs": [], + "source": [ + "# create the `data_module` that handles the dataloaders and preprocessing\n", + "data_module = EncoderDecoderTimeSeriesDataModule(\n", + " time_series_dataset=dataset,\n", + " max_encoder_length=30,\n", + " max_prediction_length=1,\n", + " batch_size=32,\n", + " categorical_encoders={\n", + " \"category\": NaNLabelEncoder(add_nan=True),\n", + " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", + " },\n", + " scalers={\n", + " \"x\": StandardScaler(),\n", + " \"future_known_feature\": StandardScaler(),\n", + " \"static_feature\": StandardScaler(),\n", + " },\n", + " target_normalizer=TorchNormalizer(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qykX7vQ7zWnC" + }, + "source": [ + "# 3. Initialise and train the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_kz3MO362Tlo" + }, + "source": [ + "To initialise the model you now don't have to pass arguments like `encoder_cont`, `decoder_cont` etc as they are calculated internally using the `metadata` property [[source]](https://github.com/sktime/pytorch-forecasting/blob/4a34931e499c2b59de3939fcffcaabd75204b045/pytorch_forecasting/data/data_module.py#L264-L292) of `EncoderDecoderTimeSeriesDataModule`. But you still have to pass other params like `loss`, `optimizer` etc" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XvwIuzD34Ytk" + }, + "source": [ + "\n", + "```python\n", + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " metadata=data_module.metadata, # <-- crucial for model setup\n", + " ...\n", + ")\n", + "```\n", + "\n", + "The `metadata` includes:\n", + "- `max_encoder_length`, `max_prediction_length`\n", + "- number of continuous/categorical variables in encoder/decoder\n", + "- number of static features\n", + "\n", + "These are used to configure internal layers like `encoder_cont`, `decoder_cat`, etc.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "xOsEucZnzCkN" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9qbjnTxnyh4H", + "outputId": "f59bf985-ffaa-4980-c890-39a80dfcc598" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/content/pytorch-forecasting/pytorch_forecasting/models/base/_base_model_v2.py:58: UserWarning: The Model 'TFT' is part of an experimental reworkof the pytorch-forecasting model layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. This class is intended for beta testing and as a basic skeleton, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warn(\n" + ] + } + ], + "source": [ + "# Initialise the Model\n", + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata, # pass the metadata from the datamodule to the model\n", + " # to initialise important params like `encoder_cont` etc\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "svdoye-d8F-z" + }, + "source": [ + "We use a `Trainer` from PyTorch Lightning to train the model:\n", + "\n", + "```python\n", + "trainer = Trainer(max_epochs=5, ...)\n", + "trainer.fit(model, data_module)\n", + "```\n", + "\n", + "The `Trainer`:\n", + "- Pulls data from `data_module`\n", + "- Handles device placement\n", + "- Logs training progress and metrics\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "RTSmUu9RytS8" + }, + "outputs": [], + "source": [ + "from lightning.pytorch import Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 930, + "referenced_widgets": [ + "edc196b34a0b49fa992fd89909a8414f", + "76e801c32da348788b0b92fd2c2849c1", + "cd6d925c2dcf4ae995f62d8f20219fde", + "283f67f9ef6a47f8b3d4a389a18b8ccc", + "8e19d683c674403b87cbb7c162a0ab33", + "525fe2cef0444d558d870c394cf67e81", + "640318d6256a4451957fce6481625f6b", + "0820bd0fd60f4e0180899f18615d9ee3", + "90cb2b96e18948679eda809a9123565c", + "5ffb5772ae4f4764b23643cfa547a615", + "04f140582780402a94e3391bfcbffa91", + "4abfcb0ade1b47e1b015e7efd75cd563", + "483c0919c7c740af89e6c7a7470591f7", + "18849509bcfc41dd89011efb5690e49b", + "098eb27809314e7ca15c5b9bce21b46d", + "53f7fc6124fb4dda8bfb147078212ada", + "8c2821dc533c4a1ba81d68836047dc54", + "cb303a9d13084d47b0714d717f482a80", + "491aa926179b41dba21a008dc6e9a4fc", + "363967a782db4627a3e16a726b4993c6", + "0d28163142aa465784871ec647efa504", + "a6d821b9d9c9453b87740c3e4d86ca77", + "2301684175b0454897af0fc5e534f52f", + "58fe6653d5e944be8fdd072cd3541b8c", + "00fb395e4cb5472faf221affc0244780", + "1cd0c86474a443b8aae6466405dd3807", + "cfcb344ba8cb49a78a8f934657252855", + "4f455332437a465bb300ccf50405358d", + "fde9c4616cd84e629998819c18adfd1f", + "405e6f4ab49b445ba7e942e0d35419eb", + "2ee2fe3ea3c04592851fef7e6dda5115", + "713bcea77e1744179502f6a187239ce1", + "531c73851dcd469791d6ebdc9aa9da52", + "b0ed53cd20bd47b5aac5fecaacb9bfed", + "da87d95793f84907819ab4c506f6936b", + "c8a64619bf574d43a01595e3c8690759", + "8d0c66b5cc984b8db7d151a970897266", + "a723c1055f4f4c0d9c7b2c8c3469fa9d", + "99dcc8dd7e88463a95ddcf86c67ec153", + "94c1a818e78a47ee9bb46f1c8b1dad34", + "903b0891a3014701b117a391d777e012", + "f6053af158304fc1b307a296cdbe7290", + "f6c93805c03545458f1af0fffaf77fb0", + "52bbe8a40baf4a8f989281bcc225a067", + "c9eb21c70ac243d1a21f50f287b4731f", + "100a61bf42b0498ba241ddd9ef7a96f3", + "1cce85eee5bc4700aa966c05e35ea042", + "2dcaeafc8bc84db786f9fa99aeda39db", + "7a9cf1446cc14d51a0a729822eb1cabc", + "1c6f0e1be6f8475485ba6354ccc68597", + "7ef67c160b8740e9932486f857470542", + "282e2d3a169a4e18b46ca85dbfe82371", + "84a96c68feba48b6880de08722de8743", + "c111af2e297641908cb1835d9a4d06ff", + "d94d97f0852b450c8ea56e6ef4e44ac5", + "1dd910df34e44cdc98586e3094dfddd4", + "86c6678f86d944eeba1fe982197cb7d3", + "b609e7acf93944a98f11b450e01222f8", + "0baf540b09454f61873cd6a3088a3c9d", + "7bc8c547554c481481d0653193dcd917", + "2d92b66364844b5491dabee7a0b90686", + "99a8fec0a22f4949a3d16fc7f48aeea6", + "0fb99264aa0644cda962905dfe1d6997", + "f3d688e29f4b4432ac370d6a1e76ec5c", + "d4cd0f3144784991ba9ec0b178a6c2e7", + "83a1e407cf60401b932c6ee6e72d90db", + "99443f56f20a4c6da81ddb107ab84490", + "2da428c9a3b541bc8f667d5715fdcb0d", + "f483a26167424f66b8b3c537537afd77", + "bc01c61ee2a14f3bbea23d04bb7e9dc9", + "b59e64ed8d9d46dd8803bfe25602c391", + "cc1d7b4924784a08999f41e1e3c989ac", + "8c6ad939d53942e48d8d6fb4c30b666d", + "49d44a5db92b4b3e979e88d14072a16c", + "36d0bfd0973c43e4958af452d2c7025c", + "8a90ddfc1fd446a6894c2771195a4419", + "aa94917a2f364a4ea9033000730d48b7" + ] + }, + "id": "aB_ayE_eykXp", + "outputId": "02c49d3e-2124-4b0b-8ca4-b2d886662613" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: \n", + " | Name | Type | Params | Mode\n", + "--------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | eval\n", + "1 | encoder_var_selection | Sequential | 709 | eval\n", + "2 | decoder_var_selection | Sequential | 193 | eval\n", + "3 | static_context_linear | Linear | 192 | eval\n", + "4 | lstm_encoder | LSTM | 51.5 K | eval\n", + "5 | lstm_decoder | LSTM | 50.4 K | eval\n", + "6 | self_attention | MultiheadAttention | 16.6 K | eval\n", + "7 | pre_output | Linear | 4.2 K | eval\n", + "8 | output_layer | Linear | 65 | eval\n", + "--------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "0 Modules in train mode\n", + "18 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode\n", + "--------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | eval\n", + "1 | encoder_var_selection | Sequential | 709 | eval\n", + "2 | decoder_var_selection | Sequential | 193 | eval\n", + "3 | static_context_linear | Linear | 192 | eval\n", + "4 | lstm_encoder | LSTM | 51.5 K | eval\n", + "5 | lstm_decoder | LSTM | 50.4 K | eval\n", + "6 | self_attention | MultiheadAttention | 16.6 K | eval\n", + "7 | pre_output | Linear | 4.2 K | eval\n", + "8 | output_layer | Linear | 65 | eval\n", + "--------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "0 Modules in train mode\n", + "18 Modules in eval mode\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "edc196b34a0b49fa992fd89909a8414f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_MAE 0.48676350712776184 │\n", + "│ test_SMAPE 1.031250238418579 │\n", + "│ test_loss 0.012420947663486004 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.48676350712776184 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.031250238418579 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.012420947663486004 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[-0.11366826]]\n", + "First true values: [[-0.12978955]]\n", + "\n", + "TFT model test complete!\n" + ] + } + ], + "source": [ + "# Evaluate the model\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/pytorch_forecasting/data/examples.py b/pytorch_forecasting/data/examples.py index 1adab65d6..dc24b349f 100644 --- a/pytorch_forecasting/data/examples.py +++ b/pytorch_forecasting/data/examples.py @@ -109,3 +109,27 @@ def generate_ar_data( ) return data + + +def load_toydata(num_series, seq_length): + data_list = [] + for i in range(num_series): + x = np.arange(seq_length) + y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length) + category = i % 5 + static_value = np.random.rand() + for t in range(seq_length - 1): + data_list.append( + { + "series_id": i, + "time_idx": t, + "x": y[t], + "y": y[t + 1], + "category": category, + "future_known_feature": np.cos(t / 10), + "static_feature": static_value, + "static_feature_cat": i % 3, + } + ) + data_df = pd.DataFrame(data_list) + return data_df