-
Notifications
You must be signed in to change notification settings - Fork 1
[WIP] Test custom obj function #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
61da860
799914f
f228237
f9b0565
2f91001
f8828a4
ee68585
6b0dd91
6735cf8
7c11ea8
76472b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -124,7 +124,7 @@ | |
| #' @return A fitted \code{lgb.Booster} object. | ||
| #' @keywords internal | ||
| #' @export | ||
| train_lightgbm <- function(x, | ||
| train_lightgbm <- function(x, # nolint | ||
| y, | ||
| num_iterations = 10, | ||
| max_depth = 17, | ||
|
|
@@ -146,8 +146,31 @@ | |
| force(y) | ||
| others <- list(...) | ||
|
|
||
| # Set training objective (always regression) | ||
| if (!any(names(others) %in% c("objective"))) { | ||
| # Custom objective handling. `mse_cov_rho` is a lightsnip-specific engine | ||
| # arg used only when `objective == "mse_cov"`; pop it off so it is not | ||
| # forwarded to lgb.train (which would error on an unknown parameter). | ||
| mse_cov_rho <- others$mse_cov_rho | ||
| others$mse_cov_rho <- NULL | ||
|
|
||
| custom_obj <- NULL | ||
|
|
||
| # Sentinel that gates two downstream branches: | ||
| # - whether to fall back to the default "regression" objective | ||
| # - whether to construct the mse_cov objective callback | ||
| mse_cov_rho_val <- NULL | ||
|
|
||
| if (!is.null(others$objective) && identical(others$objective, "mse_cov")) { | ||
| mse_cov_rho_val <- if (is.null(mse_cov_rho)) 1e-3 else as.numeric(mse_cov_rho) | ||
| # Clear `objective`/`num_class` so lgb.train doesn't reject the unknown | ||
| # name when we hand it the callback via `obj`. | ||
| others$objective <- NULL | ||
| others$num_class <- NULL | ||
| } | ||
|
|
||
| # Set training objective default (always regression) when not specified. | ||
| # Skipped when a custom `obj` callback is in use, since lgb.train will then | ||
| # supply the gradient/hessian itself and `objective` must be unset. | ||
| if (is.null(mse_cov_rho_val) && !any(names(others) %in% c("objective"))) { | ||
| others$num_class <- 1 | ||
| others$objective <- "regression" | ||
| } | ||
|
|
@@ -235,6 +258,17 @@ | |
| trn_index <- 1:n | ||
| } | ||
|
|
||
| # Build the mse_cov callback against training rows only — `y[val_index]` | ||
| # is held out for lgb.train's early stopping, so including those labels | ||
| # in `y_mean` would leak the holdout's label mean into the centering | ||
| # term used by the covariance penalty on every boosting iteration. | ||
| if (!is.null(mse_cov_rho_val)) { | ||
| custom_obj <- make_obj_mse_cov( | ||
| rho = mse_cov_rho_val, | ||
| y_mean = mean(y[trn_index]) | ||
| ) | ||
| } | ||
|
|
||
| d <- lightgbm::lgb.Dataset( | ||
| data = as.matrix(x[trn_index, , drop = FALSE]), | ||
| label = y[trn_index], | ||
|
|
@@ -270,6 +304,10 @@ | |
| if (!is.null(early_stop) && validation > 0) { | ||
| main_args$early_stopping_rounds <- early_stop | ||
| } | ||
| # Wire in the custom objective callback (if any) under lgb.train's `obj` arg | ||
| if (!is.null(custom_obj)) { | ||
| main_args$obj <- quote(custom_obj) | ||
| } | ||
|
|
||
| call <- parsnip::make_call(fun = "lgb.train", ns = "lightgbm", main_args) | ||
| rlang::eval_tidy(call, env = rlang::current_env()) | ||
|
|
@@ -288,9 +326,14 @@ | |
| #' | ||
| #' @export | ||
| pred_lgb_reg_num <- function(object, new_data, ...) { | ||
| # Use type = "raw" so the result is the unmodified booster score. For | ||
| # regression this is identical to type = "response" but, unlike "response", | ||
| # it does not warn when the booster was trained with a custom objective | ||
| # (e.g. lightsnip's `mse_cov`). | ||
| stats::predict( | ||
| object$fit, | ||
| as.matrix(new_data), | ||
| type = "raw", | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This stops the script of dropping a bunch of warnings |
||
| params = list(predict_disable_shape_check = TRUE), | ||
| ... | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| #' Custom LightGBM objective: MSE + rho * Cov(r, y)^2 | ||
| #' | ||
| #' @description Build a custom LightGBM objective callback that minimizes a | ||
| #' standard squared-error loss plus a soft penalty on the covariance between | ||
| #' the per-sample residual `r = y_pred - y_true` and the (centered) labels | ||
| #' `y_true`. The penalty pushes the model toward "vertical equity" by | ||
| #' discouraging residuals that systematically scale with `y`. | ||
| #' | ||
| #' This is an R port of the `LGBCovPenalty` objective from | ||
| #' an active collabration. (https://github.com/nicacevedo/soft-vertical-equity-constrained-mass-appraissal) # nolint | ||
| #' It is intended to be used when the model is trained in log-space (so the | ||
| #' "diff" residual is equivalent to a log-ratio). | ||
| #' | ||
| #' Penalty (using mean-centered labels yc = y_true - mean(y_true)): | ||
| #' \deqn{cov = (1/n) * sum_i r_i * yc_i} | ||
| #' \deqn{penalty = 0.5 * rho * n * cov^2} | ||
| #' Diagonal Hessian approximation is used (matches the reference Python | ||
| #' implementation). | ||
| #' | ||
| #' @param rho Numeric. Non-negative penalty weight. `rho = 0` recovers plain | ||
| #' MSE. | ||
| #' @param y_mean Numeric. Mean of the training labels. Should be computed once | ||
| #' from the training set and captured here so the centering is stable across | ||
| #' iterations. | ||
| #' @param zero_grad_tol Numeric. Floor applied to absolute gradients/Hessians | ||
| #' to avoid zero entries that confuse LightGBM. Matches the reference | ||
| #' implementation. | ||
| #' | ||
| #' @return A function with signature `function(preds, dtrain)` suitable for | ||
| #' passing as the `obj` argument of [lightgbm::lgb.train]. | ||
| #' | ||
| #' @export | ||
| make_obj_mse_cov <- function(rho, y_mean, zero_grad_tol = 1e-6) { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The math here is completely borrowed from the source implementation |
||
| rho <- as.numeric(rho) | ||
| y_mean <- as.numeric(y_mean) | ||
| zero_grad_tol <- as.numeric(zero_grad_tol) | ||
| if (length(rho) != 1L || is.na(rho) || rho < 0) { | ||
| rlang::abort("`rho` must be a single non-negative numeric value.") | ||
| } | ||
|
|
||
| function(preds, dtrain) { | ||
| y_true <- lightgbm::get_field(dtrain, "label") | ||
| y_pred <- as.numeric(preds) | ||
| n <- length(y_pred) | ||
|
|
||
| # Centered labels (training-set mean is captured at construction time so | ||
| # the penalty geometry stays stable across boosting iterations) | ||
| yc <- y_true - y_mean | ||
|
|
||
| # Residual ("diff" mode); in log-space training this is the log-ratio | ||
| r <- y_pred - y_true | ||
|
|
||
| # Covariance estimate (E[yc] is ~0 by construction) | ||
| cov_val <- mean(r * yc) | ||
|
|
||
| # Base squared-error grad/hess | ||
| grad_base <- 2.0 * (y_pred - y_true) | ||
| hess_base <- rep(2.0, n) | ||
|
|
||
| # Penalty grad/hess (diagonal approximation) | ||
| # dc/dy_pred_i = (1/n) * yc_i (since dr_i/dy_pred_i = 1) | ||
| a <- yc / n | ||
| grad_pen <- rho * n * cov_val * a | ||
| hess_pen <- rho * n * (a^2) | ||
|
|
||
| grad <- grad_base + grad_pen | ||
| hess <- hess_base + hess_pen | ||
|
|
||
| # Floor tiny values, mirroring the reference implementation | ||
| small_g <- abs(grad) < zero_grad_tol | ||
| if (any(small_g)) grad[small_g] <- zero_grad_tol | ||
| small_h <- hess < zero_grad_tol | ||
| if (any(small_h)) hess[small_h] <- zero_grad_tol | ||
|
|
||
| list(grad = grad, hess = hess) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This list is passed and parsed as |
||
| } | ||
| } | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LightGBM takes either
objectiveorobjargs, code here, and ultimately collects one of the two asobjectiveMore context on the comment here from the lightbgm source code. Shown here:
Checks to see if if there is a custom function supplied, saves it as
fobjto be passed to C later using gradient and hessian, and the"none"value lets lightgbm know that the custom math will be supplied later, rather than it using one of its built in C objective functions.