Skip to content

Commit 1706169

Browse files
committed
boost_tree() data dimension checks
1 parent 3ad1862 commit 1706169

File tree

3 files changed

+76
-7
lines changed

3 files changed

+76
-7
lines changed

R/boost_tree.R

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
221221
}
222222
x <- translate.default(x, engine, ...)
223223

224+
## -----------------------------------------------------------------------------
225+
226+
arg_vals <- x$method$fit$args
227+
224228
if (engine == "spark") {
225229
if (x$mode == "unknown") {
226230
rlang::abort(
@@ -230,9 +234,21 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
230234
)
231235
)
232236
} else {
233-
x$method$fit$args$type <- x$mode
237+
arg_vals$type <- x$mode
234238
}
235239
}
240+
241+
## -----------------------------------------------------------------------------
242+
# Protect some arguments based on data dimensions
243+
244+
# min_n parameters
245+
if (any(names(arg_vals) == "min_instances_per_node")) {
246+
arg_vals$min_instances_per_node <-
247+
rlang::call2("min", arg_vals$min_instances_per_node, expr(nrow(x)))
248+
}
249+
250+
## -----------------------------------------------------------------------------
251+
236252
x
237253
}
238254

@@ -242,14 +258,18 @@ check_args.boost_tree <- function(object) {
242258

243259
args <- lapply(object$args, rlang::eval_tidy)
244260

245-
if (is.numeric(args$trees) && args$trees < 0)
261+
if (is.numeric(args$trees) && args$trees < 0) {
246262
rlang::abort("`trees` should be >= 1.")
247-
if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1))
263+
}
264+
if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1)) {
248265
rlang::abort("`sample_size` should be within [0,1].")
249-
if (is.numeric(args$tree_depth) && args$tree_depth < 0)
266+
}
267+
if (is.numeric(args$tree_depth) && args$tree_depth < 0) {
250268
rlang::abort("`tree_depth` should be >= 1.")
251-
if (is.numeric(args$min_n) && args$min_n < 0)
269+
}
270+
if (is.numeric(args$min_n) && args$min_n < 0) {
252271
rlang::abort("`min_n` should be >= 1.")
272+
}
253273

254274
invisible(object)
255275
}
@@ -340,7 +360,7 @@ xgb_train <- function(
340360
max_depth = max_depth,
341361
gamma = gamma,
342362
colsample_bytree = colsample_bytree,
343-
min_child_weight = min_child_weight,
363+
min_child_weight = min(min_child_weight, n),
344364
subsample = subsample
345365
)
346366

@@ -516,7 +536,7 @@ C5.0_train <-
516536
fit_args <- other_args[names(other_args) %in% f_names]
517537

518538
ctrl <- call2("C5.0Control", .ns = "C50")
519-
ctrl$minCases <- minCases
539+
ctrl$minCases <- min(minCases, nrow(x))
520540
ctrl$sample <- sample
521541
ctrl <- rlang::call_modify(ctrl, !!!ctrl_args)
522542

tests/testthat/test_boost_tree_C50.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,25 @@ test_that('submodel prediction', {
138138
)
139139
})
140140

141+
142+
## -----------------------------------------------------------------------------
143+
144+
test_that('argument checks for data dimensions', {
145+
146+
skip_if_not_installed("C50")
147+
148+
data(penguins, package = "modeldata")
149+
penguins <- na.omit(penguins)
150+
151+
spec <-
152+
boost_tree(min_n = 1000, trees = 5) %>%
153+
set_engine("C5.0") %>%
154+
set_mode("classification")
155+
156+
f_fit <- spec %>% fit(species ~ ., data = penguins)
157+
xy_fit <- spec %>% fit_xy(x = penguins[, -1], y = penguins$species)
158+
159+
expect_equal(f_fit$fit$control$minCases, nrow(penguins))
160+
expect_equal(xy_fit$fit$control$minCases, nrow(penguins))
161+
162+
})

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,4 +352,31 @@ test_that('xgboost data and sparse matrices', {
352352
})
353353

354354

355+
## -----------------------------------------------------------------------------
356+
357+
test_that('argument checks for data dimensions', {
358+
359+
skip_if_not_installed("C50")
360+
361+
data(penguins, package = "modeldata")
362+
penguins <- na.omit(penguins)
363+
364+
spec <-
365+
boost_tree(mtry = 1000, min_n = 1000, trees = 5) %>%
366+
set_engine("xgboost") %>%
367+
set_mode("classification")
368+
369+
penguins_dummy <- model.matrix(species ~ ., data = penguins)
370+
penguins_dummy <- as.data.frame(penguins_dummy[, -1])
371+
372+
f_fit <- spec %>% fit(species ~ ., data = penguins)
373+
xy_fit <- spec %>% fit_xy(x = penguins_dummy, y = penguins$species)
374+
375+
expect_equal(f_fit$fit$params$colsample_bytree, 1)
376+
expect_equal(f_fit$fit$params$min_child_weight, nrow(penguins))
377+
expect_equal(xy_fit$fit$params$colsample_bytree, 1)
378+
expect_equal(xy_fit$fit$params$min_child_weight, nrow(penguins))
379+
380+
})
381+
355382

0 commit comments

Comments
 (0)