Skip to content

Commit 68777eb

Browse files
authored
[ENH] Benchmarking framework (#114)
closes #141 This PR: * Fixes a small bug in `AptaNetPipeline` * Makes `AptaNetPipeline` inherit from `BaseObject` to prevent errors during benchmarking * Removes an unnecessary test (`test_pfoa`), the loader is already being tested in [`test_loaders`](https://github.com/gc-os-ai/pyaptamer/blob/main/pyaptamer/datasets/tests/test_loaders.py) * The benchmarking framework
1 parent 31c204b commit 68777eb

File tree

9 files changed

+251
-33
lines changed

9 files changed

+251
-33
lines changed

examples/aptanet_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@
554554
" ]\n",
555555
")\n",
556556
"\n",
557-
"pipeline = AptaNetPipeline(classifier=model)"
557+
"pipeline = AptaNetPipeline(estimator=model)"
558558
]
559559
},
560560
{

pyaptamer/aptanet/_pipeline.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
__all__ = ["AptaNetPipeline"]
33
__required__ = ["python>=3.9,<3.13"]
44

5+
from skbase.base import BaseObject
56
from sklearn.base import clone
67
from sklearn.pipeline import Pipeline
78
from sklearn.preprocessing import FunctionTransformer
@@ -11,7 +12,7 @@
1112
from pyaptamer.utils._aptanet_utils import pairs_to_features
1213

1314

14-
class AptaNetPipeline:
15+
class AptaNetPipeline(BaseObject):
1516
"""
1617
AptaNet algorithm for aptamer–protein interaction prediction [1]_
1718
@@ -22,14 +23,14 @@ class AptaNetPipeline:
2223
2324
The pipeline starts from string pairs, converts them into numeric features
2425
(aptamer k-mer frequencies + protein PSeAAC), applies tree-based feature
25-
selection, and feeds the result into the classifier.
26+
selection, and feeds the result into the estimator.
2627
2728
Parameters
2829
----------
2930
k : int, optional, default=4
3031
The k-mer size used to generate aptamer k-mer vectors.
3132
32-
classifier : sklearn-compatible estimator or None, default=None
33+
estimator : sklearn-compatible estimator or None, default=None
3334
Estimator applied after feature selection. If None, uses `AptaNetClassifier`.
3435
3536
Attributes
@@ -62,18 +63,18 @@ class AptaNetPipeline:
6263
>>> proba = pipe.predict_proba(X_test_pairs)
6364
"""
6465

65-
def __init__(self, k=None, classifier=None):
66+
def __init__(self, k=4, estimator=None):
6667
self.k = k
67-
self.classifier = classifier
68+
self.estimator = estimator
6869

6970
def _build_pipeline(self):
7071
transformer = FunctionTransformer(
7172
func=pairs_to_features,
72-
kw_args=self.k,
73+
kw_args={"k": self.k},
7374
validate=False,
7475
)
75-
self._classifier = self.classifier or AptaNetClassifier()
76-
return Pipeline([("features", transformer), ("clf", clone(self._classifier))])
76+
self._estimator = self.estimator or AptaNetClassifier()
77+
return Pipeline([("features", transformer), ("clf", clone(self._estimator))])
7778

7879
def fit(self, X, y):
7980
self.pipeline_ = self._build_pipeline()

pyaptamer/aptanet/tests/test_aptanet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_pipeline_fit_and_predict_classification(aptamer_seq, protein_seq):
2626
Test if Pipeline predictions are valid class labels and shape matches input
2727
for classification.
2828
"""
29-
pipe = AptaNetPipeline()
29+
pipe = AptaNetPipeline(k=4)
3030

3131
X_raw = [(aptamer_seq, protein_seq) for _ in range(40)]
3232
y = np.array([0] * 20 + [1] * 20, dtype=np.float32)
@@ -66,7 +66,7 @@ def test_pipeline_fit_and_predict_regression(aptamer_seq, protein_seq):
6666
Test if Pipeline predictions are valid floats and shape matches input
6767
for regression.
6868
"""
69-
pipe = AptaNetPipeline(classifier=AptaNetRegressor())
69+
pipe = AptaNetPipeline(estimator=AptaNetRegressor())
7070

7171
X_raw = [(aptamer_seq, protein_seq) for _ in range(40)]
7272
y = np.linspace(0, 1, 40).astype(np.float32)

pyaptamer/benchmarking/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Benchmarking module."""
2+
3+
from pyaptamer.benchmarking._base import Benchmarking
4+
5+
__all__ = ["Benchmarking"]

pyaptamer/benchmarking/_base.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
__author__ = "satvshr"
2+
__all__ = ["Benchmarking"]
3+
4+
import numpy as np
5+
import pandas as pd
6+
from sklearn.metrics import make_scorer
7+
from sklearn.model_selection import cross_validate
8+
9+
10+
class Benchmarking:
11+
"""
12+
Benchmark estimators using cross-validation.
13+
14+
You can:
15+
16+
- pass `X, y` (feature matrix and labels/targets) along with `cv`
17+
to use any cross-validation strategy;
18+
- if you want a fixed train/test split, pass a `PredefinedSplit`
19+
object as `cv`.
20+
21+
Parameters
22+
----------
23+
estimators : list[estimator] | estimator
24+
List of sklearn-like estimators implementing `fit` and `predict`.
25+
metrics : list[callable] | callable
26+
List of callables with signature `(y_true, y_pred) -> float`.
27+
X : array-like
28+
Feature matrix.
29+
y : array-like
30+
Target vector.
31+
cv : int, CV splitter, or None, default=None
32+
Cross-validation strategy. If `None`, defaults to 5-fold CV.
33+
If you want to use an explicit train/test split, pass a
34+
`PredefinedSplit` object.
35+
36+
Attributes
37+
----------
38+
results : pd.DataFrame
39+
DataFrame produced by :meth:`run`.
40+
41+
- Index: pandas.MultiIndex with two levels (names shown in parentheses)
42+
- level 0 "estimator": estimator name
43+
- level 1 "metric": evaluator name
44+
- Columns: ["train", "test"] (both floats)
45+
- Cell values: mean scores (float) computed across CV folds:
46+
- "train" = mean of cross_validate(...)[f"train_{metric}"]
47+
- "test" = mean of cross_validate(...)[f"test_{metric}"]
48+
49+
Example
50+
-------
51+
>>> import numpy as np
52+
>>> from sklearn.metrics import accuracy_score
53+
>>> from sklearn.model_selection import PredefinedSplit
54+
>>> from pyaptamer.benchmarking._base import Benchmarking
55+
>>> from pyaptamer.aptanet import AptaNetPipeline
56+
>>> aptamer_seq = "AGCTTAGCGTACAGCTTAAAAGGGTTTCCCCTGCCCGCGTAC"
57+
>>> protein_seq = "ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWY"
58+
>>> # dataset: 20 aptamer–protein pairs
59+
>>> X = [(aptamer_seq, protein_seq) for _ in range(20)]
60+
>>> y = np.array([0] * 10 + [1] * 10, dtype=np.float32)
61+
>>> clf = AptaNetPipeline(k=4)
62+
>>> # define a fixed train/test split
63+
>>> test_fold = np.ones(len(y)) * -1
64+
>>> test_fold[-2:] = 0
65+
>>> cv = PredefinedSplit(test_fold)
66+
>>> bench = Benchmarking(
67+
... estimators=[clf],
68+
... metrics=[accuracy_score],
69+
... X=X,
70+
... y=y,
71+
... cv=cv,
72+
... )
73+
>>> summary = bench.run() # doctest: +SKIP
74+
"""
75+
76+
def __init__(self, estimators, metrics, X, y, cv=None):
77+
self.estimators = estimators if isinstance(estimators, list) else [estimators]
78+
self.metrics = metrics if isinstance(metrics, list) else [metrics]
79+
self.X = X
80+
self.y = y
81+
self.cv = cv
82+
self.results = None
83+
84+
def _to_scorers(self, metrics):
85+
"""Convert metric callables to a dict of scorers."""
86+
scorers = {}
87+
for metric in metrics:
88+
if not callable(metric):
89+
raise ValueError("Each metric should be a callable.")
90+
name = (
91+
metric.__name__
92+
if hasattr(metric, "__name__")
93+
else metric.__class__.__name__
94+
)
95+
scorers[name] = make_scorer(metric)
96+
return scorers
97+
98+
def _to_df(self, results):
99+
"""Convert nested results to a unified DataFrame."""
100+
records = []
101+
index = []
102+
103+
for est_name, est_scores in results.items():
104+
for metric_name, scores in est_scores.items():
105+
records.append(scores)
106+
index.append((est_name, metric_name))
107+
108+
index = pd.MultiIndex.from_tuples(index, names=["estimator", "metric"])
109+
return pd.DataFrame(records, index=index, columns=["train", "test"])
110+
111+
def run(self):
112+
"""
113+
Train each estimator and evaluate with cross-validation.
114+
115+
Returns
116+
-------
117+
results : pd.DataFrame
118+
119+
- Index: pandas.MultiIndex with two levels (names shown in parentheses)
120+
- level 0 "estimator": estimator name
121+
- level 1 "metric": evaluator name
122+
- Columns: ["train", "test"] (both floats)
123+
- Cell values: mean scores (float) computed across CV folds:
124+
- "train" = mean of cross_validate(...)[f"train_{metric}"]
125+
- "test" = mean of cross_validate(...)[f"test_{metric}"]
126+
127+
"""
128+
self.scorers_ = self._to_scorers(self.metrics)
129+
results = {}
130+
131+
for estimator in self.estimators:
132+
est_name = estimator.__class__.__name__
133+
134+
cv_results = cross_validate(
135+
estimator,
136+
self.X,
137+
self.y,
138+
cv=self.cv,
139+
scoring=self.scorers_,
140+
return_train_score=True,
141+
)
142+
143+
# average across folds
144+
est_scores = {}
145+
for metric in self.scorers_.keys():
146+
est_scores[metric] = {
147+
"train": float(np.mean(cv_results[f"train_{metric}"])),
148+
"test": float(np.mean(cv_results[f"test_{metric}"])),
149+
}
150+
151+
results[est_name] = est_scores
152+
153+
self.results = self._to_df(results)
154+
return self.results
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Test suite for the benchmarking module"""
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import sys
2+
3+
import numpy as np
4+
import pytest
5+
from sklearn.metrics import accuracy_score, mean_squared_error
6+
from sklearn.model_selection import PredefinedSplit
7+
8+
from pyaptamer.aptanet import AptaNetPipeline, AptaNetRegressor
9+
from pyaptamer.benchmarking._base import Benchmarking
10+
11+
params = [
12+
(
13+
"AGCTTAGCGTACAGCTTAAAAGGGTTTCCCCTGCCCGCGTAC",
14+
"ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWY",
15+
)
16+
]
17+
18+
19+
@pytest.mark.skipif(
20+
sys.version_info >= (3, 13), reason="skorch does not support Python 3.13"
21+
)
22+
@pytest.mark.parametrize("aptamer_seq, protein_seq", params)
23+
def test_benchmarking_with_predefined_split_classification(aptamer_seq, protein_seq):
24+
"""
25+
Test Benchmarking on a classification task using PredefinedSplit.
26+
"""
27+
X_raw = [(aptamer_seq, protein_seq) for _ in range(40)]
28+
y = np.array([0] * 20 + [1] * 20, dtype=np.float32)
29+
30+
clf = AptaNetPipeline()
31+
32+
test_fold = np.ones(len(y), dtype=int) * -1
33+
test_fold[-2:] = 0
34+
cv = PredefinedSplit(test_fold)
35+
36+
bench = Benchmarking(
37+
estimators=[clf],
38+
metrics=[accuracy_score],
39+
X=X_raw,
40+
y=y,
41+
cv=cv,
42+
)
43+
summary = bench.run()
44+
45+
assert "train" in summary.columns
46+
assert "test" in summary.columns
47+
assert (clf.__class__.__name__, "accuracy_score") in summary.index
48+
49+
50+
@pytest.mark.skipif(
51+
sys.version_info >= (3, 13), reason="skorch does not support Python 3.13"
52+
)
53+
@pytest.mark.parametrize("aptamer_seq, protein_seq", params)
54+
def test_benchmarking_with_predefined_split_regression(aptamer_seq, protein_seq):
55+
"""
56+
Test Benchmarking on a regression task using PredefinedSplit.
57+
"""
58+
X_raw = [(aptamer_seq, protein_seq) for _ in range(40)]
59+
y = np.linspace(0, 1, 40).astype(np.float32)
60+
61+
reg = AptaNetPipeline(estimator=AptaNetRegressor())
62+
63+
test_fold = np.ones(len(y), dtype=int) * -1
64+
test_fold[-3:] = 0
65+
cv = PredefinedSplit(test_fold)
66+
67+
bench = Benchmarking(
68+
estimators=[reg],
69+
metrics=[mean_squared_error],
70+
X=X_raw,
71+
y=y,
72+
cv=cv,
73+
)
74+
summary = bench.run()
75+
76+
assert "train" in summary.columns
77+
assert "test" in summary.columns
78+
assert (reg.__class__.__name__, "mean_squared_error") in summary.index

pyaptamer/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
"load_pfoa_structure",
1313
"load_1gnh_structure",
1414
"load_from_rcsb",
15+
"load_csv_dataset",
1516
]

pyaptamer/datasets/tests/test_pfoa.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)