Skip to content

Commit 82acbba

Browse files
[MNT] Automated pre-commit hook update and stuff from tsml-eval I am hijacking this PR for (#87)
* Automated `pre-commit` hook update * fixes * Automated `pre-commit` hook update * fixes * wildboar --------- Co-authored-by: MatthewMiddlehurst <[email protected]> Co-authored-by: MatthewMiddlehurst <[email protected]>
1 parent 70dd4ec commit 82acbba

File tree

8 files changed

+305
-7
lines changed

8 files changed

+305
-7
lines changed

.github/workflows/periodic_tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ jobs:
9090
- name: Disable Numba JIT
9191
run: echo "NUMBA_DISABLE_JIT=1" >> $GITHUB_ENV
9292

93-
- name: Install aeon and dependencies
93+
- name: Install
9494
uses: nick-fields/retry@v3
9595
with:
9696
timeout_minutes: 30
@@ -101,7 +101,7 @@ jobs:
101101
run: python -m pip list
102102

103103
- name: Run tests
104-
run: python -m pytest -n logical --cov=aeon --cov-report=xml --timeout 1800
104+
run: python -m pytest -n logical --cov=tsml --cov-report=xml --timeout 1800
105105

106106
- uses: codecov/codecov-action@v5
107107
env:

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ repos:
2626
args: [ "--create", "--python-folders", "tsml" ]
2727

2828
- repo: https://github.com/astral-sh/ruff-pre-commit
29-
rev: v0.12.10
29+
rev: v0.12.12
3030
hooks:
3131
- id: ruff
3232
args: [ "--fix" ]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ all_extras = [
5252
"grailts",
5353
"scikit-fda>=0.7.0; python_version > '3.9' and python_version < '3.13'",
5454
"statsmodels>=0.12.1",
55-
"wildboar",
55+
"wildboar<=1.2.0",
5656
]
5757
unstable_extras = [
5858
"mrsqm>=0.0.7; platform_system == 'Linux' and python_version < '3.12'", # requires gcc and fftw to be installed for Windows and some other OS (see http://www.fftw.org/index.html)

tsml/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ def _check_n_features(self, X: np.ndarray | list[np.ndarray], reset: bool):
249249
def _more_tags(self) -> dict:
250250
return _DEFAULT_TAGS
251251

252+
def _get_tags(self) -> dict:
253+
return _safe_tags(self)
254+
252255
@classmethod
253256
def get_test_params(cls, parameter_set: str | None = None) -> dict | list[dict]:
254257
"""Return unit test parameter settings for the estimator.

tsml/compose/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,17 @@
33
__all__ = [
44
"ChannelEnsembleClassifier",
55
"ChannelEnsembleRegressor",
6+
"SklearnToTsmlClassifier",
7+
"SklearnToTsmlClusterer",
8+
"SklearnToTsmlRegressor",
69
]
710

811
from tsml.compose._channel_ensemble import (
912
ChannelEnsembleClassifier,
1013
ChannelEnsembleRegressor,
1114
)
15+
from tsml.compose._sklearn_to_tsml import (
16+
SklearnToTsmlClassifier,
17+
SklearnToTsmlClusterer,
18+
SklearnToTsmlRegressor,
19+
)

tsml/compose/_sklearn_to_tsml.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
"""A tsml wrapper for sklearn classifiers."""
2+
3+
__maintainer__ = ["MatthewMiddlehurst"]
4+
__all__ = [
5+
"SklearnToTsmlClassifier",
6+
"SklearnToTsmlClusterer",
7+
"SklearnToTsmlRegressor",
8+
]
9+
10+
import numpy as np
11+
from aeon.base._base import _clone_estimator
12+
from sklearn.base import ClassifierMixin, ClusterMixin, RegressorMixin
13+
from sklearn.utils.multiclass import check_classification_targets
14+
from sklearn.utils.validation import check_is_fitted
15+
16+
from tsml.base import BaseTimeSeriesEstimator
17+
18+
19+
class SklearnToTsmlClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
20+
"""Wrapper for sklearn estimators to use the tsml base class."""
21+
22+
def __init__(
23+
self,
24+
classifier=None,
25+
pad_unequal=False,
26+
concatenate_channels=False,
27+
clone_estimator=True,
28+
random_state=None,
29+
):
30+
self.classifier = classifier
31+
self.pad_unequal = pad_unequal
32+
self.concatenate_channels = concatenate_channels
33+
self.clone_estimator = clone_estimator
34+
self.random_state = random_state
35+
36+
super().__init__()
37+
38+
def fit(self, X, y):
39+
"""Wrap fit."""
40+
if self.classifier is None:
41+
raise ValueError("Classifier not set")
42+
43+
X, y = self._validate_data(
44+
X=X,
45+
y=y,
46+
ensure_univariate=not self.concatenate_channels,
47+
ensure_equal_length=not self.pad_unequal,
48+
)
49+
X = self._convert_X(
50+
X,
51+
pad_unequal=self.pad_unequal,
52+
concatenate_channels=self.concatenate_channels,
53+
)
54+
55+
check_classification_targets(y)
56+
self.classes_ = np.unique(y)
57+
58+
self._classifier = (
59+
_clone_estimator(self.classifier, self.random_state)
60+
if self.clone_estimator
61+
else self.classifier
62+
)
63+
self._classifier.fit(X, y)
64+
65+
return self
66+
67+
def predict(self, X) -> np.ndarray:
68+
"""Wrap predict."""
69+
check_is_fitted(self)
70+
71+
X = self._validate_data(X=X, reset=False)
72+
X = self._convert_X(
73+
X,
74+
pad_unequal=self.pad_unequal,
75+
concatenate_channels=self.concatenate_channels,
76+
)
77+
78+
return self._classifier.predict(X)
79+
80+
def predict_proba(self, X) -> np.ndarray:
81+
"""Wrap predict_proba."""
82+
check_is_fitted(self)
83+
84+
X = self._validate_data(X=X, reset=False)
85+
X = self._convert_X(
86+
X,
87+
pad_unequal=self.pad_unequal,
88+
concatenate_channels=self.concatenate_channels,
89+
)
90+
91+
return self._classifier.predict_proba(X)
92+
93+
def _more_tags(self):
94+
return {
95+
"X_types": ["2darray"],
96+
"equal_length_only": (False if self.pad_unequal else True),
97+
"univariate_only": False if self.concatenate_channels else True,
98+
}
99+
100+
@classmethod
101+
def get_test_params(cls, parameter_set: str | None = None) -> dict | list[dict]:
102+
"""Return unit test parameter settings for the estimator.
103+
104+
Parameters
105+
----------
106+
parameter_set : None or str, default=None
107+
Name of the set of test parameters to return, for use in tests. If no
108+
special parameters are defined for a value, will return `"default"` set.
109+
110+
Returns
111+
-------
112+
params : dict or list of dict
113+
Parameters to create testing instances of the class.
114+
"""
115+
from sklearn.ensemble import RandomForestClassifier
116+
117+
return {"classifier": RandomForestClassifier(n_estimators=5)}
118+
119+
120+
class SklearnToTsmlClusterer(ClusterMixin, BaseTimeSeriesEstimator):
121+
"""Wrapper for sklearn estimators to use the tsml base class."""
122+
123+
def __init__(
124+
self,
125+
clusterer=None,
126+
pad_unequal=False,
127+
concatenate_channels=False,
128+
clone_estimator=True,
129+
random_state=None,
130+
):
131+
self.clusterer = clusterer
132+
self.pad_unequal = pad_unequal
133+
self.concatenate_channels = concatenate_channels
134+
self.clone_estimator = clone_estimator
135+
self.random_state = random_state
136+
137+
super().__init__()
138+
139+
def fit(self, X, y=None):
140+
"""Wrap fit."""
141+
if self.clusterer is None:
142+
raise ValueError("Clusterer not set")
143+
144+
X = self._validate_data(
145+
X=X,
146+
ensure_univariate=not self.concatenate_channels,
147+
ensure_equal_length=not self.pad_unequal,
148+
)
149+
X = self._convert_X(
150+
X,
151+
pad_unequal=self.pad_unequal,
152+
concatenate_channels=self.concatenate_channels,
153+
)
154+
155+
self._clusterer = (
156+
_clone_estimator(self.clusterer, self.random_state)
157+
if self.clone_estimator
158+
else self.clusterer
159+
)
160+
self._clusterer.fit(X, y)
161+
162+
self.labels_ = self._clusterer.labels_
163+
164+
return self
165+
166+
def predict(self, X) -> np.ndarray:
167+
"""Wrap predict."""
168+
check_is_fitted(self)
169+
170+
X = self._validate_data(X=X, reset=False)
171+
X = self._convert_X(
172+
X,
173+
pad_unequal=self.pad_unequal,
174+
concatenate_channels=self.concatenate_channels,
175+
)
176+
177+
return self._clusterer.predict(X)
178+
179+
def _more_tags(self):
180+
return {
181+
"X_types": ["2darray"],
182+
"equal_length_only": (False if self.pad_unequal else True),
183+
"univariate_only": False if self.concatenate_channels else True,
184+
}
185+
186+
@classmethod
187+
def get_test_params(cls, parameter_set: str | None = None) -> dict | list[dict]:
188+
"""Return unit test parameter settings for the estimator.
189+
190+
Parameters
191+
----------
192+
parameter_set : None or str, default=None
193+
Name of the set of test parameters to return, for use in tests. If no
194+
special parameters are defined for a value, will return `"default"` set.
195+
196+
Returns
197+
-------
198+
params : dict or list of dict
199+
Parameters to create testing instances of the class.
200+
"""
201+
from sklearn.cluster import KMeans
202+
203+
return {"clusterer": KMeans(n_clusters=2, max_iter=5)}
204+
205+
206+
class SklearnToTsmlRegressor(RegressorMixin, BaseTimeSeriesEstimator):
207+
"""Wrapper for sklearn estimators to use the tsml base class."""
208+
209+
def __init__(
210+
self,
211+
regressor=None,
212+
pad_unequal=False,
213+
concatenate_channels=False,
214+
clone_estimator=True,
215+
random_state=None,
216+
):
217+
self.regressor = regressor
218+
self.pad_unequal = pad_unequal
219+
self.concatenate_channels = concatenate_channels
220+
self.clone_estimator = clone_estimator
221+
self.random_state = random_state
222+
223+
super().__init__()
224+
225+
def fit(self, X, y):
226+
"""Wrap fit."""
227+
if self.regressor is None:
228+
raise ValueError("Regressor not set")
229+
230+
X, y = self._validate_data(
231+
X=X,
232+
y=y,
233+
ensure_univariate=not self.concatenate_channels,
234+
ensure_equal_length=not self.pad_unequal,
235+
)
236+
X = self._convert_X(
237+
X,
238+
pad_unequal=self.pad_unequal,
239+
concatenate_channels=self.concatenate_channels,
240+
)
241+
242+
self._regressor = (
243+
_clone_estimator(self.regressor, self.random_state)
244+
if self.clone_estimator
245+
else self.regressor
246+
)
247+
self._regressor.fit(X, y)
248+
249+
return self
250+
251+
def predict(self, X) -> np.ndarray:
252+
"""Wrap predict."""
253+
check_is_fitted(self)
254+
255+
X = self._validate_data(X=X, reset=False)
256+
X = self._convert_X(
257+
X,
258+
pad_unequal=self.pad_unequal,
259+
concatenate_channels=self.concatenate_channels,
260+
)
261+
262+
return self._regressor.predict(X)
263+
264+
def _more_tags(self):
265+
return {
266+
"X_types": ["2darray"],
267+
"equal_length_only": (False if self.pad_unequal else True),
268+
"univariate_only": False if self.concatenate_channels else True,
269+
}
270+
271+
@classmethod
272+
def get_test_params(cls, parameter_set: str | None = None) -> dict | list[dict]:
273+
"""Return unit test parameter settings for the estimator.
274+
275+
Parameters
276+
----------
277+
parameter_set : None or str, default=None
278+
Name of the set of test parameters to return, for use in tests. If no
279+
special parameters are defined for a value, will return `"default"` set.
280+
281+
Returns
282+
-------
283+
params : dict or list of dict
284+
Parameters to create testing instances of the class.
285+
"""
286+
from sklearn.ensemble import RandomForestRegressor
287+
288+
return {"regressor": RandomForestRegressor(n_estimators=5)}

tsml/tests/test_estimators_sklearn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from sklearn.model_selection import train_test_split
1717
from sklearn.pipeline import make_pipeline
1818
from sklearn.preprocessing import scale
19-
from sklearn.utils._tags import _safe_tags as _safe_tags_sklearn
2019
from sklearn.utils._testing import (
2120
SkipTest,
2221
assert_allclose,
@@ -1410,7 +1409,7 @@ def check_estimator_get_tags_default_keys(name, estimator_orig):
14101409
if not hasattr(estimator, "_get_tags"):
14111410
return
14121411

1413-
default_tags_keys = set(_safe_tags_sklearn(estimator).keys())
1412+
default_tags_keys = set(_safe_tags(estimator).keys())
14141413
tags_keys = set(estimator._get_tags().keys())
14151414
assert tags_keys.intersection(default_tags_keys) == default_tags_keys, (
14161415
f"{name}._get_tags() is missing entries for the following default tags: "

tsml/tests/test_sklearn_compatability.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Unit tests for aeon classifier compatability with sklearn interfaces."""
1+
"""Unit tests for tsml classifier compatability with sklearn interfaces."""
22

33
__maintainer__ = []
44
__all__ = [

0 commit comments

Comments
 (0)