-
Notifications
You must be signed in to change notification settings - Fork 289
PySRSequenceRegressor #677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wenbang24
wants to merge
198
commits into
MilesCranmer:master
Choose a base branch
from
wenbang24:recurrence
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
198 commits
Select commit
Hold shift + click to select a range
b9865b0
Added recursive_history_length
wenbang24 1ce85fe
Changed minimum recursive history length
wenbang24 f46e4d2
fixed syntax error
wenbang24 0cbdcda
Added recurrence functionality
wenbang24 7172cc7
Removed a debug print, also formatted
wenbang24 ff1944c
new PySRSequenceRegressor class!
wenbang24 4072729
made recursive_history_length not optional
wenbang24 8ab384a
added tests for PySRSequenceRegressor
wenbang24 cffcc91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e91dc68
changed __init__ of PySRSequenceRegressor to use PySRRegressor's __in…
wenbang24 f0d6ecf
fixed bug that removed first data point
wenbang24 7db3df8
added sequence to test names to make things a bit clearer
wenbang24 506f4f5
added .eggs to .gitignore
wenbang24 d556d11
Merge branch 'recurrence' into wenbang24/issue94
wenbang24 fc86554
Merge pull request #1 from wenbang24/wenbang24/issue94
wenbang24 962e0a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 87ad4d9
updated docstring for PySRSequenceRegressor
wenbang24 75ea04d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 365e663
updated fit()
wenbang24 ed61b3c
multidimensionality!!!
wenbang24 f10b6ca
fixed variable names for multidimensionality
wenbang24 518e7d8
fixed variable names
wenbang24 510d5d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] cbc9105
git is hard
wenbang24 8f0c730
added new preprocessing to predict
wenbang24 c3d5aa5
Merge branch 'recurrence' of github.com:wenbang24/PySR into recurrence
wenbang24 908fdfc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 666b2a8
removed unecessary test
wenbang24 f78cb44
small documentaion change
wenbang24 d8b8245
didn't the last commit work?
wenbang24 c6b67d6
another small doc change
wenbang24 a4f607e
ok the preprocessing ACTUALLY works now
wenbang24 9927660
fixed custom variable names
wenbang24 de74c47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 62cf992
updated tests
wenbang24 27c73e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b3e303e
all tests passing!!!
wenbang24 92f2a40
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 94eabe3
refactor: moved PySRSequenceRegressor to ssr.py
wenbang24 aca4672
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 089c565
slight change in multidimensional data error test
wenbang24 b47172c
made the type checking work, added tests for unused variables
wenbang24 a3c63c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ba1ce25
yeah the override decorator wasn't needed
wenbang24 a9019a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7bdd41a
this test is too harsh
wenbang24 a88e725
swap super.__init__ and other thing in __init__
wenbang24 aa42888
now needs numpy 1.20
wenbang24 307261b
changed predict docstring
wenbang24 14331e1
change test sequence name
wenbang24 9eb54b3
change multidimensional data error test
wenbang24 df26237
renamed ssr.py to regressor_sequence.py
wenbang24 bca7cc9
moved assertions to new function, fixed error in variable name genera…
wenbang24 34e449b
removed need for temp
wenbang24 38edd3d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 96ea494
changed variable names and made target generation a bit more efficient
wenbang24 1a29161
made predict use _check_assertions as well
wenbang24 854edc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5213cc8
moved variable name generation to new function and also added variabl…
wenbang24 925783c
remove unnecessary test
wenbang24 1a45b90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d0ff80d
made unused variables throw errors
wenbang24 b449772
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3dc1fd4
changed check_assertions to have no return value
wenbang24 709ff56
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ff126c2
removed unecessary tests
wenbang24 484c836
changed up variable names
wenbang24 518df7d
fixed up check_assertions in predict
wenbang24 31c6bb6
fixed bug in variable name generation
wenbang24 12d3e82
fixed bug in assertion checking in predict
wenbang24 edab12f
changed up variable names
wenbang24 0354818
changed name of variable name generator
wenbang24 03beb0e
changed variable name generator to take n_features
wenbang24 8e4b664
Updated docstring
wenbang24 e170496
added validation of X
wenbang24 10f26c7
added another validation
wenbang24 ff75b07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e5ed5c5
added doc string for PySRSymbolicRegressor
wenbang24 be64ee3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ceb58ff
changed variable name generation function to isinstance rather than t…
wenbang24 e3d19f2
think i fixed predicting shape
wenbang24 9721b8f
changed PySRSequenceRegressor to inherit from BaseEstimator and have …
wenbang24 eb796c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8549d8c
padding with NaNs does not work
wenbang24 a450abb
made extrapolating when prredicting work :)
wenbang24 15be253
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4656585
updated dosctring to have extra_preditciotns
wenbang24 12ba70f
tried delegation but this doesn't work
wenbang24 c0ec717
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] de14be6
forwarded all methods and properties from PySRRegressor to PySRSequec…
wenbang24 4c362f9
finally fixed the bug in variable name test
wenbang24 433f783
removed unecessary variables
wenbang24 ef26021
updated docstring
wenbang24 8dcd625
added __getstate__ to PySRSequenceRegressor
wenbang24 7846199
update docstring
wenbang24 b3a45e8
added super().__init()
wenbang24 1eda1a3
removed n_features_in
wenbang24 a2d2821
example.py is back
wenbang24 feab3e4
variable names for unnamed 1D sequences work now
wenbang24 bad23cb
latex table no longer says y, but xt_0 (or whatever the variable name…
wenbang24 57233b3
added docstring for complexity of variables
wenbang24 a35e136
grammar is hard
wenbang24 b037e86
remove unused imports
wenbang24 d08b77b
uncomment tests
wenbang24 d7f15f8
newlines??
wenbang24 cd04146
put a comment back
wenbang24 64d30b8
new inherits
wenbang24 164a399
remove y units
wenbang24 6ec0d2a
remove show pickel warnings
wenbang24 208de79
removed a lot of unneccesary properties
wenbang24 b52a221
removed __getstate__
wenbang24 9c6326c
removed some commmented out code
wenbang24 59555eb
since recursive history length >= 1, we don't need to test for len(X)…
wenbang24 ba3b538
more tests
wenbang24 4de8ec9
update dosctring
wenbang24 5bd6898
more tests
wenbang24 5db1096
more tests
wenbang24 029633a
Merge branch 'master' into recurrence
wenbang24 2341c74
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7ae96a2
ok i think from_file works now
wenbang24 1f9013f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 427ac6b
so i think we need to pickle recursive history length :(
wenbang24 263912e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1b79dd9
remove debug print
wenbang24 16cad6e
ok fromfile actually works now
wenbang24 689d717
added weight test
wenbang24 4a1f17c
added feature name and selection mask tests
wenbang24 1ef2fc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1fe3fe3
add recursive history length parameter
wenbang24 39cab90
fixed typing for variable names in fit
wenbang24 a734b48
removed julia properties
wenbang24 bff3ef3
changed from_file to use PySRRegressor's from_file
wenbang24 4c5e9e9
removed uncessary line
wenbang24 0459cf4
moved args???
wenbang24 424483b
a lot of **kwargs
wenbang24 3b1acb2
removed super().__init__()
wenbang24 06ee369
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9b5a2c0
updated numpy to 1.20 in environment.yml
wenbang24 0e55024
changed __repr__ to only change first instance
wenbang24 1632e89
updated latex_table to use PySRRegressor's latex_table, also added ou…
wenbang24 43dd918
removed MultiOutputMixin, RegressorMixin
wenbang24 6e7166d
fixed bug with output_variable_names
wenbang24 e77d56f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1bd482c
Merge branch 'master' into recurrence
wenbang24 fa1e2f2
add back super().__init__()
wenbang24 a569b2e
stars and stuff
wenbang24 8ab243a
another star
wenbang24 9010fbe
remove unused imports
wenbang24 f395db4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a745b19
refactor: moved np.lib.stride_tricks.sliding_window_view out to anoth…
wenbang24 102e453
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e75a8bc
Merge branch 'master' into recurrence
wenbang24 ef03dff
update doctoring
wenbang24 8e29d92
update docstring for weights
wenbang24 5aff022
rewrote extra predictions in predict() to use num_predictions
wenbang24 92287b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c854788
fixed predicting and changed up tests
wenbang24 6eb7920
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5407157
removed comments
wenbang24 f634dae
updated docstring
wenbang24 dbcc918
updated docstring to remove y
wenbang24 378f05b
Update docstring formatting
wenbang24 a4a2fda
added a lot of *args
wenbang24 5d7d4e6
Update predict docstring to add num_predictions times
wenbang24 206ee98
fix error in predict docstring
wenbang24 0a926ac
remove default in docstring
wenbang24 56a7780
add default to predict docstring somewhere else
wenbang24 96c7a35
Merge branch 'master' into recurrence
wenbang24 d3d4fd0
added args to latex_table
wenbang24 0c9515e
changed variable name format
wenbang24 66542ff
feat: allow 1D input
MilesCranmer 63e7785
Update pysr/regressor_sequence.py
MilesCranmer 516e4e6
feat: pretty-print sequence index
MilesCranmer 163128e
sequence example
wenbang24 9299df4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6fba3d7
update sequence example
wenbang24 6cee114
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 055bec6
added markdown example docs
wenbang24 9ee3018
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 43049e1
removed example sequence file
wenbang24 93337d3
removed some miskates in examples sequence
wenbang24 e842784
updated example sequence
wenbang24 d984850
moved sequence examples to examples.md
wenbang24 25d9905
updated latex_table to use _t instead of _{t-0}
wenbang24 5c360b0
updated latex_table docstring to refer to PySRRegressor.latex_table
wenbang24 ec5d941
updated examples to not have X = np.array(X)
wenbang24 84acd0c
updated docs to use latex
wenbang24 2d1becb
removed warning if num_predictions < len(historical_X)
wenbang24 371e0ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] abe556e
updated tests for new variable names
wenbang24 e23ccfd
missed a few
wenbang24 8c285e9
ok all working now
wenbang24 9263837
whoops forgot to remove commetns
wenbang24 7f16257
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fc6eaf1
fixed a test in TestDimensionalConstraints
wenbang24 5b4993b
removed unecessary print
wenbang24 4d69542
fix typing with a cast
wenbang24 9aecee8
Merge branch 'master' into recurrence
wenbang24 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,3 +25,4 @@ site | |
| venv | ||
| requirements-dev.lock | ||
| requirements.lock | ||
| .eggs/ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,298 @@ | ||
| from typing import List, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| from sklearn.base import BaseEstimator | ||
|
|
||
| from .sr import PySRRegressor | ||
| from .utils import ArrayLike, _subscriptify | ||
|
|
||
|
|
||
| def _check_assertions( | ||
| X, | ||
| recursive_history_length=None, | ||
| weights=None, | ||
| variable_names=None, | ||
| X_units=None, | ||
| ): | ||
| if recursive_history_length is not None and recursive_history_length <= 0: | ||
| raise ValueError( | ||
| "The `recursive_history_length` parameter must be greater than 0 (otherwise it's not recursion)." | ||
| ) | ||
| if len(X.shape) > 2: | ||
| raise ValueError( | ||
| "Recursive symbolic regression only supports up to 2D data; please flatten your data first" | ||
| ) | ||
| if len(X) <= recursive_history_length + 1: | ||
| raise ValueError( | ||
| f"Recursive symbolic regression with a history length of {recursive_history_length} requires at least {recursive_history_length + 2} datapoints." | ||
| ) | ||
| if isinstance(weights, np.ndarray) and len(weights) != len(X): | ||
| raise ValueError("The length of `weights` must have shape (n_times,).") | ||
| if isinstance(variable_names, list) and len(variable_names) != X.shape[1]: | ||
| raise ValueError( | ||
| "The length of `variable_names` must be equal to the number of features in `X`." | ||
| ) | ||
| if isinstance(X_units, list) and len(X_units) != X.shape[1]: | ||
| raise ValueError( | ||
| "The length of `X_units` must be equal to the number of features in `X`." | ||
| ) | ||
|
|
||
|
|
||
| class PySRSequenceRegressor(BaseEstimator): | ||
| """ | ||
| High performance symbolic regression for recurrent sequences. | ||
| Based off of the `PySRRegressor` class, but with a preprocessing step for recurrence relations. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| recursive_history_length : int | ||
| The number of previous time points to use as input features. | ||
| For example, if `recursive_history_length=2`, then the input features | ||
| will be `[X[0], X[1]]` and the output will be `X[2]`. | ||
| This continues on for all X: [X[n-1], X[n-2]] to predict X[n]. | ||
| Must be greater than 0. | ||
| Other parameters and attributes are inherited from `PySRRegressor`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| recursive_history_length: int = 0, | ||
| **kwargs, | ||
| ): | ||
| super().__init__() | ||
wenbang24 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self._regressor = PySRRegressor(**kwargs) | ||
| self.recursive_history_length = recursive_history_length | ||
|
|
||
| def _construct_variable_names( | ||
| self, n_features: int, variable_names: Optional[List[str]] | ||
| ) -> Tuple[List[str], List[str]]: | ||
| if not isinstance(variable_names, list): | ||
| if n_features == 1: | ||
| variable_names = ["x"] | ||
| display_variable_names = ["x"] | ||
| else: | ||
| variable_names = [f"x{i}" for i in range(n_features)] | ||
| display_variable_names = [ | ||
| f"x{_subscriptify(i)}" for i in range(n_features) | ||
| ] | ||
| else: | ||
| display_variable_names = variable_names | ||
|
|
||
| # e.g., `x0_tm1` | ||
| variable_names_with_time = [ | ||
| f"{var}_tm{j}" | ||
| for j in range(self.recursive_history_length, 0, -1) | ||
| for var in variable_names | ||
| ] | ||
| # e.g., `x₀[t-1]` | ||
| display_variable_names_with_time = [ | ||
| f"{var}[t-{j}]" | ||
| for j in range(self.recursive_history_length, 0, -1) | ||
| for var in display_variable_names | ||
| ] | ||
|
|
||
| return variable_names_with_time, display_variable_names_with_time | ||
|
|
||
| def fit( | ||
MilesCranmer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| X, | ||
| *, | ||
| weights=None, | ||
| variable_names: Optional[List[str]] = None, | ||
| complexity_of_variables: Optional[ | ||
| Union[int, float, List[Union[int, float]]] | ||
| ] = None, | ||
| X_units: Optional[ArrayLike[str]] = None, | ||
| ) -> "PySRSequenceRegressor": | ||
| """ | ||
| Search for equations to fit the sequence and store them in `self.equations_`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X : ndarray | pandas.DataFrame | ||
| Sequence of shape (n_times, n_features) or (n_times,) | ||
| weights : ndarray | pandas.DataFrame | ||
| Weight array of the same shape as `X`. | ||
| Each element is how to weight the mean-square-error loss | ||
| for that particular element of `X`. Alternatively, | ||
| if a custom `loss` was set, it can be used | ||
| in custom ways. | ||
| variable_names : list[str] | ||
| A list of names for the variables, rather than "x0t_1", "x1t_2", etc. | ||
| If `X` is a pandas dataframe, the column name will be used | ||
| instead of `variable_names`. Cannot contain spaces or special | ||
| characters. Avoid variable names which are also | ||
| function names in `sympy`, such as "N". | ||
| The number of variable names must be equal to (n_features,). | ||
wenbang24 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| complexity_of_variables : int | float | list[int] | list[float] | ||
| The complexity of each variable in `X`. If a single value is | ||
| passed, it will be used for all variables. If a list is passed, | ||
| its length must be the same as `recurrence_history_length`. | ||
| X_units : list[str] | ||
| A list of units for each variable in `X`. Each unit should be | ||
| a string representing a Julia expression. See DynamicQuantities.jl | ||
| https://symbolicml.org/DynamicQuantities.jl/dev/units/ for more | ||
| information. | ||
| Length should be equal to n_features. | ||
|
|
||
| Returns | ||
| ------- | ||
| self : object | ||
| Fitted estimator. | ||
| """ | ||
| X = self._validate_data(X, ensure_2d=False) | ||
| if X.ndim == 1: | ||
| X = X.reshape(-1, 1) | ||
| assert X.ndim == 2 | ||
| _check_assertions( | ||
| X, | ||
| self.recursive_history_length, | ||
| weights, | ||
| variable_names, | ||
| X_units, | ||
| ) | ||
| self.variable_names = variable_names # for latex_table() | ||
| self.n_features = X.shape[1] # for latex_table() | ||
|
|
||
| current_X = X[self.recursive_history_length :] | ||
| historical_X = self._sliding_window(X)[: -1 : current_X.shape[1], :] | ||
| y_units = X_units | ||
| if isinstance(weights, np.ndarray): | ||
| weights = weights[self.recursive_history_length :] | ||
| variable_names, display_variable_names = self._construct_variable_names( | ||
| current_X.shape[1], variable_names | ||
| ) | ||
|
|
||
| self._regressor.fit( | ||
| X=historical_X, | ||
| y=current_X, | ||
| weights=weights, | ||
| variable_names=variable_names, | ||
| display_variable_names=display_variable_names, | ||
| X_units=X_units, | ||
| y_units=y_units, | ||
| complexity_of_variables=complexity_of_variables, | ||
| ) | ||
| return self | ||
|
|
||
| def predict(self, X, index=None, num_predictions=1): | ||
| """ | ||
| Predict future data from input X using the equation chosen by `model_selection`. | ||
|
|
||
| You may see what equation is used by printing this object. X should | ||
| have the same columns as the training data. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X : ndarray | pandas.DataFrame | ||
| Data of shape `(n_times, n_features)`. | ||
| index : int | list[int] | ||
| If you want to compute the output of an expression using a | ||
| particular row of `self.equations_`, you may specify the index here. | ||
| For multiple output equations, you must pass a list of indices | ||
| in the same order. | ||
| num_predictions : int | ||
| How many predictions to make. If `num_predictions` is less than | ||
| `(n_times - recursive_history_length + 1)`, | ||
| some input data at the end will be ignored. | ||
wenbang24 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Default is `1`. | ||
|
|
||
| Returns | ||
| ------- | ||
| x_predicted : ndarray of shape (num_predictions, n_features) | ||
| Values predicted by substituting `X` into the fitted sequence symbolic | ||
| regression model and rolling it out for `num_predictions` steps. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| Raises if the `best_equation` cannot be evaluated. | ||
| """ | ||
| X = self._validate_data(X, ensure_2d=False) | ||
| if X.ndim == 1: | ||
| X = X.reshape(-1, 1) | ||
| assert X.ndim == 2 | ||
| _check_assertions(X, recursive_history_length=self.recursive_history_length) | ||
wenbang24 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| historical_X = self._sliding_window(X)[:: X.shape[1], :] | ||
| if num_predictions < 1: | ||
| raise ValueError("num_predictions must be greater than 0.") | ||
| if num_predictions < len(historical_X): | ||
| historical_X = historical_X[:num_predictions] | ||
| return self._regressor.predict(X=historical_X, index=index) | ||
| else: | ||
| extra_predictions = num_predictions - len(historical_X) | ||
| pred = self._regressor.predict(X=historical_X, index=index) | ||
| for _ in range(extra_predictions): | ||
| pred_data = [pred[-self.recursive_history_length :].flatten()] | ||
| pred = np.concatenate( | ||
| [pred, self._regressor.predict(X=pred_data, index=index)], axis=0 | ||
| ) | ||
| return pred | ||
|
|
||
| def _sliding_window(self, X): | ||
| return np.lib.stride_tricks.sliding_window_view( | ||
| X.flatten(), self.recursive_history_length * np.prod(X.shape[1]) | ||
| ) | ||
|
|
||
| @classmethod | ||
| def from_file( | ||
wenbang24 marked this conversation as resolved.
Show resolved
Hide resolved
wenbang24 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| cls, | ||
| *args, | ||
| recursive_history_length: int, | ||
| **kwargs, | ||
| ): | ||
| assert recursive_history_length is not None and recursive_history_length > 0 | ||
|
|
||
| model = cls(recursive_history_length=recursive_history_length) | ||
| model._regressor = PySRRegressor.from_file(*args, **kwargs) | ||
| return model | ||
|
|
||
| def __repr__(self): | ||
| return self._regressor.__repr__().replace( | ||
| "PySRRegressor", "PySRSequenceRegressor", 1 | ||
| ) | ||
|
|
||
| def get_best(self, *args, **kwargs): | ||
| return self._regressor.get_best(*args, **kwargs) | ||
|
|
||
| def refresh(self, *args, **kwargs): | ||
| return self._regressor.refresh(*args, **kwargs) | ||
|
|
||
| def sympy(self, *args, **kwargs): | ||
| return self._regressor.sympy(*args, **kwargs) | ||
|
|
||
| def latex(self, *args, **kwargs): | ||
| return self._regressor.latex(*args, **kwargs) | ||
|
|
||
| def get_hof(self): | ||
| return self._regressor.get_hof() | ||
|
|
||
| def latex_table( | ||
| self, | ||
| *args, | ||
| **kwargs, | ||
wenbang24 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ): | ||
| """ | ||
| Generates LaTeX variable names, then creates a LaTeX table of the best equation(s). | ||
| Refer to `PySRRegressor.latex_table` for information. | ||
| """ | ||
| if self.variable_names is not None: | ||
| if len(self.variable_names) == 1: | ||
| variable_names = self.variable_names[0] + "_{tm}" | ||
| else: | ||
| variable_names = [ | ||
| variable_name + "_{tm}" for variable_name in self.variable_names | ||
| ] | ||
| else: | ||
| if self.n_features == 1: | ||
| variable_names = "x_{tm}" | ||
| else: | ||
| variable_names = [f"x_{{{i} tm}}" for i in range(self.n_features)] | ||
| return self._regressor.latex_table( | ||
| *args, **kwargs, output_variable_names=variable_names | ||
| ) | ||
|
|
||
| @property | ||
| def equations_(self): | ||
| return self._regressor.equations_ | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.