@@ -19,7 +19,7 @@ test_that('primary arguments', {
19
19
x = expr(missing_arg()),
20
20
y = expr(missing_arg()),
21
21
case.weights = expr(missing_arg()),
22
- mtry = new_empty_quosure( 4 ),
22
+ mtry = expr(min( ~ 4 , ncol( x )) ),
23
23
num.threads = 1 ,
24
24
verbose = FALSE ,
25
25
seed = expr(sample.int(10 ^ 5 , 1 ))
@@ -29,7 +29,7 @@ test_that('primary arguments', {
29
29
list (
30
30
x = expr(missing_arg()),
31
31
y = expr(missing_arg()),
32
- mtry = new_empty_quosure( 4 )
32
+ mtry = expr(min( ~ 4 , ncol( x )) )
33
33
)
34
34
)
35
35
expect_equal(mtry_spark $ method $ fit $ args ,
@@ -83,7 +83,7 @@ test_that('primary arguments', {
83
83
x = expr(missing_arg()),
84
84
y = expr(missing_arg()),
85
85
case.weights = expr(missing_arg()),
86
- min.node.size = new_empty_quosure( 5 ),
86
+ min.node.size = expr(min( ~ 5 , nrow( x )) ),
87
87
num.threads = 1 ,
88
88
verbose = FALSE ,
89
89
seed = expr(sample.int(10 ^ 5 , 1 ))
@@ -93,117 +93,18 @@ test_that('primary arguments', {
93
93
list (
94
94
x = expr(missing_arg()),
95
95
y = expr(missing_arg()),
96
- nodesize = new_empty_quosure( 5 )
96
+ nodesize = expr(min( ~ 5 , nrow( x )) )
97
97
)
98
98
)
99
99
expect_equal(min_n_spark $ method $ fit $ args ,
100
100
list (
101
101
x = expr(missing_arg()),
102
102
formula = expr(missing_arg()),
103
103
type = " regression" ,
104
- min_instances_per_node = new_empty_quosure( 5 ),
104
+ min_instances_per_node = expr(min( ~ 5 , nrow( x )) ),
105
105
seed = expr(sample.int(10 ^ 5 , 1 ))
106
106
)
107
107
)
108
-
109
- mtry_v <- rand_forest(mode = " classification" , mtry = varying())
110
- mtry_v_ranger <- translate(mtry_v %> % set_engine(" ranger" ))
111
- mtry_v_randomForest <- translate(mtry_v %> % set_engine(" randomForest" ))
112
- mtry_v_spark <- translate(mtry_v %> % set_engine(" spark" ))
113
- expect_equal(mtry_v_ranger $ method $ fit $ args ,
114
- list (
115
- x = expr(missing_arg()),
116
- y = expr(missing_arg()),
117
- case.weights = expr(missing_arg()),
118
- mtry = new_empty_quosure(varying()),
119
- num.threads = 1 ,
120
- verbose = FALSE ,
121
- seed = expr(sample.int(10 ^ 5 , 1 )),
122
- probability = TRUE
123
- )
124
- )
125
- expect_equal(mtry_v_randomForest $ method $ fit $ args ,
126
- list (
127
- x = expr(missing_arg()),
128
- y = expr(missing_arg()),
129
- mtry = new_empty_quosure(varying())
130
- )
131
- )
132
- expect_equal(mtry_v_spark $ method $ fit $ args ,
133
- list (
134
- x = expr(missing_arg()),
135
- formula = expr(missing_arg()),
136
- type = " classification" ,
137
- feature_subset_strategy = new_empty_quosure(varying()),
138
- seed = expr(sample.int(10 ^ 5 , 1 ))
139
- )
140
- )
141
-
142
- trees_v <- rand_forest(mode = " regression" , trees = varying())
143
- trees_v_ranger <- translate(trees_v %> % set_engine(" ranger" ))
144
- trees_v_randomForest <- translate(trees_v %> % set_engine(" randomForest" ))
145
- trees_v_spark <- translate(trees_v %> % set_engine(" spark" ))
146
- expect_equal(trees_v_ranger $ method $ fit $ args ,
147
- list (
148
- x = expr(missing_arg()),
149
- y = expr(missing_arg()),
150
- case.weights = expr(missing_arg()),
151
- num.trees = new_empty_quosure(varying()),
152
- num.threads = 1 ,
153
- verbose = FALSE ,
154
- seed = expr(sample.int(10 ^ 5 , 1 ))
155
- )
156
- )
157
- expect_equal(trees_v_randomForest $ method $ fit $ args ,
158
- list (
159
- x = expr(missing_arg()),
160
- y = expr(missing_arg()),
161
- ntree = new_empty_quosure(varying())
162
- )
163
- )
164
- expect_equal(trees_v_spark $ method $ fit $ args ,
165
- list (
166
- x = expr(missing_arg()),
167
- formula = expr(missing_arg()),
168
- type = " regression" ,
169
- num_trees = new_empty_quosure(varying()),
170
- seed = expr(sample.int(10 ^ 5 , 1 ))
171
- )
172
- )
173
-
174
- min_n_v <- rand_forest(mode = " classification" , min_n = varying())
175
- min_n_v_ranger <- translate(min_n_v %> % set_engine(" ranger" ))
176
- min_n_v_randomForest <- translate(min_n_v %> % set_engine(" randomForest" ))
177
- min_n_v_spark <- translate(min_n_v %> % set_engine(" spark" ))
178
- expect_equal(min_n_v_ranger $ method $ fit $ args ,
179
- list (
180
- x = expr(missing_arg()),
181
- y = expr(missing_arg()),
182
- case.weights = expr(missing_arg()),
183
- min.node.size = new_empty_quosure(varying()),
184
- num.threads = 1 ,
185
- verbose = FALSE ,
186
- seed = expr(sample.int(10 ^ 5 , 1 )),
187
- probability = TRUE
188
- )
189
- )
190
- expect_equal(min_n_v_randomForest $ method $ fit $ args ,
191
- list (
192
- x = expr(missing_arg()),
193
- y = expr(missing_arg()),
194
- nodesize = new_empty_quosure(varying())
195
- )
196
- )
197
- expect_equal(min_n_v_spark $ method $ fit $ args ,
198
- list (
199
- x = expr(missing_arg()),
200
- formula = expr(missing_arg()),
201
- type = " classification" ,
202
- min_instances_per_node = new_empty_quosure(varying()),
203
- seed = expr(sample.int(10 ^ 5 , 1 ))
204
- )
205
- )
206
-
207
108
})
208
109
209
110
test_that(' engine arguments' , {
@@ -241,50 +142,6 @@ test_that('engine arguments', {
241
142
)
242
143
)
243
144
244
- ranger_samp_frac <- rand_forest(mode = " regression" )
245
- expect_equal(
246
- translate(ranger_samp_frac %> %
247
- set_engine(" ranger" , sample.fraction = varying()))$ method $ fit $ args ,
248
- list (
249
- x = expr(missing_arg()),
250
- y = expr(missing_arg()),
251
- case.weights = expr(missing_arg()),
252
- sample.fraction = new_empty_quosure(varying()),
253
- num.threads = 1 ,
254
- verbose = FALSE ,
255
- seed = expr(sample.int(10 ^ 5 , 1 ))
256
- )
257
- )
258
-
259
-
260
- randomForest_votes_v <-
261
- rand_forest(mode = " regression" )
262
- expect_equal(
263
- translate(randomForest_votes_v %> %
264
- set_engine(" randomForest" , norm.votes = FALSE , sampsize = varying()))$ method $ fit $ args ,
265
- list (
266
- x = expr(missing_arg()),
267
- y = expr(missing_arg()),
268
- norm.votes = new_empty_quosure(FALSE ),
269
- sampsize = new_empty_quosure(varying())
270
- )
271
- )
272
-
273
- spark_bins_v <-
274
- rand_forest(mode = " regression" )
275
- expect_equal(
276
- translate(spark_bins_v %> %
277
- set_engine(" spark" , uid = " id label" , max_bins = varying()))$ method $ fit $ args ,
278
- list (
279
- x = expr(missing_arg()),
280
- formula = expr(missing_arg()),
281
- type = " regression" ,
282
- uid = new_empty_quosure(" id label" ),
283
- max_bins = new_empty_quosure(varying()),
284
- seed = expr(sample.int(10 ^ 5 , 1 ))
285
- )
286
- )
287
-
288
145
})
289
146
290
147
0 commit comments