From a3c7a53c3edc6f74ce235e4a2b3f8d1ca5697e3c Mon Sep 17 00:00:00 2001 From: mathurinm Date: Thu, 5 May 2022 17:31:27 +0200 Subject: [PATCH 01/20] draft flexible gram solver with penalty and using datafit --- skglm/solvers/gram.py | 68 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 skglm/solvers/gram.py diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py new file mode 100644 index 000000000..a10d05912 --- /dev/null +++ b/skglm/solvers/gram.py @@ -0,0 +1,68 @@ +import numpy as np +from numba import njit +# from numpy.linalg import norm + +from skglm.utils import BST, ST, ST_vec +from skglm.datafits import Quadratic + + +def cd_gram_quadratic(X, y, penalty, max_iter, tol, w_init=None, check_freq=100): + """Gram solver for quadratic datafit.""" + n_features = X.shape[1] + datafit = Quadratic() + datafit.initialize(X, y) # todo sparse + G = X.T @ X # gram matrix + grads = X.T @ y / len(y) # this is wrong if an init is used + w = w_init.copy() if w_init is not None else np.zeros(n_features) + for n_iter in range(max_iter): + _cd_epoch_gram(X, G, grads, w, penalty, datafit) + if n_iter % check_freq == 0: + # check KKT + # print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f}" + + # f" :: gap {d_gap:.5f}") + # if d_gap < tol: + # print("Convergence reached!") + # break + return w + + +def fista_gram_quadratic( + X, y, penalty, max_iter, tol, w_init=None, check_freq=100): + n_samples, n_features = X.shape + norm_y2 = y @ y + t_new = 1 + w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) + G = X.T @ X + Xty = X.T @ y + L = np.linalg.norm(X, ord=2) ** 2 / n_samples + for n_iter in range(max_iter): + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + z -= (G @ z - Xty) / L / n_samples + w = ST_vec(z, alpha / L) # use penalty.prox + z = w + (t_old - 1.) / t_new * (w - w_old) + if n_iter % check_freq == 0: + pass + # use KKT instead + # print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + + # f":: gap {d_gap:.5f}") + # if d_gap < tol: + # print("Convergence reached!") + # break + return w + + +@njit +def _cd_epoch_gram(X, G, grads, w, penalty, datafit): + n_features = X.shape[1] + for j in range(n_features): + if lipschitz[j] == 0: + continue + old_w_j = w[j] + # use penalty.prox1d + w[j] = ST(w[j] + grads[j] / datafit.lipschitz[j], + alpha / lipschitz[j] * ) + if old_w_j != w[j]: + grads += G[j, :] * (old_w_j - w[j]) / len(X) From 34db4fc531e1a54762b4731f0efbe0332db4e227 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 14:35:20 +0200 Subject: [PATCH 02/20] fix wrong docstring --- skglm/solvers/cd_solver.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index c8777e0f1..60a350fb9 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -214,11 +214,11 @@ def cd_solver( Returns ------- - alphas : array, shape (n_alphas,) - The alphas along the path where models are computed. + w : array, shape (n_features,) + Coefficients. - coefs : array, shape (n_features, n_alphas) - Coefficients along the path. + obj_out : array, shape (n_iter,) + Objective value at every outer iteration. stop_crit : array, shape (n_alphas,) Value of stopping criterion at convergence along the path. From 2110db1fb65e1b0901b11b4c3732d63e30f6743e Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 14:36:18 +0200 Subject: [PATCH 03/20] reorg cd_gram_quadratic --- skglm/solvers/gram.py | 150 +++++++++++++++++++++++++++++++----------- 1 file changed, 110 insertions(+), 40 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index a10d05912..5c1c1e65e 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -1,57 +1,98 @@ import numpy as np from numba import njit +from scipy import sparse +from skglm.solvers.cd_solver import ( + construct_grad, construct_grad_sparse, dist_fix_point) # from numpy.linalg import norm -from skglm.utils import BST, ST, ST_vec +from skglm.utils import ST, ST_vec from skglm.datafits import Quadratic -def cd_gram_quadratic(X, y, penalty, max_iter, tol, w_init=None, check_freq=100): - """Gram solver for quadratic datafit.""" +def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, + ws_strategy="subdiff", verbose=0): + r"""Run a coordinate descent solver using Gram update for quadratic datafit. + + This solver should be used when n_samples >> n_features. It does not implement any + working set strategy and iteratively updates the gradients (n_features,) instead of + the residuals (n_samples,). + + Parameters + ---------- + X : array, shape (n_samples, n_features) + Training data. + + y : array, shape (n_samples,) + Target values. + + penalty : instance of Penalty class + Penalty used in the model. + + max_iter : int, optional + Maximum number of CD epochs. + + tol : float, optional + The tolerance for the optimization. + + ws_strategy : ('subdiff'|'fixpoint'), optional + The score used to compute the stopping criterion. + + verbose : bool or int, optional + Amount of verbosity. 0/False is silent. + + Returns + ------- + w : array, shape (n_features,) + Coefficients. + + obj_out : array, shape (n_iter,) + Objective value at every outer iteration. + + stop_crit : array, shape (n_alphas,) + Value of stopping criterion at convergence along the path. + """ n_features = X.shape[1] + all_feats = np.arange(n_features) + obj_out = [] datafit = Quadratic() - datafit.initialize(X, y) # todo sparse + is_sparse = sparse.issparse(X) + if is_sparse: + datafit.initialize_sparse(X.data, X.indptr, X.indices, y) + else: + datafit.initialize(X, y) G = X.T @ X # gram matrix - grads = X.T @ y / len(y) # this is wrong if an init is used + grads = (X.T @ y - G @ w_init) / len(y) if w_init is not None else X.T @ y / len(y) w = w_init.copy() if w_init is not None else np.zeros(n_features) for n_iter in range(max_iter): - _cd_epoch_gram(X, G, grads, w, penalty, datafit) - if n_iter % check_freq == 0: - # check KKT - # print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f}" + - # f" :: gap {d_gap:.5f}") - # if d_gap < tol: - # print("Convergence reached!") - # break - return w + if is_sparse: + _cd_epoch_gram_sparse( + X.data, X.indptr, X.indices, G, grads, w, penalty, datafit) + else: + _cd_epoch_gram(X, G, grads, w, penalty, datafit) + if n_iter % 50 == 0: + # TODO: X @ w + Xw = X @ w + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + if is_sparse: + grad = construct_grad_sparse( + X.data, X.indptr, X.indices, y, w, Xw, datafit, all_feats) + else: + grad = construct_grad(X, y, w, Xw, datafit, all_feats) + if ws_strategy == "subdiff": + opt_ws = penalty.subdiff_distance(w, grad, all_feats) + elif ws_strategy == "fixpoint": + opt_ws = dist_fix_point(w, grad, datafit, penalty, all_feats) + + stop_crit = np.max(opt_ws) + if max(verbose - 1, 0): + print(f"Epoch {n_iter + 1}, objective {p_obj:.10f}, " + f"stopping crit {stop_crit:.2e}") + if stop_crit <= tol: + break + obj_out.append(p_obj) + return w, np.array(obj_out), stop_crit -def fista_gram_quadratic( - X, y, penalty, max_iter, tol, w_init=None, check_freq=100): - n_samples, n_features = X.shape - norm_y2 = y @ y - t_new = 1 - w = w_init.copy() if w_init is not None else np.zeros(n_features) - z = w_init.copy() if w_init is not None else np.zeros(n_features) - G = X.T @ X - Xty = X.T @ y - L = np.linalg.norm(X, ord=2) ** 2 / n_samples - for n_iter in range(max_iter): - t_old = t_new - t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 - w_old = w.copy() - z -= (G @ z - Xty) / L / n_samples - w = ST_vec(z, alpha / L) # use penalty.prox - z = w + (t_old - 1.) / t_new * (w - w_old) - if n_iter % check_freq == 0: - pass - # use KKT instead - # print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + - # f":: gap {d_gap:.5f}") - # if d_gap < tol: - # print("Convergence reached!") - # break - return w @njit @@ -66,3 +107,32 @@ def _cd_epoch_gram(X, G, grads, w, penalty, datafit): alpha / lipschitz[j] * ) if old_w_j != w[j]: grads += G[j, :] * (old_w_j - w[j]) / len(X) + + + +# def fista_gram_quadratic( +# X, y, penalty, max_iter, tol, w_init=None, check_freq=100): +# n_samples, n_features = X.shape +# norm_y2 = y @ y +# t_new = 1 +# w = w_init.copy() if w_init is not None else np.zeros(n_features) +# z = w_init.copy() if w_init is not None else np.zeros(n_features) +# G = X.T @ X +# Xty = X.T @ y +# L = np.linalg.norm(X, ord=2) ** 2 / n_samples +# for n_iter in range(max_iter): +# t_old = t_new +# t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 +# w_old = w.copy() +# z -= (G @ z - Xty) / L / n_samples +# w = ST_vec(z, alpha / L) # use penalty.prox +# z = w + (t_old - 1.) / t_new * (w - w_old) +# if n_iter % check_freq == 0: +# pass +# # use KKT instead +# # print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + +# # f":: gap {d_gap:.5f}") +# # if d_gap < tol: +# # print("Convergence reached!") +# # break +# return w From 247bb75b09c87c4f37c7d1f7753cc3825d4e9b75 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 14:50:37 +0200 Subject: [PATCH 04/20] fix cd epoch --- skglm/solvers/gram.py | 56 ++++++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 5c1c1e65e..75d2121c6 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -12,7 +12,7 @@ def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, ws_strategy="subdiff", verbose=0): r"""Run a coordinate descent solver using Gram update for quadratic datafit. - + This solver should be used when n_samples >> n_features. It does not implement any working set strategy and iteratively updates the gradients (n_features,) instead of the residuals (n_samples,). @@ -21,25 +21,25 @@ def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, ---------- X : array, shape (n_samples, n_features) Training data. - + y : array, shape (n_samples,) Target values. - + penalty : instance of Penalty class Penalty used in the model. - + max_iter : int, optional Maximum number of CD epochs. - + tol : float, optional The tolerance for the optimization. - + ws_strategy : ('subdiff'|'fixpoint'), optional The score used to compute the stopping criterion. - + verbose : bool or int, optional Amount of verbosity. 0/False is silent. - + Returns ------- w : array, shape (n_features,) @@ -64,11 +64,7 @@ def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, grads = (X.T @ y - G @ w_init) / len(y) if w_init is not None else X.T @ y / len(y) w = w_init.copy() if w_init is not None else np.zeros(n_features) for n_iter in range(max_iter): - if is_sparse: - _cd_epoch_gram_sparse( - X.data, X.indptr, X.indices, G, grads, w, penalty, datafit) - else: - _cd_epoch_gram(X, G, grads, w, penalty, datafit) + _cd_epoch_gram(X, G, grads, w, datafit, penalty) if n_iter % 50 == 0: # TODO: X @ w Xw = X @ w @@ -96,20 +92,42 @@ def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, @njit -def _cd_epoch_gram(X, G, grads, w, penalty, datafit): +def _cd_epoch_gram(X, G, grads, w, datafit, penalty): + """Run an epoch of coordinate descent in place with gradient update using Gram. + + Parameters + ---------- + X : array, shape (n_samples, n_features) + Design matrix. + + G : array, shape (n_features, n_features) + Gram matrix. + + grads : array, shape (n_features,) + Gradient vector. + + w : array, shape (n_features,) + Coefficient vector. + + datafit : Datafit + Datafit. + + penalty : Penalty + Penalty. + """ + # TODO: sparse matrix n_features = X.shape[1] + lc = datafit.lipschitz for j in range(n_features): - if lipschitz[j] == 0: + if lc[j] == 0: continue old_w_j = w[j] - # use penalty.prox1d - w[j] = ST(w[j] + grads[j] / datafit.lipschitz[j], - alpha / lipschitz[j] * ) + stepsize = 1 / lc[j] if lc[j] != 0 else 1000 + w[j] = penalty.prox_1d(old_w_j + grads[j] / lc[j], stepsize, j) if old_w_j != w[j]: grads += G[j, :] * (old_w_j - w[j]) / len(X) - # def fista_gram_quadratic( # X, y, penalty, max_iter, tol, w_init=None, check_freq=100): # n_samples, n_features = X.shape From 11742dcbb537c36d06bab8e1c9d5db4e12db9da1 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 15:04:09 +0200 Subject: [PATCH 05/20] green --- skglm/solvers/gram.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 75d2121c6..7c55d6863 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -1,11 +1,9 @@ import numpy as np from numba import njit from scipy import sparse + from skglm.solvers.cd_solver import ( construct_grad, construct_grad_sparse, dist_fix_point) -# from numpy.linalg import norm - -from skglm.utils import ST, ST_vec from skglm.datafits import Quadratic @@ -34,6 +32,9 @@ def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, tol : float, optional The tolerance for the optimization. + w_init : array, shape (n_features,), optional + Initial coefficient vector. + ws_strategy : ('subdiff'|'fixpoint'), optional The score used to compute the stopping criterion. @@ -90,7 +91,6 @@ def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, return w, np.array(obj_out), stop_crit - @njit def _cd_epoch_gram(X, G, grads, w, datafit, penalty): """Run an epoch of coordinate descent in place with gradient update using Gram. From 4a44957315e150f085c18bccef97937e31abae55 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 15:25:18 +0200 Subject: [PATCH 06/20] ERR circular import --- skglm/solvers/cd_solver.py | 17 +++++++++++++---- skglm/solvers/gram.py | 12 ++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index 60a350fb9..ffdcc1068 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -2,6 +2,8 @@ from numba import njit from scipy import sparse from sklearn.utils import check_array +from skglm.datafits.single_task import Quadratic +from skglm.solvers.gram import cd_gram_quadratic def cd_solver_path(X, y, datafit, penalty, alphas=None, @@ -109,6 +111,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, # else: # alphas = np.sort(alphas)[::-1] + n_samples = len(y) n_alphas = len(alphas) coefs = np.zeros((n_features, n_alphas), order='F', dtype=X.dtype) @@ -144,10 +147,16 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, w = np.zeros(n_features, dtype=X.dtype) Xw = np.zeros(X.shape[0], dtype=X.dtype) - sol = cd_solver( - X, y, datafit, penalty, w, Xw, - max_iter=max_iter, max_epochs=max_epochs, p0=p0, tol=tol, - use_acc=use_acc, verbose=verbose, ws_strategy=ws_strategy) + if isinstance(datafit, Quadratic) and n_samples > n_features * 10: + # XXX: does n_samples > n_features * 10 look correct? + sol = cd_gram_quadratic( + X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, + ws_strategy=ws_strategy, verbose=verbose) + else: + sol = cd_solver( + X, y, datafit, penalty, w, Xw, + max_iter=max_iter, max_epochs=max_epochs, p0=p0, tol=tol, + use_acc=use_acc, verbose=verbose, ws_strategy=ws_strategy) coefs[:, t] = w stop_crits[t] = sol[-1] diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 7c55d6863..791da1a26 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -2,12 +2,12 @@ from numba import njit from scipy import sparse +from skglm.datafits import Quadratic from skglm.solvers.cd_solver import ( construct_grad, construct_grad_sparse, dist_fix_point) -from skglm.datafits import Quadratic -def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, +def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, ws_strategy="subdiff", verbose=0): r"""Run a coordinate descent solver using Gram update for quadratic datafit. @@ -26,7 +26,7 @@ def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, penalty : instance of Penalty class Penalty used in the model. - max_iter : int, optional + max_epochs : int, optional Maximum number of CD epochs. tol : float, optional @@ -64,9 +64,9 @@ def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, G = X.T @ X # gram matrix grads = (X.T @ y - G @ w_init) / len(y) if w_init is not None else X.T @ y / len(y) w = w_init.copy() if w_init is not None else np.zeros(n_features) - for n_iter in range(max_iter): + for epoch in range(max_epochs): _cd_epoch_gram(X, G, grads, w, datafit, penalty) - if n_iter % 50 == 0: + if epoch % 50 == 0: # TODO: X @ w Xw = X @ w p_obj = datafit.value(y, w, Xw) + penalty.value(w) @@ -83,7 +83,7 @@ def cd_gram_quadratic(X, y, penalty, max_iter=100, tol=1e-4, w_init=None, stop_crit = np.max(opt_ws) if max(verbose - 1, 0): - print(f"Epoch {n_iter + 1}, objective {p_obj:.10f}, " + print(f"Epoch {epoch + 1}, objective {p_obj:.10f}, " f"stopping crit {stop_crit:.2e}") if stop_crit <= tol: break From a85f5cfeed405e1d48b1872ae742c7dc22d98ea0 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 15:42:37 +0200 Subject: [PATCH 07/20] fix circular import --- skglm/solvers/cd_solver.py | 121 ++----------------------------------- skglm/solvers/cd_utils.py | 114 ++++++++++++++++++++++++++++++++++ skglm/solvers/gram.py | 2 +- 3 files changed, 121 insertions(+), 116 deletions(-) create mode 100644 skglm/solvers/cd_utils.py diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index ffdcc1068..277fe1827 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -2,7 +2,9 @@ from numba import njit from scipy import sparse from sklearn.utils import check_array -from skglm.datafits.single_task import Quadratic +from skglm.datafits.single_task import Quadratic, Quadratic_32 +from skglm.solvers.cd_utils import ( + dist_fix_point, construct_grad, construct_grad_sparse) from skglm.solvers.gram import cd_gram_quadratic @@ -147,8 +149,9 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, w = np.zeros(n_features, dtype=X.dtype) Xw = np.zeros(X.shape[0], dtype=X.dtype) - if isinstance(datafit, Quadratic) and n_samples > n_features * 10: - # XXX: does n_samples > n_features * 10 look correct? + if (isinstance(datafit, (Quadratic, Quadratic_32)) and n_samples > n_features + and n_features < 10_000): + # Gram matrix must fit in memory sol = cd_gram_quadratic( X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, ws_strategy=ws_strategy, verbose=verbose) @@ -356,118 +359,6 @@ def cd_solver( return w, np.array(obj_out), stop_crit -@njit -def dist_fix_point(w, grad, datafit, penalty, ws): - """Compute the violation of the fixed point iterate scheme. - - Parameters - ---------- - w : array, shape (n_features,) - Coefficient vector. - - grad : array, shape (n_features,) - Gradient. - - datafit: instance of BaseDatafit - Datafit. - - penalty: instance of BasePenalty - Penalty. - - ws : array, shape (n_features,) - The working set. - - Returns - ------- - dist_fix_point : array, shape (n_features,) - Violation score for every feature. - """ - dist_fix_point = np.zeros(ws.shape[0]) - for idx, j in enumerate(ws): - lcj = datafit.lipschitz[j] - if lcj != 0: - dist_fix_point[idx] = np.abs( - w[j] - penalty.prox_1d(w[j] - grad[idx] / lcj, 1. / lcj, j)) - return dist_fix_point - - -@njit -def construct_grad(X, y, w, Xw, datafit, ws): - """Compute the gradient of the datafit restricted to the working set. - - Parameters - ---------- - X : array, shape (n_samples, n_features) - Design matrix. - - y : array, shape (n_samples,) - Target vector. - - w : array, shape (n_features,) - Coefficient vector. - - Xw : array, shape (n_samples, ) - Model fit. - - datafit : Datafit - Datafit. - - ws : array, shape (n_features,) - The working set. - - Returns - ------- - grad : array, shape (ws_size, n_tasks) - The gradient restricted to the working set. - """ - grad = np.zeros(ws.shape[0]) - for idx, j in enumerate(ws): - grad[idx] = datafit.gradient_scalar(X, y, w, Xw, j) - return grad - - -@njit -def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws): - """Compute the gradient of the datafit restricted to the working set. - - Parameters - ---------- - data : array-like - Data array of the matrix in CSC format. - - indptr : array-like - CSC format index point array. - - indices : array-like - CSC format index array. - - y : array, shape (n_samples, ) - Target matrix. - - w : array, shape (n_features,) - Coefficient matrix. - - Xw : array, shape (n_samples, ) - Model fit. - - datafit : Datafit - Datafit. - - ws : array, shape (n_features,) - The working set. - - Returns - ------- - grad : array, shape (ws_size, n_tasks) - The gradient restricted to the working set. - """ - grad = np.zeros(ws.shape[0]) - for idx, j in enumerate(ws): - grad[idx] = datafit.gradient_scalar_sparse( - data, indptr, indices, y, Xw, j) - return grad - - @njit def _cd_epoch(X, y, w, Xw, datafit, penalty, feats): """Run an epoch of coordinate descent in place. diff --git a/skglm/solvers/cd_utils.py b/skglm/solvers/cd_utils.py new file mode 100644 index 000000000..5c00bd312 --- /dev/null +++ b/skglm/solvers/cd_utils.py @@ -0,0 +1,114 @@ +import numpy as np +from numba import njit + + +@njit +def dist_fix_point(w, grad, datafit, penalty, ws): + """Compute the violation of the fixed point iterate scheme. + + Parameters + ---------- + w : array, shape (n_features,) + Coefficient vector. + + grad : array, shape (n_features,) + Gradient. + + datafit: instance of BaseDatafit + Datafit. + + penalty: instance of BasePenalty + Penalty. + + ws : array, shape (n_features,) + The working set. + + Returns + ------- + dist_fix_point : array, shape (n_features,) + Violation score for every feature. + """ + dist_fix_point = np.zeros(ws.shape[0]) + for idx, j in enumerate(ws): + lcj = datafit.lipschitz[j] + if lcj != 0: + dist_fix_point[idx] = np.abs( + w[j] - penalty.prox_1d(w[j] - grad[idx] / lcj, 1. / lcj, j)) + return dist_fix_point + + +@njit +def construct_grad(X, y, w, Xw, datafit, ws): + """Compute the gradient of the datafit restricted to the working set. + + Parameters + ---------- + X : array, shape (n_samples, n_features) + Design matrix. + + y : array, shape (n_samples,) + Target vector. + + w : array, shape (n_features,) + Coefficient vector. + + Xw : array, shape (n_samples, ) + Model fit. + + datafit : Datafit + Datafit. + + ws : array, shape (n_features,) + The working set. + + Returns + ------- + grad : array, shape (ws_size, n_tasks) + The gradient restricted to the working set. + """ + grad = np.zeros(ws.shape[0]) + for idx, j in enumerate(ws): + grad[idx] = datafit.gradient_scalar(X, y, w, Xw, j) + return grad + + +@njit +def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws): + """Compute the gradient of the datafit restricted to the working set. + + Parameters + ---------- + data : array-like + Data array of the matrix in CSC format. + + indptr : array-like + CSC format index point array. + + indices : array-like + CSC format index array. + + y : array, shape (n_samples, ) + Target matrix. + + w : array, shape (n_features,) + Coefficient matrix. + + Xw : array, shape (n_samples, ) + Model fit. + + datafit : Datafit + Datafit. + + ws : array, shape (n_features,) + The working set. + + Returns + ------- + grad : array, shape (ws_size, n_tasks) + The gradient restricted to the working set. + """ + grad = np.zeros(ws.shape[0]) + for idx, j in enumerate(ws): + grad[idx] = datafit.gradient_scalar_sparse( + data, indptr, indices, y, Xw, j) + return grad diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 791da1a26..569c29f66 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -3,7 +3,7 @@ from scipy import sparse from skglm.datafits import Quadratic -from skglm.solvers.cd_solver import ( +from skglm.solvers.cd_utils import ( construct_grad, construct_grad_sparse, dist_fix_point) From 6c9b146bace0cca13f0e69943fc08d955ddcd400 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 15:45:08 +0200 Subject: [PATCH 08/20] linter happy --- skglm/solvers/cd_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index 277fe1827..153f9fbf8 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -150,7 +150,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, Xw = np.zeros(X.shape[0], dtype=X.dtype) if (isinstance(datafit, (Quadratic, Quadratic_32)) and n_samples > n_features - and n_features < 10_000): + and n_features < 10_000): # Gram matrix must fit in memory sol = cd_gram_quadratic( X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, From 30363ea168d43b1c768313f8f6f4842876245f37 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 16:54:56 +0200 Subject: [PATCH 09/20] fix sparse --- skglm/solvers/gram.py | 74 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 13 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 569c29f66..8554dc107 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -2,7 +2,7 @@ from numba import njit from scipy import sparse -from skglm.datafits import Quadratic +from skglm.datafits import Quadratic, Quadratic_32 from skglm.solvers.cd_utils import ( construct_grad, construct_grad_sparse, dist_fix_point) @@ -52,20 +52,26 @@ def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, stop_crit : array, shape (n_alphas,) Value of stopping criterion at convergence along the path. """ - n_features = X.shape[1] + is_sparse = sparse.issparse(X) + n_samples = len(y) + n_features = len(X.indptr) - 1 if is_sparse else X.shape[1] all_feats = np.arange(n_features) obj_out = [] - datafit = Quadratic() - is_sparse = sparse.issparse(X) + datafit = Quadratic_32() if X.dtype == np.float32 else Quadratic() if is_sparse: datafit.initialize_sparse(X.data, X.indptr, X.indices, y) else: datafit.initialize(X, y) G = X.T @ X # gram matrix grads = (X.T @ y - G @ w_init) / len(y) if w_init is not None else X.T @ y / len(y) - w = w_init.copy() if w_init is not None else np.zeros(n_features) + w = w_init.copy() if w_init is not None else np.zeros(n_features, dtype=X.dtype) for epoch in range(max_epochs): - _cd_epoch_gram(X, G, grads, w, datafit, penalty) + if is_sparse: + _cd_epoch_gram_sparse( + G.data, G.indptr, G.indices, grads, w, datafit, penalty, n_samples, + n_features) + else: + _cd_epoch_gram(G, grads, w, datafit, penalty, n_samples, n_features) if epoch % 50 == 0: # TODO: X @ w Xw = X @ w @@ -92,14 +98,11 @@ def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, @njit -def _cd_epoch_gram(X, G, grads, w, datafit, penalty): +def _cd_epoch_gram(G, grads, w, datafit, penalty, n_samples, n_features): """Run an epoch of coordinate descent in place with gradient update using Gram. Parameters ---------- - X : array, shape (n_samples, n_features) - Design matrix. - G : array, shape (n_features, n_features) Gram matrix. @@ -114,9 +117,10 @@ def _cd_epoch_gram(X, G, grads, w, datafit, penalty): penalty : Penalty Penalty. + + n_features : int + Number of features. """ - # TODO: sparse matrix - n_features = X.shape[1] lc = datafit.lipschitz for j in range(n_features): if lc[j] == 0: @@ -125,7 +129,51 @@ def _cd_epoch_gram(X, G, grads, w, datafit, penalty): stepsize = 1 / lc[j] if lc[j] != 0 else 1000 w[j] = penalty.prox_1d(old_w_j + grads[j] / lc[j], stepsize, j) if old_w_j != w[j]: - grads += G[j, :] * (old_w_j - w[j]) / len(X) + grads += G[j, :] * (old_w_j - w[j]) / n_samples + + +@njit +def _cd_epoch_gram_sparse(G_data, G_indptr, G_indices, grads, w, datafit, penalty, + n_samples, n_features): + """Run a CD epoch with Gram update for sparse design matrices. + + Parameters + ---------- + G_data : array, shape (n_elements,) + `data` attribute of the sparse CSC matrix G. + + G_indptr : array, shape (n_features + 1,) + `indptr` attribute of the sparse CSC matrix G. + + G_indices : array, shape (n_elements,) + `indices` attribute of the sparse CSC matrix G. + + grads : array, shape (n_features,) + Gradient vector. + + w : array, shape (n_features,) + Coefficient vector. + + datafit : Datafit + Datafit. + + penalty : Penalty + Penalty. + + n_features : int + Number of features. + """ + lc = datafit.lipschitz + for j in range(n_features): + if lc[j] == 0: + continue + old_w_j = w[j] + stepsize = 1 / lc[j] + w[j] = penalty.prox_1d(old_w_j + grads[j] / lc[j], stepsize, j) + diff = old_w_j - w[j] + if diff != 0: + for i in range(G_indptr[j], G_indptr[j + 1]): + grads[G_indices[i]] += diff * G_data[i] / n_samples # def fista_gram_quadratic( From 14d38001d4606d56a74fae5bf7a5bc04144b1844 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 23:02:57 +0200 Subject: [PATCH 10/20] tests are passing --- skglm/solvers/cd_solver.py | 1 + skglm/solvers/gram.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index 153f9fbf8..f23936e02 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -155,6 +155,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, sol = cd_gram_quadratic( X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, ws_strategy=ws_strategy, verbose=verbose) + w = sol[0] else: sol = cd_solver( X, y, datafit, penalty, w, Xw, diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 8554dc107..d7c618175 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -127,7 +127,7 @@ def _cd_epoch_gram(G, grads, w, datafit, penalty, n_samples, n_features): continue old_w_j = w[j] stepsize = 1 / lc[j] if lc[j] != 0 else 1000 - w[j] = penalty.prox_1d(old_w_j + grads[j] / lc[j], stepsize, j) + w[j] = penalty.prox_1d(old_w_j + grads[j] * stepsize, stepsize, j) if old_w_j != w[j]: grads += G[j, :] * (old_w_j - w[j]) / n_samples From 59b8f6c1eea8b30828176fda3d2ad9c10ccab6fe Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 23:36:51 +0200 Subject: [PATCH 11/20] added FISTA gram --- skglm/solvers/cd_utils.py | 24 ++++++++ skglm/solvers/gram.py | 124 +++++++++++++++++++++++++++++--------- 2 files changed, 120 insertions(+), 28 deletions(-) diff --git a/skglm/solvers/cd_utils.py b/skglm/solvers/cd_utils.py index 5c00bd312..d6e3650c2 100644 --- a/skglm/solvers/cd_utils.py +++ b/skglm/solvers/cd_utils.py @@ -112,3 +112,27 @@ def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws): grad[idx] = datafit.gradient_scalar_sparse( data, indptr, indices, y, Xw, j) return grad + + +@njit +def prox_vec(penalty, z, stepsize, n_features): + """Apply the proximal operator iteratively to a vector of weight. + + Parameters + ---------- + penalty : instance of Penalty + Penalty. + + z : array, shape (n_features,) + Coefficient vector. + + stepsize : float + Step size. + + n_features : int + Number of features. + """ + w = np.zeros(n_features) + for j in range(n_features): + w[j] = penalty.prox_1d(z[j], stepsize, j) + return w diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index d7c618175..12b29f4c8 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -4,7 +4,7 @@ from skglm.datafits import Quadratic, Quadratic_32 from skglm.solvers.cd_utils import ( - construct_grad, construct_grad_sparse, dist_fix_point) + construct_grad, construct_grad_sparse, dist_fix_point, prox_vec) def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, @@ -73,7 +73,6 @@ def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, else: _cd_epoch_gram(G, grads, w, datafit, penalty, n_samples, n_features) if epoch % 50 == 0: - # TODO: X @ w Xw = X @ w p_obj = datafit.value(y, w, Xw) + penalty.value(w) @@ -118,6 +117,9 @@ def _cd_epoch_gram(G, grads, w, datafit, penalty, n_samples, n_features): penalty : Penalty Penalty. + n_samples : int + Number of samples. + n_features : int Number of features. """ @@ -160,6 +162,9 @@ def _cd_epoch_gram_sparse(G_data, G_indptr, G_indices, grads, w, datafit, penalt penalty : Penalty Penalty. + n_samples : int + Number of samples. + n_features : int Number of features. """ @@ -176,29 +181,92 @@ def _cd_epoch_gram_sparse(G_data, G_indptr, G_indices, grads, w, datafit, penalt grads[G_indices[i]] += diff * G_data[i] / n_samples -# def fista_gram_quadratic( -# X, y, penalty, max_iter, tol, w_init=None, check_freq=100): -# n_samples, n_features = X.shape -# norm_y2 = y @ y -# t_new = 1 -# w = w_init.copy() if w_init is not None else np.zeros(n_features) -# z = w_init.copy() if w_init is not None else np.zeros(n_features) -# G = X.T @ X -# Xty = X.T @ y -# L = np.linalg.norm(X, ord=2) ** 2 / n_samples -# for n_iter in range(max_iter): -# t_old = t_new -# t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 -# w_old = w.copy() -# z -= (G @ z - Xty) / L / n_samples -# w = ST_vec(z, alpha / L) # use penalty.prox -# z = w + (t_old - 1.) / t_new * (w - w_old) -# if n_iter % check_freq == 0: -# pass -# # use KKT instead -# # print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + -# # f":: gap {d_gap:.5f}") -# # if d_gap < tol: -# # print("Convergence reached!") -# # break -# return w +def fista_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, + ws_strategy="subdiff", verbose=False): + r"""Run an accelerated proximal gradient descent for quadratic datafit. + + This solver should be used when n_samples >> n_features. It does not implement any + working set strategy and iteratively updates the gradients (n_features,) instead of + the residuals (n_samples,). + + Parameters + ---------- + X : array, shape (n_samples, n_features) + Training data. + + y : array, shape (n_samples,) + Target values. + + penalty : instance of Penalty class + Penalty used in the model. + + max_epochs : int, optional + Maximum number of proximal steps. + + tol : float, optional + The tolerance for the optimization. + + w_init : array, shape (n_features,), optional + Initial coefficient vector. + + ws_strategy : ('subdiff'|'fixpoint'), optional + The score used to compute the stopping criterion. + + verbose : bool or int, optional + Amount of verbosity. 0/False is silent. + + Returns + ------- + w : array, shape (n_features,) + Coefficients. + + obj_out : array, shape (n_iter,) + Objective value at every outer iteration. + + stop_crit : array, shape (n_alphas,) + Value of stopping criterion at convergence along the path. + """ + is_sparse = sparse.issparse(X) + n_samples = len(y) + n_features = len(X.indptr) - 1 if is_sparse else X.shape[1] + all_feats = np.arange(n_features) + obj_out = [] + datafit = Quadratic_32() if X.dtype == np.float32 else Quadratic() + if is_sparse: + datafit.initialize_sparse(X.data, X.indptr, X.indices, y) + else: + datafit.initialize(X, y) + t_new = 1 + w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) + G = X.T @ X + lc = np.linalg.norm(X, ord=2) ** 2 / n_samples + for epoch in range(max_epochs): + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + z -= (G @ z - datafit.Xty) / lc / n_samples + w = prox_vec(penalty, z, 1/lc, n_features) + z = w + (t_old - 1.) / t_new * (w - w_old) + if epoch % 10 == 0: + Xw = X @ w + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + + if is_sparse: + grad = construct_grad_sparse( + X.data, X.indptr, X.indices, y, w, Xw, datafit, all_feats) + else: + grad = construct_grad(X, y, w, Xw, datafit, all_feats) + if ws_strategy == "subdiff": + opt_ws = penalty.subdiff_distance(w, grad, all_feats) + elif ws_strategy == "fixpoint": + opt_ws = dist_fix_point(w, grad, datafit, penalty, all_feats) + + stop_crit = np.max(opt_ws) + if max(verbose - 1, 0): + print(f"Epoch {epoch + 1}, objective {p_obj:.10f}, " + f"stopping crit {stop_crit:.2e}") + if stop_crit <= tol: + break + obj_out.append(p_obj) + return w, np.array(obj_out), stop_crit From 6bb47063b29288c2b693719bd6be37acfe25447e Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 23:38:48 +0200 Subject: [PATCH 12/20] linter happy --- skglm/solvers/cd_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/cd_utils.py b/skglm/solvers/cd_utils.py index d6e3650c2..200ce40c0 100644 --- a/skglm/solvers/cd_utils.py +++ b/skglm/solvers/cd_utils.py @@ -117,18 +117,18 @@ def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws): @njit def prox_vec(penalty, z, stepsize, n_features): """Apply the proximal operator iteratively to a vector of weight. - + Parameters ---------- penalty : instance of Penalty Penalty. - + z : array, shape (n_features,) Coefficient vector. - + stepsize : float Step size. - + n_features : int Number of features. """ From f0ed28a5dc7d814cfa85f3b54149ac7360bcab1a Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Tue, 10 May 2022 23:59:13 +0200 Subject: [PATCH 13/20] fix tests --- skglm/solvers/cd_solver.py | 16 +++++++++++----- skglm/solvers/cd_utils.py | 2 +- skglm/solvers/gram.py | 4 ++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index f23936e02..e03d45ca1 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -5,7 +5,7 @@ from skglm.datafits.single_task import Quadratic, Quadratic_32 from skglm.solvers.cd_utils import ( dist_fix_point, construct_grad, construct_grad_sparse) -from skglm.solvers.gram import cd_gram_quadratic +from skglm.solvers.gram import cd_gram_quadratic, fista_gram_quadratic def cd_solver_path(X, y, datafit, penalty, alphas=None, @@ -151,10 +151,16 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, if (isinstance(datafit, (Quadratic, Quadratic_32)) and n_samples > n_features and n_features < 10_000): - # Gram matrix must fit in memory - sol = cd_gram_quadratic( - X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, - ws_strategy=ws_strategy, verbose=verbose) + # Gram matrix must fit in memory hence the restriction n_features < 1e5 + if (hasattr(penalty, "alpha_max") + and penalty.alpha / penalty.alpha_max(datafit.Xty) < 1e-3): + sol = fista_gram_quadratic( + X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, + ws_strategy=ws_strategy, verbose=verbose) + else: + sol = cd_gram_quadratic( + X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, + ws_strategy=ws_strategy, verbose=verbose) w = sol[0] else: sol = cd_solver( diff --git a/skglm/solvers/cd_utils.py b/skglm/solvers/cd_utils.py index 200ce40c0..c8628daa5 100644 --- a/skglm/solvers/cd_utils.py +++ b/skglm/solvers/cd_utils.py @@ -132,7 +132,7 @@ def prox_vec(penalty, z, stepsize, n_features): n_features : int Number of features. """ - w = np.zeros(n_features) + w = np.zeros(n_features, dtype=z.dtype) for j in range(n_features): w[j] = penalty.prox_1d(z[j], stepsize, j) return w diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 12b29f4c8..88cc3f377 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -237,8 +237,8 @@ def fista_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, else: datafit.initialize(X, y) t_new = 1 - w = w_init.copy() if w_init is not None else np.zeros(n_features) - z = w_init.copy() if w_init is not None else np.zeros(n_features) + w = w_init.copy() if w_init is not None else np.zeros(n_features, dtype=X.dtype) + z = w_init.copy() if w_init is not None else np.zeros(n_features, dtype=X.dtype) G = X.T @ X lc = np.linalg.norm(X, ord=2) ** 2 / n_samples for epoch in range(max_epochs): From 5dd40fda89fb3eabe1d73d004e6f2b1eab4898c9 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Wed, 11 May 2022 00:11:35 +0200 Subject: [PATCH 14/20] added solver arguments --- skglm/solvers/cd_solver.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index e03d45ca1..0dd6302cf 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -11,7 +11,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, coef_init=None, max_iter=20, max_epochs=50_000, p0=10, tol=1e-4, use_acc=True, return_n_iter=False, - ws_strategy="subdiff", verbose=0): + solver="cd_ws", ws_strategy="subdiff", verbose=0): r"""Compute optimization path with Anderson accelerated coordinate descent. The loss is customized by passing various choices of datafit and penalty: @@ -56,6 +56,9 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, return_n_iter : bool, optional If True, number of iterations along the path are returned. + solver : ('cd_ws'|'cd_gram'|'fista'), optional + The solver used to solve the optimization problem. + ws_strategy : ('subdiff'|'fixpoint'), optional The score used to build the working set. @@ -150,10 +153,13 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, Xw = np.zeros(X.shape[0], dtype=X.dtype) if (isinstance(datafit, (Quadratic, Quadratic_32)) and n_samples > n_features - and n_features < 10_000): + and n_features < 10_000) or solver in ("cd_gram", "fista"): # Gram matrix must fit in memory hence the restriction n_features < 1e5 - if (hasattr(penalty, "alpha_max") - and penalty.alpha / penalty.alpha_max(datafit.Xty) < 1e-3): + if not isinstance(datafit, (Quadratic, Quadratic_32)): + raise ValueError("`cd_gram` and `fista` solvers are only supported " + + "for `Quadratic` datafits.") + if (hasattr(penalty, "alpha_max") and penalty.alpha / + penalty.alpha_max(datafit.Xty) < 1e-3) or solver == "fista": sol = fista_gram_quadratic( X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, ws_strategy=ws_strategy, verbose=verbose) From 07714ace261ea7c80c143f0643b0d91d6a7224ab Mon Sep 17 00:00:00 2001 From: PAB Date: Wed, 11 May 2022 09:53:21 +0200 Subject: [PATCH 15/20] Update skglm/solvers/cd_solver.py Co-authored-by: mathurinm --- skglm/solvers/cd_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index 0dd6302cf..4b0b2c724 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -154,7 +154,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, if (isinstance(datafit, (Quadratic, Quadratic_32)) and n_samples > n_features and n_features < 10_000) or solver in ("cd_gram", "fista"): - # Gram matrix must fit in memory hence the restriction n_features < 1e5 + # Gram matrix must fit in memory hence the restriction n_features < 1e4 if not isinstance(datafit, (Quadratic, Quadratic_32)): raise ValueError("`cd_gram` and `fista` solvers are only supported " + "for `Quadratic` datafits.") From 8aee1b68428ba1916b5181e88d939c3433e91788 Mon Sep 17 00:00:00 2001 From: PAB Date: Wed, 11 May 2022 11:07:49 +0200 Subject: [PATCH 16/20] Update skglm/solvers/gram.py Co-authored-by: mathurinm --- skglm/solvers/gram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 88cc3f377..aacc7bddf 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -13,7 +13,7 @@ def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, This solver should be used when n_samples >> n_features. It does not implement any working set strategy and iteratively updates the gradients (n_features,) instead of - the residuals (n_samples,). + the prediction Xw (n_samples,). Parameters ---------- From 0b940ee3f33fae47b3b294f4c0c41cf6312e6b5f Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Wed, 11 May 2022 11:15:26 +0200 Subject: [PATCH 17/20] pass Mathurin's comments --- skglm/solvers/cd_solver.py | 4 +- skglm/solvers/cd_utils.py | 6 +- skglm/solvers/gram.py | 120 +++++-------------------------------- 3 files changed, 20 insertions(+), 110 deletions(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index 0dd6302cf..1c1747f04 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -162,11 +162,11 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, penalty.alpha_max(datafit.Xty) < 1e-3) or solver == "fista": sol = fista_gram_quadratic( X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, - ws_strategy=ws_strategy, verbose=verbose) + verbose=verbose) else: sol = cd_gram_quadratic( X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, - ws_strategy=ws_strategy, verbose=verbose) + verbose=verbose) w = sol[0] else: sol = cd_solver( diff --git a/skglm/solvers/cd_utils.py b/skglm/solvers/cd_utils.py index c8628daa5..730318246 100644 --- a/skglm/solvers/cd_utils.py +++ b/skglm/solvers/cd_utils.py @@ -115,7 +115,7 @@ def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws): @njit -def prox_vec(penalty, z, stepsize, n_features): +def _prox_vec(penalty, z, stepsize): """Apply the proximal operator iteratively to a vector of weight. Parameters @@ -128,10 +128,8 @@ def prox_vec(penalty, z, stepsize, n_features): stepsize : float Step size. - - n_features : int - Number of features. """ + n_features = z.shape[0] w = np.zeros(n_features, dtype=z.dtype) for j in range(n_features): w[j] = penalty.prox_1d(z[j], stepsize, j) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 88cc3f377..47ca35afd 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -4,11 +4,11 @@ from skglm.datafits import Quadratic, Quadratic_32 from skglm.solvers.cd_utils import ( - construct_grad, construct_grad_sparse, dist_fix_point, prox_vec) + construct_grad, construct_grad_sparse, dist_fix_point, _prox_vec) def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, - ws_strategy="subdiff", verbose=0): + verbose=0): r"""Run a coordinate descent solver using Gram update for quadratic datafit. This solver should be used when n_samples >> n_features. It does not implement any @@ -35,9 +35,6 @@ def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, w_init : array, shape (n_features,), optional Initial coefficient vector. - ws_strategy : ('subdiff'|'fixpoint'), optional - The score used to compute the stopping criterion. - verbose : bool or int, optional Amount of verbosity. 0/False is silent. @@ -62,16 +59,12 @@ def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, datafit.initialize_sparse(X.data, X.indptr, X.indices, y) else: datafit.initialize(X, y) - G = X.T @ X # gram matrix - grads = (X.T @ y - G @ w_init) / len(y) if w_init is not None else X.T @ y / len(y) + XtX = (X.T @ X).toarray() if is_sparse else X.T @ X # gram matrix + grad = ((datafit.Xty - XtX @ w_init) / len(y) if w_init is not None + else datafit.Xty / len(y)) w = w_init.copy() if w_init is not None else np.zeros(n_features, dtype=X.dtype) for epoch in range(max_epochs): - if is_sparse: - _cd_epoch_gram_sparse( - G.data, G.indptr, G.indices, grads, w, datafit, penalty, n_samples, - n_features) - else: - _cd_epoch_gram(G, grads, w, datafit, penalty, n_samples, n_features) + _cd_epoch_gram(XtX, grad, w, datafit, penalty, n_samples, n_features) if epoch % 50 == 0: Xw = X @ w p_obj = datafit.value(y, w, Xw) + penalty.value(w) @@ -81,12 +74,9 @@ def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, X.data, X.indptr, X.indices, y, w, Xw, datafit, all_feats) else: grad = construct_grad(X, y, w, Xw, datafit, all_feats) - if ws_strategy == "subdiff": - opt_ws = penalty.subdiff_distance(w, grad, all_feats) - elif ws_strategy == "fixpoint": - opt_ws = dist_fix_point(w, grad, datafit, penalty, all_feats) - - stop_crit = np.max(opt_ws) + # stop criterion: fixpoint + opt = dist_fix_point(w, grad, datafit, penalty, all_feats) + stop_crit = np.max(opt) if max(verbose - 1, 0): print(f"Epoch {epoch + 1}, objective {p_obj:.10f}, " f"stopping crit {stop_crit:.2e}") @@ -97,92 +87,20 @@ def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, @njit -def _cd_epoch_gram(G, grads, w, datafit, penalty, n_samples, n_features): - """Run an epoch of coordinate descent in place with gradient update using Gram. - - Parameters - ---------- - G : array, shape (n_features, n_features) - Gram matrix. - - grads : array, shape (n_features,) - Gradient vector. - - w : array, shape (n_features,) - Coefficient vector. - - datafit : Datafit - Datafit. - - penalty : Penalty - Penalty. - - n_samples : int - Number of samples. - - n_features : int - Number of features. - """ +def _cd_epoch_gram(XtX, grad, w, datafit, penalty, n_samples, n_features): lc = datafit.lipschitz for j in range(n_features): if lc[j] == 0: continue old_w_j = w[j] stepsize = 1 / lc[j] if lc[j] != 0 else 1000 - w[j] = penalty.prox_1d(old_w_j + grads[j] * stepsize, stepsize, j) + w[j] = penalty.prox_1d(old_w_j + grad[j] * stepsize, stepsize, j) if old_w_j != w[j]: - grads += G[j, :] * (old_w_j - w[j]) / n_samples - - -@njit -def _cd_epoch_gram_sparse(G_data, G_indptr, G_indices, grads, w, datafit, penalty, - n_samples, n_features): - """Run a CD epoch with Gram update for sparse design matrices. - - Parameters - ---------- - G_data : array, shape (n_elements,) - `data` attribute of the sparse CSC matrix G. - - G_indptr : array, shape (n_features + 1,) - `indptr` attribute of the sparse CSC matrix G. - - G_indices : array, shape (n_elements,) - `indices` attribute of the sparse CSC matrix G. - - grads : array, shape (n_features,) - Gradient vector. - - w : array, shape (n_features,) - Coefficient vector. - - datafit : Datafit - Datafit. - - penalty : Penalty - Penalty. - - n_samples : int - Number of samples. - - n_features : int - Number of features. - """ - lc = datafit.lipschitz - for j in range(n_features): - if lc[j] == 0: - continue - old_w_j = w[j] - stepsize = 1 / lc[j] - w[j] = penalty.prox_1d(old_w_j + grads[j] / lc[j], stepsize, j) - diff = old_w_j - w[j] - if diff != 0: - for i in range(G_indptr[j], G_indptr[j + 1]): - grads[G_indices[i]] += diff * G_data[i] / n_samples + grad += XtX[j, :] * (old_w_j - w[j]) / n_samples def fista_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, - ws_strategy="subdiff", verbose=False): + verbose=False): r"""Run an accelerated proximal gradient descent for quadratic datafit. This solver should be used when n_samples >> n_features. It does not implement any @@ -209,9 +127,6 @@ def fista_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, w_init : array, shape (n_features,), optional Initial coefficient vector. - ws_strategy : ('subdiff'|'fixpoint'), optional - The score used to compute the stopping criterion. - verbose : bool or int, optional Amount of verbosity. 0/False is silent. @@ -246,7 +161,7 @@ def fista_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 w_old = w.copy() z -= (G @ z - datafit.Xty) / lc / n_samples - w = prox_vec(penalty, z, 1/lc, n_features) + w = _prox_vec(penalty, z, 1/lc) z = w + (t_old - 1.) / t_new * (w - w_old) if epoch % 10 == 0: Xw = X @ w @@ -257,12 +172,9 @@ def fista_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, X.data, X.indptr, X.indices, y, w, Xw, datafit, all_feats) else: grad = construct_grad(X, y, w, Xw, datafit, all_feats) - if ws_strategy == "subdiff": - opt_ws = penalty.subdiff_distance(w, grad, all_feats) - elif ws_strategy == "fixpoint": - opt_ws = dist_fix_point(w, grad, datafit, penalty, all_feats) - stop_crit = np.max(opt_ws) + opt = dist_fix_point(w, grad, datafit, penalty, all_feats) + stop_crit = np.max(opt) if max(verbose - 1, 0): print(f"Epoch {epoch + 1}, objective {p_obj:.10f}, " f"stopping crit {stop_crit:.2e}") From b5b9d0956f6cd851c73013b194f63030642c0ae4 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Wed, 11 May 2022 11:19:19 +0200 Subject: [PATCH 18/20] linter happy --- skglm/solvers/gram.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index ba13faf9c..0e50db16d 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -59,9 +59,9 @@ def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, datafit.initialize_sparse(X.data, X.indptr, X.indices, y) else: datafit.initialize(X, y) - XtX = (X.T @ X).toarray() if is_sparse else X.T @ X # gram matrix + XtX = (X.T @ X).toarray() if is_sparse else X.T @ X # gram matrix grad = ((datafit.Xty - XtX @ w_init) / len(y) if w_init is not None - else datafit.Xty / len(y)) + else datafit.Xty / len(y)) w = w_init.copy() if w_init is not None else np.zeros(n_features, dtype=X.dtype) for epoch in range(max_epochs): _cd_epoch_gram(XtX, grad, w, datafit, penalty, n_samples, n_features) From fc791d6bbe7fb3405139ec193a55445001618596 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 14 May 2022 14:00:58 +0200 Subject: [PATCH 19/20] fix w_init --- skglm/solvers/cd_solver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index adf0de8c3..f63dd5bc4 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -161,11 +161,11 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, if (hasattr(penalty, "alpha_max") and penalty.alpha / penalty.alpha_max(datafit.Xty) < 1e-3) or solver == "fista": sol = fista_gram_quadratic( - X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, + X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=coef_init, verbose=verbose) else: sol = cd_gram_quadratic( - X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=None, + X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=coef_init, verbose=verbose) w = sol[0] else: From 9e19e089c278e28eef297b86bc7c2f1d62874915 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sat, 14 May 2022 14:05:16 +0200 Subject: [PATCH 20/20] ENH if statement --- skglm/solvers/cd_solver.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index f63dd5bc4..57a11401e 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -152,10 +152,11 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, w = np.zeros(n_features, dtype=X.dtype) Xw = np.zeros(X.shape[0], dtype=X.dtype) - if (isinstance(datafit, (Quadratic, Quadratic_32)) and n_samples > n_features - and n_features < 10_000) or solver in ("cd_gram", "fista"): + is_quad_df = isinstance(datafit, (Quadratic, Quadratic_32)) + if ((is_quad_df and n_samples > n_features and n_features < 10_000) + or solver in ("cd_gram", "fista")): # Gram matrix must fit in memory hence the restriction n_features < 1e4 - if not isinstance(datafit, (Quadratic, Quadratic_32)): + if not is_quad_df: raise ValueError("`cd_gram` and `fista` solvers are only supported " + "for `Quadratic` datafits.") if (hasattr(penalty, "alpha_max") and penalty.alpha /