Skip to content

Commit 70b170b

Browse files
committed
more data dimension protections
1 parent 118b897 commit 70b170b

File tree

8 files changed

+109
-13
lines changed

8 files changed

+109
-13
lines changed

R/boost_tree.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,13 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
244244
# min_n parameters
245245
if (any(names(arg_vals) == "min_instances_per_node")) {
246246
arg_vals$min_instances_per_node <-
247-
rlang::call2("min", arg_vals$min_instances_per_node, expr(nrow(x)))
247+
rlang::call2("min", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(nrow(x)))
248248
}
249249

250250
## -----------------------------------------------------------------------------
251251

252+
x$method$fit$args <- arg_vals
253+
252254
x
253255
}
254256

R/decision_tree.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,22 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {
180180
}
181181
}
182182

183+
## -----------------------------------------------------------------------------
184+
# Protect some arguments based on data dimensions
185+
186+
if (any(names(arg_vals) == "minsplit")) {
187+
arg_vals$minsplit <-
188+
rlang::call2("min", rlang::eval_tidy(arg_vals$minsplit), expr(nrow(data)))
189+
}
190+
if (any(names(arg_vals) == "min_instances_per_node")) {
191+
arg_vals$min_instances_per_node <-
192+
rlang::call2("min", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(nrow(x)))
193+
}
194+
195+
## -----------------------------------------------------------------------------
196+
197+
x$method$fit$args <- arg_vals
198+
183199
x
184200
}
185201

R/nearest_neighbor.R

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,25 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) {
168168
}
169169
x <- translate.default(x, engine, ...)
170170

171+
arg_vals <- x$method$fit$args
172+
171173
if (engine == "kknn") {
172-
if (!any(names(x$method$fit$args) == "ks") ||
173-
is_missing_arg(x$method$fit$args$ks)) {
174-
x$method$fit$args$ks <- 5
174+
175+
if (!any(names(arg_vals) == "ks") || is_missing_arg(arg_vals$ks)) {
176+
arg_vals$ks <- 5
177+
}
178+
179+
## -----------------------------------------------------------------------------
180+
# Protect some arguments based on data dimensions
181+
182+
if (any(names(arg_vals) == "ks")) {
183+
arg_vals$ks <-
184+
rlang::call2("min", rlang::eval_tidy(arg_vals$ks), expr(nrow(data) - 5))
175185
}
176186
}
187+
188+
x$method$fit$args <- arg_vals
189+
177190
x
178191
}
179192

man/nearest_neighbor.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree.R

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,17 @@ test_that('bad input', {
151151
expect_error(translate(boost_tree(formula = y ~ x)))
152152
})
153153

154-
# ------------------------------------------------------------------------------
154+
155+
## -----------------------------------------------------------------------------
156+
157+
test_that('argument checks for data dimensions', {
158+
159+
spec <-
160+
boost_tree(mtry = 1000, min_n = 1000, trees = 5) %>%
161+
set_engine("spark") %>%
162+
set_mode("classification")
163+
164+
args <- translate(spec)$method$fit$args
165+
expect_equal(args$min_instances_per_node, expr(min(1000, nrow(x))))
166+
})
167+

tests/testthat/test_decision_tree.R

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ test_that('primary arguments', {
5959
formula = expr(missing_arg()),
6060
data = expr(missing_arg()),
6161
weights = expr(missing_arg()),
62-
minsplit = new_empty_quosure(15)
62+
minsplit = expr(min(15, nrow(data)))
6363
)
6464
)
6565

@@ -148,3 +148,33 @@ test_that('default engine', {
148148
)
149149
expect_true(inherits(fit$fit, "rpart"))
150150
})
151+
152+
153+
154+
## -----------------------------------------------------------------------------
155+
156+
test_that('argument checks for data dimensions', {
157+
158+
data(penguins, package = "modeldata")
159+
penguins <- na.omit(penguins)
160+
161+
spec <-
162+
decision_tree(min_n = 1000) %>%
163+
set_engine("rpart") %>%
164+
set_mode("regression")
165+
166+
f_fit <- spec %>% fit(body_mass_g ~ ., data = penguins)
167+
xy_fit <- spec %>% fit_xy(x = penguins[, -6], y = penguins$body_mass_g)
168+
169+
expect_equal(f_fit$fit$control$minsplit, nrow(penguins))
170+
expect_equal(xy_fit$fit$control$minsplit, nrow(penguins))
171+
172+
spec <-
173+
decision_tree(min_n = 1000) %>%
174+
set_engine("spark") %>%
175+
set_mode("regression")
176+
177+
args <- translate(spec)$method$fit$args
178+
expect_equal(args$min_instances_per_node, rlang::expr(min(1000, nrow(x))))
179+
180+
})

tests/testthat/test_nearest_neighbor.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ test_that('primary arguments', {
1818
expected = list(
1919
formula = expr(missing_arg()),
2020
data = expr(missing_arg()),
21-
ks = 5
21+
ks = expr(min(5, nrow(data) - 5))
2222
)
2323
)
2424

@@ -30,7 +30,7 @@ test_that('primary arguments', {
3030
expected = list(
3131
formula = expr(missing_arg()),
3232
data = expr(missing_arg()),
33-
ks = new_empty_quosure(2)
33+
ks = expr(min(2, nrow(data) - 5))
3434
)
3535
)
3636

@@ -43,7 +43,7 @@ test_that('primary arguments', {
4343
formula = expr(missing_arg()),
4444
data = expr(missing_arg()),
4545
kernel = new_empty_quosure("triangular"),
46-
ks = 5
46+
ks = expr(min(5, nrow(data) - 5))
4747
)
4848
)
4949

@@ -56,7 +56,7 @@ test_that('primary arguments', {
5656
formula = expr(missing_arg()),
5757
data = expr(missing_arg()),
5858
distance = new_empty_quosure(2),
59-
ks = 5
59+
ks = expr(min(5, nrow(data) - 5))
6060
)
6161
)
6262

@@ -72,7 +72,7 @@ test_that('engine arguments', {
7272
formula = expr(missing_arg()),
7373
data = expr(missing_arg()),
7474
scale = new_empty_quosure(FALSE),
75-
ks = 5
75+
ks = expr(min(5, nrow(data) - 5))
7676
)
7777
)
7878

tests/testthat/test_nearest_neighbor_kknn.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,25 @@ test_that('kknn multi-predict', {
193193
dplyr::select(.pred)
194194
expect_equal(pred_uni, pred_uni_obs)
195195
})
196+
197+
198+
## -----------------------------------------------------------------------------
199+
200+
test_that('argument checks for data dimensions', {
201+
202+
data(penguins, package = "modeldata")
203+
penguins <- na.omit(penguins)
204+
205+
spec <-
206+
nearest_neighbor(neighbors = 1000) %>%
207+
set_engine("kknn") %>%
208+
set_mode("regression")
209+
210+
f_fit <- spec %>% fit(body_mass_g ~ ., data = penguins)
211+
xy_fit <- spec %>% fit_xy(x = penguins[, -6], y = penguins$body_mass_g)
212+
213+
expect_equal(f_fit$fit$best.parameters$k, nrow(penguins) - 5)
214+
expect_equal(xy_fit$fit$best.parameters$k, nrow(penguins) - 5)
215+
216+
})
217+

0 commit comments

Comments
 (0)