Skip to content

Commit 0691638

Browse files
Use .default argument in case_when() when applicable (#164)
1 parent 3fb3694 commit 0691638

File tree

7 files changed

+59
-16
lines changed

7 files changed

+59
-16
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
- Cubist rules will return simplified rules whenever possible to avoid multiplying by 0 and 1. (#152)
2424

25+
- tree based models now uses `.default` argument in produced `case_when()` code when applicable. (#153)
26+
2527
# tidypredict 0.5.1
2628

2729
- Exported a number of internal functions to be used in {orbital} package

R/tree.R

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,26 @@
4949
#' calculating all notes.
5050
#'
5151
#' @keywords internal
52-
generate_case_when_trees <- function(parsedmodel) {
52+
generate_case_when_trees <- function(parsedmodel, default = TRUE) {
5353
map(
5454
parsedmodel$trees,
5555
generate_case_when_tree,
56-
mode = parsedmodel$general$mode
56+
mode = parsedmodel$general$mode,
57+
default = default
5758
)
5859
}
5960

60-
generate_case_when_tree <- function(tree, mode) {
61-
expr(case_when(!!!generate_tree_nodes(tree, mode)))
61+
generate_case_when_tree <- function(tree, mode, default = TRUE) {
62+
nodes <- generate_tree_nodes(tree, mode)
63+
64+
if (default) {
65+
default <- nodes[[length(nodes)]]
66+
default <- rlang::f_rhs(default)
67+
nodes[[length(nodes)]] <- NULL
68+
nodes <- c(nodes, .default = default)
69+
}
70+
71+
expr(case_when(!!!nodes))
6272
}
6373

6474
generate_tree_nodes <- function(tree, mode) {

man/generate_case_when_trees.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/model-partykit.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Code
44
rlang::expr_text(tf)
55
Output
6-
[1] "case_when(cyl <= 4 ~ 26.6636363636364, cyl <= 6 & cyl > 4 ~ 19.7428571428571, \n cyl > 6 & cyl > 4 ~ 15.1)"
6+
[1] "case_when(cyl <= 4 ~ 26.6636363636364, cyl <= 6 & cyl > 4 ~ 19.7428571428571, \n .default = 15.1)"
77

88
# formulas produces correct predictions
99

tests/testthat/_snaps/model-ranger.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Code
44
rlang::expr_text(tf)
55
Output
6-
[1] "case_when(Petal.Length < 2.6 ~ \"setosa\", Sepal.Length < 6.25 & \n Petal.Length >= 2.6 ~ \"versicolor\", Sepal.Length >= 6.25 & \n Petal.Length >= 2.6 ~ \"virginica\") + case_when(Petal.Width < \n 0.75 ~ \"setosa\", Petal.Width < 1.75 & Petal.Width >= 0.75 ~ \n \"versicolor\", Petal.Width >= 1.75 & Petal.Width >= 0.75 ~ \n \"virginica\") + case_when(Petal.Length < 2.35 ~ \"setosa\", \n Petal.Length < 4.75 & Petal.Length >= 2.35 ~ \"versicolor\", \n Petal.Length >= 4.75 & Petal.Length >= 2.35 ~ \"virginica\")"
6+
[1] "case_when(Petal.Length < 2.6 ~ \"setosa\", Sepal.Length < 6.25 & \n Petal.Length >= 2.6 ~ \"versicolor\", .default = \"virginica\") + \n case_when(Petal.Width < 0.75 ~ \"setosa\", Petal.Width < 1.75 & \n Petal.Width >= 0.75 ~ \"versicolor\", .default = \"virginica\") + \n case_when(Petal.Length < 2.35 ~ \"setosa\", Petal.Length < \n 4.75 & Petal.Length >= 2.35 ~ \"versicolor\", .default = \"virginica\")"
77

88
# formulas produces correct predictions
99

tests/testthat/_snaps/model-rf.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Code
44
rlang::expr_text(tf)
55
Output
6-
[1] "(case_when(disp < 78.85 & wt < 2.3325 ~ 33.525, disp >= 78.85 & \n wt < 2.3325 ~ 27.425, disp >= 281 & wt >= 2.3325 ~ 14.1, \n drat < 3.695 & cyl < 5 & disp < 281 & wt >= 2.3325 ~ 24.4, \n drat >= 3.695 & cyl < 5 & disp < 281 & wt >= 2.3325 ~ 22.3666666666667, \n disp < 152.5 & wt < 3.3275 & cyl >= 5 & disp < 281 & wt >= \n 2.3325 ~ 19.7, disp >= 152.5 & wt < 3.3275 & cyl >= 5 & \n disp < 281 & wt >= 2.3325 ~ 21.1, disp < 196.3 & wt >= \n 3.3275 & cyl >= 5 & disp < 281 & wt >= 2.3325 ~ 18.92, \n disp >= 196.3 & wt >= 3.3275 & cyl >= 5 & disp < 281 & wt >= \n 2.3325 ~ 18.1) + case_when(hp < 80.5 & disp < 266.9 ~ \n 28.5333333333333, drat < 3.035 & disp >= 266.9 ~ 10.4, carb < \n 2.5 & drat >= 3.035 & disp >= 266.9 ~ 18.7, wt < 1.989 & \n hp < 118 & hp >= 80.5 & disp < 266.9 ~ 30.4, qsec < 18.6 & \n hp >= 118 & hp >= 80.5 & disp < 266.9 ~ 19.6, qsec >= 18.6 & \n hp >= 118 & hp >= 80.5 & disp < 266.9 ~ 17.8, drat >= 3.635 & \n carb >= 2.5 & drat >= 3.035 & disp >= 266.9 ~ 13.3, hp < \n 96 & wt >= 1.989 & hp < 118 & hp >= 80.5 & disp < 266.9 ~ \n 22.8, wt < 4.5625 & drat < 3.635 & carb >= 2.5 & drat >= \n 3.035 & disp >= 266.9 ~ 15.04, wt >= 4.5625 & drat < 3.635 & \n carb >= 2.5 & drat >= 3.035 & disp >= 266.9 ~ 14.7, qsec < \n 19.725 & hp >= 96 & wt >= 1.989 & hp < 118 & hp >= 80.5 & \n disp < 266.9 ~ 21.4, qsec >= 19.725 & hp >= 96 & wt >= 1.989 & \n hp < 118 & hp >= 80.5 & disp < 266.9 ~ 21.5) + case_when(wt < \n 3.4725 & drat < 3.75 ~ 20.8333333333333, qsec < 16.23 & wt >= \n 3.4725 & drat < 3.75 ~ 13.3, disp < 78.85 & disp < 130.55 & \n drat >= 3.75 ~ 32.2333333333333, disp >= 78.85 & disp < 130.55 & \n drat >= 3.75 ~ 27.75, cyl < 5 & disp >= 130.55 & drat >= \n 3.75 ~ 22.8, disp >= 456 & qsec >= 16.23 & wt >= 3.4725 & \n drat < 3.75 ~ 10.4, qsec < 17.66 & cyl >= 5 & disp >= 130.55 & \n drat >= 3.75 ~ 21, qsec >= 17.66 & cyl >= 5 & disp >= 130.55 & \n drat >= 3.75 ~ 19.2, qsec < 17.225 & disp < 456 & qsec >= \n 16.23 & wt >= 3.4725 & drat < 3.75 ~ 19.2, wt < 4.7075 & \n qsec >= 17.225 & disp < 456 & qsec >= 16.23 & wt >= 3.4725 & \n drat < 3.75 ~ 16.34, wt >= 4.7075 & qsec >= 17.225 & disp < \n 456 & qsec >= 16.23 & wt >= 3.4725 & drat < 3.75 ~ 14.7))/3L"
6+
[1] "(case_when(disp < 78.85 & wt < 2.3325 ~ 33.525, disp >= 78.85 & \n wt < 2.3325 ~ 27.425, disp >= 281 & wt >= 2.3325 ~ 14.1, \n drat < 3.695 & cyl < 5 & disp < 281 & wt >= 2.3325 ~ 24.4, \n drat >= 3.695 & cyl < 5 & disp < 281 & wt >= 2.3325 ~ 22.3666666666667, \n disp < 152.5 & wt < 3.3275 & cyl >= 5 & disp < 281 & wt >= \n 2.3325 ~ 19.7, disp >= 152.5 & wt < 3.3275 & cyl >= 5 & \n disp < 281 & wt >= 2.3325 ~ 21.1, disp < 196.3 & wt >= \n 3.3275 & cyl >= 5 & disp < 281 & wt >= 2.3325 ~ 18.92, \n .default = 18.1) + case_when(hp < 80.5 & disp < 266.9 ~ 28.5333333333333, \n drat < 3.035 & disp >= 266.9 ~ 10.4, carb < 2.5 & drat >= \n 3.035 & disp >= 266.9 ~ 18.7, wt < 1.989 & hp < 118 & \n hp >= 80.5 & disp < 266.9 ~ 30.4, qsec < 18.6 & hp >= \n 118 & hp >= 80.5 & disp < 266.9 ~ 19.6, qsec >= 18.6 & \n hp >= 118 & hp >= 80.5 & disp < 266.9 ~ 17.8, drat >= \n 3.635 & carb >= 2.5 & drat >= 3.035 & disp >= 266.9 ~ \n 13.3, hp < 96 & wt >= 1.989 & hp < 118 & hp >= 80.5 & \n disp < 266.9 ~ 22.8, wt < 4.5625 & drat < 3.635 & carb >= \n 2.5 & drat >= 3.035 & disp >= 266.9 ~ 15.04, wt >= 4.5625 & \n drat < 3.635 & carb >= 2.5 & drat >= 3.035 & disp >= \n 266.9 ~ 14.7, qsec < 19.725 & hp >= 96 & wt >= 1.989 & \n hp < 118 & hp >= 80.5 & disp < 266.9 ~ 21.4, .default = 21.5) + \n case_when(wt < 3.4725 & drat < 3.75 ~ 20.8333333333333, qsec < \n 16.23 & wt >= 3.4725 & drat < 3.75 ~ 13.3, disp < 78.85 & \n disp < 130.55 & drat >= 3.75 ~ 32.2333333333333, disp >= \n 78.85 & disp < 130.55 & drat >= 3.75 ~ 27.75, cyl < 5 & \n disp >= 130.55 & drat >= 3.75 ~ 22.8, disp >= 456 & qsec >= \n 16.23 & wt >= 3.4725 & drat < 3.75 ~ 10.4, qsec < 17.66 & \n cyl >= 5 & disp >= 130.55 & drat >= 3.75 ~ 21, qsec >= \n 17.66 & cyl >= 5 & disp >= 130.55 & drat >= 3.75 ~ 19.2, \n qsec < 17.225 & disp < 456 & qsec >= 16.23 & wt >= 3.4725 & \n drat < 3.75 ~ 19.2, wt < 4.7075 & qsec >= 17.225 & \n disp < 456 & qsec >= 16.23 & wt >= 3.4725 & drat < \n 3.75 ~ 16.34, .default = 14.7))/3L"
77

88
# formulas produces correct predictions
99

tests/testthat/test-tree.R

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ test_that("generate_case_when_trees() works", {
1818
)
1919

2020
expect_identical(
21-
generate_case_when_trees(parsedmodel),
21+
generate_case_when_trees(parsedmodel, default = FALSE),
2222
list(
2323
quote(case_when(disp > 100 ~ 14 + hp * 4 + drat * 2))
2424
)
@@ -30,7 +30,7 @@ test_that("generate_case_when_trees() works", {
3030
)
3131

3232
expect_identical(
33-
generate_case_when_trees(parsedmodel),
33+
generate_case_when_trees(parsedmodel, default = FALSE),
3434
list(
3535
quote(case_when(ifelse(disp > 100, 14 + hp * 4 + drat * 2, 0)))
3636
)
@@ -42,7 +42,7 @@ test_that("generate_case_when_trees() works", {
4242
)
4343

4444
expect_identical(
45-
generate_case_when_trees(parsedmodel),
45+
generate_case_when_trees(parsedmodel, default = FALSE),
4646
list(
4747
quote(
4848
case_when(
@@ -63,7 +63,7 @@ test_that("generate_case_when_trees() works", {
6363
)
6464

6565
expect_identical(
66-
generate_case_when_trees(parsedmodel),
66+
generate_case_when_trees(parsedmodel, default = FALSE),
6767
list(
6868
quote(
6969
case_when(
@@ -94,18 +94,18 @@ test_that("generate_case_when_tree() works", {
9494
nodes <- list(node)
9595

9696
expect_identical(
97-
generate_case_when_tree(nodes, mode = ""),
97+
generate_case_when_tree(nodes, mode = "", default = FALSE),
9898
quote(case_when(disp > 100 ~ 14 + hp * 4 + drat * 2))
9999
)
100100
expect_identical(
101-
generate_case_when_tree(nodes, mode = "ifelse"),
101+
generate_case_when_tree(nodes, mode = "ifelse", default = FALSE),
102102
quote(case_when(ifelse(disp > 100, 14 + hp * 4 + drat * 2, 0)))
103103
)
104104

105105
nodes <- list(node, node)
106106

107107
expect_identical(
108-
generate_case_when_tree(nodes, mode = ""),
108+
generate_case_when_tree(nodes, mode = "", default = FALSE),
109109
quote(
110110
case_when(
111111
disp > 100 ~ 14 + hp * 4 + drat * 2,
@@ -114,14 +114,45 @@ test_that("generate_case_when_tree() works", {
114114
)
115115
)
116116
expect_identical(
117-
generate_case_when_tree(nodes, mode = "ifelse"),
117+
generate_case_when_tree(nodes, mode = "ifelse", default = FALSE),
118118
quote(
119119
case_when(
120120
ifelse(disp > 100, 14 + hp * 4 + drat * 2, 0),
121121
ifelse(disp > 100, 14 + hp * 4 + drat * 2, 0)
122122
)
123123
)
124124
)
125+
126+
nodes <- list(
127+
list(
128+
prediction = 25,
129+
path = list(list(
130+
type = "conditional",
131+
col = "cyl",
132+
val = 4,
133+
op = "less-equal"
134+
))
135+
),
136+
list(
137+
prediction = 20,
138+
path = list(
139+
list(type = "conditional", col = "cyl", val = 6, op = "less-equal"),
140+
list(type = "conditional", col = "cyl", val = 4, op = "more")
141+
)
142+
),
143+
list(
144+
prediction = 15,
145+
path = list(
146+
list(type = "conditional", col = "cyl", val = 6, op = "more"),
147+
list(type = "conditional", col = "cyl", val = 4, op = "more")
148+
)
149+
)
150+
)
151+
152+
expect_identical(
153+
generate_case_when_tree(nodes, mode = "", default = TRUE),
154+
quote(case_when(cyl <= 4 ~ 25, cyl <= 6 & cyl > 4 ~ 20, .default = 15))
155+
)
125156
})
126157

127158
test_that("generate_tree_nodes() works", {

0 commit comments

Comments
 (0)