Skip to content

Commit d251c73

Browse files
authored
Merge pull request #351 from tidymodels/update-engine-args
Update engine-specific args
2 parents 78a51ac + d9075f0 commit d251c73

27 files changed

+145
-55
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ export(tidy)
177177
export(translate)
178178
export(translate.default)
179179
export(update_dot_check)
180+
export(update_engine_parameters)
180181
export(update_main_parameters)
181182
export(varying)
182183
export(varying_args)

R/boost_tree.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ update.boost_tree <-
163163
loss_reduction = NULL, sample_size = NULL,
164164
stop_iter = NULL,
165165
fresh = FALSE, ...) {
166-
update_dot_check(...)
166+
167+
eng_args <- update_engine_parameters(object$eng_args, ...)
167168

168169
if (!is.null(parameters)) {
169170
parameters <- check_final_param(parameters)
@@ -185,12 +186,15 @@ update.boost_tree <-
185186
# TODO make these blocks into a function and document well
186187
if (fresh) {
187188
object$args <- args
189+
object$eng_args <- eng_args
188190
} else {
189191
null_args <- map_lgl(args, null_value)
190192
if (any(null_args))
191193
args <- args[!null_args]
192194
if (length(args) > 0)
193195
object$args[names(args)] <- args
196+
if (length(eng_args) > 0)
197+
object$eng_args[names(eng_args)] <- eng_args
194198
}
195199

196200
new_model_spec(

R/decision_tree.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ update.decision_tree <-
116116
parameters = NULL,
117117
cost_complexity = NULL, tree_depth = NULL, min_n = NULL,
118118
fresh = FALSE, ...) {
119-
update_dot_check(...)
119+
120+
eng_args <- update_engine_parameters(object$eng_args, ...)
120121

121122
if (!is.null(parameters)) {
122123
parameters <- check_final_param(parameters)
@@ -131,12 +132,15 @@ update.decision_tree <-
131132

132133
if (fresh) {
133134
object$args <- args
135+
object$eng_args <- eng_args
134136
} else {
135137
null_args <- map_lgl(args, null_value)
136138
if (any(null_args))
137139
args <- args[!null_args]
138140
if (length(args) > 0)
139141
object$args[names(args)] <- args
142+
if (length(eng_args) > 0)
143+
object$eng_args[names(eng_args)] <- eng_args
140144
}
141145

142146
new_model_spec(

R/linear_reg.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ update.linear_reg <-
131131
parameters = NULL,
132132
penalty = NULL, mixture = NULL,
133133
fresh = FALSE, ...) {
134-
update_dot_check(...)
134+
135+
eng_args <- update_engine_parameters(object$eng_args, ...)
135136

136137
if (!is.null(parameters)) {
137138
parameters <- check_final_param(parameters)
@@ -145,12 +146,15 @@ update.linear_reg <-
145146

146147
if (fresh) {
147148
object$args <- args
149+
object$eng_args <- eng_args
148150
} else {
149151
null_args <- map_lgl(args, null_value)
150152
if (any(null_args))
151153
args <- args[!null_args]
152154
if (length(args) > 0)
153155
object$args[names(args)] <- args
156+
if (length(eng_args) > 0)
157+
object$eng_args[names(eng_args)] <- eng_args
154158
}
155159

156160
new_model_spec(

R/logistic_reg.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ update.logistic_reg <-
115115
parameters = NULL,
116116
penalty = NULL, mixture = NULL,
117117
fresh = FALSE, ...) {
118-
update_dot_check(...)
118+
119+
eng_args <- update_engine_parameters(object$eng_args, ...)
119120

120121
if (!is.null(parameters)) {
121122
parameters <- check_final_param(parameters)
@@ -129,12 +130,15 @@ update.logistic_reg <-
129130

130131
if (fresh) {
131132
object$args <- args
133+
object$eng_args <- eng_args
132134
} else {
133135
null_args <- map_lgl(args, null_value)
134136
if (any(null_args))
135137
args <- args[!null_args]
136138
if (length(args) > 0)
137139
object$args[names(args)] <- args
140+
if (length(eng_args) > 0)
141+
object$eng_args[names(eng_args)] <- eng_args
138142
}
139143

140144
new_model_spec(

R/mars.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ update.mars <-
9393
parameters = NULL,
9494
num_terms = NULL, prod_degree = NULL, prune_method = NULL,
9595
fresh = FALSE, ...) {
96-
update_dot_check(...)
96+
97+
eng_args <- update_engine_parameters(object$eng_args, ...)
9798

9899
if (!is.null(parameters)) {
99100
parameters <- check_final_param(parameters)
@@ -109,12 +110,15 @@ update.mars <-
109110

110111
if (fresh) {
111112
object$args <- args
113+
object$eng_args <- eng_args
112114
} else {
113115
null_args <- map_lgl(args, null_value)
114116
if (any(null_args))
115117
args <- args[!null_args]
116118
if (length(args) > 0)
117119
object$args[names(args)] <- args
120+
if (length(eng_args) > 0)
121+
object$eng_args[names(eng_args)] <- eng_args
118122
}
119123

120124
new_model_spec(

R/misc.R

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ names0 <- function (num, prefix = "x") {
171171
#' @export
172172
#' @keywords internal
173173
#' @rdname add_on_exports
174-
update_dot_check <- function(...) {
175-
dots <- enquos(...)
174+
update_dot_check <- function(dots) {
176175
if (length(dots) > 0)
177176
rlang::abort(
178177
glue::glue(
@@ -282,5 +281,25 @@ update_main_parameters <- function(args, param) {
282281
args <- utils::modifyList(args, param)
283282
}
284283

284+
#' @export
285+
#' @keywords internal
286+
#' @rdname add_on_exports
287+
update_engine_parameters <- function(eng_args, ...) {
288+
289+
dots <- enquos(...)
290+
291+
## only update from dots when there are eng args in original model spec
292+
if (is_null(eng_args)) {
293+
ret <- NULL
294+
} else {
295+
ret <- utils::modifyList(eng_args, dots)
296+
}
297+
298+
has_extra_dots <- !(names(dots) %in% names(eng_args))
299+
dots <- dots[has_extra_dots]
300+
update_dot_check(dots)
301+
302+
ret
303+
}
285304

286305

R/mlp.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ update.mlp <-
120120
hidden_units = NULL, penalty = NULL, dropout = NULL,
121121
epochs = NULL, activation = NULL,
122122
fresh = FALSE, ...) {
123-
update_dot_check(...)
123+
124+
eng_args <- update_engine_parameters(object$eng_args, ...)
124125

125126
if (!is.null(parameters)) {
126127
parameters <- check_final_param(parameters)
@@ -139,12 +140,15 @@ update.mlp <-
139140
# TODO make these blocks into a function and document well
140141
if (fresh) {
141142
object$args <- args
143+
object$eng_args <- eng_args
142144
} else {
143145
null_args <- map_lgl(args, null_value)
144146
if (any(null_args))
145147
args <- args[!null_args]
146148
if (length(args) > 0)
147149
object$args[names(args)] <- args
150+
if (length(eng_args) > 0)
151+
object$eng_args[names(eng_args)] <- eng_args
148152
}
149153

150154
new_model_spec(

R/multinom_reg.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ update.multinom_reg <-
114114
parameters = NULL,
115115
penalty = NULL, mixture = NULL,
116116
fresh = FALSE, ...) {
117-
update_dot_check(...)
117+
118+
eng_args <- update_engine_parameters(object$eng_args, ...)
118119

119120
if (!is.null(parameters)) {
120121
parameters <- check_final_param(parameters)
@@ -128,12 +129,15 @@ update.multinom_reg <-
128129

129130
if (fresh) {
130131
object$args <- args
132+
object$eng_args <- eng_args
131133
} else {
132134
null_args <- map_lgl(args, null_value)
133135
if (any(null_args))
134136
args <- args[!null_args]
135137
if (length(args) > 0)
136138
object$args[names(args)] <- args
139+
if (length(eng_args) > 0)
140+
object$eng_args[names(eng_args)] <- eng_args
137141
}
138142

139143
new_model_spec(

R/nearest_neighbor.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ update.nearest_neighbor <- function(object,
9696
weight_func = NULL,
9797
dist_power = NULL,
9898
fresh = FALSE, ...) {
99-
update_dot_check(...)
99+
100+
eng_args <- update_engine_parameters(object$eng_args, ...)
100101

101102
if (!is.null(parameters)) {
102103
parameters <- check_final_param(parameters)
@@ -112,12 +113,15 @@ update.nearest_neighbor <- function(object,
112113

113114
if (fresh) {
114115
object$args <- args
116+
object$eng_args <- eng_args
115117
} else {
116118
null_args <- map_lgl(args, null_value)
117119
if (any(null_args))
118120
args <- args[!null_args]
119121
if (length(args) > 0)
120122
object$args[names(args)] <- args
123+
if (length(eng_args) > 0)
124+
object$eng_args[names(eng_args)] <- eng_args
121125
}
122126

123127
new_model_spec(

0 commit comments

Comments
 (0)