Skip to content

Commit 4976b64

Browse files
authored
Tunable update (#676)
* minor tweaking of some parameter ranges and added missing brulee engine arg * note location of tests * nocov end, not nocov stop
1 parent e7e6f6e commit 4976b64

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

R/tunable.R

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Lazily registered in .onLoad()
2+
# Unit tests are in extratests
3+
# nocov start
24
tunable_model_spec <- function(x, ...) {
35
mod_env <- rlang::ns_env("parsnip")$parsnip
46

@@ -141,11 +143,13 @@ brulee_engine_args <-
141143
tibble::tibble(
142144
name = c(
143145
"batch_size",
144-
"class_weights"
146+
"class_weights",
147+
"mixture"
145148
),
146149
call_info = list(
147-
list(pkg = "dials", fun = "batch_size", range = c(5, 10)),
148-
list(pkg = "dials", fun = "class_weights")
150+
list(pkg = "dials", fun = "batch_size", range = c(3, 10)),
151+
list(pkg = "dials", fun = "class_weights"),
152+
list(pkg = "dials", fun = "mixture")
149153
),
150154
source = "model_spec",
151155
component = "mlp",
@@ -160,6 +164,8 @@ tunable_linear_reg <- function(x, ...) {
160164
if (x$engine == "glmnet") {
161165
res$call_info[res$name == "mixture"] <-
162166
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
167+
} else if (x$engine == "brulee") {
168+
res <- add_engine_parameters(res, brulee_engine_args)
163169
}
164170
res
165171
}
@@ -170,6 +176,8 @@ tunable_logistic_reg <- function(x, ...) {
170176
if (x$engine == "glmnet") {
171177
res$call_info[res$name == "mixture"] <-
172178
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
179+
} else if (x$engine == "brulee") {
180+
res <- add_engine_parameters(res, brulee_engine_args)
173181
}
174182
res
175183
}
@@ -180,6 +188,8 @@ tunable_multinomial_reg <- function(x, ...) {
180188
if (x$engine == "glmnet") {
181189
res$call_info[res$name == "mixture"] <-
182190
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
191+
} else if (x$engine == "brulee") {
192+
res <- add_engine_parameters(res, brulee_engine_args)
183193
}
184194
res
185195
}
@@ -191,6 +201,8 @@ tunable_boost_tree <- function(x, ...) {
191201
res <- add_engine_parameters(res, xgboost_engine_args)
192202
res$call_info[res$name == "sample_size"] <-
193203
list(list(pkg = "dials", fun = "sample_prop"))
204+
res$call_info[res$name == "learn_rate"] <-
205+
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2)))
194206
} else {
195207
if (x$engine == "C5.0") {
196208
res <- add_engine_parameters(res, c5_boost_engine_args)
@@ -249,7 +261,10 @@ tunable_mlp <- function(x, ...) {
249261
res <- NextMethod()
250262
if (x$engine == "brulee") {
251263
res <- add_engine_parameters(res, brulee_engine_args)
264+
res$call_info[res$name == "learn_rate"] <-
265+
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2)))
252266
}
253267
res
254268
}
255269

270+
# nocov end

0 commit comments

Comments
 (0)