From 341bf2fb1359bba1b524dee5a3049b9886f24924 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 6 Apr 2023 18:00:06 +0200 Subject: [PATCH 1/3] implement decorator && docs --- skglm/penalties/base.py | 71 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/skglm/penalties/base.py b/skglm/penalties/base.py index b45254b71..165cd58cc 100644 --- a/skglm/penalties/base.py +++ b/skglm/penalties/base.py @@ -1,3 +1,4 @@ +from numba import float64 from abc import abstractmethod @@ -60,3 +61,73 @@ def is_penalized(self, n_features): @abstractmethod def generalized_support(self, w): r"""Return a mask which is True for coefficients in the generalized support.""" + + +def overload_with_l2(cls): + """Decorate a penalty class to add L2 regularization. + + The resulting penalty reads + + .. math:: + + "penalty"(w) + "l2"_"regularization" xx ||w||**2 / 2 + + Parameters + ---------- + cls: Penalty class + The penalty class to be overloaded with L2 regularization. + + Return + ------ + cls: Penalty class + Penalty overloaded with L2 regularization. + """ + # keep ref to original methods + cls_constructor = cls.__init__ + cls_prox_1d = cls.prox_1d + cls_value = cls.value + cls_subdiff_distance = cls. subdiff_distance + cls_params_to_dict = cls.params_to_dict + cls_get_spec = cls.get_spec + + # implement new methods + def __init__(self, *args, l2_regularization=0., **kwargs): + cls_constructor(self, *args, **kwargs) + self.l2_regularization = l2_regularization + + def prox_1d(self, value, stepsize, j): + if self.l2_regularization == 0.: + return cls_prox_1d(self, value, stepsize, j) + + scale = 1 + stepsize * self.l2_regularization + return cls_prox_1d(self, value / scale, stepsize / scale, j) + + def value(self, w): + l2_regularization = self.l2_regularization + if l2_regularization == 0.: + return cls_value(self, w) + + return cls_value(self, w) + l2_regularization * 0.5 * w ** 2 + + def subdiff_distance(self, w, grad, ws): + if self.l2_regularization == 0.: + return cls_subdiff_distance(self, w, grad, ws) + + return cls_subdiff_distance(self, w, grad + self.l2_regularization * w[ws], ws) + + def get_spec(self): + return (('l2_regularization', float64), *cls_get_spec(self)) + + def params_to_dict(self): + return dict(l2_regularization=self.l2_regularization, + **cls_params_to_dict(self)) + + # override methods + cls.__init__ = __init__ + cls.value = value + cls.prox_1d = prox_1d + cls.subdiff_distance = subdiff_distance + cls.get_spec = get_spec + cls.params_to_dict = params_to_dict + + return cls From 5b0d5ef641cdb7a0a0632cdfad7395d1a4a91268 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 6 Apr 2023 18:00:40 +0200 Subject: [PATCH 2/3] unittest for elasticnet --- skglm/penalties/separable.py | 8 +++++++- skglm/tests/test_penalties.py | 38 +++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/skglm/penalties/separable.py b/skglm/penalties/separable.py index 2c1429a87..4901dbd98 100644 --- a/skglm/penalties/separable.py +++ b/skglm/penalties/separable.py @@ -2,7 +2,7 @@ from numba import float64 from numba.types import bool_ -from skglm.penalties.base import BasePenalty +from skglm.penalties.base import BasePenalty, overload_with_l2 from skglm.utils.prox_funcs import ( ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP, value_MCP) @@ -67,6 +67,12 @@ def alpha_max(self, gradient0): return np.max(np.abs(gradient0)) +# To add support of L2 regularization, one needs to decorate the penalty +@overload_with_l2 +class _TestL1(L1): + pass + + class L1_plus_L2(BasePenalty): """:math:`ell_1 + ell_2` penalty (aka ElasticNet penalty).""" diff --git a/skglm/tests/test_penalties.py b/skglm/tests/test_penalties.py index cafeb9d03..3a5b2dcf4 100644 --- a/skglm/tests/test_penalties.py +++ b/skglm/tests/test_penalties.py @@ -14,6 +14,9 @@ from skglm.solvers import AndersonCD, MultiTaskBCD, FISTA from skglm.utils.data import make_correlated_data +from skglm.penalties.separable import _TestL1 +from skglm.utils.jit_compilation import compiled_clone + n_samples = 20 n_features = 10 @@ -118,5 +121,40 @@ def test_nnls(fit_intercept): np.testing.assert_allclose(clf.intercept_, reg_nnls.intercept_) +def test_overload_with_l2_ElasticNet(): + lmbd = 0.2 + l1_ratio = 0.7 + + elastic_net = L1_plus_L2(lmbd, l1_ratio) + implicit_elastic_net = _TestL1(alpha=lmbd * l1_ratio, + l2_regularization=lmbd * (1 - l1_ratio)) + + n_feautures, ws_size = 5, 3 + stepsize = 0.8 + + rng = np.random.RandomState(425) + w = rng.randn(n_feautures) + grad = rng.randn(ws_size) + ws = rng.choice(n_feautures, size=ws_size, replace=False) + + x = w[2] + np.testing.assert_equal( + elastic_net.value(x), + implicit_elastic_net.value(x) + ) + np.testing.assert_equal( + elastic_net.prox_1d(x, stepsize, 0), + implicit_elastic_net.prox_1d(x, stepsize, 0) + ) + np.testing.assert_array_equal( + elastic_net.subdiff_distance(w, grad, ws), + implicit_elastic_net.subdiff_distance(w, grad, ws) + ) + + # This will raise an error as *args and **kwargs are not supported in numba + with pytest.raises(Exception, match=r"VAR_POSITIONAL.*unsupported.*jitclass"): + compiled_clone(implicit_elastic_net) + + if __name__ == "__main__": pass From 23bb6a471ff50689cfaddc80e78af55b1ce35ab1 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Thu, 6 Apr 2023 19:34:32 +0200 Subject: [PATCH 3/3] linter happy --- skglm/penalties/base.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/skglm/penalties/base.py b/skglm/penalties/base.py index 165cd58cc..f9b68bb1f 100644 --- a/skglm/penalties/base.py +++ b/skglm/penalties/base.py @@ -63,7 +63,7 @@ def generalized_support(self, w): r"""Return a mask which is True for coefficients in the generalized support.""" -def overload_with_l2(cls): +def overload_with_l2(klass): """Decorate a penalty class to add L2 regularization. The resulting penalty reads @@ -74,21 +74,21 @@ def overload_with_l2(cls): Parameters ---------- - cls: Penalty class + klass : Penalty class The penalty class to be overloaded with L2 regularization. - Return - ------ - cls: Penalty class + Returns + ------- + klass : Penalty class Penalty overloaded with L2 regularization. """ # keep ref to original methods - cls_constructor = cls.__init__ - cls_prox_1d = cls.prox_1d - cls_value = cls.value - cls_subdiff_distance = cls. subdiff_distance - cls_params_to_dict = cls.params_to_dict - cls_get_spec = cls.get_spec + cls_constructor = klass.__init__ + cls_prox_1d = klass.prox_1d + cls_value = klass.value + cls_subdiff_distance = klass. subdiff_distance + cls_params_to_dict = klass.params_to_dict + cls_get_spec = klass.get_spec # implement new methods def __init__(self, *args, l2_regularization=0., **kwargs): @@ -123,11 +123,11 @@ def params_to_dict(self): **cls_params_to_dict(self)) # override methods - cls.__init__ = __init__ - cls.value = value - cls.prox_1d = prox_1d - cls.subdiff_distance = subdiff_distance - cls.get_spec = get_spec - cls.params_to_dict = params_to_dict - - return cls + klass.__init__ = __init__ + klass.value = value + klass.prox_1d = prox_1d + klass.subdiff_distance = subdiff_distance + klass.get_spec = get_spec + klass.params_to_dict = params_to_dict + + return klass