|
| 1 | +"""Registry lookup methods. |
| 2 | +
|
| 3 | +This module exports the following methods for registry lookup: |
| 4 | +
|
| 5 | +all_objects(object_types, filter_tags) |
| 6 | + lookup and filtering of objects |
| 7 | +""" |
| 8 | + |
| 9 | +# based on the sktime module of same name |
| 10 | + |
| 11 | +__author__ = ["fkiraly"] |
| 12 | +# all_objects is based on the sklearn utility all_estimators |
| 13 | + |
| 14 | +from inspect import isclass |
| 15 | +from pathlib import Path |
| 16 | + |
| 17 | +from skbase.lookup import all_objects as _all_objects |
| 18 | + |
| 19 | +from pytorch_forecasting.models.base import _BaseObject |
| 20 | + |
| 21 | + |
| 22 | +def all_objects( |
| 23 | + object_types=None, |
| 24 | + filter_tags=None, |
| 25 | + exclude_objects=None, |
| 26 | + return_names=True, |
| 27 | + as_dataframe=False, |
| 28 | + return_tags=None, |
| 29 | + suppress_import_stdout=True, |
| 30 | +): |
| 31 | + """Get a list of all objects from pytorch_forecasting. |
| 32 | +
|
| 33 | + This function crawls the module and gets all classes that inherit |
| 34 | + from skbase compatible base classes. |
| 35 | +
|
| 36 | + Not included are: the base classes themselves, classes defined in test |
| 37 | + modules. |
| 38 | +
|
| 39 | + Parameters |
| 40 | + ---------- |
| 41 | + object_types: str, list of str, optional (default=None) |
| 42 | + Which kind of objects should be returned. |
| 43 | +
|
| 44 | + * if None, no filter is applied and all objects are returned. |
| 45 | + * if str or list of str, strings define scitypes specified in search |
| 46 | + only objects that are of (at least) one of the scitypes are returned |
| 47 | +
|
| 48 | + return_names: bool, optional (default=True) |
| 49 | +
|
| 50 | + * if True, estimator class name is included in the ``all_objects`` |
| 51 | + return in the order: name, estimator class, optional tags, either as |
| 52 | + a tuple or as pandas.DataFrame columns |
| 53 | + * if False, estimator class name is removed from the ``all_objects`` return. |
| 54 | +
|
| 55 | + filter_tags: dict of (str or list of str or re.Pattern), optional (default=None) |
| 56 | + For a list of valid tag strings, use the registry.all_tags utility. |
| 57 | +
|
| 58 | + ``filter_tags`` subsets the returned objects as follows: |
| 59 | +
|
| 60 | + * each key/value pair is statement in "and"/conjunction |
| 61 | + * key is tag name to sub-set on |
| 62 | + * value str or list of string are tag values |
| 63 | + * condition is "key must be equal to value, or in set(value)" |
| 64 | +
|
| 65 | + In detail, he return will be filtered to keep exactly the classes |
| 66 | + where tags satisfy all the filter conditions specified by ``filter_tags``. |
| 67 | + Filter conditions are as follows, for ``tag_name: search_value`` pairs in |
| 68 | + the ``filter_tags`` dict, applied to a class ``klass``: |
| 69 | +
|
| 70 | + - If ``klass`` does not have a tag with name ``tag_name``, it is excluded. |
| 71 | + Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``. |
| 72 | + - If ``search_value`` is a string, and ``tag_value`` is a string, |
| 73 | + the filter condition is that ``search_value`` must match the tag value. |
| 74 | + - If ``search_value`` is a string, and ``tag_value`` is a list, |
| 75 | + the filter condition is that ``search_value`` is contained in ``tag_value``. |
| 76 | + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string, |
| 77 | + the filter condition is that ``search_value.fullmatch(tag_value)`` |
| 78 | + is true, i.e., the regex matches the tag value. |
| 79 | + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list, |
| 80 | + the filter condition is that at least one element of ``tag_value`` |
| 81 | + matches the regex. |
| 82 | + - If ``search_value`` is iterable, then the filter condition is that |
| 83 | + at least one element of ``search_value`` satisfies the above conditions, |
| 84 | + applied to ``tag_value``. |
| 85 | +
|
| 86 | + Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0. |
| 87 | +
|
| 88 | + exclude_objects: str, list of str, optional (default=None) |
| 89 | + Names of objects to exclude. |
| 90 | +
|
| 91 | + as_dataframe: bool, optional (default=False) |
| 92 | +
|
| 93 | + * True: ``all_objects`` will return a ``pandas.DataFrame`` with named |
| 94 | + columns for all of the attributes being returned. |
| 95 | + * False: ``all_objects`` will return a list (either a list of |
| 96 | + objects or a list of tuples, see Returns) |
| 97 | +
|
| 98 | + return_tags: str or list of str, optional (default=None) |
| 99 | + Names of tags to fetch and return each estimator's value of. |
| 100 | + For a list of valid tag strings, use the ``registry.all_tags`` utility. |
| 101 | + if str or list of str, |
| 102 | + the tag values named in return_tags will be fetched for each |
| 103 | + estimator and will be appended as either columns or tuple entries. |
| 104 | +
|
| 105 | + suppress_import_stdout : bool, optional. Default=True |
| 106 | + whether to suppress stdout printout upon import. |
| 107 | +
|
| 108 | + Returns |
| 109 | + ------- |
| 110 | + all_objects will return one of the following: |
| 111 | +
|
| 112 | + 1. list of objects, if ``return_names=False``, and ``return_tags`` is None |
| 113 | +
|
| 114 | + 2. list of tuples (optional estimator name, class, optional estimator |
| 115 | + tags), if ``return_names=True`` or ``return_tags`` is not ``None``. |
| 116 | +
|
| 117 | + 3. ``pandas.DataFrame`` if ``as_dataframe = True`` |
| 118 | +
|
| 119 | + if list of objects: |
| 120 | + entries are objects matching the query, |
| 121 | + in alphabetical order of estimator name |
| 122 | +
|
| 123 | + if list of tuples: |
| 124 | + list of (optional estimator name, estimator, optional estimator |
| 125 | + tags) matching the query, in alphabetical order of estimator name, |
| 126 | + where |
| 127 | + ``name`` is the estimator name as string, and is an |
| 128 | + optional return |
| 129 | + ``estimator`` is the actual estimator |
| 130 | + ``tags`` are the estimator's values for each tag in return_tags |
| 131 | + and is an optional return. |
| 132 | +
|
| 133 | + if ``DataFrame``: |
| 134 | + column names represent the attributes contained in each column. |
| 135 | + "objects" will be the name of the column of objects, "names" |
| 136 | + will be the name of the column of estimator class names and the string(s) |
| 137 | + passed in return_tags will serve as column names for all columns of |
| 138 | + tags that were optionally requested. |
| 139 | +
|
| 140 | + Examples |
| 141 | + -------- |
| 142 | + >>> from pytorch_forecasting._registry import all_objects |
| 143 | + >>> # return a complete list of objects as pd.Dataframe |
| 144 | + >>> all_objects(as_dataframe=True) # doctest: +SKIP |
| 145 | +
|
| 146 | + References |
| 147 | + ---------- |
| 148 | + Adapted version of sktime's ``all_estimators``, |
| 149 | + which is an evolution of scikit-learn's ``all_estimators`` |
| 150 | + """ |
| 151 | + MODULES_TO_IGNORE = ( |
| 152 | + "tests", |
| 153 | + "setup", |
| 154 | + "contrib", |
| 155 | + "utils", |
| 156 | + "all", |
| 157 | + ) |
| 158 | + |
| 159 | + result = [] |
| 160 | + ROOT = str(Path(__file__).parent.parent) # package root directory |
| 161 | + |
| 162 | + def _coerce_to_str(obj): |
| 163 | + if isinstance(obj, (list, tuple)): |
| 164 | + return [_coerce_to_str(o) for o in obj] |
| 165 | + if isclass(obj): |
| 166 | + obj = obj.get_tag("object_type") |
| 167 | + return obj |
| 168 | + |
| 169 | + def _coerce_to_list_of_str(obj): |
| 170 | + obj = _coerce_to_str(obj) |
| 171 | + if isinstance(obj, str): |
| 172 | + return [obj] |
| 173 | + return obj |
| 174 | + |
| 175 | + if object_types is not None: |
| 176 | + object_types = _coerce_to_list_of_str(object_types) |
| 177 | + object_types = list(set(object_types)) |
| 178 | + |
| 179 | + if object_types is not None: |
| 180 | + if filter_tags is None: |
| 181 | + filter_tags = {} |
| 182 | + elif isinstance(filter_tags, str): |
| 183 | + filter_tags = {filter_tags: True} |
| 184 | + else: |
| 185 | + filter_tags = filter_tags.copy() |
| 186 | + |
| 187 | + if "object_type" in filter_tags: |
| 188 | + obj_field = filter_tags["object_type"] |
| 189 | + obj_field = _coerce_to_list_of_str(obj_field) |
| 190 | + obj_field = obj_field + object_types |
| 191 | + else: |
| 192 | + obj_field = object_types |
| 193 | + |
| 194 | + filter_tags["object_type"] = obj_field |
| 195 | + |
| 196 | + result = _all_objects( |
| 197 | + object_types=[_BaseObject], |
| 198 | + filter_tags=filter_tags, |
| 199 | + exclude_objects=exclude_objects, |
| 200 | + return_names=return_names, |
| 201 | + as_dataframe=as_dataframe, |
| 202 | + return_tags=return_tags, |
| 203 | + suppress_import_stdout=suppress_import_stdout, |
| 204 | + package_name="pytorch_forecasting", |
| 205 | + path=ROOT, |
| 206 | + modules_to_ignore=MODULES_TO_IGNORE, |
| 207 | + ) |
| 208 | + |
| 209 | + return result |
| 210 | + |
| 211 | + |
| 212 | +def _check_list_of_str_or_error(arg_to_check, arg_name): |
| 213 | + """Check that certain arguments are str or list of str. |
| 214 | +
|
| 215 | + Parameters |
| 216 | + ---------- |
| 217 | + arg_to_check: argument we are testing the type of |
| 218 | + arg_name: str, |
| 219 | + name of the argument we are testing, will be added to the error if |
| 220 | + ``arg_to_check`` is not a str or a list of str |
| 221 | +
|
| 222 | + Returns |
| 223 | + ------- |
| 224 | + arg_to_check: list of str, |
| 225 | + if arg_to_check was originally a str it converts it into a list of str |
| 226 | + so that it can be iterated over. |
| 227 | +
|
| 228 | + Raises |
| 229 | + ------ |
| 230 | + TypeError if arg_to_check is not a str or list of str |
| 231 | + """ |
| 232 | + # check that return_tags has the right type: |
| 233 | + if isinstance(arg_to_check, str): |
| 234 | + arg_to_check = [arg_to_check] |
| 235 | + if not isinstance(arg_to_check, list) or not all( |
| 236 | + isinstance(value, str) for value in arg_to_check |
| 237 | + ): |
| 238 | + raise TypeError( |
| 239 | + f"Error in all_objects! Argument {arg_name} must be either\ |
| 240 | + a str or list of str" |
| 241 | + ) |
| 242 | + return arg_to_check |
0 commit comments