Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2ee6d94d",
"metadata": {},
"source": [
"# Retrieving regressor coefficients"
]
},
{
"cell_type": "markdown",
"id": "61d31237-c428-483a-bac1-419dddad3000",
"metadata": {},
"source": [
"Understanding the coefficients of various components in a forecasting model is crucial as it provides insights into how different factors influence the predicted values. We will demonstrate how to retrieve these coefficients using specific functions provided in NeuralProphet.\n",
"\n",
"The following functions are available:\n",
"- get_future_regressor_coefficients: Retrieves the coefficients for future regressors.\n",
"- get_event_coefficients: Retrieves the coefficients for events and holidays.\n",
"- get_lagged_regressor_coefficients: Retrieves the coefficients for lagged regressors.\n",
"- get_ar_coefficients: Retrieves the coefficients for autoregressive lags.\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6575cb59",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from neuralprophet import NeuralProphet\n",
"\n",
"# Load tutorial datasets \n",
"df = pd.read_csv(\"https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial04.csv\")\n",
"\n",
"df1 = pd.read_csv(\"https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial01.csv\")\n"
]
},
{
"cell_type": "markdown",
"id": "0d2ae750",
"metadata": {},
"source": [
"## Future regressors\n",
"\n",
"Useful for understanding the impact of external variables that are known in advance, such as temperature in this example. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95511f2b",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(epochs=10)\n",
"\n",
"# Add the new future regressor\n",
"m.add_future_regressor(\"temperature\")\n",
"\n",
"\n",
"# Continue training the model and making a prediction\n",
"metrics = m.fit(df)\n",
"\n",
"print(\"Future regressor coefficients:\", m.model.get_future_regressor_coefficients())"
]
},
{
"cell_type": "markdown",
"id": "455b60e1",
"metadata": {},
"source": [
"## Events\n",
"\n",
"Helps in assessing the effect of specific events or holidays on the forecasted values."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffd52d2b",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(epochs=10)\n",
"\n",
"# Add holidays for the US as events \n",
"m.add_country_holidays(\"US\")\n",
"\n",
"metrics = m.fit(df1)\n",
"\n",
"print(\"Event coefficients:\", m.model.get_event_coefficients())"
]
},
{
"cell_type": "markdown",
"id": "757056b4",
"metadata": {},
"source": [
"## Lagged regressors\n",
"\n",
"Lagged regressor coefficients are useful for understanding the influence of past values of external variables on the forecast."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c61347cb-bea9-4732-a7f6-4c05aa496354",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(epochs=10)\n",
"\n",
"# Add temperature of last three days as lagged regressor\n",
"m.add_lagged_regressor(\"temperature\", n_lags=3)\n",
"\n",
"metrics = m.fit(df)\n",
"print(m.model.get_lagged_regressor_coefficients())"
]
},
{
"cell_type": "markdown",
"id": "a9440659",
"metadata": {},
"source": [
"## Autoregressive\n",
"\n",
"Useful for understanding how past values of the time series itself influence future predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "feff9910",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(n_lags=5, epochs=10)\n",
"\n",
"metrics = m.fit(df1)\n",
"\n",
"print(\"AR coefficients:\", m.model.get_ar_coefficients())"
]
},
{
"cell_type": "markdown",
"id": "bc77b042",
"metadata": {},
"source": [
"## Visualizing coefficients\n",
"\n",
"With the Neuralprophet plotting features it is easy to automatically create plots for model parameters that visulize the previously discussed coefficients."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f90dd1b",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(\n",
" n_lags=10, # Autogression\n",
" epochs=10\n",
")\n",
"\n",
"# Add the new future regressor\n",
"m.add_future_regressor(\"temperature\")\n",
"\n",
"# Add holidays for the US as events\n",
"m.add_country_holidays(\"US\")\n",
"\n",
"metrics = m.fit(df)\n",
"\n",
"print(m.model.get_future_regressor_coefficients())\n",
"print(m.model.get_event_coefficients())\n",
"print(m.model.get_ar_coefficients())\n",
"\n",
"m.plot_parameters()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
72 changes: 34 additions & 38 deletions neuralprophet/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,6 @@ def get_valid_configuration( # move to utils
# Identify components to be plotted
# as dict, minimum: {plot_name}
plot_components = []
if validator == "plot_parameters":
quantile_index = m.model.quantiles.index(quantile)

# Plot trend
if "trend" in components:
Expand Down Expand Up @@ -418,38 +416,32 @@ def get_valid_configuration( # move to utils
multiplicative_events = []
if "events" in components:
additive_events_flag = False
muliplicative_events_flag = False
multiplicative_events_flag = False
event_configs = {}
if m.config_events is not None:
for event, configs in m.config_events.items():
if validator == "plot_components" and configs.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and configs.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(event)
weight_list = [
(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()
]
if configs.mode == "additive":
additive_events = additive_events + weight_list
elif configs.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list

event_configs.update(m.config_events)
if m.config_country_holidays is not None:
for country_holiday in m.config_country_holidays.holiday_names:
if validator == "plot_components" and m.config_country_holidays.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and m.config_country_holidays.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(country_holiday)
weight_list = [
(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()
]
if m.config_country_holidays.mode == "additive":
additive_events = additive_events + weight_list
elif m.config_country_holidays.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list
event_configs.update(
{holiday: m.config_country_holidays for holiday in m.config_country_holidays.holiday_names}
)

if event_configs:
if validator == "plot_components":
additive_events_flag = any(config.mode == "additive" for config in event_configs.values())
multiplicative_events_flag = any(config.mode == "multiplicative" for config in event_configs.values())

elif validator == "plot_parameters":
event_coefficients = m.model.get_event_coefficients()
for _, row in event_coefficients.iterrows():
event = row["regressor"]
mode = row["regressor_mode"]
coef = row["coef"]
weight_tuple = (event, coef)

if mode == "additive":
additive_events.append(weight_tuple)
elif mode == "multiplicative":
multiplicative_events.append(weight_tuple)

if additive_events_flag:
plot_components.append(
Expand All @@ -458,7 +450,7 @@ def get_valid_configuration( # move to utils
"comp_name": "events_additive",
}
)
if muliplicative_events_flag:
if multiplicative_events_flag:
plot_components.append(
{
"plot_name": "Multiplicative Events",
Expand Down Expand Up @@ -488,11 +480,15 @@ def get_valid_configuration( # move to utils
}
)
elif validator == "plot_parameters":
regressor_param = m.model.future_regressors.get_reg_weights(regressor)[quantile_index, :]
if configs.mode == "additive":
additive_future_regressors.append((regressor, regressor_param.detach().numpy()))
elif configs.mode == "multiplicative":
multiplicative_future_regressors.append((regressor, regressor_param.detach().numpy()))
future_regressor_coefficients = m.model.get_future_regressor_coefficients()
for _, row in future_regressor_coefficients.iterrows():
regressor = row["regressor"]
mode = row["regressor_mode"]
coef = row["coef"]
if mode == "additive":
additive_future_regressors.append((regressor, coef))
elif mode == "multiplicative":
multiplicative_future_regressors.append((regressor, coef))

# Plot quantiles as a separate component, if present
# If multiple steps in the future are predicted, only plot quantiles if highlight_forecast_step_n is set
Expand Down
Loading