55from scipy .sparse .linalg import svds
66from sklearn .decomposition ._base import _BasePCA
77from sklearn .decomposition ._pca import _infer_dimension
8+ from sklearn .exceptions import NotFittedError
9+ from sklearn .kernel_ridge import KernelRidge
810from sklearn .linear_model ._base import LinearModel
911from sklearn .metrics .pairwise import pairwise_kernels
1012from sklearn .utils import (
2325)
2426
2527from ..preprocessing import KernelNormalizer
26- from ..utils import pcovr_kernel
28+ from ..utils import (
29+ check_krr_fit ,
30+ pcovr_kernel ,
31+ )
2732
2833
2934class KernelPCovR (_BasePCA , LinearModel ):
@@ -75,10 +80,18 @@ class KernelPCovR(_BasePCA, LinearModel):
7580 If randomized :
7681 run randomized SVD by the method of Halko et al.
7782
83+ regressor : instance of `sklearn.kernel_ridge.KernelRidge`, default=None
84+ The regressor to use for computing
85+ the property predictions :math:`\\hat{\\mathbf{Y}}`.
86+ A pre-fitted regressor may be provided.
87+ If the regressor is not `None`, its kernel parameters
88+ (`kernel`, `gamma`, `degree`, `coef0`, and `kernel_params`)
89+ must be identical to those passed directly to `KernelPCovR`.
90+
7891 kernel: "linear" | "poly" | "rbf" | "sigmoid" | "cosine" | "precomputed"
7992 Kernel. Default="linear".
8093
81- gamma: float, default=1/n_features
94+ gamma: float, default=None
8295 Kernel coefficient for rbf, poly and sigmoid kernels. Ignored by other
8396 kernels.
8497
@@ -96,15 +109,13 @@ class KernelPCovR(_BasePCA, LinearModel):
96109 center: bool, default=False
97110 Whether to center any computed kernels
98111
99- alpha: float, default=1E-6
100- Regularization parameter to use in all regression operations.
101-
102112 fit_inverse_transform: bool, default=False
103113 Learn the inverse transform for non-precomputed kernels.
104114 (i.e. learn to find the pre-image of a point)
105115
106116 tol: float, default=1e-12
107- Tolerance for singular values computed by svd_solver == 'arpack'.
117+ Tolerance for singular values computed by svd_solver == 'arpack'
118+ and for matrix inversions.
108119 Must be of range [0.0, infinity).
109120
110121 n_jobs: int, default=None
@@ -121,6 +132,9 @@ class KernelPCovR(_BasePCA, LinearModel):
121132 Used when the 'arpack' or 'randomized' solvers are used. Pass an int
122133 for reproducible results across multiple function calls.
123134
135+ **regressor_params: additional keyword arguments to be passed
136+ to the regressor. Ignored if `regressor` is not `None`.
137+
124138
125139 Attributes
126140 ----------
@@ -154,53 +168,53 @@ class KernelPCovR(_BasePCA, LinearModel):
154168 >>> import numpy as np
155169 >>> from skcosmo.decomposition import KernelPCovR
156170 >>> from skcosmo.preprocessing import StandardFlexibleScaler as SFS
171+ >>> from sklearn.kernel_ridge import KernelRidge
157172 >>>
158173 >>> X = np.array([[-1, 1, -3, 1], [1, -2, 1, 2], [-2, 0, -2, -2], [1, 0, 2, -1]])
159174 >>> X = SFS().fit_transform(X)
160175 >>> Y = np.array([[ 0, -5], [-1, 1], [1, -5], [-3, 2]])
161176 >>> Y = SFS(column_wise=True).fit_transform(Y)
162177 >>>
163- >>> kpcovr = KernelPCovR(mixing=0.1, n_components=2, kernel='rbf', gamma=2 )
178+ >>> kpcovr = KernelPCovR(mixing=0.1, n_components=2, regressor=KernelRidge( kernel='rbf', gamma=1), kernel='rbf', gamma=1 )
164179 >>> kpcovr.fit(X, Y)
165- KernelPCovR(coef0=1, degree=3, fit_inverse_transform=False, gamma=0.01, kernel='rbf',
166- kernel_params=None, mixing=None, n_components=2, n_jobs=None,
167- alpha=None, tol=1e-12)
180+ KernelPCovR(gamma=1, kernel='rbf', mixing=0.1, n_components=2,
181+ regressor=KernelRidge(gamma=1, kernel='rbf'))
168182 >>> T = kpcovr.transform(X)
169- [[ 1.01199065 , -0.35439061 ],
170- [-0.68099591 , 0.48912275 ],
171- [ 1.4677616 , 0.13757037 ],
172- [-1.79874193 , -0.27232032 ]]
183+ [[-0.61261285 , -0.18937908 ],
184+ [ 0.45242098 , 0.25453465 ],
185+ [-0.77871824 , 0.04847559 ],
186+ [ 0.91186937 , -0.21211816 ]]
173187 >>> Yp = kpcovr.predict(X)
174- [[-0.01044648 , -0.84443158 ],
175- [-0.1758848 , 0.16224503 ],
176- [ 0.1573037 , -0.84211944 ],
177- [-0.51133139 , 0.32552881 ]]
188+ [[ 0.5100212 , -0.99488463 ],
189+ [-0.18992219 , 0.82064368 ],
190+ [ 1.11923584 , -1.04798016 ],
191+ [-1.5635827 , 1.11078662 ]]
178192 >>> kpcovr.score(X, Y)
179- (0.5312320029915978, 0.06254540655698511)
193+ -0.520388347837897
180194 """
181195
182196 def __init__ (
183197 self ,
184198 mixing = 0.5 ,
185199 n_components = None ,
186200 svd_solver = "auto" ,
201+ regressor = None ,
187202 kernel = "linear" ,
188203 gamma = None ,
189204 degree = 3 ,
190205 coef0 = 1 ,
191- alpha = 1e-6 ,
192206 kernel_params = None ,
193207 center = False ,
194208 fit_inverse_transform = False ,
195209 tol = 1e-12 ,
196210 n_jobs = None ,
197211 iterated_power = "auto" ,
198212 random_state = None ,
213+ ** regressor_params
199214 ):
200215
201216 self .mixing = mixing
202217 self .n_components = n_components
203- self .alpha = alpha
204218
205219 self .svd_solver = svd_solver
206220 self .tol = tol
@@ -209,15 +223,19 @@ def __init__(
209223 self .center = center
210224
211225 self .kernel = kernel
212- self .kernel_params = kernel_params
213226 self .gamma = gamma
214227 self .degree = degree
215228 self .coef0 = coef0
229+ self .kernel_params = kernel_params
230+
216231 self .n_jobs = n_jobs
217232 self .n_samples_ = None
218233
219234 self .fit_inverse_transform = fit_inverse_transform
220235
236+ self .regressor = regressor
237+ self .regressor_params = regressor_params
238+
221239 def _get_kernel (self , X , Y = None ):
222240 if callable (self .kernel ):
223241 params = self .kernel_params or {}
@@ -252,9 +270,9 @@ def _fit(self, K, Yhat, W):
252270 self .pkt_ = P @ U @ np .sqrt (np .diagflat (S_inv ))
253271
254272 T = K @ self .pkt_
255- self .pt__ = np .linalg .lstsq (T , np .eye (T .shape [0 ]), rcond = self .alpha )[0 ]
273+ self .pt__ = np .linalg .lstsq (T , np .eye (T .shape [0 ]), rcond = self .tol )[0 ]
256274
257- def fit (self , X , Y , Yhat = None , W = None ):
275+ def fit (self , X , Y ):
258276 """
259277
260278 Fit the model with X and Y.
@@ -279,18 +297,16 @@ def fit(self, X, Y, Yhat=None, W=None):
279297 to have unit variance, otherwise :math:`\\ mathbf{Y}` should be
280298 scaled so that each feature has a variance of 1 / n_features.
281299
282- Yhat: ndarray, shape (n_samples, n_properties), optional
283- Regressed training data, where n_samples is the number of samples and
284- n_properties is the number of properties. If not supplied, computed
285- by ridge regression.
286-
287300 Returns
288301 -------
289302 self: object
290303 Returns the instance itself.
291304
292305 """
293306
307+ if self .regressor is not None and not isinstance (self .regressor , KernelRidge ):
308+ raise ValueError ("Regressor must be an instance of `KernelRidge`" )
309+
294310 X , Y = check_X_y (X , Y , y_numeric = True , multi_output = True )
295311 self .X_fit_ = X .copy ()
296312
@@ -308,14 +324,66 @@ def fit(self, X, Y, Yhat=None, W=None):
308324
309325 self .n_samples_ = X .shape [0 ]
310326
311- if W is None :
312- if Yhat is None :
313- W = (np .linalg .lstsq (K , Y , rcond = self .alpha )[0 ]).reshape (X .shape [0 ], - 1 )
314- else :
315- W = np .linalg .lstsq (K , Yhat , rcond = self .alpha )[0 ]
327+ if self .regressor is None :
328+ regressor = KernelRidge (
329+ kernel = self .kernel ,
330+ gamma = self .gamma ,
331+ degree = self .degree ,
332+ coef0 = self .coef0 ,
333+ kernel_params = self .kernel_params ,
334+ ** self .regressor_params ,
335+ )
336+ else :
337+ regressor = self .regressor
338+ kernel_attrs = ["kernel" , "gamma" , "degree" , "coef0" , "kernel_params" ]
339+ if not all (
340+ [
341+ getattr (self , attr ) == getattr (regressor , attr )
342+ for attr in kernel_attrs
343+ ]
344+ ):
345+ raise ValueError (
346+ "Kernel parameter mismatch: the regressor has kernel parameters {%s}"
347+ " and KernelPCovR was initialized with kernel parameters {%s}"
348+ % (
349+ ", " .join (
350+ [
351+ "%s: %r" % (attr , getattr (regressor , attr ))
352+ for attr in kernel_attrs
353+ ]
354+ ),
355+ ", " .join (
356+ [
357+ "%s: %r" % (attr , getattr (self , attr ))
358+ for attr in kernel_attrs
359+ ]
360+ ),
361+ )
362+ )
363+
364+ # Check if regressor is fitted; if not, fit with precomputed K
365+ # to avoid needing to compute the kernel a second time
366+ self .regressor_ = check_krr_fit (regressor , K , X , Y )
316367
317- if Yhat is None :
318- Yhat = K @ W
368+ W = self .regressor_ .dual_coef_ .reshape (X .shape [0 ], - 1 )
369+
370+ # Use this instead of `self.regressor_.predict(K)`
371+ # so that we can handle the case of the pre-fitted regressor
372+ Yhat = K @ W
373+
374+ # When we have an unfitted regressor,
375+ # we fit it with a precomputed K
376+ # so we must subsequently "reset" it so that
377+ # it will work on the particular X
378+ # of the KPCovR call. The dual coefficients are kept.
379+ # Can be bypassed if the regressor is pre-fitted.
380+ try :
381+ check_is_fitted (regressor )
382+
383+ except NotFittedError :
384+ self .regressor_ .set_params (** regressor .get_params ())
385+ self .regressor_ .X_fit_ = self .X_fit_
386+ self .regressor_ ._check_n_features (self .X_fit_ , reset = True )
319387
320388 # Handle svd_solver
321389 self ._fit_svd_solver = self .svd_solver
@@ -408,7 +476,7 @@ def inverse_transform(self, T):
408476
409477 def score (self , X , Y ):
410478 r"""
411- Computes the loss values for KernelPCovR on the given predictor and
479+ Computes the (negative) loss values for KernelPCovR on the given predictor and
412480 response variables. The loss in :math:`\mathbf{K}`, as explained in
413481 [Helfrecht2020]_ does not correspond to a traditional Gram loss
414482 :math:`\mathbf{K} - \mathbf{TT}^T`. Indicating the kernel between set
@@ -424,15 +492,17 @@ def score(self, X, Y):
424492 \mathbf{K}_{NN} \mathbf{T}_N (\mathbf{T}_N^T \mathbf{T}_N)^{-1}
425493 \mathbf{T}_V^T\right]}{\operatorname{Tr}(\mathbf{K}_{VV})}
426494
495+ The negative loss is returned for easier use in sklearn pipelines, e.g., a grid search, where methods named 'score' are meant to be maximized.
496+
427497 Arguments
428498 ---------
429499 X: independent (predictor) variable
430500 Y: dependent (response) variable
431501
432502 Returns
433503 -------
434- Lk : KPCA loss, determined by the reconstruction of the kernel
435- Ly: KR loss
504+ L : Negative sum of the KPCA and KRR losses, with the KPCA loss
505+ determined by the reconstruction of the kernel
436506
437507 """
438508
@@ -455,10 +525,14 @@ def score(self, X, Y):
455525 t_n = K_NN @ self .pkt_
456526 t_v = K_VN @ self .pkt_
457527
458- w = t_n @ np .linalg .pinv (t_n .T @ t_n , rcond = self .alpha ) @ t_v .T
528+ w = (
529+ t_n
530+ @ np .linalg .lstsq (t_n .T @ t_n , np .eye (t_n .shape [1 ]), rcond = self .tol )[0 ]
531+ @ t_v .T
532+ )
459533 Lkpca = np .trace (K_VV - 2 * K_VN @ w + w .T @ K_VV @ w ) / np .trace (K_VV )
460534
461- return sum ([Lkpca , Lkrr ])
535+ return - sum ([Lkpca , Lkrr ])
462536
463537 def _decompose_truncated (self , mat ):
464538
0 commit comments