Skip to content

Commit 118b897

Browse files
committed
fixed ranger test cases
1 parent 1706169 commit 118b897

File tree

2 files changed

+7
-150
lines changed

2 files changed

+7
-150
lines changed

man/boost_tree.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_rand_forest.R

Lines changed: 5 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ test_that('primary arguments', {
1919
x = expr(missing_arg()),
2020
y = expr(missing_arg()),
2121
case.weights = expr(missing_arg()),
22-
mtry = new_empty_quosure(4),
22+
mtry = expr(min(~4, ncol(x))),
2323
num.threads = 1,
2424
verbose = FALSE,
2525
seed = expr(sample.int(10^5, 1))
@@ -29,7 +29,7 @@ test_that('primary arguments', {
2929
list(
3030
x = expr(missing_arg()),
3131
y = expr(missing_arg()),
32-
mtry = new_empty_quosure(4)
32+
mtry = expr(min(~4, ncol(x)))
3333
)
3434
)
3535
expect_equal(mtry_spark$method$fit$args,
@@ -83,7 +83,7 @@ test_that('primary arguments', {
8383
x = expr(missing_arg()),
8484
y = expr(missing_arg()),
8585
case.weights = expr(missing_arg()),
86-
min.node.size = new_empty_quosure(5),
86+
min.node.size = expr(min(~5, nrow(x))),
8787
num.threads = 1,
8888
verbose = FALSE,
8989
seed = expr(sample.int(10^5, 1))
@@ -93,117 +93,18 @@ test_that('primary arguments', {
9393
list(
9494
x = expr(missing_arg()),
9595
y = expr(missing_arg()),
96-
nodesize = new_empty_quosure(5)
96+
nodesize = expr(min(~5, nrow(x)))
9797
)
9898
)
9999
expect_equal(min_n_spark$method$fit$args,
100100
list(
101101
x = expr(missing_arg()),
102102
formula = expr(missing_arg()),
103103
type = "regression",
104-
min_instances_per_node = new_empty_quosure(5),
104+
min_instances_per_node = expr(min(~5, nrow(x))),
105105
seed = expr(sample.int(10^5, 1))
106106
)
107107
)
108-
109-
mtry_v <- rand_forest(mode = "classification", mtry = varying())
110-
mtry_v_ranger <- translate(mtry_v %>% set_engine("ranger"))
111-
mtry_v_randomForest <- translate(mtry_v %>% set_engine("randomForest"))
112-
mtry_v_spark <- translate(mtry_v %>% set_engine("spark"))
113-
expect_equal(mtry_v_ranger$method$fit$args,
114-
list(
115-
x = expr(missing_arg()),
116-
y = expr(missing_arg()),
117-
case.weights = expr(missing_arg()),
118-
mtry = new_empty_quosure(varying()),
119-
num.threads = 1,
120-
verbose = FALSE,
121-
seed = expr(sample.int(10^5, 1)),
122-
probability = TRUE
123-
)
124-
)
125-
expect_equal(mtry_v_randomForest$method$fit$args,
126-
list(
127-
x = expr(missing_arg()),
128-
y = expr(missing_arg()),
129-
mtry = new_empty_quosure(varying())
130-
)
131-
)
132-
expect_equal(mtry_v_spark$method$fit$args,
133-
list(
134-
x = expr(missing_arg()),
135-
formula = expr(missing_arg()),
136-
type = "classification",
137-
feature_subset_strategy = new_empty_quosure(varying()),
138-
seed = expr(sample.int(10^5, 1))
139-
)
140-
)
141-
142-
trees_v <- rand_forest(mode = "regression", trees = varying())
143-
trees_v_ranger <- translate(trees_v %>% set_engine("ranger"))
144-
trees_v_randomForest <- translate(trees_v %>% set_engine("randomForest"))
145-
trees_v_spark <- translate(trees_v %>% set_engine("spark"))
146-
expect_equal(trees_v_ranger$method$fit$args,
147-
list(
148-
x = expr(missing_arg()),
149-
y = expr(missing_arg()),
150-
case.weights = expr(missing_arg()),
151-
num.trees = new_empty_quosure(varying()),
152-
num.threads = 1,
153-
verbose = FALSE,
154-
seed = expr(sample.int(10^5, 1))
155-
)
156-
)
157-
expect_equal(trees_v_randomForest$method$fit$args,
158-
list(
159-
x = expr(missing_arg()),
160-
y = expr(missing_arg()),
161-
ntree = new_empty_quosure(varying())
162-
)
163-
)
164-
expect_equal(trees_v_spark$method$fit$args,
165-
list(
166-
x = expr(missing_arg()),
167-
formula = expr(missing_arg()),
168-
type = "regression",
169-
num_trees = new_empty_quosure(varying()),
170-
seed = expr(sample.int(10^5, 1))
171-
)
172-
)
173-
174-
min_n_v <- rand_forest(mode = "classification", min_n = varying())
175-
min_n_v_ranger <- translate(min_n_v %>% set_engine("ranger"))
176-
min_n_v_randomForest <- translate(min_n_v %>% set_engine("randomForest"))
177-
min_n_v_spark <- translate(min_n_v %>% set_engine("spark"))
178-
expect_equal(min_n_v_ranger$method$fit$args,
179-
list(
180-
x = expr(missing_arg()),
181-
y = expr(missing_arg()),
182-
case.weights = expr(missing_arg()),
183-
min.node.size = new_empty_quosure(varying()),
184-
num.threads = 1,
185-
verbose = FALSE,
186-
seed = expr(sample.int(10^5, 1)),
187-
probability = TRUE
188-
)
189-
)
190-
expect_equal(min_n_v_randomForest$method$fit$args,
191-
list(
192-
x = expr(missing_arg()),
193-
y = expr(missing_arg()),
194-
nodesize = new_empty_quosure(varying())
195-
)
196-
)
197-
expect_equal(min_n_v_spark$method$fit$args,
198-
list(
199-
x = expr(missing_arg()),
200-
formula = expr(missing_arg()),
201-
type = "classification",
202-
min_instances_per_node = new_empty_quosure(varying()),
203-
seed = expr(sample.int(10^5, 1))
204-
)
205-
)
206-
207108
})
208109

209110
test_that('engine arguments', {
@@ -241,50 +142,6 @@ test_that('engine arguments', {
241142
)
242143
)
243144

244-
ranger_samp_frac <- rand_forest(mode = "regression")
245-
expect_equal(
246-
translate(ranger_samp_frac %>%
247-
set_engine("ranger", sample.fraction = varying()))$method$fit$args,
248-
list(
249-
x = expr(missing_arg()),
250-
y = expr(missing_arg()),
251-
case.weights = expr(missing_arg()),
252-
sample.fraction = new_empty_quosure(varying()),
253-
num.threads = 1,
254-
verbose = FALSE,
255-
seed = expr(sample.int(10^5, 1))
256-
)
257-
)
258-
259-
260-
randomForest_votes_v <-
261-
rand_forest(mode = "regression")
262-
expect_equal(
263-
translate(randomForest_votes_v %>%
264-
set_engine("randomForest", norm.votes = FALSE, sampsize = varying()))$method$fit$args,
265-
list(
266-
x = expr(missing_arg()),
267-
y = expr(missing_arg()),
268-
norm.votes = new_empty_quosure(FALSE),
269-
sampsize = new_empty_quosure(varying())
270-
)
271-
)
272-
273-
spark_bins_v <-
274-
rand_forest(mode = "regression")
275-
expect_equal(
276-
translate(spark_bins_v %>%
277-
set_engine("spark", uid = "id label", max_bins = varying()))$method$fit$args,
278-
list(
279-
x = expr(missing_arg()),
280-
formula = expr(missing_arg()),
281-
type = "regression",
282-
uid = new_empty_quosure("id label"),
283-
max_bins = new_empty_quosure(varying()),
284-
seed = expr(sample.int(10^5, 1))
285-
)
286-
)
287-
288145
})
289146

290147

0 commit comments

Comments
 (0)