Skip to content

Commit aa3bc9d

Browse files
authored
[ENH] test suite for pytorch-forecasting forecasters (#1780)
This PR adds a systematic test suite and a `check_estimator` utility for `pytorch-forecasting` forecasters. The interface checked is the current unified API across models. This may change in the future, but no changes to the API are made. Work in progress and partial - merged to prevent too high PR stacks in the v2 development.
1 parent 5d57319 commit aa3bc9d

File tree

14 files changed

+1329
-4
lines changed

14 files changed

+1329
-4
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ jobs:
112112

113113
- name: Run pytest
114114
shell: bash
115-
run: python -m pytest tests
115+
run: python -m pytest
116116

117117
pytest:
118118
name: Run pytest
@@ -152,7 +152,7 @@ jobs:
152152

153153
- name: Run pytest
154154
shell: bash
155-
run: python -m pytest tests
155+
run: python -m pytest
156156

157157
- name: Statistics
158158
run: |

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ dev = [
102102
"pytest-dotenv>=0.5.2,<1.0.0",
103103
"tensorboard>=2.12.1,<3.0.0",
104104
"pandoc>=2.3,<3.0.0",
105+
"scikit-base",
105106
]
106107

107108
# docs - dependencies for building the documentation

pytest.ini

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ addopts =
1010
--no-cov-on-fail
1111

1212
markers =
13-
testpaths = tests/
13+
testpaths =
14+
tests/
15+
pytorch_forecasting/tests/
1416
log_cli_level = ERROR
1517
log_format = %(asctime)s %(levelname)s %(message)s
1618
log_date_format = %Y-%m-%d %H:%M:%S
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""PyTorch Forecasting registry."""
2+
3+
from pytorch_forecasting._registry._lookup import all_objects
4+
5+
__all__ = ["all_objects"]
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""Registry lookup methods.
2+
3+
This module exports the following methods for registry lookup:
4+
5+
all_objects(object_types, filter_tags)
6+
lookup and filtering of objects
7+
"""
8+
9+
# based on the sktime module of same name
10+
11+
__author__ = ["fkiraly"]
12+
# all_objects is based on the sklearn utility all_estimators
13+
14+
from inspect import isclass
15+
from pathlib import Path
16+
17+
from skbase.lookup import all_objects as _all_objects
18+
19+
from pytorch_forecasting.models.base import _BaseObject
20+
21+
22+
def all_objects(
23+
object_types=None,
24+
filter_tags=None,
25+
exclude_objects=None,
26+
return_names=True,
27+
as_dataframe=False,
28+
return_tags=None,
29+
suppress_import_stdout=True,
30+
):
31+
"""Get a list of all objects from pytorch_forecasting.
32+
33+
This function crawls the module and gets all classes that inherit
34+
from skbase compatible base classes.
35+
36+
Not included are: the base classes themselves, classes defined in test
37+
modules.
38+
39+
Parameters
40+
----------
41+
object_types: str, list of str, optional (default=None)
42+
Which kind of objects should be returned.
43+
44+
* if None, no filter is applied and all objects are returned.
45+
* if str or list of str, strings define scitypes specified in search
46+
only objects that are of (at least) one of the scitypes are returned
47+
48+
return_names: bool, optional (default=True)
49+
50+
* if True, estimator class name is included in the ``all_objects``
51+
return in the order: name, estimator class, optional tags, either as
52+
a tuple or as pandas.DataFrame columns
53+
* if False, estimator class name is removed from the ``all_objects`` return.
54+
55+
filter_tags: dict of (str or list of str or re.Pattern), optional (default=None)
56+
For a list of valid tag strings, use the registry.all_tags utility.
57+
58+
``filter_tags`` subsets the returned objects as follows:
59+
60+
* each key/value pair is statement in "and"/conjunction
61+
* key is tag name to sub-set on
62+
* value str or list of string are tag values
63+
* condition is "key must be equal to value, or in set(value)"
64+
65+
In detail, he return will be filtered to keep exactly the classes
66+
where tags satisfy all the filter conditions specified by ``filter_tags``.
67+
Filter conditions are as follows, for ``tag_name: search_value`` pairs in
68+
the ``filter_tags`` dict, applied to a class ``klass``:
69+
70+
- If ``klass`` does not have a tag with name ``tag_name``, it is excluded.
71+
Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``.
72+
- If ``search_value`` is a string, and ``tag_value`` is a string,
73+
the filter condition is that ``search_value`` must match the tag value.
74+
- If ``search_value`` is a string, and ``tag_value`` is a list,
75+
the filter condition is that ``search_value`` is contained in ``tag_value``.
76+
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string,
77+
the filter condition is that ``search_value.fullmatch(tag_value)``
78+
is true, i.e., the regex matches the tag value.
79+
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list,
80+
the filter condition is that at least one element of ``tag_value``
81+
matches the regex.
82+
- If ``search_value`` is iterable, then the filter condition is that
83+
at least one element of ``search_value`` satisfies the above conditions,
84+
applied to ``tag_value``.
85+
86+
Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0.
87+
88+
exclude_objects: str, list of str, optional (default=None)
89+
Names of objects to exclude.
90+
91+
as_dataframe: bool, optional (default=False)
92+
93+
* True: ``all_objects`` will return a ``pandas.DataFrame`` with named
94+
columns for all of the attributes being returned.
95+
* False: ``all_objects`` will return a list (either a list of
96+
objects or a list of tuples, see Returns)
97+
98+
return_tags: str or list of str, optional (default=None)
99+
Names of tags to fetch and return each estimator's value of.
100+
For a list of valid tag strings, use the ``registry.all_tags`` utility.
101+
if str or list of str,
102+
the tag values named in return_tags will be fetched for each
103+
estimator and will be appended as either columns or tuple entries.
104+
105+
suppress_import_stdout : bool, optional. Default=True
106+
whether to suppress stdout printout upon import.
107+
108+
Returns
109+
-------
110+
all_objects will return one of the following:
111+
112+
1. list of objects, if ``return_names=False``, and ``return_tags`` is None
113+
114+
2. list of tuples (optional estimator name, class, optional estimator
115+
tags), if ``return_names=True`` or ``return_tags`` is not ``None``.
116+
117+
3. ``pandas.DataFrame`` if ``as_dataframe = True``
118+
119+
if list of objects:
120+
entries are objects matching the query,
121+
in alphabetical order of estimator name
122+
123+
if list of tuples:
124+
list of (optional estimator name, estimator, optional estimator
125+
tags) matching the query, in alphabetical order of estimator name,
126+
where
127+
``name`` is the estimator name as string, and is an
128+
optional return
129+
``estimator`` is the actual estimator
130+
``tags`` are the estimator's values for each tag in return_tags
131+
and is an optional return.
132+
133+
if ``DataFrame``:
134+
column names represent the attributes contained in each column.
135+
"objects" will be the name of the column of objects, "names"
136+
will be the name of the column of estimator class names and the string(s)
137+
passed in return_tags will serve as column names for all columns of
138+
tags that were optionally requested.
139+
140+
Examples
141+
--------
142+
>>> from pytorch_forecasting._registry import all_objects
143+
>>> # return a complete list of objects as pd.Dataframe
144+
>>> all_objects(as_dataframe=True) # doctest: +SKIP
145+
146+
References
147+
----------
148+
Adapted version of sktime's ``all_estimators``,
149+
which is an evolution of scikit-learn's ``all_estimators``
150+
"""
151+
MODULES_TO_IGNORE = (
152+
"tests",
153+
"setup",
154+
"contrib",
155+
"utils",
156+
"all",
157+
)
158+
159+
result = []
160+
ROOT = str(Path(__file__).parent.parent) # package root directory
161+
162+
def _coerce_to_str(obj):
163+
if isinstance(obj, (list, tuple)):
164+
return [_coerce_to_str(o) for o in obj]
165+
if isclass(obj):
166+
obj = obj.get_tag("object_type")
167+
return obj
168+
169+
def _coerce_to_list_of_str(obj):
170+
obj = _coerce_to_str(obj)
171+
if isinstance(obj, str):
172+
return [obj]
173+
return obj
174+
175+
if object_types is not None:
176+
object_types = _coerce_to_list_of_str(object_types)
177+
object_types = list(set(object_types))
178+
179+
if object_types is not None:
180+
if filter_tags is None:
181+
filter_tags = {}
182+
elif isinstance(filter_tags, str):
183+
filter_tags = {filter_tags: True}
184+
else:
185+
filter_tags = filter_tags.copy()
186+
187+
if "object_type" in filter_tags:
188+
obj_field = filter_tags["object_type"]
189+
obj_field = _coerce_to_list_of_str(obj_field)
190+
obj_field = obj_field + object_types
191+
else:
192+
obj_field = object_types
193+
194+
filter_tags["object_type"] = obj_field
195+
196+
result = _all_objects(
197+
object_types=[_BaseObject],
198+
filter_tags=filter_tags,
199+
exclude_objects=exclude_objects,
200+
return_names=return_names,
201+
as_dataframe=as_dataframe,
202+
return_tags=return_tags,
203+
suppress_import_stdout=suppress_import_stdout,
204+
package_name="pytorch_forecasting",
205+
path=ROOT,
206+
modules_to_ignore=MODULES_TO_IGNORE,
207+
)
208+
209+
return result
210+
211+
212+
def _check_list_of_str_or_error(arg_to_check, arg_name):
213+
"""Check that certain arguments are str or list of str.
214+
215+
Parameters
216+
----------
217+
arg_to_check: argument we are testing the type of
218+
arg_name: str,
219+
name of the argument we are testing, will be added to the error if
220+
``arg_to_check`` is not a str or a list of str
221+
222+
Returns
223+
-------
224+
arg_to_check: list of str,
225+
if arg_to_check was originally a str it converts it into a list of str
226+
so that it can be iterated over.
227+
228+
Raises
229+
------
230+
TypeError if arg_to_check is not a str or list of str
231+
"""
232+
# check that return_tags has the right type:
233+
if isinstance(arg_to_check, str):
234+
arg_to_check = [arg_to_check]
235+
if not isinstance(arg_to_check, list) or not all(
236+
isinstance(value, str) for value in arg_to_check
237+
):
238+
raise TypeError(
239+
f"Error in all_objects! Argument {arg_name} must be either\
240+
a str or list of str"
241+
)
242+
return arg_to_check

pytorch_forecasting/models/base/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
BaseModelWithCovariates,
88
Prediction,
99
)
10+
from pytorch_forecasting.models.base._base_object import (
11+
_BaseObject,
12+
_BasePtForecaster,
13+
)
1014

1115
__all__ = [
16+
"_BaseObject",
17+
"_BasePtForecaster",
1218
"AutoRegressiveBaseModel",
1319
"AutoRegressiveBaseModelWithCovariates",
1420
"BaseModel",

0 commit comments

Comments
 (0)