Skip to content

Commit dd586ca

Browse files
Add glmnet support (#165)
1 parent 0691638 commit dd586ca

File tree

10 files changed

+460
-1
lines changed

10 files changed

+460
-1
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Suggests:
3333
DBI,
3434
dbplyr,
3535
earth (>= 5.1.2),
36+
glmnet,
3637
methods,
3738
mlbench,
3839
modeldata,
@@ -53,5 +54,5 @@ Config/testthat/edition: 3
5354
Encoding: UTF-8
5455
Roxygen: list(markdown = TRUE)
5556
RoxygenNote: 7.3.3
56-
Remotes:
57+
Remotes:
5758
topepo/Cubist

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ S3method(knit_print,tidypredict_test)
77
S3method(parse_model,cubist)
88
S3method(parse_model,earth)
99
S3method(parse_model,glm)
10+
S3method(parse_model,glmnet)
1011
S3method(parse_model,lm)
1112
S3method(parse_model,model_fit)
1213
S3method(parse_model,party)
@@ -19,6 +20,7 @@ S3method(tidypredict_fit,"_xgb.Booster")
1920
S3method(tidypredict_fit,cubist)
2021
S3method(tidypredict_fit,earth)
2122
S3method(tidypredict_fit,glm)
23+
S3method(tidypredict_fit,glmnet)
2224
S3method(tidypredict_fit,lm)
2325
S3method(tidypredict_fit,model_fit)
2426
S3method(tidypredict_fit,party)
@@ -34,6 +36,7 @@ S3method(tidypredict_interval,list)
3436
S3method(tidypredict_interval,lm)
3537
S3method(tidypredict_test,"_xgb.Booster")
3638
S3method(tidypredict_test,default)
39+
S3method(tidypredict_test,glmnet)
3740
S3method(tidypredict_test,model_fit)
3841
S3method(tidypredict_test,party)
3942
S3method(tidypredict_test,xgb.Booster)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
- tree based models now uses `.default` argument in produced `case_when()` code when applicable. (#153)
2626

27+
- Added support for glmnet models. (#165)
28+
2729
# tidypredict 0.5.1
2830

2931
- Exported a number of internal functions to be used in {orbital} package

R/model-glmnet.R

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Predict ---------------------------------------
2+
3+
#' @export
4+
tidypredict_fit.glmnet <- function(model) {
5+
parsedmodel <- parse_model(model)
6+
build_fit_formula(parsedmodel)
7+
}
8+
9+
# Parse model --------------------------------------
10+
11+
#' @export
12+
parse_model.glmnet <- function(model) {
13+
parse_model_glmnet(model)
14+
}
15+
16+
parse_model_glmnet <- function(model, call = rlang::caller_env()) {
17+
if (length(model$lambda) != 1) {
18+
cli::cli_abort(
19+
"{.fn tidypredict_fit} requires that there are only 1 penalty selected,
20+
{length(model$lambda)} were provided.",
21+
call = call
22+
)
23+
}
24+
if (inherits(model$beta, "dgCMatrix")) {
25+
model$beta <- setNames(as.numeric(model$beta), rownames(model$beta))
26+
}
27+
coefs <- c("(Intercept)" = unname(model$a0), model$beta)
28+
29+
names <- names(coefs)
30+
values <- as.vector(coefs)
31+
32+
terms <- map2(values, names, \(value, name) {
33+
if (value == 0) {
34+
return(NULL)
35+
}
36+
list(
37+
label = name,
38+
coef = value,
39+
is_intercept = as.integer(name == "(Intercept)"),
40+
fields = list(list(type = "ordinary", col = name))
41+
)
42+
})
43+
44+
terms <- purrr::discard(terms, is.null)
45+
46+
pm <- list()
47+
pm$general$model <- class(model)[[2]]
48+
pm$general$version <- 1
49+
pm$general$type <- "regression"
50+
pm$general$is_glm <- 1
51+
pm$terms <- terms
52+
53+
if (inherits(model, "elnet")) {
54+
pm$general$family <- "gaussian"
55+
pm$general$link <- "identity"
56+
} else if (inherits(model, "lognet")) {
57+
pm$general$family <- "binomial"
58+
pm$general$link <- "logit"
59+
} else if (inherits(model, "fishnet")) {
60+
pm$general$family <- "poisson"
61+
pm$general$link <- "log"
62+
} else {
63+
cli::cli_abort(
64+
"Model fit with this {.arg family} is not supported."
65+
)
66+
}
67+
68+
as_parsed_model(pm)
69+
}

R/tidymodels.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,37 @@ tidypredict_fit._xgb.Booster <- function(model) {
77

88
#' @export
99
tidypredict_fit.model_fit <- function(model) {
10+
model <- glmnet_set_lambda(model)
1011
tidypredict_fit(model$fit)
1112
}
1213

1314
#' @export
1415
parse_model.model_fit <- function(model) {
16+
model <- glmnet_set_lambda(model)
1517
parse_model(model$fit)
1618
}
1719

20+
# glmnet adjustment ------------------------------------------------------
21+
22+
glmnet_set_lambda <- function(model) {
23+
if (inherits(model$fit, "glmnet")) {
24+
penalty <- model$spec$args$penalty
25+
coef <- glmnet::predict.glmnet(
26+
model$fit,
27+
s = penalty,
28+
type = "coefficients"
29+
)
30+
31+
if ("(Intercept)" %in% rownames(coef)) {
32+
model$fit$a0 <- coef["(Intercept)", ]
33+
coef <- coef["(Intercept)" != rownames(coef), ]
34+
}
35+
model$fit$lambda <- penalty
36+
model$fit$beta <- coef
37+
}
38+
model
39+
}
40+
1841
# broom ------------------------------------------------------------------
1942

2043
#' @export

R/tidypredict_test.R

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,121 @@ tidypredict_test_default <- function(
186186
structure(results, class = c("tidypredict_test", "list"))
187187
}
188188

189+
#' @export
190+
tidypredict_test.glmnet <- function(
191+
model,
192+
df = model$model,
193+
threshold = 0.000000000001,
194+
include_intervals = FALSE,
195+
max_rows = NULL,
196+
xg_df = NULL
197+
) {
198+
offset <- model$call$offset
199+
ismodels <- paste0(colnames(model$model), collapse = " ") ==
200+
paste0(colnames(df), collapse = " ")
201+
202+
if (!is.null(offset) && ismodels) {
203+
index <- colnames(df) == "(offset)"
204+
colnames(df) <- replace(colnames(df), index, as.character(offset))
205+
}
206+
207+
interval <- "none"
208+
if (include_intervals) {
209+
interval <- "prediction"
210+
}
211+
212+
if (is.numeric(max_rows)) {
213+
df <- head(df, max_rows)
214+
}
215+
216+
preds <- predict(model, as.matrix(df), interval = interval, type = "response")
217+
218+
if (!include_intervals) {
219+
base <- data.frame(fit = as.vector(preds), row.names = NULL)
220+
} else {
221+
base <- as.data.frame(preds)
222+
}
223+
224+
te <- tidypredict_to_column(
225+
df,
226+
model,
227+
add_interval = include_intervals,
228+
vars = c("fit_te", "upr_te", "lwr_te")
229+
)
230+
if (include_intervals) {
231+
te <- te[, c("fit_te", "upr_te", "lwr_te")]
232+
} else {
233+
te <- data.frame(fit_te = te[, "fit_te"])
234+
}
235+
236+
raw_results <- cbind(base, te)
237+
raw_results$fit_diff <- raw_results$fit - raw_results$fit_te
238+
raw_results$fit_threshold <- abs(raw_results$fit_diff) > threshold
239+
240+
if (include_intervals) {
241+
raw_results$lwr_diff <- abs(raw_results$lwr - raw_results$lwr_te)
242+
raw_results$upr_diff <- abs(raw_results$upr - raw_results$upr_te)
243+
raw_results$lwr_threshold <- raw_results$lwr_diff > threshold
244+
raw_results$upr_threshold <- raw_results$upr_diff > threshold
245+
}
246+
247+
rowid <- seq_len(nrow(raw_results))
248+
raw_results <- cbind(data.frame(rowid), raw_results)
249+
250+
threshold_df <- data.frame(fit_threshold = sum(raw_results$fit_threshold))
251+
if (include_intervals) {
252+
threshold_df$lwr_threshold <- sum(raw_results$lwr_threshold)
253+
threshold_df$upr_threshold <- sum(raw_results$upr_threshold)
254+
}
255+
256+
alert <- any(threshold_df > 0)
257+
258+
message <- paste0(
259+
"tidypredict test results\n",
260+
"Difference threshold: ",
261+
threshold,
262+
"\n"
263+
)
264+
265+
if (alert) {
266+
difference <- data.frame(fit_diff = max(raw_results$fit_diff))
267+
if (include_intervals) {
268+
difference$lwr_diff <- max(raw_results$lwr_diff)
269+
difference$upr_diff <- max(raw_results$upr_diff)
270+
}
271+
message <- paste0(
272+
message,
273+
"\nFitted records above the threshold: ",
274+
threshold_df$fit_threshold,
275+
if (!is.null(threshold_df$lwr_threshold)) {
276+
"\nLower interval records above the threshold: "
277+
},
278+
threshold_df$lwr_threshold,
279+
if (!is.null(threshold_df$upr_threshold)) {
280+
"\nUpper interval records above the threshold: "
281+
},
282+
threshold_df$upr_threshold,
283+
"\n\nFit max difference:",
284+
difference$upr_diff,
285+
"\nLower max difference:",
286+
difference$lwr_diff,
287+
"\nUpper max difference:",
288+
difference$fit_diff
289+
)
290+
} else {
291+
message <- paste0(
292+
message,
293+
"\n All results are within the difference threshold"
294+
)
295+
}
296+
results <- list()
297+
results$model_call <- model$call
298+
results$raw_results <- raw_results
299+
results$message <- message
300+
results$alert <- alert
301+
structure(results, class = c("tidypredict_test", "list"))
302+
}
303+
189304
#' @export
190305
tidypredict_test.xgb.Booster <- function(
191306
model,

_pkgdown.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ navbar:
2121
href: articles/lm.html
2222
- text: Generalized Regression - glm()
2323
href: articles/glm.html
24+
- text: Regularized Regression - glmnet()
25+
href: articles/glmnet.html
2426
- text: Random Forest - Ranger - ranger()
2527
href: articles/ranger.html
2628
- text: Random Forest - randomForest()
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# returns the right output
2+
3+
Code
4+
rlang::expr_text(tf)
5+
Output
6+
[1] "35.3137765116027 + (cyl * -0.871451193824228) + (hp * -0.0101173960249783) + \n (wt * -2.59443677687505)"
7+
8+
# formulas produces correct predictions
9+
10+
Code
11+
tidypredict_test(glmnet::glmnet(mtcars[, -1], mtcars$mpg, family = "gaussian",
12+
lambda = 1), mtcars[, -1])
13+
Output
14+
tidypredict test results
15+
Difference threshold: 1e-12
16+
17+
All results are within the difference threshold
18+
19+
---
20+
21+
Code
22+
tidypredict_test(glmnet::glmnet(mtcars[, -8], mtcars$vs, family = "binomial",
23+
lambda = 1), mtcars[, -1])
24+
Output
25+
tidypredict test results
26+
Difference threshold: 1e-12
27+
28+
All results are within the difference threshold
29+
30+
---
31+
32+
Code
33+
tidypredict_test(glmnet::glmnet(mtcars[, -8], mtcars$vs, family = "poisson",
34+
lambda = 1), mtcars[, -1])
35+
Output
36+
tidypredict test results
37+
Difference threshold: 1e-12
38+
39+
All results are within the difference threshold
40+
41+
# errors if more than 1 penalty is selected
42+
43+
Code
44+
tidypredict_fit(model)
45+
Condition
46+
Error in `parse_model()`:
47+
! `tidypredict_fit()` requires that there are only 1 penalty selected, 79 were provided.
48+
49+
---
50+
51+
Code
52+
tidypredict_fit(model)
53+
Condition
54+
Error in `parse_model()`:
55+
! `tidypredict_fit()` requires that there are only 1 penalty selected, 2 were provided.
56+
57+
# glmnet are handeld neatly with parsnip
58+
59+
Code
60+
rlang::expr_text(tf)
61+
Output
62+
[1] "35.3140536966127 + (cyl * -0.871623418095165) + (hp * -0.0101157918502673) + \n (wt * -2.59426484734253)"
63+

0 commit comments

Comments
 (0)