diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index 83584ee1e..50495e2ec 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -2,12 +2,16 @@ from numba import njit from scipy import sparse from sklearn.utils import check_array +from skglm.datafits.single_task import Quadratic, Quadratic_32 +from skglm.solvers.cd_utils import ( + dist_fix_point, construct_grad, construct_grad_sparse) +from skglm.solvers.gram import cd_gram_quadratic, fista_gram_quadratic def cd_solver_path(X, y, datafit, penalty, alphas=None, coef_init=None, max_iter=20, max_epochs=50_000, p0=10, tol=1e-4, use_acc=True, return_n_iter=False, - ws_strategy="subdiff", verbose=0): + solver="cd_ws", ws_strategy="subdiff", verbose=0): r"""Compute optimization path with Anderson accelerated coordinate descent. The loss is customized by passing various choices of datafit and penalty: @@ -52,6 +56,9 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, return_n_iter : bool, optional If True, number of iterations along the path are returned. + solver : ('cd_ws'|'cd_gram'|'fista'), optional + The solver used to solve the optimization problem. + ws_strategy : ('subdiff'|'fixpoint'), optional The score used to build the working set. @@ -109,6 +116,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, # else: # alphas = np.sort(alphas)[::-1] + n_samples = len(y) n_alphas = len(alphas) coefs = np.zeros((n_features, n_alphas), order='F', dtype=X.dtype) @@ -144,10 +152,28 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, w = np.zeros(n_features, dtype=X.dtype) Xw = np.zeros(X.shape[0], dtype=X.dtype) - sol = cd_solver( - X, y, datafit, penalty, w, Xw, - max_iter=max_iter, max_epochs=max_epochs, p0=p0, tol=tol, - use_acc=use_acc, verbose=verbose, ws_strategy=ws_strategy) + is_quad_df = isinstance(datafit, (Quadratic, Quadratic_32)) + if ((is_quad_df and n_samples > n_features and n_features < 10_000) + or solver in ("cd_gram", "fista")): + # Gram matrix must fit in memory hence the restriction n_features < 1e4 + if not is_quad_df: + raise ValueError("`cd_gram` and `fista` solvers are only supported " + + "for `Quadratic` datafits.") + if (hasattr(penalty, "alpha_max") and penalty.alpha / + penalty.alpha_max(datafit.Xty) < 1e-3) or solver == "fista": + sol = fista_gram_quadratic( + X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=coef_init, + verbose=verbose) + else: + sol = cd_gram_quadratic( + X, y, penalty, max_epochs=max_epochs, tol=tol, w_init=coef_init, + verbose=verbose) + w = sol[0] + else: + sol = cd_solver( + X, y, datafit, penalty, w, Xw, + max_iter=max_iter, max_epochs=max_epochs, p0=p0, tol=tol, + use_acc=use_acc, verbose=verbose, ws_strategy=ws_strategy) coefs[:, t] = w stop_crits[t] = sol[-1] @@ -214,11 +240,11 @@ def cd_solver( Returns ------- - coefs : array, shape (n_features, n_alphas) - Coefficients along the path. + w : array, shape (n_features,) + Coefficients. obj_out : array, shape (n_iter,) - The objective values at every outer iteration. + Objective value at every outer iteration. stop_crit : float Value of stopping criterion at convergence. @@ -347,118 +373,6 @@ def cd_solver( return w, np.array(obj_out), stop_crit -@njit -def dist_fix_point(w, grad, datafit, penalty, ws): - """Compute the violation of the fixed point iterate scheme. - - Parameters - ---------- - w : array, shape (n_features,) - Coefficient vector. - - grad : array, shape (n_features,) - Gradient. - - datafit: instance of BaseDatafit - Datafit. - - penalty: instance of BasePenalty - Penalty. - - ws : array, shape (n_features,) - The working set. - - Returns - ------- - dist_fix_point : array, shape (n_features,) - Violation score for every feature. - """ - dist_fix_point = np.zeros(ws.shape[0]) - for idx, j in enumerate(ws): - lcj = datafit.lipschitz[j] - if lcj != 0: - dist_fix_point[idx] = np.abs( - w[j] - penalty.prox_1d(w[j] - grad[idx] / lcj, 1. / lcj, j)) - return dist_fix_point - - -@njit -def construct_grad(X, y, w, Xw, datafit, ws): - """Compute the gradient of the datafit restricted to the working set. - - Parameters - ---------- - X : array, shape (n_samples, n_features) - Design matrix. - - y : array, shape (n_samples,) - Target vector. - - w : array, shape (n_features,) - Coefficient vector. - - Xw : array, shape (n_samples, ) - Model fit. - - datafit : Datafit - Datafit. - - ws : array, shape (n_features,) - The working set. - - Returns - ------- - grad : array, shape (ws_size, n_tasks) - The gradient restricted to the working set. - """ - grad = np.zeros(ws.shape[0]) - for idx, j in enumerate(ws): - grad[idx] = datafit.gradient_scalar(X, y, w, Xw, j) - return grad - - -@njit -def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws): - """Compute the gradient of the datafit restricted to the working set. - - Parameters - ---------- - data : array-like - Data array of the matrix in CSC format. - - indptr : array-like - CSC format index point array. - - indices : array-like - CSC format index array. - - y : array, shape (n_samples, ) - Target matrix. - - w : array, shape (n_features,) - Coefficient matrix. - - Xw : array, shape (n_samples, ) - Model fit. - - datafit : Datafit - Datafit. - - ws : array, shape (n_features,) - The working set. - - Returns - ------- - grad : array, shape (ws_size, n_tasks) - The gradient restricted to the working set. - """ - grad = np.zeros(ws.shape[0]) - for idx, j in enumerate(ws): - grad[idx] = datafit.gradient_scalar_sparse( - data, indptr, indices, y, Xw, j) - return grad - - @njit def _cd_epoch(X, y, w, Xw, datafit, penalty, feats): """Run an epoch of coordinate descent in place. diff --git a/skglm/solvers/cd_utils.py b/skglm/solvers/cd_utils.py new file mode 100644 index 000000000..730318246 --- /dev/null +++ b/skglm/solvers/cd_utils.py @@ -0,0 +1,136 @@ +import numpy as np +from numba import njit + + +@njit +def dist_fix_point(w, grad, datafit, penalty, ws): + """Compute the violation of the fixed point iterate scheme. + + Parameters + ---------- + w : array, shape (n_features,) + Coefficient vector. + + grad : array, shape (n_features,) + Gradient. + + datafit: instance of BaseDatafit + Datafit. + + penalty: instance of BasePenalty + Penalty. + + ws : array, shape (n_features,) + The working set. + + Returns + ------- + dist_fix_point : array, shape (n_features,) + Violation score for every feature. + """ + dist_fix_point = np.zeros(ws.shape[0]) + for idx, j in enumerate(ws): + lcj = datafit.lipschitz[j] + if lcj != 0: + dist_fix_point[idx] = np.abs( + w[j] - penalty.prox_1d(w[j] - grad[idx] / lcj, 1. / lcj, j)) + return dist_fix_point + + +@njit +def construct_grad(X, y, w, Xw, datafit, ws): + """Compute the gradient of the datafit restricted to the working set. + + Parameters + ---------- + X : array, shape (n_samples, n_features) + Design matrix. + + y : array, shape (n_samples,) + Target vector. + + w : array, shape (n_features,) + Coefficient vector. + + Xw : array, shape (n_samples, ) + Model fit. + + datafit : Datafit + Datafit. + + ws : array, shape (n_features,) + The working set. + + Returns + ------- + grad : array, shape (ws_size, n_tasks) + The gradient restricted to the working set. + """ + grad = np.zeros(ws.shape[0]) + for idx, j in enumerate(ws): + grad[idx] = datafit.gradient_scalar(X, y, w, Xw, j) + return grad + + +@njit +def construct_grad_sparse(data, indptr, indices, y, w, Xw, datafit, ws): + """Compute the gradient of the datafit restricted to the working set. + + Parameters + ---------- + data : array-like + Data array of the matrix in CSC format. + + indptr : array-like + CSC format index point array. + + indices : array-like + CSC format index array. + + y : array, shape (n_samples, ) + Target matrix. + + w : array, shape (n_features,) + Coefficient matrix. + + Xw : array, shape (n_samples, ) + Model fit. + + datafit : Datafit + Datafit. + + ws : array, shape (n_features,) + The working set. + + Returns + ------- + grad : array, shape (ws_size, n_tasks) + The gradient restricted to the working set. + """ + grad = np.zeros(ws.shape[0]) + for idx, j in enumerate(ws): + grad[idx] = datafit.gradient_scalar_sparse( + data, indptr, indices, y, Xw, j) + return grad + + +@njit +def _prox_vec(penalty, z, stepsize): + """Apply the proximal operator iteratively to a vector of weight. + + Parameters + ---------- + penalty : instance of Penalty + Penalty. + + z : array, shape (n_features,) + Coefficient vector. + + stepsize : float + Step size. + """ + n_features = z.shape[0] + w = np.zeros(n_features, dtype=z.dtype) + for j in range(n_features): + w[j] = penalty.prox_1d(z[j], stepsize, j) + return w diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py new file mode 100644 index 000000000..0e50db16d --- /dev/null +++ b/skglm/solvers/gram.py @@ -0,0 +1,184 @@ +import numpy as np +from numba import njit +from scipy import sparse + +from skglm.datafits import Quadratic, Quadratic_32 +from skglm.solvers.cd_utils import ( + construct_grad, construct_grad_sparse, dist_fix_point, _prox_vec) + + +def cd_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, + verbose=0): + r"""Run a coordinate descent solver using Gram update for quadratic datafit. + + This solver should be used when n_samples >> n_features. It does not implement any + working set strategy and iteratively updates the gradients (n_features,) instead of + the prediction Xw (n_samples,). + + Parameters + ---------- + X : array, shape (n_samples, n_features) + Training data. + + y : array, shape (n_samples,) + Target values. + + penalty : instance of Penalty class + Penalty used in the model. + + max_epochs : int, optional + Maximum number of CD epochs. + + tol : float, optional + The tolerance for the optimization. + + w_init : array, shape (n_features,), optional + Initial coefficient vector. + + verbose : bool or int, optional + Amount of verbosity. 0/False is silent. + + Returns + ------- + w : array, shape (n_features,) + Coefficients. + + obj_out : array, shape (n_iter,) + Objective value at every outer iteration. + + stop_crit : array, shape (n_alphas,) + Value of stopping criterion at convergence along the path. + """ + is_sparse = sparse.issparse(X) + n_samples = len(y) + n_features = len(X.indptr) - 1 if is_sparse else X.shape[1] + all_feats = np.arange(n_features) + obj_out = [] + datafit = Quadratic_32() if X.dtype == np.float32 else Quadratic() + if is_sparse: + datafit.initialize_sparse(X.data, X.indptr, X.indices, y) + else: + datafit.initialize(X, y) + XtX = (X.T @ X).toarray() if is_sparse else X.T @ X # gram matrix + grad = ((datafit.Xty - XtX @ w_init) / len(y) if w_init is not None + else datafit.Xty / len(y)) + w = w_init.copy() if w_init is not None else np.zeros(n_features, dtype=X.dtype) + for epoch in range(max_epochs): + _cd_epoch_gram(XtX, grad, w, datafit, penalty, n_samples, n_features) + if epoch % 50 == 0: + Xw = X @ w + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + + if is_sparse: + grad = construct_grad_sparse( + X.data, X.indptr, X.indices, y, w, Xw, datafit, all_feats) + else: + grad = construct_grad(X, y, w, Xw, datafit, all_feats) + # stop criterion: fixpoint + opt = dist_fix_point(w, grad, datafit, penalty, all_feats) + stop_crit = np.max(opt) + if max(verbose - 1, 0): + print(f"Epoch {epoch + 1}, objective {p_obj:.10f}, " + f"stopping crit {stop_crit:.2e}") + if stop_crit <= tol: + break + obj_out.append(p_obj) + return w, np.array(obj_out), stop_crit + + +@njit +def _cd_epoch_gram(XtX, grad, w, datafit, penalty, n_samples, n_features): + lc = datafit.lipschitz + for j in range(n_features): + if lc[j] == 0: + continue + old_w_j = w[j] + stepsize = 1 / lc[j] if lc[j] != 0 else 1000 + w[j] = penalty.prox_1d(old_w_j + grad[j] * stepsize, stepsize, j) + if old_w_j != w[j]: + grad += XtX[j, :] * (old_w_j - w[j]) / n_samples + + +def fista_gram_quadratic(X, y, penalty, max_epochs=1000, tol=1e-4, w_init=None, + verbose=False): + r"""Run an accelerated proximal gradient descent for quadratic datafit. + + This solver should be used when n_samples >> n_features. It does not implement any + working set strategy and iteratively updates the gradients (n_features,) instead of + the residuals (n_samples,). + + Parameters + ---------- + X : array, shape (n_samples, n_features) + Training data. + + y : array, shape (n_samples,) + Target values. + + penalty : instance of Penalty class + Penalty used in the model. + + max_epochs : int, optional + Maximum number of proximal steps. + + tol : float, optional + The tolerance for the optimization. + + w_init : array, shape (n_features,), optional + Initial coefficient vector. + + verbose : bool or int, optional + Amount of verbosity. 0/False is silent. + + Returns + ------- + w : array, shape (n_features,) + Coefficients. + + obj_out : array, shape (n_iter,) + Objective value at every outer iteration. + + stop_crit : array, shape (n_alphas,) + Value of stopping criterion at convergence along the path. + """ + is_sparse = sparse.issparse(X) + n_samples = len(y) + n_features = len(X.indptr) - 1 if is_sparse else X.shape[1] + all_feats = np.arange(n_features) + obj_out = [] + datafit = Quadratic_32() if X.dtype == np.float32 else Quadratic() + if is_sparse: + datafit.initialize_sparse(X.data, X.indptr, X.indices, y) + else: + datafit.initialize(X, y) + t_new = 1 + w = w_init.copy() if w_init is not None else np.zeros(n_features, dtype=X.dtype) + z = w_init.copy() if w_init is not None else np.zeros(n_features, dtype=X.dtype) + G = X.T @ X + lc = np.linalg.norm(X, ord=2) ** 2 / n_samples + for epoch in range(max_epochs): + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + z -= (G @ z - datafit.Xty) / lc / n_samples + w = _prox_vec(penalty, z, 1/lc) + z = w + (t_old - 1.) / t_new * (w - w_old) + if epoch % 10 == 0: + Xw = X @ w + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + + if is_sparse: + grad = construct_grad_sparse( + X.data, X.indptr, X.indices, y, w, Xw, datafit, all_feats) + else: + grad = construct_grad(X, y, w, Xw, datafit, all_feats) + + opt = dist_fix_point(w, grad, datafit, penalty, all_feats) + stop_crit = np.max(opt) + if max(verbose - 1, 0): + print(f"Epoch {epoch + 1}, objective {p_obj:.10f}, " + f"stopping crit {stop_crit:.2e}") + if stop_crit <= tol: + break + obj_out.append(p_obj) + return w, np.array(obj_out), stop_crit