Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 100 additions & 4 deletions ml_baselines/features.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from pathlib import Path
import getpass
import gzip
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import pandas as pd
import getpass
import gzip
from pathlib import Path

from sklearn.inspection import permutation_importance

from ml_baselines.config import Config
from ml_baselines.utils import longitude_to_360
Expand Down Expand Up @@ -333,7 +336,7 @@ def preprocess_all_features_arco_era5(force=False):
def open_features(site,
start_year=1978,
end_year=2024,
time_shift_hours=[6],
time_shift_hours=[6, 12, 18, 24],
features_dir=""):
"""Opens the preprocessed features for a given site.

Expand Down Expand Up @@ -413,6 +416,99 @@ def open_features(site,
return df


def feature_importance(model, X_train, y_train,
use_permutation=False, n_repeats=100):
"""Calculates feature importances for a trained model.

Args:
model: The trained model to analyse.
X_train (pd.DataFrame): The feature matrix used for training.
y_train (pd.DataFrame): The target labels used for training.
use_permutation (bool): Whether to use permutation importance. If False and the model has a feature_importances_ attribute, this will be used instead.
n_repeats (int): Number of times to repeat the permutation.

Returns:
pd.DataFrame: DataFrame with columns ['variable', 'importance_mean', 'importance_sum'],
sorted by importance_mean descending. Variables are grouped by type.
"""

if hasattr(model, "feature_importances_") and not use_permutation:
importances = model.feature_importances_
df_importance = pd.DataFrame({
"feature": X_train.columns,
"importance": importances
}).sort_values('importance', ascending=False)

# group features by variable type
df_grouping = df_importance.copy()
df_grouping["variable"] = df_grouping["feature"].apply(
lambda col: col if col in ["hour_of_day", "day_of_year"] else col.split("_")[0]
)
df_importance = (
df_grouping.groupby("variable")["importance"]
.agg(importance_mean="mean", importance_sum="sum")
.sort_values("importance_mean", ascending=False)
.reset_index()
)

else:
# extract variable groups
variables = np.unique([
col if col in ["hour_of_day", "day_of_year"] else col.split("_")[0]
for col in X_train.columns
])

rng = np.random.default_rng(42)
orig_score = model.score(X_train, y_train)
results = []
for var in variables:
cols = [var] if var in ["hour_of_day", "day_of_year"] else [c for c in X_train.columns if c.startswith(var + "_")]

scores_diff = []
X_permuted = X_train.copy()
for _ in range(n_repeats):
shuffled_indices = rng.permutation(X_train.index)
for col in cols:
X_permuted[col] = X_train.loc[shuffled_indices, col].values
scores_diff.append(orig_score - model.score(X_permuted, y_train))

results.append({
"variable": var,
"importance_mean": np.mean(scores_diff),
"importance_sum": np.sum(scores_diff)
})

df_importance = pd.DataFrame(results).sort_values("importance_mean", ascending=False).reset_index(drop=True)

return df_importance


def plot_importance(df_importance, figsize=None):
"""Plot feature importance.

Args:
df_importance (pd.DataFrame): DataFrame containing feature importance.
figsize: Optional (width, height) tuple. Auto-sized if None.

Returns:
(fig, ax) — call plt.show() or fig.savefig() in the caller.

"""

fig, ax = plt.subplots(figsize=figsize)

ax.barh(df_importance["variable"], df_importance["importance_mean"], color='#4C72B0')

ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x*100:.2f}%'))
ax.set_xlabel("Mean Accuracy Drop (%)")
ax.invert_yaxis()
ax.xaxis.grid(True, linestyle='--', alpha=0.7)
ax.set_axisbelow(True)
ax.set_title("Feature Importance by Variable")

return fig, ax


if __name__ == "__main__":

if cfg.met_type == "arco-era5":
Expand Down
4 changes: 2 additions & 2 deletions ml_baselines/modelling/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def plot_model_confidence(labelled_df, title="", cmap=None, shade_train_and_val_
ax.set_title(title if title else "Model confidence over observations")
ax.set_ylabel("mole fraction in air / ppt")
ax.set_xlabel("Time")
ax.legend(loc="upper left")
ax.legend(loc="best")

plt.show()

Expand Down Expand Up @@ -271,7 +271,7 @@ def plot_monthly_means(monthly_means, shade_train_and_val_periods=True, site=Non
ax.scatter(anomaly_months.index, anomaly_months["true_monthly_mf"] - deviation , label=f"Anomalies > {threshold} std", color='red', marker='^', s=10*threshold, zorder=5)


ax.legend(loc="upper left")
ax.legend(loc="best" if not plot_count_hist else "upper left")

plt.show()

Expand Down
6 changes: 5 additions & 1 deletion ml_baselines/modelling/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def load_baseline_model(site, model_type="mlp", models_folder=cfg.models_path, t
print(f"Loaded model from {model_file[-1]}")
return model_dict['model'], model_dict['info']

def predict_baselines(site, model, time_shift_hours=[6], prediction_threshold=0.5, prediction_mode="validation", scaler=None, verbose=True, save_preds = False, return_proba=False):
def predict_baselines(site, model, time_shift_hours=[6, 12, 18, 24], prediction_threshold=0.5, prediction_mode="validation", scaler=None, verbose=True, save_preds = False, return_proba=False):
"""
Predict baseline events for a given site using a trained model.

Expand Down Expand Up @@ -114,6 +114,10 @@ def align_predictions_and_obs(y, y_pred, df_obs, y_proba=None):
baseline labels, and predicted baseline labels, all aligned by
their datetime index.
"""
y = y.copy()
df_obs = df_obs.copy()
y.index = y.index.astype("datetime64[ns]")
df_obs.index = df_obs.index.astype("datetime64[ns]")

y = y[(y.index.year >= min(df_obs.index.year)) & (y.index.year <= max(df_obs.index.year))]
labelled_df = pd.merge_asof(pd.DataFrame({"baseline": y}), df_obs["mf"], left_index=True, right_index=True, direction='nearest')
Expand Down
21 changes: 15 additions & 6 deletions ml_baselines/modelling/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
def get_train_test_data(site, test_train,
balance=-1,
undersample=0,
time_shift_hours=[6],
time_shift_hours=[6, 12, 18, 24],
return_dataframe=False,
balance_method="random", verbose=True):
""" Get the training, testing or validation data for a given site.
Expand Down Expand Up @@ -343,7 +343,7 @@ def train_baseline_model(site, model_type="mlp",
undersample=0,
sample_weights=None,
normalise_inputs=False,
time_shift_hours=[6], prediction_threshold=0.5,
time_shift_hours=[6, 12, 18, 24], prediction_threshold=0.5,
model_params=None, return_scores=False, return_scaler=False, verbose=True, save_model=False, save_folder=cfg.models_path, save_suffix=None, evaluate_on_test=True, random_seed=42):
""" Train a model to classify baseline events for a given site.

Expand Down Expand Up @@ -435,7 +435,7 @@ def train_baseline_model(site, model_type="mlp",
else:
model.fit(X_train, y_train)
fit_time = time.time() - start_time
print(f"Fit time: {fit_time:.1f}s")
if verbose: print(f"Fit time: {fit_time:.1f}s")

# Make predictions
if verbose: print("... predicting")
Expand Down Expand Up @@ -641,14 +641,14 @@ def train_baseline_model_grid_search(site,

{
"balance": [-1, 0.5],
"time_shift_hours": [[6], [6, 24]],
"time_shift_hours": [[6, 12, 18, 24], [6, 24]],
"sample_weights": [None, "auto", 2.0],
"normalise_inputs": [True, False],
}

Valid keys are ``balance``, ``undersample``, ``time_shift_hours``,
``balance_method``, ``sample_weights`` and ``normalise_inputs``. If None, defaults to
``{"balance": [-1], "time_shift_hours": [[6]], "normalise_inputs": [False]}``.
``{"balance": [-1], "time_shift_hours": [[6, 12, 18, 24]], "normalise_inputs": [False]}``.
prediction_thresholds (list of float, optional): Prediction thresholds to
evaluate after the grid search. For each data-kwarg combo, the grid
search is run once (at the default 0.5 threshold); the best model
Expand Down Expand Up @@ -697,7 +697,7 @@ def train_baseline_model_grid_search(site,

if data_kwargs is None:
data_kwargs = {"balance": [-1],
"time_shift_hours": [[6]],
"time_shift_hours": [[6, 12, 18, 24]],
"normalise_inputs": [False]}

if validation_keys is None:
Expand Down Expand Up @@ -789,8 +789,15 @@ def train_baseline_model_grid_search(site,
print("Best score (default threshold 0.5): ", grid_search.best_score_)

# Evaluate the best model at each prediction threshold (grid search used 0.5 implicitly)
train_proba = grid_search.best_estimator_.predict_proba(X_train)[:, 1]
val_proba = grid_search.best_estimator_.predict_proba(X_val)[:, 1]

for threshold in prediction_thresholds:
y_pred_train = (train_proba >= threshold).astype(int)
train_score = primary_scorer._score_func(y_train, y_pred_train)
if train_score >= 0.99:
print(f" Skipping threshold {threshold:.2f}. Training score indicates likely overfitting.")
continue
y_pred = (val_proba >= threshold).astype(int)
threshold_score = primary_scorer._score_func(y_val, y_pred)
print(f" Threshold {threshold:.2f}: {primary_metric} = {threshold_score:.3f}")
Expand Down Expand Up @@ -822,6 +829,8 @@ def train_baseline_model_grid_search(site,
cv_results[str(combo_kw)] = cv_df

# Find the best combination across all data-kwarg combos and thresholds
if not best_scores_list:
raise ValueError("All combinations skipped due to overfitting. Consider expanding param_grid or data_kwargs.")
best_index = best_scores_list.index(max(best_scores_list))
best_best_params = best_params_list[best_index]
best_combo_kw = best_data_kwargs_list[best_index]
Expand Down
1 change: 0 additions & 1 deletion models/features/.gitignore

This file was deleted.