Skip to content
124 changes: 110 additions & 14 deletions R/step-subset-expand.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,101 @@
#' fruits %>% dplyr::right_join(all)
# exported onLoad
expand.dtplyr_step <- function(data, ..., .name_repair = "check_unique") {
dots <- capture_dots(data, ..., .j = FALSE)
dots <- dots[!vapply(dots, is_null, logical(1))]
dots <- prepare_expand_dots(data, ..., .name_repair = .name_repair)

# TODO handle factors
if (length(dots) == 0) {
return(data)
}

tbl_list <- c(
list(expand_no_nesting(data, dots$simple)),
expand_nesting(data, dots$nesting)
)

out <- Reduce(function(x, y) left_join(x, y, by = group_vars(data)), tbl_list)

renamed <- names(dots$select) != unname(dots$select)
relocated <- unname(dots$select) != out$vars
if (any(renamed) || any(relocated)) {
out <- select(out, !!!dots$select)
}

out
}

# exported onLoad
expand.data.table <- function(data, ..., .name_repair = "check_unique") {
data <- lazy_dt(data)
tidyr::expand(data, ..., .name_repair = .name_repair)
}

prepare_expand_dots <- function(data, ..., .name_repair) {
dots <- capture_dots(data, ..., .j = FALSE)

dot_is_null <- vapply(dots, is_null, logical(1))
dots <- dots[!dot_is_null]
dot_names_tidyr <- names(exprs(..., .named = TRUE))[!dot_is_null]
if (is_null(dots)) {
return(NULL)
}

is_nesting <- vapply(dots, function(x) is_call(x, "nesting"), logical(1))
dots_df <- tibble::tibble(
expr = dots,
position = seq_along(dots)
)

dots_df_nesting <- dots_df[is_nesting, ]
nesting_vars <- lapply(dots_df_nesting$expr, get_nesting_vars)
dots_df_nesting$name_tidyr <- lapply(nesting_vars, names)
dots_df_nesting$var <- lapply(nesting_vars, unlist)

dots_df_simple <- dots_df[!is_nesting, ]
simple_vars <- dt_dot_names(dots_df_simple$expr)
dots_df_simple$name_dt <- names(simple_vars)
dots_df_simple$var <- simple_vars
dots_df_simple$name_tidyr <- dot_names_tidyr[!is_nesting]

meta_df <- dplyr::bind_rows(
dots_df_simple,
tidyr::unnest(dots_df_nesting, "name_tidyr")
)
groups <- group_vars(data)
names_dt <- c(groups, dplyr::coalesce(meta_df$name_dt, meta_df$name_tidyr))
names_tidyr <- vctrs::vec_as_names(
c(groups, meta_df$name_tidyr),
repair = .name_repair
)
order <- c(seq_along(groups), length(groups) + order(meta_df$position))

list(
simple = dots_df_simple$var,
nesting = dots_df_nesting$var,
select = set_names(names_dt, names_tidyr)[order]
)
}

get_nesting_vars <- function(expr) {
args <- call_args(expr)

repair <- args[[".name_repair"]] %||% "check_unique"
args[[".name_repair"]] <- NULL

vars <- exprs_auto_name(args)
nms <- vctrs::vec_as_names(names(vars), repair = repair)
set_names(vars, nms)
}

expand_nesting <- function(data, vars) {
if (is_empty(vars)) {
return(NULL)
}

lapply(vars, function(x) distinct(data, !!!x))
}

dt_dot_names <- function(dots, .name_repair) {
named_dots <- have_name(dots)
if (any(!named_dots)) {
# Auto-names generated by enquos() don't always work with the CJ() step
Expand All @@ -55,24 +144,31 @@ expand.dtplyr_step <- function(data, ..., .name_repair = "check_unique") {
names(dots)[needs_v_name] <- v_names[needs_v_name]
names(dots)[symbol_dots] <- lapply(dots[symbol_dots], as_name)
}
names(dots) <- vctrs::vec_as_names(names(dots), repair = .name_repair)

on <- names(dots)
cj <- expr(CJ(!!!syms(on), unique = TRUE))
dots
}

out <- distinct(data, !!!syms(data$groups), !!!dots)
expand_no_nesting <- function(data, dots, .name_repair) {
if (length(data$groups) == 0) {
out <- step_subset(out, i = cj, on = on)
dt_vars <- names(dots)

dt_auto_names <- names(dt_dot_names(unname(dots)))
name_needed <- dt_auto_names != dt_vars
names(dots)[!name_needed] <- ""

out <- step_subset_j(
parent = data,
vars = dt_vars,
j = expr(CJ(!!!dots, unique = TRUE))
)
} else {
out <- distinct(data, !!!syms(data$groups), !!!dots)

on <- names(dots)
cj <- expr(CJ(!!!syms(on), unique = TRUE))

on <- call2(".", !!!syms(on))
out <- step_subset(out, j = expr(.SD[!!cj, on = !!on]))
}

out
}

# exported onLoad
expand.data.table <- function(data, ..., .name_repair = "check_unique") {
data <- lazy_dt(data)
tidyr::expand(data, ..., .name_repair = .name_repair)
}
10 changes: 6 additions & 4 deletions tests/testthat/test-step-subset-expand.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_that("expand completes all values", {

expect_equal(
show_query(step),
expr(unique(DT)[CJ(x, y, unique = TRUE), on = .(x, y)])
expr(DT[, CJ(x, y, unique = TRUE)])
)
expect_equal(step$vars, c("x", "y"))
expect_equal(nrow(out), 4)
Expand All @@ -29,9 +29,10 @@ test_that("works with unnamed vectors", {

expect_equal(
show_query(step),
expr(unique(DT[, .(x = x, V2 = 1:2)])[CJ(x, V2, unique = TRUE), on = .(x, V2)])
# expr(unique(DT[, .(x = x, V2 = 1:2)])[CJ(x, V2, unique = TRUE), on = .(x, V2)])
expr(DT[, CJ(x, 1:2, unique = TRUE)][, .(x, `1:2` = V2)])
)
expect_equal(step$vars, c("x", "V2"))
expect_equal(step$vars, c("x", "1:2"))
expect_equal(nrow(out), 4)
})

Expand All @@ -43,7 +44,8 @@ test_that("works with named vectors", {

expect_equal(
show_query(step),
expr(unique(DT[, .(x = x, val = 1:2)])[CJ(x, val, unique = TRUE), on = .(x, val)])
# expr(unique(DT[, .(x = x, val = 1:2)])[CJ(x, val, unique = TRUE), on = .(x, val)])
expr(DT[, CJ(x, val = 1:2, unique = TRUE)])
)
expect_equal(step$vars, c("x", "val"))
expect_equal(nrow(out), 4)
Expand Down