Skip to content

Commit b4e06ae

Browse files
authored
Feat/kpcovr fitted regressor (#113)
* Add shape checking utilities for coefficients of precomputed kernel regressors * Modify instantiation and fit call of KPCovR to accept pre-fitted regressors as in PCovR * Update KPCovR tests to be compatible with new regressor usage * Update PCovR example notebook to be compatible with new regressor usage * Reorganize regressor usage to pull kernel parameters directly from the regressor; use None as the default argument for the regressor * Pull alpha from the KPCovR regressor * Make regressor default argument None, assign default within __init__ * Change inversions to use least squares with singular value cutoff based on tol instead of the regularization * Compute Yhat directly from the dual coefficients * Move regressor checking to occur immediately * Add more details about pre-fitted regressors to PCovR and KPCovR documentation * Use KPCovR tolerance in matrix inversion instead of regularization * Add tests for KPCovR to cover the pre-fitted regressors * Add PCovR test to check for regressor modifications * Move default regressor assignment to fit and accept regressor params * Reorganize KPCovR regressor infrastructure * Make PCovR example compatible with new KPCovR regressor infrastructure * Add PCovR test for None regressor * Modify KPCovR tests for compatibility with new regressor infrastructure * Add KPCovR test for None regressor * Fix KPCovR docstring example * Consolidate regressor checking * Simplify tests for pre-fitted regressors * Negate KPCovR score according to sklearn guidelines
1 parent 9a438e8 commit b4e06ae

File tree

7 files changed

+377
-113
lines changed

7 files changed

+377
-113
lines changed

examples/PCovR.ipynb

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"from skcosmo.decomposition import PCovR\n",
2626
"from sklearn.preprocessing import StandardScaler\n",
2727
"from sklearn.linear_model import Ridge\n",
28+
"from sklearn.kernel_ridge import KernelRidge\n",
2829
"\n",
2930
"cmapX = cm.plasma\n",
3031
"cmapy = cm.Greys"
@@ -182,7 +183,11 @@
182183
"mixing = 0.5\n",
183184
"kpcovr = KernelPCovR(\n",
184185
" mixing=mixing,\n",
185-
" alpha=1e-8,\n",
186+
" regressor=KernelRidge(\n",
187+
" alpha=1e-8,\n",
188+
" kernel=\"rbf\",\n",
189+
" gamma=0.1,\n",
190+
" ),\n",
186191
" kernel=\"rbf\",\n",
187192
" gamma=0.1,\n",
188193
" n_components=2,\n",

skcosmo/decomposition/_kernel_pcovr.py

Lines changed: 115 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from scipy.sparse.linalg import svds
66
from sklearn.decomposition._base import _BasePCA
77
from sklearn.decomposition._pca import _infer_dimension
8+
from sklearn.exceptions import NotFittedError
9+
from sklearn.kernel_ridge import KernelRidge
810
from sklearn.linear_model._base import LinearModel
911
from sklearn.metrics.pairwise import pairwise_kernels
1012
from sklearn.utils import (
@@ -23,7 +25,10 @@
2325
)
2426

2527
from ..preprocessing import KernelNormalizer
26-
from ..utils import pcovr_kernel
28+
from ..utils import (
29+
check_krr_fit,
30+
pcovr_kernel,
31+
)
2732

2833

2934
class KernelPCovR(_BasePCA, LinearModel):
@@ -75,10 +80,18 @@ class KernelPCovR(_BasePCA, LinearModel):
7580
If randomized :
7681
run randomized SVD by the method of Halko et al.
7782
83+
regressor : instance of `sklearn.kernel_ridge.KernelRidge`, default=None
84+
The regressor to use for computing
85+
the property predictions :math:`\\hat{\\mathbf{Y}}`.
86+
A pre-fitted regressor may be provided.
87+
If the regressor is not `None`, its kernel parameters
88+
(`kernel`, `gamma`, `degree`, `coef0`, and `kernel_params`)
89+
must be identical to those passed directly to `KernelPCovR`.
90+
7891
kernel: "linear" | "poly" | "rbf" | "sigmoid" | "cosine" | "precomputed"
7992
Kernel. Default="linear".
8093
81-
gamma: float, default=1/n_features
94+
gamma: float, default=None
8295
Kernel coefficient for rbf, poly and sigmoid kernels. Ignored by other
8396
kernels.
8497
@@ -96,15 +109,13 @@ class KernelPCovR(_BasePCA, LinearModel):
96109
center: bool, default=False
97110
Whether to center any computed kernels
98111
99-
alpha: float, default=1E-6
100-
Regularization parameter to use in all regression operations.
101-
102112
fit_inverse_transform: bool, default=False
103113
Learn the inverse transform for non-precomputed kernels.
104114
(i.e. learn to find the pre-image of a point)
105115
106116
tol: float, default=1e-12
107-
Tolerance for singular values computed by svd_solver == 'arpack'.
117+
Tolerance for singular values computed by svd_solver == 'arpack'
118+
and for matrix inversions.
108119
Must be of range [0.0, infinity).
109120
110121
n_jobs: int, default=None
@@ -121,6 +132,9 @@ class KernelPCovR(_BasePCA, LinearModel):
121132
Used when the 'arpack' or 'randomized' solvers are used. Pass an int
122133
for reproducible results across multiple function calls.
123134
135+
**regressor_params: additional keyword arguments to be passed
136+
to the regressor. Ignored if `regressor` is not `None`.
137+
124138
125139
Attributes
126140
----------
@@ -154,53 +168,53 @@ class KernelPCovR(_BasePCA, LinearModel):
154168
>>> import numpy as np
155169
>>> from skcosmo.decomposition import KernelPCovR
156170
>>> from skcosmo.preprocessing import StandardFlexibleScaler as SFS
171+
>>> from sklearn.kernel_ridge import KernelRidge
157172
>>>
158173
>>> X = np.array([[-1, 1, -3, 1], [1, -2, 1, 2], [-2, 0, -2, -2], [1, 0, 2, -1]])
159174
>>> X = SFS().fit_transform(X)
160175
>>> Y = np.array([[ 0, -5], [-1, 1], [1, -5], [-3, 2]])
161176
>>> Y = SFS(column_wise=True).fit_transform(Y)
162177
>>>
163-
>>> kpcovr = KernelPCovR(mixing=0.1, n_components=2, kernel='rbf', gamma=2)
178+
>>> kpcovr = KernelPCovR(mixing=0.1, n_components=2, regressor=KernelRidge(kernel='rbf', gamma=1), kernel='rbf', gamma=1)
164179
>>> kpcovr.fit(X, Y)
165-
KernelPCovR(coef0=1, degree=3, fit_inverse_transform=False, gamma=0.01, kernel='rbf',
166-
kernel_params=None, mixing=None, n_components=2, n_jobs=None,
167-
alpha=None, tol=1e-12)
180+
KernelPCovR(gamma=1, kernel='rbf', mixing=0.1, n_components=2,
181+
regressor=KernelRidge(gamma=1, kernel='rbf'))
168182
>>> T = kpcovr.transform(X)
169-
[[ 1.01199065, -0.35439061],
170-
[-0.68099591, 0.48912275],
171-
[ 1.4677616 , 0.13757037],
172-
[-1.79874193, -0.27232032]]
183+
[[-0.61261285, -0.18937908],
184+
[ 0.45242098, 0.25453465],
185+
[-0.77871824, 0.04847559],
186+
[ 0.91186937, -0.21211816]]
173187
>>> Yp = kpcovr.predict(X)
174-
[[-0.01044648, -0.84443158],
175-
[-0.1758848 , 0.16224503],
176-
[ 0.1573037 , -0.84211944],
177-
[-0.51133139, 0.32552881]]
188+
[[ 0.5100212 , -0.99488463],
189+
[-0.18992219, 0.82064368],
190+
[ 1.11923584, -1.04798016],
191+
[-1.5635827 , 1.11078662]]
178192
>>> kpcovr.score(X, Y)
179-
(0.5312320029915978, 0.06254540655698511)
193+
-0.520388347837897
180194
"""
181195

182196
def __init__(
183197
self,
184198
mixing=0.5,
185199
n_components=None,
186200
svd_solver="auto",
201+
regressor=None,
187202
kernel="linear",
188203
gamma=None,
189204
degree=3,
190205
coef0=1,
191-
alpha=1e-6,
192206
kernel_params=None,
193207
center=False,
194208
fit_inverse_transform=False,
195209
tol=1e-12,
196210
n_jobs=None,
197211
iterated_power="auto",
198212
random_state=None,
213+
**regressor_params
199214
):
200215

201216
self.mixing = mixing
202217
self.n_components = n_components
203-
self.alpha = alpha
204218

205219
self.svd_solver = svd_solver
206220
self.tol = tol
@@ -209,15 +223,19 @@ def __init__(
209223
self.center = center
210224

211225
self.kernel = kernel
212-
self.kernel_params = kernel_params
213226
self.gamma = gamma
214227
self.degree = degree
215228
self.coef0 = coef0
229+
self.kernel_params = kernel_params
230+
216231
self.n_jobs = n_jobs
217232
self.n_samples_ = None
218233

219234
self.fit_inverse_transform = fit_inverse_transform
220235

236+
self.regressor = regressor
237+
self.regressor_params = regressor_params
238+
221239
def _get_kernel(self, X, Y=None):
222240
if callable(self.kernel):
223241
params = self.kernel_params or {}
@@ -252,9 +270,9 @@ def _fit(self, K, Yhat, W):
252270
self.pkt_ = P @ U @ np.sqrt(np.diagflat(S_inv))
253271

254272
T = K @ self.pkt_
255-
self.pt__ = np.linalg.lstsq(T, np.eye(T.shape[0]), rcond=self.alpha)[0]
273+
self.pt__ = np.linalg.lstsq(T, np.eye(T.shape[0]), rcond=self.tol)[0]
256274

257-
def fit(self, X, Y, Yhat=None, W=None):
275+
def fit(self, X, Y):
258276
"""
259277
260278
Fit the model with X and Y.
@@ -279,18 +297,16 @@ def fit(self, X, Y, Yhat=None, W=None):
279297
to have unit variance, otherwise :math:`\\mathbf{Y}` should be
280298
scaled so that each feature has a variance of 1 / n_features.
281299
282-
Yhat: ndarray, shape (n_samples, n_properties), optional
283-
Regressed training data, where n_samples is the number of samples and
284-
n_properties is the number of properties. If not supplied, computed
285-
by ridge regression.
286-
287300
Returns
288301
-------
289302
self: object
290303
Returns the instance itself.
291304
292305
"""
293306

307+
if self.regressor is not None and not isinstance(self.regressor, KernelRidge):
308+
raise ValueError("Regressor must be an instance of `KernelRidge`")
309+
294310
X, Y = check_X_y(X, Y, y_numeric=True, multi_output=True)
295311
self.X_fit_ = X.copy()
296312

@@ -308,14 +324,66 @@ def fit(self, X, Y, Yhat=None, W=None):
308324

309325
self.n_samples_ = X.shape[0]
310326

311-
if W is None:
312-
if Yhat is None:
313-
W = (np.linalg.lstsq(K, Y, rcond=self.alpha)[0]).reshape(X.shape[0], -1)
314-
else:
315-
W = np.linalg.lstsq(K, Yhat, rcond=self.alpha)[0]
327+
if self.regressor is None:
328+
regressor = KernelRidge(
329+
kernel=self.kernel,
330+
gamma=self.gamma,
331+
degree=self.degree,
332+
coef0=self.coef0,
333+
kernel_params=self.kernel_params,
334+
**self.regressor_params,
335+
)
336+
else:
337+
regressor = self.regressor
338+
kernel_attrs = ["kernel", "gamma", "degree", "coef0", "kernel_params"]
339+
if not all(
340+
[
341+
getattr(self, attr) == getattr(regressor, attr)
342+
for attr in kernel_attrs
343+
]
344+
):
345+
raise ValueError(
346+
"Kernel parameter mismatch: the regressor has kernel parameters {%s}"
347+
" and KernelPCovR was initialized with kernel parameters {%s}"
348+
% (
349+
", ".join(
350+
[
351+
"%s: %r" % (attr, getattr(regressor, attr))
352+
for attr in kernel_attrs
353+
]
354+
),
355+
", ".join(
356+
[
357+
"%s: %r" % (attr, getattr(self, attr))
358+
for attr in kernel_attrs
359+
]
360+
),
361+
)
362+
)
363+
364+
# Check if regressor is fitted; if not, fit with precomputed K
365+
# to avoid needing to compute the kernel a second time
366+
self.regressor_ = check_krr_fit(regressor, K, X, Y)
316367

317-
if Yhat is None:
318-
Yhat = K @ W
368+
W = self.regressor_.dual_coef_.reshape(X.shape[0], -1)
369+
370+
# Use this instead of `self.regressor_.predict(K)`
371+
# so that we can handle the case of the pre-fitted regressor
372+
Yhat = K @ W
373+
374+
# When we have an unfitted regressor,
375+
# we fit it with a precomputed K
376+
# so we must subsequently "reset" it so that
377+
# it will work on the particular X
378+
# of the KPCovR call. The dual coefficients are kept.
379+
# Can be bypassed if the regressor is pre-fitted.
380+
try:
381+
check_is_fitted(regressor)
382+
383+
except NotFittedError:
384+
self.regressor_.set_params(**regressor.get_params())
385+
self.regressor_.X_fit_ = self.X_fit_
386+
self.regressor_._check_n_features(self.X_fit_, reset=True)
319387

320388
# Handle svd_solver
321389
self._fit_svd_solver = self.svd_solver
@@ -408,7 +476,7 @@ def inverse_transform(self, T):
408476

409477
def score(self, X, Y):
410478
r"""
411-
Computes the loss values for KernelPCovR on the given predictor and
479+
Computes the (negative) loss values for KernelPCovR on the given predictor and
412480
response variables. The loss in :math:`\mathbf{K}`, as explained in
413481
[Helfrecht2020]_ does not correspond to a traditional Gram loss
414482
:math:`\mathbf{K} - \mathbf{TT}^T`. Indicating the kernel between set
@@ -424,15 +492,17 @@ def score(self, X, Y):
424492
\mathbf{K}_{NN} \mathbf{T}_N (\mathbf{T}_N^T \mathbf{T}_N)^{-1}
425493
\mathbf{T}_V^T\right]}{\operatorname{Tr}(\mathbf{K}_{VV})}
426494
495+
The negative loss is returned for easier use in sklearn pipelines, e.g., a grid search, where methods named 'score' are meant to be maximized.
496+
427497
Arguments
428498
---------
429499
X: independent (predictor) variable
430500
Y: dependent (response) variable
431501
432502
Returns
433503
-------
434-
Lk: KPCA loss, determined by the reconstruction of the kernel
435-
Ly: KR loss
504+
L: Negative sum of the KPCA and KRR losses, with the KPCA loss
505+
determined by the reconstruction of the kernel
436506
437507
"""
438508

@@ -455,10 +525,14 @@ def score(self, X, Y):
455525
t_n = K_NN @ self.pkt_
456526
t_v = K_VN @ self.pkt_
457527

458-
w = t_n @ np.linalg.pinv(t_n.T @ t_n, rcond=self.alpha) @ t_v.T
528+
w = (
529+
t_n
530+
@ np.linalg.lstsq(t_n.T @ t_n, np.eye(t_n.shape[1]), rcond=self.tol)[0]
531+
@ t_v.T
532+
)
459533
Lkpca = np.trace(K_VV - 2 * K_VN @ w + w.T @ K_VV @ w) / np.trace(K_VV)
460534

461-
return sum([Lkpca, Lkrr])
535+
return -sum([Lkpca, Lkrr])
462536

463537
def _decompose_truncated(self, mat):
464538

0 commit comments

Comments
 (0)