-
Notifications
You must be signed in to change notification settings - Fork 13.2k
Closed
Description
Description
Chapter 4:
Notebook name: 04_training_linear_models.ipynb
, Cell 47
, line 11
:
sgd_reg = SGDRegressor(max_iter=1, tol=-np.infty, warm_start=True,
penalty=None, learning_rate="constant", eta0=0.0005, random_state=42)
tol = -np.infty
is not defined according to the documentation. It can only take the values None
or [0, np.infty)
.
Scikit-learn version 1.2.2
To Reproduce
from sklearn.linear_model import SGDRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import mean_squared_error
from copy import deepcopy
np.random.seed(42)
m = 100
X = 6 * np.random.rand(m, 1) - 3
y = 2 + X + 0.5 * X**2 + np.random.randn(m, 1)
X_train, X_val, y_train, y_val = train_test_split(X[:50], y[:50].ravel(), test_size=0.5, random_state=10)
poly_scaler = Pipeline([
("poly_features", PolynomialFeatures(degree=90, include_bias=False)),
("std_scaler", StandardScaler())
])
X_train_poly_scaled = poly_scaler.fit_transform(X_train)
X_val_poly_scaled = poly_scaler.transform(X_val)
sgd_reg = SGDRegressor(max_iter=1, tol=-np.infty, warm_start=True,
penalty=None, learning_rate="constant", eta0=0.0005, random_state=42)
minimum_val_error = float("inf")
best_epoch = None
best_model = None
for epoch in range(1000):
sgd_reg.fit(X_train_poly_scaled, y_train) # continues where it left off
y_val_predict = sgd_reg.predict(X_val_poly_scaled)
val_error = mean_squared_error(y_val, y_val_predict)
if val_error < minimum_val_error:
minimum_val_error = val_error
best_epoch = epoch
best_model = deepcopy(sgd_reg)
Error:
InvalidParameterError: The 'tol' parameter of SGDRegressor must be a float in the range [0, inf) or None. Got -inf instead.
Expected behavior
Since we are doing the early stopping manually, i think the tolerance should be None
i.e. in a way saying that "don't use tolerance". I am new to ML and i am only guessing this.
Versions:
- OS: Windows 10.0.22621
- Python: 3.10.11
- Scikit-Learn: 1.2.2
Metadata
Metadata
Assignees
Labels
No labels