Skip to content

Commit b8dbf45

Browse files
committed
Add warnings when the argument is corrected
1 parent 636de7c commit b8dbf45

17 files changed

+195
-37
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ export(make_classes)
132132
export(mars)
133133
export(maybe_data_frame)
134134
export(maybe_matrix)
135+
export(min_cols)
136+
export(min_rows)
135137
export(mlp)
136138
export(model_printer)
137139
export(multi_predict)

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
* For three models (`glmnet`, `xgboost`, and `ranger`), enable sparse matrix use via `fit_xy()` (#373).
66

7-
* Some added protections were added for function arguments that are dependent on the data dimensions (e.g., `mtry`, `neighbors`, `min_n`, etc).
7+
* Some added protections were added for function arguments that are dependent on the data dimensions (e.g., `mtry`, `neighbors`, `min_n`, etc). (#184)
8+
9+
* Infrastructure was improved for running `parsnip` models in parallel using PSOCK clusters on Windows.
810

911
# parsnip 0.1.3
1012

R/arguments.R

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,66 @@ make_xy_call <- function(object, target) {
204204

205205
fit_call
206206
}
207+
208+
## -----------------------------------------------------------------------------
209+
#' Execution-time data dimension checks
210+
#'
211+
#' For some tuning parameters, the range of values depend on the data
212+
#' dimensions (e.g. `mtry`). Some packages will fail if the parameter values are
213+
#' outside of these ranges. Since the model might receive resampled versions of
214+
#' the data, these ranges can't be set prior to the point where the model is
215+
#' fit. These functions check the possible range of the data and adjust them
216+
#' if needed (with a warning).
217+
#'
218+
#' @param num_cols,num_rows The parameter value requested by the user.
219+
#' @param source A data frame for the data to be used in the fit. If the source
220+
#' is named "data", it is assumed that one column of the data corresponds to
221+
#' an outcome (and is subtracted off).
222+
#' @param offset A number subtracted off of the number of rows available in the
223+
#' data.
224+
#' @return An integer (and perhaps a warning).
225+
#' @examples
226+
227+
#' nearest_neighbor(neighbors= 100) %>%
228+
#' set_engine("kknn") %>%
229+
#' set_mode("regression") %>%
230+
#' translate()
231+
#'
232+
#' library(ranger)
233+
#' rand_forest(mtry = 2, min_n = 100, trees = 3) %>%
234+
#' set_engine("ranger") %>%
235+
#' set_mode("regression") %>%
236+
#' fit(mpg ~ ., data = mtcars)
237+
#' @export
238+
min_cols <- function(num_cols, source) {
239+
cl <- match.call()
240+
src_name <- rlang::expr_text(cl$source)
241+
if (cl$source == "data") {
242+
p <- ncol(source) - 1
243+
} else {
244+
p <- ncol(source)
245+
}
246+
if (num_cols > p) {
247+
msg <- paste0(num_cols, " columns were requested but there were ", p,
248+
" predictors in the data. ", p, " will be used.")
249+
rlang::warn(msg)
250+
num_cols <- p
251+
}
252+
253+
as.integer(num_cols)
254+
}
255+
256+
#' @export
257+
#' @rdname min_cols
258+
min_rows <- function(num_rows, source, offset = 0) {
259+
n <- nrow(source)
260+
261+
if (num_rows > n - offset) {
262+
msg <- paste0(num_rows, " samples were requested but there were ", n,
263+
" rows in the data. ", n - offset, " will be used.")
264+
rlang::warn(msg)
265+
num_rows <- n - offset
266+
}
267+
268+
as.integer(num_rows)
269+
}

R/boost_tree.R

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
244244
# min_n parameters
245245
if (any(names(arg_vals) == "min_instances_per_node")) {
246246
arg_vals$min_instances_per_node <-
247-
rlang::call2("min", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(nrow(x)))
247+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(x))
248248
}
249249

250250
## -----------------------------------------------------------------------------
@@ -357,6 +357,13 @@ xgb_train <- function(
357357
colsample_bytree <- 1
358358
}
359359

360+
if (min_child_weight > n) {
361+
msg <- paste0(min_child_weight, " samples were requested but there were ",
362+
n, " rows in the data. ", n, " will be used.")
363+
rlang::warn(msg)
364+
min_child_weight <- min(min_child_weight, n)
365+
}
366+
360367
arg_list <- list(
361368
eta = eta,
362369
max_depth = max_depth,
@@ -537,8 +544,21 @@ C5.0_train <-
537544
ctrl_args <- other_args[names(other_args) %in% c_names]
538545
fit_args <- other_args[names(other_args) %in% f_names]
539546

547+
n <- nrow(x)
548+
if (n == 0) {
549+
rlang::abort("There are zero rows in the predictor set.")
550+
}
551+
552+
540553
ctrl <- call2("C5.0Control", .ns = "C50")
541-
ctrl$minCases <- min(minCases, nrow(x))
554+
if (minCases > n) {
555+
msg <- paste0(minCases, " samples were requested but there were ",
556+
n, " rows in the data. ", n, " will be used.")
557+
rlang::warn(msg)
558+
minCases <- n
559+
}
560+
ctrl$minCases <- minCases
561+
542562
ctrl$sample <- sample
543563
ctrl <- rlang::call_modify(ctrl, !!!ctrl_args)
544564

R/decision_tree.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,11 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {
185185

186186
if (any(names(arg_vals) == "minsplit")) {
187187
arg_vals$minsplit <-
188-
rlang::call2("min", rlang::eval_tidy(arg_vals$minsplit), expr(nrow(data)))
188+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$minsplit), expr(data))
189189
}
190190
if (any(names(arg_vals) == "min_instances_per_node")) {
191191
arg_vals$min_instances_per_node <-
192-
rlang::call2("min", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(nrow(x)))
192+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(x))
193193
}
194194

195195
## -----------------------------------------------------------------------------

R/nearest_neighbor.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) {
181181

182182
if (any(names(arg_vals) == "ks")) {
183183
arg_vals$ks <-
184-
rlang::call2("min", rlang::eval_tidy(arg_vals$ks), expr(nrow(data) - 5))
184+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$ks), expr(data), 5)
185185
}
186186
}
187187

R/rand_forest.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,20 +203,20 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
203203
# Protect some arguments based on data dimensions
204204

205205
if (any(names(arg_vals) == "mtry")) {
206-
arg_vals$mtry <- rlang::call2("min", arg_vals$mtry, expr(ncol(x)))
206+
arg_vals$mtry <- rlang::call2("min_cols", arg_vals$mtry, expr(x))
207207
}
208208

209209
if (any(names(arg_vals) == "min.node.size")) {
210210
arg_vals$min.node.size <-
211-
rlang::call2("min", arg_vals$min.node.size, expr(nrow(x)))
211+
rlang::call2("min_rows", arg_vals$min.node.size, expr(x))
212212
}
213213
if (any(names(arg_vals) == "nodesize")) {
214214
arg_vals$nodesize <-
215-
rlang::call2("min", arg_vals$nodesize, expr(nrow(x)))
215+
rlang::call2("min_rows", arg_vals$nodesize, expr(x))
216216
}
217217
if (any(names(arg_vals) == "min_instances_per_node")) {
218218
arg_vals$min_instances_per_node <-
219-
rlang::call2("min", arg_vals$min_instances_per_node, expr(nrow(x)))
219+
rlang::call2("min_rows", arg_vals$min_instances_per_node, expr(x))
220220
}
221221

222222
## -----------------------------------------------------------------------------

man/boost_tree.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/min_cols.Rd

Lines changed: 44 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nearest_neighbor.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)