Skip to content

Commit cdb8d2c

Browse files
fixes
1 parent 283663d commit cdb8d2c

File tree

5 files changed

+22
-127
lines changed

5 files changed

+22
-127
lines changed

tsml/compose/_channel_ensemble.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,17 @@ class ChannelEnsembleClassifier(ClassifierMixin, _BaseChannelEnsemble):
204204
Examples
205205
--------
206206
>>> from tsml.compose import ChannelEnsembleClassifier
207-
>>> from tsml.interval_based import IntervalForestClassifier
207+
>>> from tsml.dummy import DummyClassifier
208208
>>> from tsml.utils.testing import generate_3d_test_data
209209
>>> X, y = generate_3d_test_data(n_samples=8, series_length=10, random_state=0)
210210
>>> reg = ChannelEnsembleClassifier(
211-
... estimators=("tsf", IntervalForestClassifier(n_estimators=2), "all-split"),
211+
... estimators=("d", DummyClassifier(), "all-split"),
212212
... random_state=0,
213213
... )
214214
>>> reg.fit(X, y)
215215
ChannelEnsembleClassifier(...)
216216
>>> reg.predict(X)
217-
array([0, 1, 1, 0, 0, 1, 0, 1])
217+
array([0, 0, 0, 0, 0, 0, 0, 0])
218218
"""
219219

220220
def __init__(self, estimators, remainder="drop", random_state=None):
@@ -349,12 +349,12 @@ def get_test_params(
349349
params : dict or list of dict
350350
Parameters to create testing instances of the class.
351351
"""
352-
from tsml.interval_based import IntervalForestClassifier
352+
from tsml.dummy import DummyClassifier
353353

354354
return {
355355
"estimators": [
356-
("tsf1", IntervalForestClassifier(n_estimators=2), 0),
357-
("tsf2", IntervalForestClassifier(n_estimators=2), 0),
356+
("d1", DummyClassifier(), 0),
357+
("d2", DummyClassifier(), 0),
358358
]
359359
}
360360

@@ -411,19 +411,19 @@ class ChannelEnsembleRegressor(RegressorMixin, _BaseChannelEnsemble):
411411
Examples
412412
--------
413413
>>> from tsml.compose import ChannelEnsembleRegressor
414-
>>> from tsml.interval_based import IntervalForestRegressor
414+
>>> from tsml.dummy import DummyRegressor
415415
>>> from tsml.utils.testing import generate_3d_test_data
416416
>>> X, y = generate_3d_test_data(n_samples=8, series_length=10,
417417
... regression_target=True, random_state=0)
418418
>>> reg = ChannelEnsembleRegressor(
419-
... estimators=("tsf", IntervalForestRegressor(n_estimators=2), "all-split"),
419+
... estimators=("d", DummyRegressor(), "all-split"),
420420
... random_state=0,
421421
... )
422422
>>> reg.fit(X, y)
423423
ChannelEnsembleRegressor(...)
424424
>>> reg.predict(X)
425-
array([0.31798318, 1.41426301, 1.06414747, 0.6924721 , 0.56660146,
426-
1.26538944, 0.52324808, 1.0939405 ])
425+
array([0.8672557, 0.8672557, 0.8672557, 0.8672557, 0.8672557, 0.8672557,
426+
0.8672557, 0.8672557], dtype=float32)
427427
"""
428428

429429
def __init__(self, estimators, remainder="drop", random_state=None):
@@ -518,12 +518,12 @@ def get_test_params(
518518
params : dict or list of dict
519519
Parameters to create testing instances of the class.
520520
"""
521-
from tsml.interval_based import IntervalForestRegressor
521+
from tsml.dummy import DummyRegressor
522522

523523
return {
524524
"estimators": [
525-
("tsf1", IntervalForestRegressor(n_estimators=2), 0),
526-
("tsf2", IntervalForestRegressor(n_estimators=2), 0),
525+
("d1", DummyRegressor(), 0),
526+
("d2", DummyRegressor(), 0),
527527
]
528528
}
529529

tsml/compose/tests/test_channel_ensemble.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,21 @@
99
_check_key_type,
1010
_get_channel,
1111
)
12-
from tsml.interval_based import IntervalForestClassifier, IntervalForestRegressor
12+
from tsml.dummy import DummyClassifier, DummyRegressor
1313
from tsml.utils.testing import generate_3d_test_data, generate_unequal_test_data
1414

1515

1616
def test_single_estimator():
1717
"""Test that a single estimator is correctly applied to all channels."""
1818
X, y = generate_3d_test_data(n_channels=3)
1919

20-
ens = ChannelEnsembleClassifier(
21-
estimators=[("tsf", IntervalForestClassifier(n_estimators=2), "all")]
22-
)
20+
ens = ChannelEnsembleClassifier(estimators=[("d", DummyClassifier(), "all")])
2321
ens.fit(X, y)
2422

2523
assert len(ens.estimators_[0][2]) == 3
2624
assert ens.predict(X).shape == (X.shape[0],)
2725

28-
ens = ChannelEnsembleRegressor(
29-
estimators=[("tsf", IntervalForestRegressor(n_estimators=2), "all")]
30-
)
26+
ens = ChannelEnsembleRegressor(estimators=[("d", DummyRegressor(), "all")])
3127
ens.fit(X, y)
3228

3329
assert len(ens.estimators_[0][2]) == 3
@@ -38,18 +34,14 @@ def test_single_estimator_split():
3834
"""Test that a single split estimator correctly creates an estimator per channel."""
3935
X, y = generate_3d_test_data(n_channels=3)
4036

41-
ens = ChannelEnsembleClassifier(
42-
estimators=("tsf", IntervalForestClassifier(n_estimators=2), "all-split")
43-
)
37+
ens = ChannelEnsembleClassifier(estimators=("d", DummyClassifier(), "all-split"))
4438
ens.fit(X, y)
4539

4640
assert len(ens.estimators_) == 3
4741
assert isinstance(ens.estimators_[0][2], int)
4842
assert ens.predict(X).shape == (X.shape[0],)
4943

50-
ens = ChannelEnsembleRegressor(
51-
estimators=("tsf", IntervalForestRegressor(n_estimators=2), "all-split")
52-
)
44+
ens = ChannelEnsembleRegressor(estimators=("d", DummyRegressor(), "all-split"))
5345
ens.fit(X, y)
5446

5547
assert len(ens.estimators_) == 3
@@ -62,17 +54,17 @@ def test_remainder():
6254
X, y = generate_3d_test_data(n_channels=3)
6355

6456
ens = ChannelEnsembleClassifier(
65-
estimators=[("tsf", IntervalForestClassifier(n_estimators=2), 0)],
66-
remainder=IntervalForestClassifier(n_estimators=2),
57+
estimators=[("d", DummyClassifier(), 0)],
58+
remainder=DummyClassifier(),
6759
)
6860
ens.fit(X, y)
6961

7062
assert len(ens._remainder[2]) == 2
7163
assert ens.predict(X).shape == (X.shape[0],)
7264

7365
ens = ChannelEnsembleRegressor(
74-
estimators=[("tsf", IntervalForestRegressor(n_estimators=2), 0)],
75-
remainder=IntervalForestRegressor(n_estimators=2),
66+
estimators=[("d", DummyRegressor(), 0)],
67+
remainder=DummyRegressor(),
7668
)
7769
ens.fit(X, y)
7870

tsml/transformations/tests/test_interval_extraction.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

tsml/transformations/tests/test_periodogram.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

tsml/utils/testing.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,6 @@ def parametrize_with_checks(estimators: List[BaseEstimator]) -> Callable:
9696
See Also
9797
--------
9898
check_estimator : Check if estimator adheres to tsml or scikit-learn conventions.
99-
100-
Examples
101-
--------
102-
>>> from tsml.utils.testing import parametrize_with_checks
103-
>>> from tsml.interval_based import IntervalForestRegressor
104-
>>> from tsml.vector import RotationForestClassifier
105-
>>> @parametrize_with_checks(
106-
... [IntervalForestRegressor(), RotationForestClassifier()]
107-
... )
108-
... def test_tsml_compatible_estimator(estimator, check):
109-
... check(estimator)
11099
"""
111100
import pytest
112101

0 commit comments

Comments
 (0)