Skip to content

Commit ca2f22b

Browse files
committed
Update and use check_outcome in fit helpers
1 parent a272a2c commit ca2f22b

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

R/fit_helpers.R

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
form_form <-
88
function(object, control, env, ...) {
99

10+
# prob rewrite this as simple subset/levels
11+
y_levels <- levels_from_formula(env$formula, env$data)
12+
1013
if (object$mode == "classification") {
11-
# prob rewrite this as simple subset/levels
12-
y_levels <- levels_from_formula(env$formula, env$data)
1314
if (!inherits(env$data, "tbl_spark") && is.null(y_levels))
14-
rlang::abort("For classification models, the outcome should be a factor.")
15-
} else {
16-
y_levels <- NULL
15+
rlang::abort("For a classification model, the outcome should be a factor.")
16+
} else if (object$mode == "regression") {
17+
if (!inherits(env$data, "tbl_spark") && !is.null(y_levels))
18+
rlang::abort("For a regression model, the outcome should be numeric.")
1719
}
1820

1921
object <- check_mode(object, y_levels)
@@ -57,11 +59,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
5759
rlang::abort("spark objects can only be used with the formula interface to `fit()`")
5860

5961
object <- check_mode(object, levels(env$y))
60-
61-
if (object$mode == "classification") {
62-
if (is.null(levels(env$y)))
63-
rlang::abort("For classification models, the outcome should be a factor.")
64-
}
62+
check_outcome(env$y, object)
6563

6664
encoding_info <-
6765
get_encoding(class(object)[1]) %>%
@@ -133,7 +131,10 @@ form_xy <- function(object, control, env,
133131
res <- list(lvl = levels_from_formula(env$formula, env$data), spec = object)
134132
if (object$mode == "classification") {
135133
if (is.null(res$lvl))
136-
rlang::abort("For classification models, the outcome should be a factor.")
134+
rlang::abort("For a classification model, the outcome should be a factor.")
135+
} else if (object$mode == "regression") {
136+
if (!is.null(res$lvl))
137+
rlang::abort("For a regression model, the outcome should be numeric.")
137138
}
138139

139140
res <- xy_xy(
@@ -153,10 +154,7 @@ form_xy <- function(object, control, env,
153154

154155
xy_form <- function(object, env, control, ...) {
155156

156-
if (object$mode == "classification") {
157-
if (is.null(levels(env$y)))
158-
rlang::abort("For classification models, the outcome should be a factor.")
159-
}
157+
check_outcome(env$y, object)
160158

161159
encoding_info <-
162160
get_encoding(class(object)[1]) %>%

R/misc.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,11 @@ check_outcome <- function(y, spec) {
212212
if (spec$mode == "unknown") {
213213
return(invisible(NULL))
214214
} else if (spec$mode == "regression") {
215-
if (!is.numeric(y))
216-
rlang::abort("The model outcome should be numeric for regression models.")
215+
if (!all(map_lgl(y, is.numeric)))
216+
rlang::abort("For a regression model, the outcome should be numeric.")
217217
} else if (spec$mode == "classification") {
218-
if (!is.factor(y)) {
219-
rlang::abort("The model outcome should be a factor for regression models.")
218+
if (!all(map_lgl(y, is.factor))) {
219+
rlang::abort("For a classification model, the outcome should be a factor.")
220220
}
221221
}
222222
invisible(NULL)

0 commit comments

Comments
 (0)