Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
620f2ce
updating evaluation to use new splitters
bruAristimunha May 2, 2025
7d6b83f
cross-subject
bruAristimunha May 2, 2025
e24d8e7
including the whats new
bruAristimunha May 2, 2025
8597ad1
updating the whats new
bruAristimunha May 2, 2025
e64d2e1
Update docs/source/whats_new.rst
bruAristimunha May 3, 2025
062d1c3
simple fit for everybody
bruAristimunha May 3, 2025
ec9d288
Merge branch 'develop' into using-the-new-splitters
bruAristimunha May 5, 2025
68da5d1
updating the splitter
bruAristimunha May 5, 2025
3fadc5c
parallel evaluation now
bruAristimunha May 5, 2025
0297f3c
solving the small issue
bruAristimunha May 5, 2025
2b3ae21
updating the evaluation
bruAristimunha May 5, 2025
86028c9
adjusting in the other evaluation too
bruAristimunha May 5, 2025
70daa3a
updating
bruAristimunha May 6, 2025
125c1d8
Merge branch 'develop' into using-the-new-splitters
bruAristimunha May 6, 2025
8eb2c9d
updating the evaluations
bruAristimunha May 6, 2025
ac37720
Merge branch 'using-the-new-splitters' of https://github.com/bruArist…
bruAristimunha May 6, 2025
6451efd
Apply suggestions from code review
bruAristimunha May 6, 2025
8c932d4
updating base
bruAristimunha May 6, 2025
ba0edf1
Merge branch 'develop' into using-the-new-splitters
bruAristimunha May 8, 2025
5619568
Merge branch 'develop' into using-the-new-splitters
bruAristimunha Jul 25, 2025
3bf2ab7
updating the pyproject
bruAristimunha Jul 25, 2025
c76acce
trying to solve this shit...
bruAristimunha Jul 25, 2025
08b97ed
crazy things here..
bruAristimunha Jul 25, 2025
d3a4aa2
too much things at the same time
bruAristimunha Jul 25, 2025
22f5950
reverting
bruAristimunha Jul 28, 2025
0dcfaf4
Merge branch 'develop' into using-the-new-splitters
bruAristimunha Jul 28, 2025
8996991
evaluation
bruAristimunha Jul 28, 2025
e1b6de1
including acceptance test
bruAristimunha Jul 28, 2025
041e0b7
forcing two reference results
bruAristimunha Jul 28, 2025
f36c2f0
reverting small detail
bruAristimunha Jul 28, 2025
480497f
updating the pyproject
bruAristimunha Jul 28, 2025
9d17557
upgrading the mne version
bruAristimunha Aug 4, 2025
5bcdfb5
solving issue with saving
bruAristimunha Aug 4, 2025
6fe882d
scoring
bruAristimunha Aug 4, 2025
492ad04
fixing import
bruAristimunha Aug 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Enhancements
- Adding :func:`moabb.analysis.plotting.dataset_bubble_plot` plus the corresponding tutorial (:gh:`753` by `Pierre Guetschel`_)
- Adding :func:`moabb.datasets.utils.plot_all_datasets` and update the tutorial (:gh:`758` by `Pierre Guetschel`_)
- Improve the dataset model cards in each API page (:gh:`765` by `Pierre Guetschel`_)
- Refactor :class:`moabb.evaluation.CrossSessionEvaluation`, :class:`moabb.evaluation.CrossSubjectEvaluation` and :class:`moabb.evaluation.WithinSessionEvaluation` to use the new splitter classes (:gh:`769` by `Bruno Aristimunha`_)
- Adding tutorial on using mne-features (:gh:`762` by `Alexander de Ranitz`_, `Luuk Neervens`_, `Charlynn van Osch`_ and `Bruno Aristimunha`_)
- Creating tutorial to expose the pre-processing steps (:gh:`771` by `Bruno Aristimunha`_)
- Add function to auto-generate tables for the paper results documentation page (:gh:`785` by `Lucas Heck`_)
Expand Down
42 changes: 0 additions & 42 deletions examples/advanced_examples/plot_grid_search_withinsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"""

import os
from pickle import load

import matplotlib.pyplot as plt
import seaborn as sns
Expand Down Expand Up @@ -132,44 +131,3 @@
)
sns.pointplot(data=result, y="score", x="pipeline", ax=axes, palette="Set1")
axes.set_ylabel("ROC AUC")

##########################################################
# Load Best Model Parameter
# -------------------------
# The best model are automatically saved in a pickle file, in the
# results directory. It is possible to load those model for each
# dataset, subject and session. Here, we could see that the grid
# search found a l1_ratio that is different from the baseline
# value.

with open(
"./Results/Models_WithinSession/BNCI2014-001/1/1test/GridSearchEN/fitted_model_best.pkl",
"rb",
) as pickle_file:
GridSearchEN_Session_E = load(pickle_file)

print(
"Best Parameter l1_ratio Session_E GridSearchEN ",
GridSearchEN_Session_E.best_params_["LogistReg__l1_ratio"],
)

print(
"Best Parameter l1_ratio Session_E VanillaEN: ",
pipelines["VanillaEN"].steps[2][1].l1_ratio,
)

with open(
"./Results/Models_WithinSession/BNCI2014-001/1/0train/GridSearchEN/fitted_model_best.pkl",
"rb",
) as pickle_file:
GridSearchEN_Session_T = load(pickle_file)

print(
"Best Parameter l1_ratio Session_T GridSearchEN ",
GridSearchEN_Session_T.best_params_["LogistReg__l1_ratio"],
)

print(
"Best Parameter l1_ratio Session_T VanillaEN: ",
pipelines["VanillaEN"].steps[2][1].l1_ratio,
)
2 changes: 1 addition & 1 deletion moabb/evaluations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
WithinSessionEvaluation,
)
from .splitters import CrossSessionSplitter, CrossSubjectSplitter, WithinSessionSplitter
from .utils import create_save_path, save_model_cv, save_model_list
from .utils import _create_save_path, _save_model_cv
76 changes: 46 additions & 30 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,23 @@
from warnings import warn

import pandas as pd
from joblib import Parallel, delayed
from sklearn.base import BaseEstimator
from sklearn.model_selection import GridSearchCV

from moabb.analysis import Results
from moabb.datasets.base import BaseDataset
from moabb.evaluations.utils import _convert_sklearn_params_to_optuna
from moabb.evaluations.utils import (
_convert_sklearn_params_to_optuna,
check_search_available,
)
from moabb.paradigms.base import BaseParadigm


search_methods, optuna_available = check_search_available()

log = logging.getLogger(__name__)

# Making the optuna soft dependency
try:
from optuna.integration import OptunaSearchCV

optuna_available = True
except ImportError:
optuna_available = False

if optuna_available:
search_methods = {"grid": GridSearchCV, "optuna": OptunaSearchCV}
else:
search_methods = {"grid": GridSearchCV}


class BaseEvaluation(ABC):
Expand Down Expand Up @@ -83,6 +77,8 @@ class BaseEvaluation(ABC):
optuna, time_out parameters.
"""

search = False

def __init__(
self,
paradigm,
Expand Down Expand Up @@ -201,7 +197,6 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None):
This pipeline must be "fixed" because it will not be trained,
i.e. no call to ``fit`` will be made.


Returns
-------
results: pd.DataFrame
Expand All @@ -216,26 +211,44 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None):
if not (isinstance(pipeline, BaseEstimator)):
raise (ValueError("pipelines must only contains Pipelines " "instance"))

res_per_db = []
for dataset in self.datasets:
log.info("Processing dataset: {}".format(dataset.code))
process_pipeline = self.paradigm.make_process_pipelines(
# Prepare dataset processing parameters
processing_params = [
(
dataset,
return_epochs=self.return_epochs,
return_raws=self.return_raws,
postprocess_pipeline=postprocess_pipeline,
)[0]
# (we only keep the pipeline for the first frequency band, better ideas?)

results = self.evaluate(
dataset,
pipelines,
param_grid=param_grid,
process_pipeline=process_pipeline,
postprocess_pipeline=postprocess_pipeline,
self.paradigm.make_process_pipelines(
dataset,
return_epochs=self.return_epochs,
return_raws=self.return_raws,
postprocess_pipeline=postprocess_pipeline,
)[0],
)
for dataset in self.datasets
]

# Parallel processing...
parallel_results = Parallel(n_jobs=self.n_jobs)(
delayed(
lambda d, p: list(
self.evaluate(
d,
pipelines,
param_grid=param_grid,
process_pipeline=p,
postprocess_pipeline=postprocess_pipeline,
)
)
)(dataset, process_pipeline)
for dataset, process_pipeline in processing_params
)

res_per_db = []
# Process results in order
for (dataset, process_pipeline), results in zip(
processing_params, parallel_results
):
for res in results:
self.push_result(res, pipelines, process_pipeline)

res_per_db.append(
self.results.to_dataframe(
pipelines=pipelines, process_pipeline=process_pipeline
Expand Down Expand Up @@ -316,9 +329,12 @@ def _grid_search(self, param_grid, name, grid_clf, inner_cv):
return_train_score=True,
**extra_params,
)
self.search = True
return search
else:
self.search = True
return grid_clf

else:
self.search = False
return grid_clf
Loading
Loading