diff --git a/ml_baselines/features.py b/ml_baselines/features.py index c7a46b6..4e6108e 100644 --- a/ml_baselines/features.py +++ b/ml_baselines/features.py @@ -1,9 +1,12 @@ -from pathlib import Path +import getpass +import gzip +import matplotlib.pyplot as plt import numpy as np import xarray as xr import pandas as pd -import getpass -import gzip +from pathlib import Path + +from sklearn.inspection import permutation_importance from ml_baselines.config import Config from ml_baselines.utils import longitude_to_360 @@ -333,7 +336,7 @@ def preprocess_all_features_arco_era5(force=False): def open_features(site, start_year=1978, end_year=2024, - time_shift_hours=[6], + time_shift_hours=[6, 12, 18, 24], features_dir=""): """Opens the preprocessed features for a given site. @@ -413,6 +416,99 @@ def open_features(site, return df +def feature_importance(model, X_train, y_train, + use_permutation=False, n_repeats=100): + """Calculates feature importances for a trained model. + + Args: + model: The trained model to analyse. + X_train (pd.DataFrame): The feature matrix used for training. + y_train (pd.DataFrame): The target labels used for training. + use_permutation (bool): Whether to use permutation importance. If False and the model has a feature_importances_ attribute, this will be used instead. + n_repeats (int): Number of times to repeat the permutation. + + Returns: + pd.DataFrame: DataFrame with columns ['variable', 'importance_mean', 'importance_sum'], + sorted by importance_mean descending. Variables are grouped by type. + """ + + if hasattr(model, "feature_importances_") and not use_permutation: + importances = model.feature_importances_ + df_importance = pd.DataFrame({ + "feature": X_train.columns, + "importance": importances + }).sort_values('importance', ascending=False) + + # group features by variable type + df_grouping = df_importance.copy() + df_grouping["variable"] = df_grouping["feature"].apply( + lambda col: col if col in ["hour_of_day", "day_of_year"] else col.split("_")[0] + ) + df_importance = ( + df_grouping.groupby("variable")["importance"] + .agg(importance_mean="mean", importance_sum="sum") + .sort_values("importance_mean", ascending=False) + .reset_index() + ) + + else: + # extract variable groups + variables = np.unique([ + col if col in ["hour_of_day", "day_of_year"] else col.split("_")[0] + for col in X_train.columns + ]) + + rng = np.random.default_rng(42) + orig_score = model.score(X_train, y_train) + results = [] + for var in variables: + cols = [var] if var in ["hour_of_day", "day_of_year"] else [c for c in X_train.columns if c.startswith(var + "_")] + + scores_diff = [] + X_permuted = X_train.copy() + for _ in range(n_repeats): + shuffled_indices = rng.permutation(X_train.index) + for col in cols: + X_permuted[col] = X_train.loc[shuffled_indices, col].values + scores_diff.append(orig_score - model.score(X_permuted, y_train)) + + results.append({ + "variable": var, + "importance_mean": np.mean(scores_diff), + "importance_sum": np.sum(scores_diff) + }) + + df_importance = pd.DataFrame(results).sort_values("importance_mean", ascending=False).reset_index(drop=True) + + return df_importance + + +def plot_importance(df_importance, figsize=None): + """Plot feature importance. + + Args: + df_importance (pd.DataFrame): DataFrame containing feature importance. + figsize: Optional (width, height) tuple. Auto-sized if None. + + Returns: + (fig, ax) — call plt.show() or fig.savefig() in the caller. + + """ + + fig, ax = plt.subplots(figsize=figsize) + + ax.barh(df_importance["variable"], df_importance["importance_mean"], color='#4C72B0') + + ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x*100:.2f}%')) + ax.set_xlabel("Mean Accuracy Drop (%)") + ax.invert_yaxis() + ax.xaxis.grid(True, linestyle='--', alpha=0.7) + ax.set_axisbelow(True) + ax.set_title("Feature Importance by Variable") + + return fig, ax + + if __name__ == "__main__": if cfg.met_type == "arco-era5": diff --git a/ml_baselines/modelling/plot.py b/ml_baselines/modelling/plot.py index 7793816..db793ac 100644 --- a/ml_baselines/modelling/plot.py +++ b/ml_baselines/modelling/plot.py @@ -189,7 +189,7 @@ def plot_model_confidence(labelled_df, title="", cmap=None, shade_train_and_val_ ax.set_title(title if title else "Model confidence over observations") ax.set_ylabel("mole fraction in air / ppt") ax.set_xlabel("Time") - ax.legend(loc="upper left") + ax.legend(loc="best") plt.show() @@ -271,7 +271,7 @@ def plot_monthly_means(monthly_means, shade_train_and_val_periods=True, site=Non ax.scatter(anomaly_months.index, anomaly_months["true_monthly_mf"] - deviation , label=f"Anomalies > {threshold} std", color='red', marker='^', s=10*threshold, zorder=5) - ax.legend(loc="upper left") + ax.legend(loc="best" if not plot_count_hist else "upper left") plt.show() diff --git a/ml_baselines/modelling/predict.py b/ml_baselines/modelling/predict.py index ab96d3b..af42544 100644 --- a/ml_baselines/modelling/predict.py +++ b/ml_baselines/modelling/predict.py @@ -41,7 +41,7 @@ def load_baseline_model(site, model_type="mlp", models_folder=cfg.models_path, t print(f"Loaded model from {model_file[-1]}") return model_dict['model'], model_dict['info'] -def predict_baselines(site, model, time_shift_hours=[6], prediction_threshold=0.5, prediction_mode="validation", scaler=None, verbose=True, save_preds = False, return_proba=False): +def predict_baselines(site, model, time_shift_hours=[6, 12, 18, 24], prediction_threshold=0.5, prediction_mode="validation", scaler=None, verbose=True, save_preds = False, return_proba=False): """ Predict baseline events for a given site using a trained model. @@ -114,6 +114,10 @@ def align_predictions_and_obs(y, y_pred, df_obs, y_proba=None): baseline labels, and predicted baseline labels, all aligned by their datetime index. """ + y = y.copy() + df_obs = df_obs.copy() + y.index = y.index.astype("datetime64[ns]") + df_obs.index = df_obs.index.astype("datetime64[ns]") y = y[(y.index.year >= min(df_obs.index.year)) & (y.index.year <= max(df_obs.index.year))] labelled_df = pd.merge_asof(pd.DataFrame({"baseline": y}), df_obs["mf"], left_index=True, right_index=True, direction='nearest') diff --git a/ml_baselines/modelling/train.py b/ml_baselines/modelling/train.py index 394e298..d79b248 100644 --- a/ml_baselines/modelling/train.py +++ b/ml_baselines/modelling/train.py @@ -26,7 +26,7 @@ def get_train_test_data(site, test_train, balance=-1, undersample=0, - time_shift_hours=[6], + time_shift_hours=[6, 12, 18, 24], return_dataframe=False, balance_method="random", verbose=True): """ Get the training, testing or validation data for a given site. @@ -343,7 +343,7 @@ def train_baseline_model(site, model_type="mlp", undersample=0, sample_weights=None, normalise_inputs=False, - time_shift_hours=[6], prediction_threshold=0.5, + time_shift_hours=[6, 12, 18, 24], prediction_threshold=0.5, model_params=None, return_scores=False, return_scaler=False, verbose=True, save_model=False, save_folder=cfg.models_path, save_suffix=None, evaluate_on_test=True, random_seed=42): """ Train a model to classify baseline events for a given site. @@ -435,7 +435,7 @@ def train_baseline_model(site, model_type="mlp", else: model.fit(X_train, y_train) fit_time = time.time() - start_time - print(f"Fit time: {fit_time:.1f}s") + if verbose: print(f"Fit time: {fit_time:.1f}s") # Make predictions if verbose: print("... predicting") @@ -641,14 +641,14 @@ def train_baseline_model_grid_search(site, { "balance": [-1, 0.5], - "time_shift_hours": [[6], [6, 24]], + "time_shift_hours": [[6, 12, 18, 24], [6, 24]], "sample_weights": [None, "auto", 2.0], "normalise_inputs": [True, False], } Valid keys are ``balance``, ``undersample``, ``time_shift_hours``, ``balance_method``, ``sample_weights`` and ``normalise_inputs``. If None, defaults to - ``{"balance": [-1], "time_shift_hours": [[6]], "normalise_inputs": [False]}``. + ``{"balance": [-1], "time_shift_hours": [[6, 12, 18, 24]], "normalise_inputs": [False]}``. prediction_thresholds (list of float, optional): Prediction thresholds to evaluate after the grid search. For each data-kwarg combo, the grid search is run once (at the default 0.5 threshold); the best model @@ -697,7 +697,7 @@ def train_baseline_model_grid_search(site, if data_kwargs is None: data_kwargs = {"balance": [-1], - "time_shift_hours": [[6]], + "time_shift_hours": [[6, 12, 18, 24]], "normalise_inputs": [False]} if validation_keys is None: @@ -789,8 +789,15 @@ def train_baseline_model_grid_search(site, print("Best score (default threshold 0.5): ", grid_search.best_score_) # Evaluate the best model at each prediction threshold (grid search used 0.5 implicitly) + train_proba = grid_search.best_estimator_.predict_proba(X_train)[:, 1] val_proba = grid_search.best_estimator_.predict_proba(X_val)[:, 1] + for threshold in prediction_thresholds: + y_pred_train = (train_proba >= threshold).astype(int) + train_score = primary_scorer._score_func(y_train, y_pred_train) + if train_score >= 0.99: + print(f" Skipping threshold {threshold:.2f}. Training score indicates likely overfitting.") + continue y_pred = (val_proba >= threshold).astype(int) threshold_score = primary_scorer._score_func(y_val, y_pred) print(f" Threshold {threshold:.2f}: {primary_metric} = {threshold_score:.3f}") @@ -822,6 +829,8 @@ def train_baseline_model_grid_search(site, cv_results[str(combo_kw)] = cv_df # Find the best combination across all data-kwarg combos and thresholds + if not best_scores_list: + raise ValueError("All combinations skipped due to overfitting. Consider expanding param_grid or data_kwargs.") best_index = best_scores_list.index(max(best_scores_list)) best_best_params = best_params_list[best_index] best_combo_kw = best_data_kwargs_list[best_index] diff --git a/models/features/.gitignore b/models/features/.gitignore deleted file mode 100644 index f59ec20..0000000 --- a/models/features/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* \ No newline at end of file