Skip to content

Commit d94039f

Browse files
committed
Test
1 parent 0cb9e44 commit d94039f

File tree

2 files changed

+9
-21
lines changed

2 files changed

+9
-21
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ keel = ["pandas"]
3131
keras = ["keras", "tensorflow"]
3232
physionet = ["pandas", "wfdb"]
3333
utils-estimator = ["jsonpickle"]
34-
utils-experiments = ["sacred", "incense"]
34+
utils-experiments = ["sacred", "incense@git+https://github.com/JarnoRFB/incense.git@bc736b71cd15a136acf42e14ec4bfac09a8dba53"]
3535
utils-scores = ["statsmodels", "jinja2"]
3636
all = ["scikit-datasets[cran, forex, keel, keras, physionet, utils-estimator, utils-experiments, utils-scores]"]
3737
test = ["pytest", "pytest-cov[all]", "coverage", "scikit-datasets[all]"]

skdatasets/utils/experiment.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
@author: David Diaz Vico
33
@license: MIT
44
"""
5+
56
from __future__ import annotations
67

78
import itertools
@@ -145,6 +146,7 @@ class ScoresInfo:
145146
fetch_scores
146147
147148
"""
149+
148150
dataset_names: Sequence[str]
149151
estimator_names: Sequence[str]
150152
scores: np.typing.NDArray[float]
@@ -182,9 +184,7 @@ def _iterate_outer_cv(
182184
yield from outer_cv
183185

184186
cv = check_cv(outer_cv, y, classifier=is_classifier(estimator))
185-
yield from (
186-
(X[train], y[train], X[test], y[test]) for train, test in cv.split(X, y)
187-
)
187+
yield from ((X[train], y[train], X[test], y[test]) for train, test in cv.split(X, y))
188188

189189

190190
def _benchmark_from_data(
@@ -690,23 +690,16 @@ def _get_experiments(
690690
find_all_fun = getattr(
691691
loader,
692692
"find_all",
693-
lambda: [
694-
FileSystemExperiment.from_run_dir(run_dir)
695-
for run_dir in loader._runs_dir.iterdir()
696-
],
693+
lambda: [FileSystemExperiment.from_run_dir(run_dir) for run_dir in loader._runs_dir.iterdir()],
697694
)
698695

699696
experiments = find_all_fun()
700697

701-
elif (dataset_names, estimator_names) == (None, None) or isinstance(
702-
loader, FileSystemExperimentLoader
703-
):
698+
elif (dataset_names, estimator_names) == (None, None) or isinstance(loader, FileSystemExperimentLoader):
704699
load_ids_fun = getattr(
705700
loader,
706701
"find_by_ids",
707-
lambda id_seq: [
708-
loader.find_by_id(experiment_id) for experiment_id in id_seq
709-
],
702+
lambda id_seq: [loader.find_by_id(experiment_id) for experiment_id in id_seq],
710703
)
711704

712705
experiments = load_ids_fun(ids)
@@ -739,10 +732,7 @@ def _get_experiments(
739732
e
740733
for e in experiments
741734
if (
742-
(
743-
estimator_names is None
744-
or e.config["estimator_name"] in estimator_names
745-
)
735+
(estimator_names is None or e.config["estimator_name"] in estimator_names)
746736
and (dataset_names is None or e.config["dataset_name"] in dataset_names)
747737
)
748738
]
@@ -842,9 +832,7 @@ def fetch_scores(
842832
score_std,
843833
)
844834

845-
estimator_names = (
846-
tuple(estimator_list) if estimator_names is None else estimator_names
847-
)
835+
estimator_names = tuple(estimator_list) if estimator_names is None else estimator_names
848836
dataset_names = tuple(dataset_list) if dataset_names is None else dataset_names
849837
matrix_shape = (len(dataset_names), len(estimator_names))
850838

0 commit comments

Comments
 (0)