Skip to content

Commit 375e310

Browse files
authored
Merge pull request #381 from tidymodels/regression-xy-error
Update and make `mode` checking more robust
2 parents a272a2c + f569b38 commit 375e310

File tree

6 files changed

+63
-35
lines changed

6 files changed

+63
-35
lines changed

R/fit_helpers.R

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
form_form <-
88
function(object, control, env, ...) {
99

10+
# prob rewrite this as simple subset/levels
11+
y_levels <- levels_from_formula(env$formula, env$data)
12+
1013
if (object$mode == "classification") {
11-
# prob rewrite this as simple subset/levels
12-
y_levels <- levels_from_formula(env$formula, env$data)
1314
if (!inherits(env$data, "tbl_spark") && is.null(y_levels))
14-
rlang::abort("For classification models, the outcome should be a factor.")
15-
} else {
16-
y_levels <- NULL
15+
rlang::abort("For a classification model, the outcome should be a factor.")
16+
} else if (object$mode == "regression") {
17+
if (!inherits(env$data, "tbl_spark") && !is.null(y_levels))
18+
rlang::abort("For a regression model, the outcome should be numeric.")
1719
}
1820

1921
object <- check_mode(object, y_levels)
@@ -57,11 +59,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
5759
rlang::abort("spark objects can only be used with the formula interface to `fit()`")
5860

5961
object <- check_mode(object, levels(env$y))
60-
61-
if (object$mode == "classification") {
62-
if (is.null(levels(env$y)))
63-
rlang::abort("For classification models, the outcome should be a factor.")
64-
}
62+
check_outcome(env$y, object)
6563

6664
encoding_info <-
6765
get_encoding(class(object)[1]) %>%
@@ -133,7 +131,10 @@ form_xy <- function(object, control, env,
133131
res <- list(lvl = levels_from_formula(env$formula, env$data), spec = object)
134132
if (object$mode == "classification") {
135133
if (is.null(res$lvl))
136-
rlang::abort("For classification models, the outcome should be a factor.")
134+
rlang::abort("For a classification model, the outcome should be a factor.")
135+
} else if (object$mode == "regression") {
136+
if (!is.null(res$lvl))
137+
rlang::abort("For a regression model, the outcome should be numeric.")
137138
}
138139

139140
res <- xy_xy(
@@ -153,10 +154,7 @@ form_xy <- function(object, control, env,
153154

154155
xy_form <- function(object, env, control, ...) {
155156

156-
if (object$mode == "classification") {
157-
if (is.null(levels(env$y)))
158-
rlang::abort("For classification models, the outcome should be a factor.")
159-
}
157+
check_outcome(env$y, object)
160158

161159
encoding_info <-
162160
get_encoding(class(object)[1]) %>%

R/misc.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,11 @@ check_outcome <- function(y, spec) {
212212
if (spec$mode == "unknown") {
213213
return(invisible(NULL))
214214
} else if (spec$mode == "regression") {
215-
if (!is.numeric(y))
216-
rlang::abort("The model outcome should be numeric for regression models.")
215+
if (!all(map_lgl(y, is.numeric)))
216+
rlang::abort("For a regression model, the outcome should be numeric.")
217217
} else if (spec$mode == "classification") {
218-
if (!is.factor(y)) {
219-
rlang::abort("The model outcome should be a factor for regression models.")
218+
if (!all(map_lgl(y, is.factor))) {
219+
rlang::abort("For a classification model, the outcome should be a factor.")
220220
}
221221
}
222222
invisible(NULL)

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ test_that('xgboost execution, regression', {
122122
),
123123
regexp = NA
124124
)
125+
126+
expect_error(
127+
res <- parsnip::fit_xy(
128+
car_basic,
129+
x = mtcars[, num_pred],
130+
y = factor(mtcars$vs),
131+
control = ctrl
132+
),
133+
regexp = "For a regression model"
134+
)
125135
})
126136

127137

tests/testthat/test_linear_reg.R

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,16 @@ test_that('lm execution', {
241241
regexp = NA
242242
)
243243

244+
expect_error(
245+
res <- fit_xy(
246+
hpc_basic,
247+
x = hpc[, num_pred],
248+
y = hpc$class,
249+
control = ctrl
250+
),
251+
regexp = "For a regression model"
252+
)
253+
244254
expect_error(
245255
res <- fit(
246256
hpc_basic,
@@ -250,13 +260,15 @@ test_that('lm execution', {
250260
)
251261
)
252262

253-
lm_form_catch <- fit(
254-
hpc_basic,
255-
hpc_bad_form,
256-
data = hpc,
257-
control = caught_ctrl
263+
expect_error(
264+
lm_form_catch <- fit(
265+
hpc_basic,
266+
hpc_bad_form,
267+
data = hpc,
268+
control = caught_ctrl
269+
),
270+
regexp = "For a regression model"
258271
)
259-
expect_true(inherits(lm_form_catch$fit, "try-error"))
260272

261273
## multivariate y
262274

tests/testthat/test_mars.R

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,14 @@ test_that('mars execution', {
151151
expect_true(has_multi_predict(res))
152152
expect_equal(multi_predict_args(res), "num_terms")
153153

154-
expect_message(
155-
expect_error(
156-
res <- fit(
157-
hpc_basic,
158-
hpc_bad_form,
159-
data = hpc,
160-
control = ctrl
161-
)
154+
expect_error(
155+
res <- fit(
156+
hpc_basic,
157+
hpc_bad_form,
158+
data = hpc,
159+
control = ctrl
162160
),
163-
"Timing stopped"
161+
regexp = "For a regression model"
164162
)
165163

166164
## multivariate y
@@ -203,7 +201,7 @@ test_that('mars prediction', {
203201
input_fields =
204202
c(430.476046435458, 158.833790342308, 218.07635084308,
205203
158.833790342308, 158.833790342308)
206-
),
204+
),
207205
class = "data.frame", row.names = c(NA, -5L)
208206
)
209207

tests/testthat/test_rand_forest_ranger.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ test_that('ranger classification execution', {
3737
)
3838
expect_output(print(res), "parsnip model object")
3939

40+
expect_error(
41+
res <- fit(
42+
lc_ranger,
43+
funded_amnt ~ Class + term,
44+
data = lending_club,
45+
control = ctrl
46+
),
47+
regexp = "For a classification model"
48+
)
49+
4050
expect_error(
4151
res <- fit_xy(
4252
lc_ranger,

0 commit comments

Comments
 (0)