Skip to content

Commit 3ad1862

Browse files
committed
protype changes for #184
1 parent 7caf085 commit 3ad1862

File tree

3 files changed

+81
-3
lines changed

3 files changed

+81
-3
lines changed

R/rand_forest.R

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
161161

162162
x <- translate.default(x, engine, ...)
163163

164+
## -----------------------------------------------------------------------------
165+
164166
# slightly cleaner code using
165167
arg_vals <- x$method$fit$args
166168

@@ -185,14 +187,40 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
185187

186188
# add checks to error trap or change things for this method
187189
if (engine == "ranger") {
188-
if (any(names(arg_vals) == "importance"))
189-
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance))))
190+
191+
if (any(names(arg_vals) == "importance")) {
192+
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) {
190193
rlang::abort("`importance` should be a character value. See ?ranger::ranger.")
194+
}
195+
}
191196
# unless otherwise specified, classification models are probability forests
192-
if (x$mode == "classification" && !any(names(arg_vals) == "probability"))
197+
if (x$mode == "classification" && !any(names(arg_vals) == "probability")) {
193198
arg_vals$probability <- TRUE
199+
}
200+
}
201+
202+
## -----------------------------------------------------------------------------
203+
# Protect some arguments based on data dimensions
194204

205+
if (any(names(arg_vals) == "mtry")) {
206+
arg_vals$mtry <- rlang::call2("min", arg_vals$mtry, expr(ncol(x)))
195207
}
208+
209+
if (any(names(arg_vals) == "min.node.size")) {
210+
arg_vals$min.node.size <-
211+
rlang::call2("min", arg_vals$min.node.size, expr(nrow(x)))
212+
}
213+
if (any(names(arg_vals) == "nodesize")) {
214+
arg_vals$nodesize <-
215+
rlang::call2("min", arg_vals$nodesize, expr(nrow(x)))
216+
}
217+
if (any(names(arg_vals) == "min_instances_per_node")) {
218+
arg_vals$min_instances_per_node <-
219+
rlang::call2("min", arg_vals$min_instances_per_node, expr(nrow(x)))
220+
}
221+
222+
## -----------------------------------------------------------------------------
223+
196224
x$method$fit$args <- arg_vals
197225

198226
x

tests/testthat/test_rand_forest_randomForest.R

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,29 @@ test_that('randomForest regression prediction', {
213213
expect_equal(xy_pred, predict(xy_fit, new_data = tail(mtcars))$.pred)
214214

215215
})
216+
217+
## -----------------------------------------------------------------------------
218+
219+
test_that('argument checks for data dimensions', {
220+
221+
skip_if_not_installed("randomForest")
222+
223+
data(penguins, package = "modeldata")
224+
penguins <- na.omit(penguins)
225+
226+
spec <-
227+
rand_forest(mtry = 1000, min_n = 1000, trees = 5) %>%
228+
set_engine("randomForest") %>%
229+
set_mode("regression")
230+
231+
f_fit <- spec %>% fit(body_mass_g ~ ., data = penguins)
232+
xy_fit <- spec %>% fit_xy(x = penguins[, -6], y = penguins$body_mass_g)
233+
234+
expect_equal(f_fit$fit$mtry, 6)
235+
expect_equal(f_fit$fit$call$nodesize, rlang::expr(min(~1000, nrow(x))))
236+
expect_equal(xy_fit$fit$mtry, 6)
237+
expect_equal(xy_fit$fit$call$nodesize, rlang::expr(min(~1000, nrow(x))))
238+
239+
})
240+
241+

tests/testthat/test_rand_forest_ranger.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,3 +474,27 @@ test_that('ranger and sparse matrices', {
474474

475475
})
476476

477+
478+
## -----------------------------------------------------------------------------
479+
480+
test_that('argument checks for data dimensions', {
481+
482+
skip_if_not_installed("ranger")
483+
484+
data(penguins, package = "modeldata")
485+
penguins <- na.omit(penguins)
486+
487+
spec <-
488+
rand_forest(mtry = 1000, min_n = 1000, trees = 5) %>%
489+
set_engine("ranger") %>%
490+
set_mode("regression")
491+
492+
f_fit <- spec %>% fit(body_mass_g ~ ., data = penguins)
493+
xy_fit <- spec %>% fit_xy(x = penguins[, -6], y = penguins$body_mass_g)
494+
495+
expect_equal(f_fit$fit$mtry, 6)
496+
expect_equal(f_fit$fit$min.node.size, nrow(penguins))
497+
expect_equal(xy_fit$fit$mtry, 6)
498+
expect_equal(xy_fit$fit$min.node.size, nrow(penguins))
499+
500+
})

0 commit comments

Comments
 (0)