Skip to content

Commit 29d67fa

Browse files
FEAT - add fit_intercept support for LBFGS (#326)
1 parent 6a4ee61 commit 29d67fa

File tree

3 files changed

+46
-19
lines changed

3 files changed

+46
-19
lines changed

doc/changes/0.6.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22

33
Version 0.6 (in progress)
44
-------------------------
5+
6+
- :class:`skglm.solvers.LBFGS` now supports fitting an intercept with the `fit_intercept` parameter.

skglm/solvers/lbfgs.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,21 @@ class LBFGS(BaseSolver):
2424
tol : float, default 1e-4
2525
Tolerance for convergence.
2626
27+
fit_intercept : bool, default False
28+
Whether or not to fit an intercept.
29+
2730
verbose : bool, default False
2831
Amount of verbosity. 0/False is silent.
2932
"""
3033

3134
_datafit_required_attr = ("gradient",)
3235
_penalty_required_attr = ("gradient",)
3336

34-
def __init__(self, max_iter=50, tol=1e-4, verbose=False):
37+
def __init__(self, max_iter=50, tol=1e-4, fit_intercept=False, verbose=False):
3538
self.max_iter = max_iter
3639
self.tol = tol
40+
self.fit_intercept = fit_intercept
41+
self.warm_start = False
3742
self.verbose = verbose
3843

3944
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
@@ -46,25 +51,40 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
4651
datafit.initialize(X, y)
4752

4853
def objective(w):
49-
Xw = X @ w
50-
datafit_value = datafit.value(y, w, Xw)
51-
penalty_value = penalty.value(w)
52-
54+
w_features = w[:n_features]
55+
Xw = X @ w_features
56+
if self.fit_intercept:
57+
Xw += w[-1]
58+
datafit_value = datafit.value(y, w_features, Xw)
59+
penalty_value = penalty.value(w_features)
5360
return datafit_value + penalty_value
5461

5562
def d_jac(w):
56-
Xw = X @ w
63+
w_features = w[:n_features]
64+
Xw = X @ w_features
65+
if self.fit_intercept:
66+
Xw += w[-1]
5767
datafit_grad = datafit.gradient(X, y, Xw)
58-
penalty_grad = penalty.gradient(w)
59-
60-
return datafit_grad + penalty_grad
68+
penalty_grad = penalty.gradient(w_features)
69+
if self.fit_intercept:
70+
intercept_grad = datafit.raw_grad(y, Xw).sum()
71+
return np.concatenate([datafit_grad + penalty_grad, [intercept_grad]])
72+
else:
73+
return datafit_grad + penalty_grad
6174

6275
def s_jac(w):
63-
Xw = X @ w
64-
datafit_grad = datafit.gradient_sparse(X.data, X.indptr, X.indices, y, Xw)
65-
penalty_grad = penalty.gradient(w)
66-
67-
return datafit_grad + penalty_grad
76+
w_features = w[:n_features]
77+
Xw = X @ w_features
78+
if self.fit_intercept:
79+
Xw += w[-1]
80+
datafit_grad = datafit.gradient_sparse(
81+
X.data, X.indptr, X.indices, y, Xw)
82+
penalty_grad = penalty.gradient(w_features)
83+
if self.fit_intercept:
84+
intercept_grad = datafit.raw_grad(y, Xw).sum()
85+
return np.concatenate([datafit_grad + penalty_grad, [intercept_grad]])
86+
else:
87+
return datafit_grad + penalty_grad
6888

6989
def callback_post_iter(w_k):
7090
# save p_obj
@@ -81,7 +101,7 @@ def callback_post_iter(w_k):
81101
)
82102

83103
n_features = X.shape[1]
84-
w = np.zeros(n_features) if w_init is None else w_init
104+
w = np.zeros(n_features + self.fit_intercept) if w_init is None else w_init
85105
jac = s_jac if issparse(X) else d_jac
86106
p_objs_out = []
87107

skglm/tests/test_lbfgs_solver.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313

1414
@pytest.mark.parametrize("X_sparse", [True, False])
15-
def test_lbfgs_L2_logreg(X_sparse):
15+
@pytest.mark.parametrize("fit_intercept", [True, False])
16+
def test_lbfgs_L2_logreg(X_sparse, fit_intercept):
1617
reg = 1.0
1718
X_density = 1.0 if not X_sparse else 0.5
1819
n_samples, n_features = 100, 50
@@ -28,17 +29,21 @@ def test_lbfgs_L2_logreg(X_sparse):
2829
# fit L-BFGS
2930
datafit = Logistic()
3031
penalty = L2(reg)
31-
w, *_ = LBFGS(tol=1e-12).solve(X, y, datafit, penalty)
32+
w, *_ = LBFGS(tol=1e-12, fit_intercept=fit_intercept).solve(X, y, datafit, penalty)
3233

3334
# fit scikit learn
3435
estimator = LogisticRegression(
3536
penalty="l2",
3637
C=1 / (n_samples * reg),
37-
fit_intercept=False,
38+
fit_intercept=fit_intercept,
3839
tol=1e-12,
3940
).fit(X, y)
4041

41-
np.testing.assert_allclose(w, estimator.coef_.flatten(), atol=1e-5)
42+
if fit_intercept:
43+
np.testing.assert_allclose(w[:-1], estimator.coef_.flatten(), atol=1e-5)
44+
np.testing.assert_allclose(w[-1], estimator.intercept_[0], atol=1e-5)
45+
else:
46+
np.testing.assert_allclose(w, estimator.coef_.flatten(), atol=1e-5)
4247

4348

4449
@pytest.mark.parametrize("use_efron", [True, False])

0 commit comments

Comments
 (0)