1
1
# Lazily registered in .onLoad()
2
+ # Unit tests are in extratests
3
+ # nocov start
2
4
tunable_model_spec <- function (x , ... ) {
3
5
mod_env <- rlang :: ns_env(" parsnip" )$ parsnip
4
6
@@ -141,11 +143,13 @@ brulee_engine_args <-
141
143
tibble :: tibble(
142
144
name = c(
143
145
" batch_size" ,
144
- " class_weights"
146
+ " class_weights" ,
147
+ " mixture"
145
148
),
146
149
call_info = list (
147
- list (pkg = " dials" , fun = " batch_size" , range = c(5 , 10 )),
148
- list (pkg = " dials" , fun = " class_weights" )
150
+ list (pkg = " dials" , fun = " batch_size" , range = c(3 , 10 )),
151
+ list (pkg = " dials" , fun = " class_weights" ),
152
+ list (pkg = " dials" , fun = " mixture" )
149
153
),
150
154
source = " model_spec" ,
151
155
component = " mlp" ,
@@ -160,6 +164,8 @@ tunable_linear_reg <- function(x, ...) {
160
164
if (x $ engine == " glmnet" ) {
161
165
res $ call_info [res $ name == " mixture" ] <-
162
166
list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
167
+ } else if (x $ engine == " brulee" ) {
168
+ res <- add_engine_parameters(res , brulee_engine_args )
163
169
}
164
170
res
165
171
}
@@ -170,6 +176,8 @@ tunable_logistic_reg <- function(x, ...) {
170
176
if (x $ engine == " glmnet" ) {
171
177
res $ call_info [res $ name == " mixture" ] <-
172
178
list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
179
+ } else if (x $ engine == " brulee" ) {
180
+ res <- add_engine_parameters(res , brulee_engine_args )
173
181
}
174
182
res
175
183
}
@@ -180,6 +188,8 @@ tunable_multinomial_reg <- function(x, ...) {
180
188
if (x $ engine == " glmnet" ) {
181
189
res $ call_info [res $ name == " mixture" ] <-
182
190
list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
191
+ } else if (x $ engine == " brulee" ) {
192
+ res <- add_engine_parameters(res , brulee_engine_args )
183
193
}
184
194
res
185
195
}
@@ -191,6 +201,8 @@ tunable_boost_tree <- function(x, ...) {
191
201
res <- add_engine_parameters(res , xgboost_engine_args )
192
202
res $ call_info [res $ name == " sample_size" ] <-
193
203
list (list (pkg = " dials" , fun = " sample_prop" ))
204
+ res $ call_info [res $ name == " learn_rate" ] <-
205
+ list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
194
206
} else {
195
207
if (x $ engine == " C5.0" ) {
196
208
res <- add_engine_parameters(res , c5_boost_engine_args )
@@ -249,7 +261,10 @@ tunable_mlp <- function(x, ...) {
249
261
res <- NextMethod()
250
262
if (x $ engine == " brulee" ) {
251
263
res <- add_engine_parameters(res , brulee_engine_args )
264
+ res $ call_info [res $ name == " learn_rate" ] <-
265
+ list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
252
266
}
253
267
res
254
268
}
255
269
270
+ # nocov end
0 commit comments